-
Notifications
You must be signed in to change notification settings - Fork 13
/
variational_smc.py
130 lines (111 loc) · 3.63 KB
/
variational_smc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import absolute_import
from __future__ import print_function
import autograd.numpy as np
from autograd import grad
from autograd.extend import notrace_primitive
@notrace_primitive
def resampling(w, rs):
"""
Stratified resampling with "nograd_primitive" to ensure autograd
takes no derivatives through it.
"""
N = w.shape[0]
bins = np.cumsum(w)
ind = np.arange(N)
u = (ind + rs.rand(N))/N
return np.digitize(u, bins)
def vsmc_lower_bound(prop_params, model_params, y, smc_obj, rs, verbose=False, adapt_resamp=False):
"""
Estimate the VSMC lower bound. Amenable to (biased) reparameterization
gradients.
.. math::
ELBO(\theta,\lambda) =
\mathbb{E}_{\phi}\left[\nabla_\lambda \log \hat p(y_{1:T}) \right]
Requires an SMC object with 2 member functions:
-- sim_prop(t, x_{t-1}, y, prop_params, model_params, rs)
-- log_weights(t, x_t, x_{t-1}, y, prop_params, model_params)
"""
# Extract constants
T = y.shape[0]
Dx = smc_obj.Dx
N = smc_obj.N
# Initialize SMC
X = np.zeros((N,Dx))
Xp = np.zeros((N,Dx))
logW = np.zeros(N)
W = np.exp(logW)
W /= np.sum(W)
logZ = 0.
ESS = 1./np.sum(W**2)/N
for t in range(T):
# Resampling
if adapt_resamp:
if ESS < 0.5:
ancestors = resampling(W, rs)
Xp = X[ancestors]
logZ = logZ + max_logW + np.log(np.sum(W)) - np.log(N)
logW = np.zeros(N)
else:
Xp = X
else:
if t > 0:
ancestors = resampling(W, rs)
Xp = X[ancestors]
else:
Xp = X
# Propagation
X = smc_obj.sim_prop(t, Xp, y, prop_params, model_params, rs)
# Weighting
if adapt_resamp:
logW = logW + smc_obj.log_weights(t, X, Xp, y, prop_params, model_params)
else:
logW = smc_obj.log_weights(t, X, Xp, y, prop_params, model_params)
max_logW = np.max(logW)
W = np.exp(logW-max_logW)
if adapt_resamp:
if t == T-1:
logZ = logZ + max_logW + np.log(np.sum(W)) - np.log(N)
else:
logZ = logZ + max_logW + np.log(np.sum(W)) - np.log(N)
W /= np.sum(W)
ESS = 1./np.sum(W**2)/N
if verbose:
print('ESS: '+str(ESS))
return logZ
def sim_q(prop_params, model_params, y, smc_obj, rs, verbose=False):
"""
Simulates a single sample from the VSMC approximation.
Requires an SMC object with 2 member functions:
-- sim_prop(t, x_{t-1}, y, prop_params, model_params, rs)
-- log_weights(t, x_t, x_{t-1}, y, prop_params, model_params)
"""
# Extract constants
T = y.shape[0]
Dx = smc_obj.Dx
N = smc_obj.N
# Initialize SMC
X = np.zeros((N,T,Dx))
logW = np.zeros(N)
W = np.zeros((N,T))
ESS = np.zeros(T)
for t in range(T):
# Resampling
if t > 0:
ancestors = resampling(W[:,t-1], rs)
X[:,:t,:] = X[ancestors,:t,:]
# Propagation
X[:,t,:] = smc_obj.sim_prop(t, X[:,t-1,:], y, prop_params, model_params, rs)
# Weighting
logW = smc_obj.log_weights(t, X[:,t,:], X[:,t-1,:], y, prop_params, model_params)
max_logW = np.max(logW)
W[:,t] = np.exp(logW-max_logW)
W[:,t] /= np.sum(W[:,t])
ESS[t] = 1./np.sum(W[:,t]**2)
# Sample from the empirical approximation
bins = np.cumsum(W[:,-1])
u = rs.rand()
B = np.digitize(u,bins)
if verbose:
print('Mean ESS', np.mean(ESS)/N)
print('Min ESS', np.min(ESS))
return X[B,:,:]