diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 11fdde45c..4b97528a7 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -10,6 +10,21 @@ can be composed together or new ones added to enable implementation of custom inference utilities and algorithms. +When a handler, such as `handlers.seed`, is applied to a model in NumPyro, e.g., +`seeded_model = handlers.seed(model, rng_seed=0)`, it creates a callable object +with stateful attributes. These attributes can interfere with JAX primitives, +such as `jax.jit`, `jax.vmap`, and `jax.grad`. To ensure proper composition with +JAX primitives, handlers should be applied locally within the function or context +where the model is used, rather than globally. For example:: + + # Good: can be used in a jitted function + def seeded_model(data): + return handlers.seed(model, rng_seed=0)(data) + + # Bad: might create tracer-leaks when used in a jitted function + seeded_model = handlers.seed(model, rng_seed=0) + + **Example** As an example, we are using :class:`~numpyro.handlers.seed`, :class:`~numpyro.handlers.trace` diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 7dd596bca..af49f0d99 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -365,6 +365,12 @@ def __init__( stacklevel=find_stack_level(), ) self.chain_method = chain_method + if callable(chain_method) and (num_chains > 1) and progress_bar: + warnings.warn( + "Disabling progress bar as `chain_method` is a callable and `num_chains > 1`.", + stacklevel=find_stack_level(), + ) + progress_bar = False self.progress_bar = progress_bar if "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ: self.progress_bar = False diff --git a/test/infer/test_hmc_gibbs.py b/test/infer/test_hmc_gibbs.py index 5caef6f38..c4195da68 100644 --- a/test/infer/test_hmc_gibbs.py +++ b/test/infer/test_hmc_gibbs.py @@ -478,7 +478,12 @@ def gibbs_fn(rng_key, gibbs_sites, hmc_sites): hmc_kernel = NUTS(model) kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=["x"]) mcmc = MCMC( - kernel, num_warmup=100, num_chains=2, num_samples=100, chain_method=vmap + kernel, + num_warmup=100, + num_chains=2, + num_samples=100, + chain_method=vmap, + progress_bar=False, ) mcmc.run(random.PRNGKey(0)) samples = mcmc.get_samples()