-
Notifications
You must be signed in to change notification settings - Fork 270
/
seq2seq_wrapper.py
182 lines (150 loc) · 7.26 KB
/
seq2seq_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import tensorflow as tf
import numpy as np
import sys
class Seq2Seq(object):
def __init__(self, xseq_len, yseq_len,
xvocab_size, yvocab_size,
emb_dim, num_layers, ckpt_path,
lr=0.0001,
epochs=100000, model_name='seq2seq_model'):
# attach these arguments to self
self.xseq_len = xseq_len
self.yseq_len = yseq_len
self.ckpt_path = ckpt_path
self.epochs = epochs
self.model_name = model_name
# build thy graph
# attach any part of the graph that needs to be exposed, to the self
def __graph__():
# placeholders
tf.reset_default_graph()
# encoder inputs : list of indices of length xseq_len
self.enc_ip = [ tf.placeholder(shape=[None,],
dtype=tf.int64,
name='ei_{}'.format(t)) for t in range(xseq_len) ]
# labels that represent the real outputs
self.labels = [ tf.placeholder(shape=[None,],
dtype=tf.int64,
name='ei_{}'.format(t)) for t in range(yseq_len) ]
# decoder inputs : 'GO' + [ y1, y2, ... y_t-1 ]
self.dec_ip = [ tf.zeros_like(self.enc_ip[0], dtype=tf.int64, name='GO') ] + self.labels[:-1]
# Basic LSTM cell wrapped in Dropout Wrapper
self.keep_prob = tf.placeholder(tf.float32)
# define the basic cell
basic_cell = tf.contrib.rnn.core_rnn_cell.DropoutWrapper(
tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(emb_dim, state_is_tuple=True),
output_keep_prob=self.keep_prob)
# stack cells together : n layered model
stacked_lstm = tf.contrib.rnn.core_rnn_cell.MultiRNNCell([basic_cell]*num_layers, state_is_tuple=True)
# for parameter sharing between training model
# and testing model
with tf.variable_scope('decoder') as scope:
# build the seq2seq model
# inputs : encoder, decoder inputs, LSTM cell type, vocabulary sizes, embedding dimensions
self.decode_outputs, self.decode_states = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(self.enc_ip,self.dec_ip, stacked_lstm,
xvocab_size, yvocab_size, emb_dim)
# share parameters
scope.reuse_variables()
# testing model, where output of previous timestep is fed as input
# to the next timestep
self.decode_outputs_test, self.decode_states_test = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(
self.enc_ip, self.dec_ip, stacked_lstm, xvocab_size, yvocab_size,emb_dim,
feed_previous=True)
# now, for training,
# build loss function
# weighted loss
# TODO : add parameter hint
loss_weights = [ tf.ones_like(label, dtype=tf.float32) for label in self.labels ]
self.loss = tf.contrib.legacy_seq2seq.sequence_loss(self.decode_outputs, self.labels, loss_weights, yvocab_size)
# train op to minimize the loss
self.train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.loss)
sys.stdout.write('<log> Building Graph ')
# build comput graph
__graph__()
sys.stdout.write('</log>')
'''
Training and Evaluation
'''
# get the feed dictionary
def get_feed(self, X, Y, keep_prob):
feed_dict = {self.enc_ip[t]: X[t] for t in range(self.xseq_len)}
feed_dict.update({self.labels[t]: Y[t] for t in range(self.yseq_len)})
feed_dict[self.keep_prob] = keep_prob # dropout prob
return feed_dict
# run one batch for training
def train_batch(self, sess, train_batch_gen):
# get batches
batchX, batchY = train_batch_gen.__next__()
# build feed
feed_dict = self.get_feed(batchX, batchY, keep_prob=0.5)
_, loss_v = sess.run([self.train_op, self.loss], feed_dict)
return loss_v
def eval_step(self, sess, eval_batch_gen):
# get batches
batchX, batchY = eval_batch_gen.__next__()
# build feed
feed_dict = self.get_feed(batchX, batchY, keep_prob=1.)
loss_v, dec_op_v = sess.run([self.loss, self.decode_outputs_test], feed_dict)
# dec_op_v is a list; also need to transpose 0,1 indices
# (interchange batch_size and timesteps dimensions
dec_op_v = np.array(dec_op_v).transpose([1,0,2])
return loss_v, dec_op_v, batchX, batchY
# evaluate 'num_batches' batches
def eval_batches(self, sess, eval_batch_gen, num_batches):
losses = []
for i in range(num_batches):
loss_v, dec_op_v, batchX, batchY = self.eval_step(sess, eval_batch_gen)
losses.append(loss_v)
return np.mean(losses)
# finally the train function that
# runs the train_op in a session
# evaluates on valid set periodically
# prints statistics
def train(self, train_set, valid_set, sess=None ):
# we need to save the model periodically
saver = tf.train.Saver()
# if no session is given
if not sess:
# create a session
sess = tf.Session()
# init all variables
sess.run(tf.global_variables_initializer())
sys.stdout.write('\n<log> Training started </log>\n')
# run M epochs
for i in range(self.epochs):
try:
self.train_batch(sess, train_set)
if i and i% (self.epochs//100) == 0: # TODO : make this tunable by the user
# save model to disk
saver.save(sess, self.ckpt_path + self.model_name + '.ckpt', global_step=i)
# evaluate to get validation loss
val_loss = self.eval_batches(sess, valid_set, 16) # TODO : and this
# print stats
print('\nModel saved to disk at iteration #{}'.format(i))
print('val loss : {0:.6f}'.format(val_loss))
sys.stdout.flush()
except KeyboardInterrupt: # this will most definitely happen, so handle it
print('Interrupted by user at iteration {}'.format(i))
self.session = sess
return sess
def restore_last_session(self):
saver = tf.train.Saver()
# create a session
sess = tf.Session()
# get checkpoint state
ckpt = tf.train.get_checkpoint_state(self.ckpt_path)
# restore session
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# return to user
return sess
# prediction
def predict(self, sess, X):
feed_dict = {self.enc_ip[t]: X[t] for t in range(self.xseq_len)}
feed_dict[self.keep_prob] = 1.
dec_op_v = sess.run(self.decode_outputs_test, feed_dict)
# dec_op_v is a list; also need to transpose 0,1 indices
# (interchange batch_size and timesteps dimensions
dec_op_v = np.array(dec_op_v).transpose([1,0,2])
# return the index of item with highest probability
return np.argmax(dec_op_v, axis=2)