Skip to content

Commit

Permalink
Fix progbar description when warmup / sample are run separately (#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored and fehiepsi committed Dec 4, 2019
1 parent e1a6fa1 commit db872b1
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,23 +754,26 @@ def _get_cached_init_state(self, rng_key, args, kwargs):
return None

def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('z',)):
num_warmup = self.num_warmup
if init_state is None:
init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
model_args=args, model_kwargs=kwargs)
if self.constrain_fn is None:
self.constrain_fn = self.sampler.constrain_fn(args, kwargs)
diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None # noqa: E731
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
collect_vals = fori_collect(self._collection_params["lower"],
self._collection_params["upper"],
lower_idx = self._collection_params["lower"]
upper_idx = self._collection_params["upper"]

collect_vals = fori_collect(lower_idx,
upper_idx,
self._get_cached_fn(),
init_val,
transform=_collect_fn(collect_fields),
progbar=self.progress_bar,
return_last_val=True,
collection_size=self._collection_params["collection_size"],
progbar_desc=functools.partial(get_progbar_desc_str, num_warmup),
progbar_desc=functools.partial(get_progbar_desc_str,
num_warmup=lower_idx),
diagnostics_fn=diagnostics)
states, last_val = collect_vals
# Get first argument of type `HMCState`
Expand Down

0 comments on commit db872b1

Please sign in to comment.