diff --git a/numpyro/diagnostics.py b/numpyro/diagnostics.py index 1ae30f469..527bd9b24 100644 --- a/numpyro/diagnostics.py +++ b/numpyro/diagnostics.py @@ -257,6 +257,8 @@ def summary( summary_dict = {} for name, value in samples.items(): + if len(value) == 0: + continue value = device_get(value) value_flat = np.reshape(value, (-1,) + value.shape[2:]) mean = value_flat.mean(axis=0) @@ -307,6 +309,8 @@ def print_summary( "Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0]) } summary_dict = summary(samples, prob, group_by_chain=True) + if not summary_dict: + return row_names = { k: k + "[" + ",".join(map(lambda x: str(x - 1), v.shape[2:])) + "]" diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 5cbf47015..dd3ec0c3a 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -195,9 +195,7 @@ def collect_and_postprocess(x): if collect_fields: fields = nested_attrgetter(*collect_fields)(x[0]) fields = [fields] if len(collect_fields) == 1 else list(fields) - site_values = jax.tree.flatten(fields[0])[0] - if len(site_values) > 0: - fields[0] = postprocess_fn(fields[0], *x[1:]) + fields[0] = postprocess_fn(fields[0], *x[1:]) if remove_sites != (): assert isinstance(fields[0], dict) @@ -400,13 +398,27 @@ def _get_cached_fns(self): fns, key = None, None if fns is None: + def ensure_vmap(fn, batch_size=None): + def wrapper(x): + x_arrays = jax.tree.flatten(x)[0] + if len(x_arrays) > 0: + return vmap(fn)(x) + else: + assert batch_size is not None + return jax.tree.map( + lambda x: jnp.broadcast_to(x, (batch_size,) + jnp.shape(x)), + fn(x), + ) + + return wrapper + def _postprocess_fn(state, args, kwargs): if self.postprocess_fn is None: body_fn = self.sampler.postprocess_fn(args, kwargs) else: body_fn = self.postprocess_fn if self.chain_method == "vectorized" and self.num_chains > 1: - body_fn = vmap(body_fn) + body_fn = ensure_vmap(body_fn, batch_size=self.num_chains) return body_fn(state) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index c938781e8..d5cfef4f7 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -1208,3 +1208,29 @@ def model(): mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) mcmc.run(random.PRNGKey(0), extra_fields=("z.x",)) assert_allclose(mcmc.get_samples()["x"], jnp.exp(mcmc.get_extra_fields()["z.x"])) + + +def test_all_deterministic(): + def model1(): + numpyro.deterministic("x", 1.0) + + def model2(): + numpyro.deterministic("x", jnp.array([1.0, 2.0])) + + num_samples = 10 + shapes = {model1: (), model2: (2,)} + + for model, shape in shapes.items(): + mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=num_samples) + mcmc.run(random.PRNGKey(0)) + assert mcmc.get_samples()["x"].shape == (num_samples,) + shape + + +def test_empty_summary(): + def model(): + pass + + mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) + mcmc.run(random.PRNGKey(0)) + + mcmc.print_summary()