diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index d27310931..5cbf47015 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -189,27 +189,30 @@ def _sample_fn_nojit_args(state, sampler, args, kwargs): return (sampler.sample(state[0], args, kwargs),) -def _collect_fn(collect_fields, remove_sites): - @cached_by(_collect_fn, collect_fields, remove_sites) - def collect(x): +def _collect_and_postprocess(postprocess_fn, collect_fields, remove_sites): + @cached_by(_collect_and_postprocess, postprocess_fn, collect_fields, remove_sites) + def collect_and_postprocess(x): if collect_fields: fields = nested_attrgetter(*collect_fields)(x[0]) + fields = [fields] if len(collect_fields) == 1 else list(fields) + site_values = jax.tree.flatten(fields[0])[0] + if len(site_values) > 0: + fields[0] = postprocess_fn(fields[0], *x[1:]) if remove_sites != (): - fields = [fields] if len(collect_fields) == 1 else list(fields) assert isinstance(fields[0], dict) sample_sites = fields[0].copy() for site in remove_sites: sample_sites.pop(site) fields[0] = sample_sites - fields = fields[0] if len(collect_fields) == 1 else fields + fields = fields[0] if len(collect_fields) == 1 else fields return fields else: return x[0] - return collect + return collect_and_postprocess # XXX: Is there a better hash key that we can use? @@ -397,7 +400,7 @@ def _get_cached_fns(self): fns, key = None, None if fns is None: - def laxmap_postprocess_fn(states, args, kwargs): + def _postprocess_fn(state, args, kwargs): if self.postprocess_fn is None: body_fn = self.sampler.postprocess_fn(args, kwargs) else: @@ -405,11 +408,11 @@ def laxmap_postprocess_fn(states, args, kwargs): if self.chain_method == "vectorized" and self.num_chains > 1: body_fn = vmap(body_fn) - return lax.map(body_fn, states) + return body_fn(state) if self._jit_model_args: sample_fn = partial(_sample_fn_jit_args, sampler=self.sampler) - postprocess_fn = jit(laxmap_postprocess_fn) + postprocess_fn = _postprocess_fn else: sample_fn = partial( _sample_fn_nojit_args, @@ -417,8 +420,8 @@ def laxmap_postprocess_fn(states, args, kwargs): args=self._args, kwargs=self._kwargs, ) - postprocess_fn = jit( - partial(laxmap_postprocess_fn, args=self._args, kwargs=self._kwargs) + postprocess_fn = partial( + _postprocess_fn, args=self._args, kwargs=self._kwargs ) fns = sample_fn, postprocess_fn @@ -470,7 +473,9 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites): upper_idx, sample_fn, init_val, - transform=_collect_fn(collect_fields, remove_sites), + transform=_collect_and_postprocess( + postprocess_fn, collect_fields, remove_sites + ), progbar=self.progress_bar, return_last_val=True, thinning=self.thinning, @@ -487,18 +492,6 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites): if len(collect_fields) == 1: states = (states,) states = dict(zip(collect_fields, states)) - # Apply constraints if number of samples is non-zero - site_values = jax.tree.flatten(states[self._sample_field])[0] - # XXX: lax.map still works if some arrays have 0 size - # so we only need to filter out the case site_value.shape[0] == 0 - # (which happens when lower_idx==upper_idx) - if len(site_values) > 0 and jnp.shape(site_values[0])[0] > 0: - if self._jit_model_args: - states[self._sample_field] = postprocess_fn( - states[self._sample_field], args, kwargs - ) - else: - states[self._sample_field] = postprocess_fn(states[self._sample_field]) return states, last_state def _set_collection_params( diff --git a/test/test_compile.py b/test/test_compile.py index 545ecc5db..0abd233c4 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -36,7 +36,10 @@ def test_mcmc_one_chain(deterministic, find_heuristic_step_size): num_traces_for_heuristic = 2 if find_heuristic_step_size else 0 if deterministic: - assert GLOBAL["count"] == 4 + num_traces_for_heuristic + # We have two extra calls to the model to get deterministic values: + # 1. transform the init state + # 2. transform state during the loop + assert GLOBAL["count"] == 5 + num_traces_for_heuristic else: assert GLOBAL["count"] == 3 + num_traces_for_heuristic