From 0c0669d6ed061af103b32fd2d3b19cfef7833137 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 27 Feb 2022 18:50:25 -0500 Subject: [PATCH] Make replay consistent with Pyro (#1345) * make replay consistent with numpyro * Add docs for apply_stack to clarify its functionality --- numpyro/distributions/distribution.py | 4 ++-- numpyro/handlers.py | 17 ++++++++++++-- numpyro/primitives.py | 32 +++++++++++++++++++++------ test/test_distributions.py | 10 ++++----- 4 files changed, 47 insertions(+), 16 deletions(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 4eb46ec99..2210a1c8f 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -762,8 +762,8 @@ def has_enumerate_support(self): return self.base_dist.has_enumerate_support @property - def reparameterized_params(self): - return self.base_dist.reparameterized_params + def reparametrized_params(self): + return self.base_dist.reparametrized_params @property def mean(self): diff --git a/numpyro/handlers.py b/numpyro/handlers.py index e38ae30ea..29c78cfc9 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -202,14 +202,27 @@ class replay(Messenger): >>> assert replayed_trace['a']['value'] == exec_trace['a']['value'] """ - def __init__(self, fn=None, trace=None, guide_trace=None): + def __init__(self, fn=None, trace=None): assert trace is not None self.trace = trace super(replay, self).__init__(fn) def process_message(self, msg): if msg["type"] in ("sample", "plate") and msg["name"] in self.trace: - msg["value"] = self.trace[msg["name"]]["value"] + name = msg["name"] + if msg["type"] in ("sample", "plate") and name in self.trace: + guide_msg = self.trace[name] + if msg["type"] == "plate": + if guide_msg["type"] != "plate": + raise RuntimeError(f"Site {name} must be a plate in trace.") + msg["value"] = guide_msg["value"] + return None + if msg["is_observed"]: + return None + if guide_msg["type"] != "sample" or guide_msg["is_observed"]: + raise RuntimeError(f"Site {name} must be sampled in trace.") + msg["value"] = guide_msg["value"] + msg["infer"] = guide_msg["infer"] class block(Messenger): diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 67d26b322..58c3392bd 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -18,7 +18,30 @@ CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "dim", "size"]) +def default_process_message(msg): + if msg["value"] is None: + if msg["type"] == "sample": + msg["value"], msg["intermediates"] = msg["fn"]( + *msg["args"], sample_intermediates=True, **msg["kwargs"] + ) + else: + msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"]) + + def apply_stack(msg): + """ + Execute the effect stack at a single site according to the following scheme: + + 1. For each ``Messenger`` in the stack from bottom to top, + execute ``Messenger.process_message`` with the message; + if the message field "stop" is True, stop; + otherwise, continue + 2. Apply default behavior (``default_process_message``) to finish remaining + site execution + 3. For each ``Messenger`` in the stack from top to bottom, + execute ``Messenger.postprocess_message`` to update the message + and internal messenger state with the site results + """ pointer = 0 for pointer, handler in enumerate(reversed(_PYRO_STACK)): handler.process_message(msg) @@ -26,13 +49,8 @@ def apply_stack(msg): # it prevents any Messengers above it on the stack from being applied. if msg.get("stop"): break - if msg["value"] is None: - if msg["type"] == "sample": - msg["value"], msg["intermediates"] = msg["fn"]( - *msg["args"], sample_intermediates=True, **msg["kwargs"] - ) - else: - msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"]) + + default_process_message(msg) # A Messenger that sets msg["stop"] == True also prevents application # of postprocess_message by Messengers above it on the stack diff --git a/test/test_distributions.py b/test/test_distributions.py index d3bf13878..8cfdc7564 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -894,23 +894,23 @@ def test_sample_gradient(jax_dist, sp_dist, params): params_dict = dict(zip(dist_args[: len(params)], params)) jax_class = type(jax_dist(**params_dict)) - reparameterized_params = [ + reparametrized_params = [ p for p in jax_class.reparametrized_params if p not in gamma_derived_params ] - if not reparameterized_params: + if not reparametrized_params: pytest.skip("{} not reparametrized.".format(jax_class.__name__)) nonrepara_params_dict = { - k: v for k, v in params_dict.items() if k not in reparameterized_params + k: v for k, v in params_dict.items() if k not in reparametrized_params } repara_params = tuple( - v for k, v in params_dict.items() if k in reparameterized_params + v for k, v in params_dict.items() if k in reparametrized_params ) rng_key = random.PRNGKey(0) def fn(args): - args_dict = dict(zip(reparameterized_params, args)) + args_dict = dict(zip(reparametrized_params, args)) return jnp.sum( jax_dist(**args_dict, **nonrepara_params_dict).sample(key=rng_key) )