diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index dd3ec0c3a..7dd596bca 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -657,7 +657,11 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): if self._warmup_state is not None: self._set_collection_params(0, self.num_samples, self.num_samples, "sample") - init_state = self._warmup_state._replace(rng_key=rng_key) + + if self.sampler.is_ensemble_kernel: + init_state = self._warmup_state._replace(rng_key=rng_key[0]) + else: + init_state = self._warmup_state._replace(rng_key=rng_key) if init_params is not None and self.num_chains > 1: prototype_init_val = jax.tree.flatten(init_params)[0][0] diff --git a/test/infer/test_ensemble_mcmc.py b/test/infer/test_ensemble_mcmc.py index a49e6a63f..bda377b99 100644 --- a/test/infer/test_ensemble_mcmc.py +++ b/test/infer/test_ensemble_mcmc.py @@ -96,3 +96,20 @@ def test_multirun(kernel_cls): ) mcmc.run(random.PRNGKey(2), labels) mcmc.run(random.PRNGKey(3), labels) + + +@pytest.mark.parametrize("kernel_cls", [AIES, ESS]) +def test_warmup(kernel_cls): + n_chains = 10 + kernel = kernel_cls(model) + + mcmc = MCMC( + kernel, + num_warmup=10, + num_samples=10, + progress_bar=False, + num_chains=n_chains, + chain_method="vectorized", + ) + mcmc.warmup(random.PRNGKey(2), labels) + mcmc.run(random.PRNGKey(3), labels)