Skip to content

Commit

Permalink
Support scan for Trace_ELBO (#1693)
Browse files Browse the repository at this point in the history
* scan replay handling

* fix off by one error

* cleaned a bit

* lint

* use substack

* handle custom guides

* handle empty value shape

* a

* test

* move copy inside get_ith_value

---------

Co-authored-by: frans <[email protected]>
Co-authored-by: OlaRonning <[email protected]>
  • Loading branch information
3 people authored Dec 18, 2023
1 parent c914c53 commit fb7a029
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
27 changes: 24 additions & 3 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@
from numpyro.util import not_jax_tracer


def _replay_wrapper(replay_trace, trace, i, length):
def get_ith_value(site):
value_shape = jnp.shape(site["value"])
site_len = value_shape[0] if value_shape else 0
if (
site["name"] not in trace
or site_len != length
or site["type"] not in ("sample", "deterministic")
):
return site

site = site.copy()
site["value"] = site["value"][i]
return site

return {k: get_ith_value(v) for k, v in replay_trace.items()}


def _subs_wrapper(subs_map, i, length, site):
if site["type"] != "sample":
return
Expand Down Expand Up @@ -264,10 +282,10 @@ def scan_wrapper(
first_available_dim=None,
):
if length is None:
length = tree_flatten(xs)[0][0].shape[0]
length = jnp.shape(tree_flatten(xs)[0][0])[0]

if enum and history > 0:
return scan_enum(
return scan_enum( # TODO: replay for enum
f,
init,
xs,
Expand All @@ -289,14 +307,17 @@ def body_fn(wrapped_carry, x):
fn = handlers.infer_config(
f, config_fn=lambda msg: {"_scan_current_index": i}
)

seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
for subs_type, subs_map in substitute_stack:
subs_fn = partial(_subs_wrapper, subs_map, i, length)
if subs_type == "condition":
seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
elif subs_type == "substitute":
seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)
elif subs_type == "replay":
trace = handlers.trace(seeded_fn).get_trace(carry, x)
replay_trace_i = _replay_wrapper(subs_map, trace, i, length)
seeded_fn = handlers.replay(seeded_fn, trace=replay_trace_i)

with handlers.trace() as trace:
carry, y = seeded_fn(carry, x)
Expand Down
2 changes: 2 additions & 0 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def process_message(self, msg):
raise RuntimeError(f"Site {name} must be sampled in trace.")
msg["value"] = guide_msg["value"]
msg["infer"] = guide_msg["infer"].copy()
if msg["type"] == "control_flow":
msg["kwargs"]["substitute_stack"].append(("replay", self.trace))


class block(Messenger):
Expand Down
31 changes: 31 additions & 0 deletions test/contrib/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import numpyro.distributions as dist
from numpyro.handlers import mask, seed, substitute, trace
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.util import log_density, potential_energy
from numpyro.optim import Adam


def test_scan():
Expand Down Expand Up @@ -241,3 +243,32 @@ def transition(carry, y_curr):
assert model_density
assert model_trace["x"]["fn"].batch_shape == (12, 10)
assert model_trace["x"]["fn"].event_shape == (3,)


def test_scan_svi():
T = 3
N = 5

def gaussian_hmm(y=None, T=T, N=N):
def transition(x_prev, y_curr):
with numpyro.plate("data", N):
x_curr = numpyro.sample("x", dist.Normal(x_prev, 1.5))
y_curr = numpyro.sample("y", dist.Normal(x_curr, 0.1), obs=y_curr)
return x_curr, (x_curr, y_curr)

with numpyro.plate("data", N):
x0 = numpyro.sample("x_0", dist.Normal(jnp.zeros(N), 5.0))
_, (x, y) = scan(transition, x0, y, length=T)
return (x, y)

with numpyro.handlers.seed(rng_seed=0):
x, y = gaussian_hmm()
with numpyro.handlers.seed(rng_seed=0):
tr = numpyro.handlers.trace(gaussian_hmm).get_trace(y=y, T=T, N=N)

guide = AutoNormal(gaussian_hmm)
svi = SVI(gaussian_hmm, guide, Adam(0.1), Trace_ELBO(), y=y, T=T, N=N)
results = svi.run(random.PRNGKey(0), 10**3)

xhat = results.params["x_auto_loc"]
assert_allclose(xhat, tr["x"]["value"], rtol=0.1)

0 comments on commit fb7a029

Please sign in to comment.