Skip to content

Commit

Permalink
Clarify the composability of handlers with jax primities (#1926)
Browse files Browse the repository at this point in the history
* Clarify the composability of handlers with jax primities

* disable progress bar for callable chain_method tests
  • Loading branch information
fehiepsi authored Dec 4, 2024
1 parent 2ff50d8 commit d66d27c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
15 changes: 15 additions & 0 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@
can be composed together or new ones added to enable implementation of custom
inference utilities and algorithms.
When a handler, such as `handlers.seed`, is applied to a model in NumPyro, e.g.,
`seeded_model = handlers.seed(model, rng_seed=0)`, it creates a callable object
with stateful attributes. These attributes can interfere with JAX primitives,
such as `jax.jit`, `jax.vmap`, and `jax.grad`. To ensure proper composition with
JAX primitives, handlers should be applied locally within the function or context
where the model is used, rather than globally. For example::
# Good: can be used in a jitted function
def seeded_model(data):
return handlers.seed(model, rng_seed=0)(data)
# Bad: might create tracer-leaks when used in a jitted function
seeded_model = handlers.seed(model, rng_seed=0)
**Example**
As an example, we are using :class:`~numpyro.handlers.seed`, :class:`~numpyro.handlers.trace`
Expand Down
6 changes: 6 additions & 0 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,12 @@ def __init__(
stacklevel=find_stack_level(),
)
self.chain_method = chain_method
if callable(chain_method) and (num_chains > 1) and progress_bar:
warnings.warn(
"Disabling progress bar as `chain_method` is a callable and `num_chains > 1`.",
stacklevel=find_stack_level(),
)
progress_bar = False
self.progress_bar = progress_bar
if "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ:
self.progress_bar = False
Expand Down
7 changes: 6 additions & 1 deletion test/infer/test_hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,12 @@ def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
hmc_kernel = NUTS(model)
kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=["x"])
mcmc = MCMC(
kernel, num_warmup=100, num_chains=2, num_samples=100, chain_method=vmap
kernel,
num_warmup=100,
num_chains=2,
num_samples=100,
chain_method=vmap,
progress_bar=False,
)
mcmc.run(random.PRNGKey(0))
samples = mcmc.get_samples()
Expand Down

0 comments on commit d66d27c

Please sign in to comment.