From 057cbf2b878a4457762ab602af9abe46100dbd5c Mon Sep 17 00:00:00 2001 From: Hessam Mehr Date: Wed, 20 Nov 2024 23:11:38 +0000 Subject: [PATCH] Post-process when no sample sites present. Current post-processing behaviour skips models with only deterministic variables. Applying this change will return consistent samples regardless of whether `sample` sites are present. --- numpyro/infer/mcmc.py | 4 +--- test/infer/test_mcmc.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 5cbf47015..8d0c403c4 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -195,9 +195,7 @@ def collect_and_postprocess(x): if collect_fields: fields = nested_attrgetter(*collect_fields)(x[0]) fields = [fields] if len(collect_fields) == 1 else list(fields) - site_values = jax.tree.flatten(fields[0])[0] - if len(site_values) > 0: - fields[0] = postprocess_fn(fields[0], *x[1:]) + fields[0] = postprocess_fn(fields[0], *x[1:]) if remove_sites != (): assert isinstance(fields[0], dict) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index c938781e8..c2019cbcd 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -1208,3 +1208,27 @@ def model(): mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) mcmc.run(random.PRNGKey(0), extra_fields=("z.x",)) assert_allclose(mcmc.get_samples()["x"], jnp.exp(mcmc.get_extra_fields()["z.x"])) + +def test_all_deterministic(): + def model1(): + numpyro.deterministic("x", 1.0) + + def model2(): + numpyro.deterministic("x", jnp.array([1.0, 2.0])) + + num_samples = 10 + shapes = {model1: (), model2: (2,)} + + for model, shape in shapes.items(): + mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=num_samples) + mcmc.run(random.PRNGKey(0)) + assert mcmc.get_samples()["x"].shape == (num_samples,) + shape + +def test_empty_summary(): + def model(): + pass + + mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) + mcmc.run(random.PRNGKey(0)) + + mcmc.print_summary()