-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
88 lines (76 loc) · 4.43 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import tensorflow as tf
import os
import traceback
from tfsession import TFSession
class Model:
def __init__(self, model_path, batch_size = 100, display_step = 5):
self.model_path = model_path
self.batch_size = batch_size
self.display_step = display_step
self.sess = TFSession().sess
try:
path = os.path.dirname(os.path.realpath(__file__)) + '/' + self.model_path
ckpt = tf.train.get_checkpoint_state(path)
print("Reading saved model parameters from %s" % ckpt.model_checkpoint_path)
self.saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
except Exception as e:
raise ValueError("Error loading model: %s" % traceback.format_exc())
self.input = tf.get_collection("image")[0]
self.keep_prob = tf.get_collection("kp")[0]
self.predictor = tf.get_collection("predictor")[0]
self.softmax = tf.get_collection("predictor")[1]
self.global_step = tf.get_collection("step")[0]
def train(self, images, labels, n_classes, learning_rate = 0.001, dropout = 0.75):
training_ops = tf.get_collection("training_ops")
if len(training_ops) == 0:
#If no training operations have been previously defined,
#create new training ops.
labels_placeholder = tf.placeholder(tf.uint8, [None])
labels_one_hot = tf.one_hot(labels_placeholder, n_classes)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.predictor, labels=labels_one_hot))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost, global_step=self.global_step)
correct_pred = tf.equal(tf.argmax(self.predictor, 1), tf.argmax(labels_one_hot, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
tf.add_to_collection("training_ops", labels_placeholder)
tf.add_to_collection("training_ops", labels_one_hot)
tf.add_to_collection("training_ops", cost)
tf.add_to_collection("training_ops", optimizer)
tf.add_to_collection("training_ops", correct_pred)
tf.add_to_collection("training_ops", accuracy)
else:
print('Loading previous training operations.')
labels_placeholder = training_ops[0]
labels_one_hot = training_ops[1]
cost = training_ops[2]
optimizer = training_ops[3]
correct_pred = training_ops[4]
accuracy = training_ops[5]
#Initialize only training variables, leaving model variables alone
self.sess.run(tf.variables_initializer([v for v in tf.global_variables() if 'beta' in v.name or 'Adam' in v.name]))
step = 1
while step * self.batch_size < len(images):
batch_x = images[(step-1)*self.batch_size:step*self.batch_size]
batch_y = labels[(step-1)*self.batch_size:step*self.batch_size]
# Run optimization
self.sess.run(optimizer, feed_dict={self.input: batch_x, labels_placeholder: batch_y, self.keep_prob: dropout})
if step % self.display_step == 0:
# Calculate batch loss and accuracy
loss, acc= self.sess.run([cost, accuracy], feed_dict={self.input: batch_x, labels_placeholder: batch_y, self.keep_prob: 1.})
# Save the variables to disk.
save_path = self.saver.save(self.sess, self.model_path + '/model', global_step=self.global_step)
print("Checkpoint saved in file: %s" % save_path)
print("Iter " + str(step*self.batch_size) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
step += 1
save_path = self.saver.save(self.sess, self.model_path + '/model', global_step=self.global_step)
print("Final checkpoint saved in file: %s" % save_path)
def test(self, images):
labels = []
step = 0
while step * self.batch_size < len(images):
batch_x = images[step*self.batch_size:(step+1)*self.batch_size]
labels.extend(self.sess.run(self.predictor, feed_dict={self.input: batch_x, self.keep_prob: 1.}))
step += 1
return labels