Skip to content

Commit

Permalink
BF: fix ensemble mcmc run after warmup (#1918)
Browse files Browse the repository at this point in the history
  • Loading branch information
amifalk authored Nov 29, 2024
1 parent 07e4c9b commit 5692d2d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
6 changes: 5 additions & 1 deletion numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,11 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):

if self._warmup_state is not None:
self._set_collection_params(0, self.num_samples, self.num_samples, "sample")
init_state = self._warmup_state._replace(rng_key=rng_key)

if self.sampler.is_ensemble_kernel:
init_state = self._warmup_state._replace(rng_key=rng_key[0])
else:
init_state = self._warmup_state._replace(rng_key=rng_key)

if init_params is not None and self.num_chains > 1:
prototype_init_val = jax.tree.flatten(init_params)[0][0]
Expand Down
17 changes: 17 additions & 0 deletions test/infer/test_ensemble_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,20 @@ def test_multirun(kernel_cls):
)
mcmc.run(random.PRNGKey(2), labels)
mcmc.run(random.PRNGKey(3), labels)


@pytest.mark.parametrize("kernel_cls", [AIES, ESS])
def test_warmup(kernel_cls):
n_chains = 10
kernel = kernel_cls(model)

mcmc = MCMC(
kernel,
num_warmup=10,
num_samples=10,
progress_bar=False,
num_chains=n_chains,
chain_method="vectorized",
)
mcmc.warmup(random.PRNGKey(2), labels)
mcmc.run(random.PRNGKey(3), labels)

0 comments on commit 5692d2d

Please sign in to comment.