diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index b04c51862..4bd2143a0 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -15,6 +15,24 @@ from numpyro.util import not_jax_tracer +def _replay_wrapper(replay_trace, trace, i, length): + def get_ith_value(site): + value_shape = jnp.shape(site["value"]) + site_len = value_shape[0] if value_shape else 0 + if ( + site["name"] not in trace + or site_len != length + or site["type"] not in ("sample", "deterministic") + ): + return site + + site = site.copy() + site["value"] = site["value"][i] + return site + + return {k: get_ith_value(v) for k, v in replay_trace.items()} + + def _subs_wrapper(subs_map, i, length, site): if site["type"] != "sample": return @@ -264,10 +282,10 @@ def scan_wrapper( first_available_dim=None, ): if length is None: - length = tree_flatten(xs)[0][0].shape[0] + length = jnp.shape(tree_flatten(xs)[0][0])[0] if enum and history > 0: - return scan_enum( + return scan_enum( # TODO: replay for enum f, init, xs, @@ -289,7 +307,6 @@ def body_fn(wrapped_carry, x): fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i} ) - seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) @@ -297,6 +314,10 @@ def body_fn(wrapped_carry, x): seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) + elif subs_type == "replay": + trace = handlers.trace(seeded_fn).get_trace(carry, x) + replay_trace_i = _replay_wrapper(subs_map, trace, i, length) + seeded_fn = handlers.replay(seeded_fn, trace=replay_trace_i) with handlers.trace() as trace: carry, y = seeded_fn(carry, x) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 5745cea32..3e38e204c 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -222,6 +222,8 @@ def process_message(self, msg): raise RuntimeError(f"Site {name} must be sampled in trace.") msg["value"] = guide_msg["value"] msg["infer"] = guide_msg["infer"].copy() + if msg["type"] == "control_flow": + msg["kwargs"]["substitute_stack"].append(("replay", self.trace)) class block(Messenger): diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index 67273d02c..f75686daf 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -12,7 +12,9 @@ import numpyro.distributions as dist from numpyro.handlers import mask, seed, substitute, trace from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO +from numpyro.infer.autoguide import AutoNormal from numpyro.infer.util import log_density, potential_energy +from numpyro.optim import Adam def test_scan(): @@ -241,3 +243,32 @@ def transition(carry, y_curr): assert model_density assert model_trace["x"]["fn"].batch_shape == (12, 10) assert model_trace["x"]["fn"].event_shape == (3,) + + +def test_scan_svi(): + T = 3 + N = 5 + + def gaussian_hmm(y=None, T=T, N=N): + def transition(x_prev, y_curr): + with numpyro.plate("data", N): + x_curr = numpyro.sample("x", dist.Normal(x_prev, 1.5)) + y_curr = numpyro.sample("y", dist.Normal(x_curr, 0.1), obs=y_curr) + return x_curr, (x_curr, y_curr) + + with numpyro.plate("data", N): + x0 = numpyro.sample("x_0", dist.Normal(jnp.zeros(N), 5.0)) + _, (x, y) = scan(transition, x0, y, length=T) + return (x, y) + + with numpyro.handlers.seed(rng_seed=0): + x, y = gaussian_hmm() + with numpyro.handlers.seed(rng_seed=0): + tr = numpyro.handlers.trace(gaussian_hmm).get_trace(y=y, T=T, N=N) + + guide = AutoNormal(gaussian_hmm) + svi = SVI(gaussian_hmm, guide, Adam(0.1), Trace_ELBO(), y=y, T=T, N=N) + results = svi.run(random.PRNGKey(0), 10**3) + + xhat = results.params["x_auto_loc"] + assert_allclose(xhat, tr["x"]["value"], rtol=0.1)