-
Notifications
You must be signed in to change notification settings - Fork 17
/
test.py
57 lines (51 loc) · 2.27 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 18 17:57:59 2018
@author: Peter
"""
import numpy as np
import tensorflow as tf
from dgcnn import DGCNN
from prepro import get_max_length, load_embedding, load_data, next_batch
from config import get_config
def main():
testing_file = "./new_data/test.ann.json"
trained_model = "./checkpoints/model.ckpt"
embedding_file = "D:/DataMining/QASystem/wiki/wiki.zh.text.vector"
# embedding_file = "./wiki.zh.text.vector"
embedding_size = 60 # Word embedding dimension
batch_size = 64 # Batch data size
sequence_length = 150 # Sentence length
learning_rate = 0.01
gpu_mem_usage = 0.75
gpu_device = "/gpu:0"
cpu_device = "/cpu:0"
config = get_config() # Not used yet
embeddings, word2idx = load_embedding(embedding_file)
questions, evidences, y1, y2 = load_data(testing_file, word2idx, sequence_length)
with tf.Graph().as_default(), tf.device(gpu_device):
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_mem_usage)
session_conf = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
with tf.variable_scope('Model'):
model = DGCNN(config, embeddings, sequence_length, embedding_size)
with tf.Session(config=session_conf).as_default() as sess:
saver = tf.train.Saver()
print("Start loading the model")
saver.restore(sess, trained_model)
print("The model is loaded")
acc1, acc2 = [], []
for batch_questions, batch_evidences, batch_y1, batch_y2 in next_batch(questions, evidences, y1, y2, batch_size):
feed_dict = {
model.e: batch_evidences,
model.q: batch_questions,
model.y1: batch_y1,
model.y2: batch_y2,
model.is_train: False
}
acc1_, acc2_ = sess.run([model.acc1, model.acc2], feed_dict)
acc1.append(acc1_)
acc2.append(acc2_)
print('Acc1 %2.3f\tAcc2 %2.3f' % (acc1_, acc2_))
print('Average: Acc1 %2.3f\tAcc2 %2.3f' % (np.mean(acc1), np.mean(acc2)))
if __name__ == '__main__':
main()