Skip to content

Commit

Permalink
Make replay consistent with Pyro (#1345)
Browse files Browse the repository at this point in the history
* make replay consistent with numpyro

* Add docs for apply_stack to clarify its functionality
  • Loading branch information
fehiepsi authored Feb 27, 2022
1 parent f9c756c commit 0c0669d
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 16 deletions.
4 changes: 2 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,8 +762,8 @@ def has_enumerate_support(self):
return self.base_dist.has_enumerate_support

@property
def reparameterized_params(self):
return self.base_dist.reparameterized_params
def reparametrized_params(self):
return self.base_dist.reparametrized_params

@property
def mean(self):
Expand Down
17 changes: 15 additions & 2 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,27 @@ class replay(Messenger):
>>> assert replayed_trace['a']['value'] == exec_trace['a']['value']
"""

def __init__(self, fn=None, trace=None, guide_trace=None):
def __init__(self, fn=None, trace=None):
assert trace is not None
self.trace = trace
super(replay, self).__init__(fn)

def process_message(self, msg):
if msg["type"] in ("sample", "plate") and msg["name"] in self.trace:
msg["value"] = self.trace[msg["name"]]["value"]
name = msg["name"]
if msg["type"] in ("sample", "plate") and name in self.trace:
guide_msg = self.trace[name]
if msg["type"] == "plate":
if guide_msg["type"] != "plate":
raise RuntimeError(f"Site {name} must be a plate in trace.")
msg["value"] = guide_msg["value"]
return None
if msg["is_observed"]:
return None
if guide_msg["type"] != "sample" or guide_msg["is_observed"]:
raise RuntimeError(f"Site {name} must be sampled in trace.")
msg["value"] = guide_msg["value"]
msg["infer"] = guide_msg["infer"]


class block(Messenger):
Expand Down
32 changes: 25 additions & 7 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,39 @@
CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "dim", "size"])


def default_process_message(msg):
if msg["value"] is None:
if msg["type"] == "sample":
msg["value"], msg["intermediates"] = msg["fn"](
*msg["args"], sample_intermediates=True, **msg["kwargs"]
)
else:
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])


def apply_stack(msg):
"""
Execute the effect stack at a single site according to the following scheme:
1. For each ``Messenger`` in the stack from bottom to top,
execute ``Messenger.process_message`` with the message;
if the message field "stop" is True, stop;
otherwise, continue
2. Apply default behavior (``default_process_message``) to finish remaining
site execution
3. For each ``Messenger`` in the stack from top to bottom,
execute ``Messenger.postprocess_message`` to update the message
and internal messenger state with the site results
"""
pointer = 0
for pointer, handler in enumerate(reversed(_PYRO_STACK)):
handler.process_message(msg)
# When a Messenger sets the "stop" field of a message,
# it prevents any Messengers above it on the stack from being applied.
if msg.get("stop"):
break
if msg["value"] is None:
if msg["type"] == "sample":
msg["value"], msg["intermediates"] = msg["fn"](
*msg["args"], sample_intermediates=True, **msg["kwargs"]
)
else:
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])

default_process_message(msg)

# A Messenger that sets msg["stop"] == True also prevents application
# of postprocess_message by Messengers above it on the stack
Expand Down
10 changes: 5 additions & 5 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,23 +894,23 @@ def test_sample_gradient(jax_dist, sp_dist, params):
params_dict = dict(zip(dist_args[: len(params)], params))

jax_class = type(jax_dist(**params_dict))
reparameterized_params = [
reparametrized_params = [
p for p in jax_class.reparametrized_params if p not in gamma_derived_params
]
if not reparameterized_params:
if not reparametrized_params:
pytest.skip("{} not reparametrized.".format(jax_class.__name__))

nonrepara_params_dict = {
k: v for k, v in params_dict.items() if k not in reparameterized_params
k: v for k, v in params_dict.items() if k not in reparametrized_params
}
repara_params = tuple(
v for k, v in params_dict.items() if k in reparameterized_params
v for k, v in params_dict.items() if k in reparametrized_params
)

rng_key = random.PRNGKey(0)

def fn(args):
args_dict = dict(zip(reparameterized_params, args))
args_dict = dict(zip(reparametrized_params, args))
return jnp.sum(
jax_dist(**args_dict, **nonrepara_params_dict).sample(key=rng_key)
)
Expand Down

0 comments on commit 0c0669d

Please sign in to comment.