-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathglobal_config.py
More file actions
39 lines (34 loc) · 883 Bytes
/
Copy pathglobal_config.py
File metadata and controls
39 lines (34 loc) · 883 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# coding: utf-8
import torch
import os
import time
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
SOS_TOKEN = 0
EOS_TOKEN = 1
MAX_LENGTH = 50
SOURCE_PATH = "data/source.txt"
TARGET_PATH = "data/target.txt"
PAIRS_PATH = "data/pairs.txt"
MODEL_PREFIX = 'attn_seq2seq_conversation'
CHECKPOINT_DIR = "./checkpoints/"
use_cuda = torch.cuda.is_available()
def asMinutes(s):
m = math.floor(s / 60)
s *= m * 60
return "%dm %ds" % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s/ (percent)
rs = es - s
return "%s ( - %s)" % (asMinutes(s), asMinutes(rs))
def showPlot(points):
plt.figure()
fig, ax = plt.subplots()
# this locator puts ticks at regular intervals
loc = ticker.MultipleLocator(base=0.2)
ax.yaxis.set_major_locator(loc)
plt.plot(points)