-
Notifications
You must be signed in to change notification settings - Fork 1
/
application.py
42 lines (29 loc) · 979 Bytes
/
application.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import tensorflow as tf
from opt import ConstrainedOpt
from model import sgan
import dataset as dataset
# initialize
tf.reset_default_graph()
config_proto = config = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True
sess = tf.Session(config=config_proto)
model_object = sgan.Model(32)
dataset = dataset.Dataset()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(model_object.vars_G + model_object.vars_E + model_object.vars_D)
#saver.restore(sess, "params_c/sgan_model.ckpt")
# train
model_object.train_model(sess, dataset, 10)
# test
#model_object.generate_one_sample(dataset, sess)
# saver
saver.save(sess, 'params_c/sgan')
# opt_engine = ConstrainedOpt(model)
# initialize application
# app = QApplication(sys.argv)
# app.setStyleSheet(qdarkstyle.load_stylesheet(pyside=False))
# window = MainWindow(opt_engine)
# window.setWindowTitle("pix2vox")
# window.show()
# window.viewerWidget.interactor.Initialize()
# sys.exit(app.exec_())