From db872b1e93d3626a527cb19f8c00dcf014b77146 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 3 Dec 2019 17:28:12 -0800 Subject: [PATCH] Fix progbar description when warmup / sample are run separately (#486) --- numpyro/infer/mcmc.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 8bf38911b..02549a43d 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -754,7 +754,6 @@ def _get_cached_init_state(self, rng_key, args, kwargs): return None def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('z',)): - num_warmup = self.num_warmup if init_state is None: init_state = self.sampler.init(rng_key, self.num_warmup, init_params, model_args=args, model_kwargs=kwargs) @@ -762,15 +761,19 @@ def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, col self.constrain_fn = self.sampler.constrain_fn(args, kwargs) diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None # noqa: E731 init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) - collect_vals = fori_collect(self._collection_params["lower"], - self._collection_params["upper"], + lower_idx = self._collection_params["lower"] + upper_idx = self._collection_params["upper"] + + collect_vals = fori_collect(lower_idx, + upper_idx, self._get_cached_fn(), init_val, transform=_collect_fn(collect_fields), progbar=self.progress_bar, return_last_val=True, collection_size=self._collection_params["collection_size"], - progbar_desc=functools.partial(get_progbar_desc_str, num_warmup), + progbar_desc=functools.partial(get_progbar_desc_str, + num_warmup=lower_idx), diagnostics_fn=diagnostics) states, last_val = collect_vals # Get first argument of type `HMCState`