Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify the composability of handlers with jax primities #1926

Merged
merged 2 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@
can be composed together or new ones added to enable implementation of custom
inference utilities and algorithms.

When a handler, such as `handlers.seed`, is applied to a model in NumPyro, e.g.,
`seeded_model = handlers.seed(model, rng_seed=0)`, it creates a callable object
with stateful attributes. These attributes can interfere with JAX primitives,
such as `jax.jit`, `jax.vmap`, and `jax.grad`. To ensure proper composition with
JAX primitives, handlers should be applied locally within the function or context
where the model is used, rather than globally. For example::

# Good: can be used in a jitted function
def seeded_model(data):
return handlers.seed(model, rng_seed=0)(data)

# Bad: might create tracer-leaks when used in a jitted function
seeded_model = handlers.seed(model, rng_seed=0)


**Example**

As an example, we are using :class:`~numpyro.handlers.seed`, :class:`~numpyro.handlers.trace`
Expand Down
6 changes: 6 additions & 0 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,12 @@ def __init__(
stacklevel=find_stack_level(),
)
self.chain_method = chain_method
if callable(chain_method) and (num_chains > 1) and progress_bar:
warnings.warn(
"Disabling progress bar as `chain_method` is a callable and `num_chains > 1`.",
stacklevel=find_stack_level(),
)
progress_bar = False
self.progress_bar = progress_bar
if "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ:
self.progress_bar = False
Expand Down
7 changes: 6 additions & 1 deletion test/infer/test_hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,12 @@ def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
hmc_kernel = NUTS(model)
kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=["x"])
mcmc = MCMC(
kernel, num_warmup=100, num_chains=2, num_samples=100, chain_method=vmap
kernel,
num_warmup=100,
num_chains=2,
num_samples=100,
chain_method=vmap,
progress_bar=False,
)
mcmc.run(random.PRNGKey(0))
samples = mcmc.get_samples()
Expand Down
Loading