Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hm/all deterministic #1914

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def summary(

summary_dict = {}
for name, value in samples.items():
if len(value) == 0:
continue
value = device_get(value)
value_flat = np.reshape(value, (-1,) + value.shape[2:])
mean = value_flat.mean(axis=0)
Expand Down Expand Up @@ -307,6 +309,8 @@ def print_summary(
"Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
}
summary_dict = summary(samples, prob, group_by_chain=True)
if not summary_dict:
return

row_names = {
k: k + "[" + ",".join(map(lambda x: str(x - 1), v.shape[2:])) + "]"
Expand Down
20 changes: 16 additions & 4 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def collect_and_postprocess(x):
if collect_fields:
fields = nested_attrgetter(*collect_fields)(x[0])
fields = [fields] if len(collect_fields) == 1 else list(fields)
site_values = jax.tree.flatten(fields[0])[0]
if len(site_values) > 0:
fields[0] = postprocess_fn(fields[0], *x[1:])
fields[0] = postprocess_fn(fields[0], *x[1:])

if remove_sites != ():
assert isinstance(fields[0], dict)
Expand Down Expand Up @@ -400,13 +398,27 @@ def _get_cached_fns(self):
fns, key = None, None
if fns is None:

def ensure_vmap(fn, batch_size=None):
def wrapper(x):
x_arrays = jax.tree.flatten(x)[0]
if len(x_arrays) > 0:
return vmap(fn)(x)
else:
assert batch_size is not None
return jax.tree.map(
lambda x: jnp.broadcast_to(x, (batch_size,) + jnp.shape(x)),
fn(x),
)

return wrapper

def _postprocess_fn(state, args, kwargs):
if self.postprocess_fn is None:
body_fn = self.sampler.postprocess_fn(args, kwargs)
else:
body_fn = self.postprocess_fn
if self.chain_method == "vectorized" and self.num_chains > 1:
body_fn = vmap(body_fn)
body_fn = ensure_vmap(body_fn, batch_size=self.num_chains)

return body_fn(state)

Expand Down
26 changes: 26 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,3 +1208,29 @@ def model():
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0), extra_fields=("z.x",))
assert_allclose(mcmc.get_samples()["x"], jnp.exp(mcmc.get_extra_fields()["z.x"]))


def test_all_deterministic():
def model1():
numpyro.deterministic("x", 1.0)

def model2():
numpyro.deterministic("x", jnp.array([1.0, 2.0]))

num_samples = 10
shapes = {model1: (), model2: (2,)}

for model, shape in shapes.items():
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=num_samples)
mcmc.run(random.PRNGKey(0))
assert mcmc.get_samples()["x"].shape == (num_samples,) + shape


def test_empty_summary():
def model():
pass

mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0))

mcmc.print_summary()
Loading