Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge postprocess_fn into the fori_collect loop #1910

Merged
merged 4 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,27 +189,30 @@ def _sample_fn_nojit_args(state, sampler, args, kwargs):
return (sampler.sample(state[0], args, kwargs),)


def _collect_fn(collect_fields, remove_sites):
@cached_by(_collect_fn, collect_fields, remove_sites)
def collect(x):
def _collect_and_postprocess(postprocess_fn, collect_fields, remove_sites):
@cached_by(_collect_and_postprocess, postprocess_fn, collect_fields, remove_sites)
def collect_and_postprocess(x):
if collect_fields:
fields = nested_attrgetter(*collect_fields)(x[0])
fields = [fields] if len(collect_fields) == 1 else list(fields)
site_values = jax.tree.flatten(fields[0])[0]
if len(site_values) > 0:
fields[0] = postprocess_fn(fields[0], *x[1:])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify what is stored in fields[1:] if len(collect_fields) >1 since we only need to run postprocess_fn on the first element?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first one is the sample field. The rest are extra fields.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, then it makes sense that constraints are only applied to the first field.


if remove_sites != ():
fields = [fields] if len(collect_fields) == 1 else list(fields)
assert isinstance(fields[0], dict)

sample_sites = fields[0].copy()
for site in remove_sites:
sample_sites.pop(site)
fields[0] = sample_sites
fields = fields[0] if len(collect_fields) == 1 else fields

fields = fields[0] if len(collect_fields) == 1 else fields
return fields
else:
return x[0]

return collect
return collect_and_postprocess


# XXX: Is there a better hash key that we can use?
Expand Down Expand Up @@ -397,28 +400,28 @@ def _get_cached_fns(self):
fns, key = None, None
if fns is None:

def laxmap_postprocess_fn(states, args, kwargs):
def _postprocess_fn(state, args, kwargs):
if self.postprocess_fn is None:
body_fn = self.sampler.postprocess_fn(args, kwargs)
else:
body_fn = self.postprocess_fn
if self.chain_method == "vectorized" and self.num_chains > 1:
body_fn = vmap(body_fn)

return lax.map(body_fn, states)
return body_fn(state)

if self._jit_model_args:
sample_fn = partial(_sample_fn_jit_args, sampler=self.sampler)
postprocess_fn = jit(laxmap_postprocess_fn)
postprocess_fn = _postprocess_fn
else:
sample_fn = partial(
_sample_fn_nojit_args,
sampler=self.sampler,
args=self._args,
kwargs=self._kwargs,
)
postprocess_fn = jit(
partial(laxmap_postprocess_fn, args=self._args, kwargs=self._kwargs)
postprocess_fn = partial(
_postprocess_fn, args=self._args, kwargs=self._kwargs
)

fns = sample_fn, postprocess_fn
Expand Down Expand Up @@ -470,7 +473,9 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites):
upper_idx,
sample_fn,
init_val,
transform=_collect_fn(collect_fields, remove_sites),
transform=_collect_and_postprocess(
postprocess_fn, collect_fields, remove_sites
),
progbar=self.progress_bar,
return_last_val=True,
thinning=self.thinning,
Expand All @@ -487,18 +492,6 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites):
if len(collect_fields) == 1:
states = (states,)
states = dict(zip(collect_fields, states))
# Apply constraints if number of samples is non-zero
site_values = jax.tree.flatten(states[self._sample_field])[0]
# XXX: lax.map still works if some arrays have 0 size
# so we only need to filter out the case site_value.shape[0] == 0
# (which happens when lower_idx==upper_idx)
if len(site_values) > 0 and jnp.shape(site_values[0])[0] > 0:
if self._jit_model_args:
states[self._sample_field] = postprocess_fn(
states[self._sample_field], args, kwargs
)
else:
states[self._sample_field] = postprocess_fn(states[self._sample_field])
return states, last_state

def _set_collection_params(
Expand Down
5 changes: 4 additions & 1 deletion test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def test_mcmc_one_chain(deterministic, find_heuristic_step_size):

num_traces_for_heuristic = 2 if find_heuristic_step_size else 0
if deterministic:
assert GLOBAL["count"] == 4 + num_traces_for_heuristic
# We have two extra calls to the model to get deterministic values:
# 1. transform the init state
# 2. transform state during the loop
assert GLOBAL["count"] == 5 + num_traces_for_heuristic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add an assertion that x_copy is in the get_samples? Perhaps test_mcmc.py would be a better place for the suggested test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is tested at

assert samples["logits"].shape == (num_samples, N)

else:
assert GLOBAL["count"] == 3 + num_traces_for_heuristic

Expand Down
Loading