forked from awerries/online-svr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
traffic_prediction.py
112 lines (99 loc) · 3.95 KB
/
traffic_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Animated demo using traffic prediction data for Online SVR. 2 features generated at 100 random samples, sent through a sine function, and used to predict and learn.
Usage:
python3 traffic_prediction.py <C> <epsilon> <kernel parameter> <log filename>
C is the regularizatin parameter, essentially defining the limit on how close the learner must adhere to the dataset (smoothness). Epsilon is the acceptable error, and defines the width of what is sometimes called the "SVR tube". The kernel parameter is the scaling factor for comparing feature distance (this implementation uses a Radial Basis Function).
Author: Adam Werries, [email protected]
"""
import sys
import time
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation
import online_svr
def input_data(filename):
times = list()
setX = list()
setY = list()
setXp = list()
f = open(filename, 'r')
for line in f:
t,X,Y,Xp = line.split(':')
t = int(t)
times.append(t)
Y = float(Y)
setY.append(Y)
X = [float(val) for val in X.strip().split(',')]
setX.append(X)
Xp = [float(val) for val in Xp.strip().split(',')]
setXp.append(Xp)
times = np.array(times)
setX = np.array(setX)
setY = np.array(setY)
setXp = np.array(setXp)
setY.shape = (setY.size,1)
return times,setX,setY,setXp
def init():
line1.set_data([],[])
line2.set_data([],[])
return line1,line2,
def animate(i):
global ydata,iteration_times
print('%%%%%%%%%%%%%%%%%%%%%%%% {0}/{1} %%%%%%%%%%%%%%%%%%%%%%%%'.format(i+1,testSetX.shape[0]))
iter_start_time = time.time()
timesteps = np.array(range(i+1))
ax.set_xlim(0,i)
ydata.append(OSVR.predict(testSetX[i,:]).item(0))
line1.set_data(timesteps, np.array(ydata))
if i < testSetX.shape[0]-1:
line2.set_data(timesteps, testSetY[1:i+2])
else:
line2.set_data(timesteps[:-1], testSetY[1:i+1])
OSVR.learn(testSetXp[i,:], testSetY[i])
elapsed = time.time() - iter_start_time
if i < testSetX.shape[0]-1:
print('Iteration prediction error was {0:.4f}'.format(np.abs(testSetY[i+1]-ydata[i]).item()))
print('@@@@@@ Iteration elapsed time: {0:.3f} seconds @@@@@@'.format(elapsed))
iteration_times.append(elapsed)
return line1,line2,
# Initialization
program_start_time = time.time()
iteration_times = list()
ydata = list()
C = float(sys.argv[1])
eps = float(sys.argv[2])
kernelParam = float(sys.argv[3])
filename = str(sys.argv[4])
# Load data
times, testSetX, testSetY, testSetXp = input_data(filename + '.txt')
# Set up animated figure
fig = plt.figure(0)
ax = plt.axes(xlim=(0,1),ylim=(0,5))
line1, = ax.plot([],[],'dr',lw=2,label='Predicted')
line2, = ax.plot([],[],'ob',lw=2,label='Truth')
plt.title('C={0},eps={1},kernelParam={2}\n{3}'.format(C,eps,kernelParam,filename))
plt.legend(loc='upper left')
# Set up learner
OSVR = online_svr.OnlineSVR(numFeatures = testSetX.shape[1], C = C, eps = eps,
kernelParam = kernelParam, bias = 0.5, debug = False)
# Run learner with animate() function
anim = animation.FuncAnimation(fig,animate,init_func=init,interval=1,
frames=range(testSetX.shape[0]),repeat=False)
plt.show()
# Display performance (time and RMS error)
print('\n\nC={0},eps={1},kernelParam={2}'.format(C,eps,kernelParam))
iteration_times = np.array(iteration_times)
print('Elapsed time: {0:.3f}'.format(np.sum(iteration_times)))
print('Mean iteration time: {0:.3f}'.format(np.mean(iteration_times)))
rmse = np.sqrt(np.mean((testSetY-np.array(ydata))**2))
print('RMSE: {0}'.format(rmse))
fig = plt.figure(1)
plt.plot(iteration_times,'o')
plt.xlabel('Iteration')
plt.ylabel('Elapsed time')
plt.title('C={0},eps={1},kernelParam={2}\n{3}'.format(C,eps,kernelParam,filename))
plt.show()
with open('{0}_prediction_log.txt'.format(filename),'w') as f:
for i,t in enumerate(times):
y = ydata[i]
f.write('{0}: {1}\n'.format(t,y))