-
Notifications
You must be signed in to change notification settings - Fork 9
/
Step_6_2_DBN_test.py
150 lines (123 loc) · 6.32 KB
/
Step_6_2_DBN_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
This is the file for DBN
Di Wu [email protected]
2015-06-12
"""
from numpy import log
from glob import glob
import os
import sys
import cPickle
from classes import GestureSample
# customized imports
# modular imports
# the hyperparameter set the data dir, use etc classes, it's important to modify it according to your need
from classes.hyperparameters import use, lr, batch, reg, mom, tr, drop,\
net, DataLoader_with_skeleton_normalisation
from functions.test_functions import *
from functions.train_functions import _shared, _avg, write, ndtensor, print_params, lin,\
training_report, epoch_report, _batch,\
save_results, move_results, save_params, test_lio_skel
from classes.hyperparameters import batch
from dbn.utils import normalize
# customized imports
from dbn.GRBM_DBN import GRBM_DBN
from convnet3d_grbm_early_fusion import convnet3d_grbm_early_fusion
import scipy.io as sio
from time import localtime, time
# number of hidden states for each gesture class
STATE_NO = 5
#data path and store path definition
pc="linux"
pc="windows"
if pc=="linux":
data = "/idiap/user/dwu/chalearn/Test_video_skel"
save_dst = "/idiap/user/dwu/chalearn/Test_DBN_state_matrix"
res_dir_ = "/idiap/user/dwu/chalearn/result/"
elif pc=="windows":
data = "/idiap/user/dwu/chalearn/Test_video_skel"
save_dst = "/idiap/user/dwu/chalearn/Test_DBN_state_matrix"
res_dir_ = "/idiap/user/dwu/chalearn/result/"
os.chdir(data)
samples=glob("*.zip")
print len(samples), "samples found"
used_joints = ['ElbowLeft', 'WristLeft', 'ShoulderLeft', 'HandLeft',
'ElbowRight', 'WristRight', 'ShoulderRight', 'HandRight',
'Head', 'Spine', 'HipCenter']
lt = localtime()
res_dir = res_dir_+"/try/"+str(lt.tm_year)+"."+str(lt.tm_mon).zfill(2)+"." \
+str(lt.tm_mday).zfill(2)+"."+str(lt.tm_hour).zfill(2)+"."\
+str(lt.tm_min).zfill(2)+"."+str(lt.tm_sec).zfill(2)
os.makedirs(res_dir)
# we need to parse an absolute path for HPC to load
load_path = '/idiap/home/dwu/chalearn2014_wudi_lio/chalearn2014_wudi_lio'
######################################################################
net_convnet3d_grbm_early_fusion = convnet3d_grbm_early_fusion(res_dir, load_path)
load_path = '/idiap/home/dwu/chalearn2014_wudi_lio/chalearn2014_wudi_lio'
net_convnet3d_grbm_early_fusion.load_params(os.path.join(load_path,'paramsbest.zip'))
x_ = _shared(empty(tr.in_shape))
x_skeleton_ = _shared(empty(tr._skeleon_in_shape))
p_y_given_x = net_convnet3d_grbm_early_fusion.prediction_function(x_, x_skeleton_)
#############################
# load normalisation constant given load_path
Mean_skel, Std_skel, Mean_CNN, Std_CNN = net_convnet3d_grbm_early_fusion.load_normalisation_constant(load_path)
####################################################################
# DBN for skeleton modules
####################################################################
# ------------------------------------------------------------------------------
# symbolic variables
x_skeleton = ndtensor(len(tr._skeleon_in_shape))(name = 'x_skeleton') # video input
x_skeleton_ = _shared(empty(tr._skeleon_in_shape))
dbn = GRBM_DBN(numpy_rng=random.RandomState(123), n_ins=891, \
hidden_layers_sizes=[2000, 2000, 1000], n_outs=101, input_x=x_skeleton, label=y )
# we load the pretrained DBN skeleton parameteres here
load_path = '/idiap/user/dwu/chalearn/result/try/36.7% 2015.07.09.17.53.10'
dbn.load_params_DBN(os.path.join(load_path,'paramsbest.zip'))
test_model = function([], dbn.logLayer.p_y_given_x,
givens={x_skeleton: x_skeleton_},
on_unused_input='ignore')
for file_count, file in enumerate(samples):
condition = (file_count > -1)
if condition: #wudi only used first 650 for validation !!! Lio be careful!
save_path= os.path.join(data, file)
print file
time_start = time()
# we load precomputed feature set or recompute the whole feature set
if os.path.isfile(save_path):
print "loading exiting file"
data_dic = cPickle.load(open(save_path,'rb'))
video = data_dic["video"]
Feature_gesture = data_dic["Feature_gesture"]
assert video.shape[0] == Feature_gesture.shape[0]
else:
print("\t Processing file " + file)
# Create the object to access the sample
sample = GestureSample(os.path.join(data,file))
print "finish loading samples"
video, Feature_gesture = sample.get_test_data_wudi_lio(used_joints)
assert video.shape[0] == Feature_gesture.shape[0]# -*- coding: utf-8 -*-
print "finish preprocessing"
out_file = open(save_path, 'wb')
cPickle.dump({"video":video, "Feature_gesture":Feature_gesture}, out_file, protocol=cPickle.HIGHEST_PROTOCOL)
out_file.close()
print "start computing likelihood"
observ_likelihood = numpy.empty(shape=(video.shape[0],20*STATE_NO+1)) # 20 classed * 5 states + 1 ergodic state
for batchnumber in xrange(video.shape[0]/batch.micro):
skel_temp = Feature_gesture[batch.micro*batchnumber:batch.micro*(batchnumber+1),:]
x_skeleton_.set_value(normalize(skel_temp,Mean_skel, Std_skel).astype("float32"), borrow=True)
observ_likelihood[batch.micro*batchnumber:batch.micro*(batchnumber+1),:] = test_model()
# because input batch number should be 64, so here it is a bit of hack:
skel_temp_1 = Feature_gesture[batch.micro* (batchnumber+1):,:]
skel_temp_2 = numpy.zeros(shape=(64-skel_temp_1.shape[0],891))
skel_temp = numpy.concatenate((skel_temp_1, skel_temp_2), axis=0)
x_skeleton_.set_value(normalize(skel_temp,Mean_skel, Std_skel).astype("float32"), borrow=True)
ob_temp = test_model()
observ_likelihood[batch.micro* (batchnumber+1):,:] = ob_temp[:video_temp_1.shape[0], :]
##########################
# save state matrix
#####################
save_path= os.path.join(save_dst, file)
out_file = open(save_path, 'wb')
cPickle.dump(observ_likelihood, out_file, protocol=cPickle.HIGHEST_PROTOCOL)
out_file.close()
print "use %f second"% (time()-time_start)