From f704dc4348aa8bcc971cd6dc6734c344b0081de6 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 27 Nov 2024 17:37:52 -0500 Subject: [PATCH 1/2] Bump to 0.16.0 --- examples/annotation.py | 2 +- examples/ar2.py | 2 +- examples/baseball.py | 2 +- examples/bnn.py | 2 +- examples/covtype.py | 2 +- examples/funnel.py | 2 +- examples/gaussian_shells.py | 2 +- examples/gp.py | 2 +- examples/hmm.py | 2 +- examples/holt_winters.py | 2 +- examples/horseshoe_regression.py | 2 +- examples/minipyro.py | 2 +- examples/mortality.py | 2 +- examples/neutra.py | 2 +- examples/ode.py | 2 +- examples/prodlda.py | 2 +- examples/proportion_test.py | 2 +- examples/sparse_regression.py | 2 +- examples/stochastic_volatility.py | 2 +- examples/thompson_sampling.py | 2 +- examples/toy_mixture_model_discrete_enumeration.py | 2 +- examples/ucbadmit.py | 2 +- examples/vae.py | 2 +- notebooks/source/bad_posterior_geometry.ipynb | 2 +- notebooks/source/bayesian_hierarchical_linear_regression.ipynb | 2 +- notebooks/source/bayesian_hierarchical_stacking.ipynb | 2 +- notebooks/source/bayesian_imputation.ipynb | 2 +- notebooks/source/bayesian_regression.ipynb | 2 +- notebooks/source/censoring.ipynb | 2 +- notebooks/source/gmm.ipynb | 2 +- notebooks/source/hsgp_example.ipynb | 2 +- notebooks/source/logistic_regression.ipynb | 2 +- notebooks/source/model_rendering.ipynb | 2 +- notebooks/source/ordinal_regression.ipynb | 2 +- notebooks/source/other_samplers.ipynb | 2 +- notebooks/source/time_series_forecasting.ipynb | 2 +- numpyro/version.py | 2 +- 37 files changed, 37 insertions(+), 37 deletions(-) diff --git a/examples/annotation.py b/examples/annotation.py index bcf393925..76836063e 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -320,7 +320,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Bayesian Models of Annotation") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/ar2.py b/examples/ar2.py index a04bad396..103bc5a62 100644 --- a/examples/ar2.py +++ b/examples/ar2.py @@ -138,7 +138,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="AR2 example") parser.add_argument("--num-data", nargs="?", default=142, type=int) parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) diff --git a/examples/baseball.py b/examples/baseball.py index 16149df15..6c6615f6e 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -210,7 +210,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Baseball batting average using MCMC") parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1500, type=int) diff --git a/examples/bnn.py b/examples/bnn.py index 806693052..a290683f3 100644 --- a/examples/bnn.py +++ b/examples/bnn.py @@ -160,7 +160,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Bayesian neural network example") parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/covtype.py b/examples/covtype.py index 184a8c80a..d7a488e90 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -206,7 +206,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="parse args") parser.add_argument( "-n", "--num-samples", default=1000, type=int, help="number of samples" diff --git a/examples/funnel.py b/examples/funnel.py index 0aa5d5ebb..707b44614 100644 --- a/examples/funnel.py +++ b/examples/funnel.py @@ -139,7 +139,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser( description="Non-centered reparameterization example" ) diff --git a/examples/gaussian_shells.py b/examples/gaussian_shells.py index 1cf0001cf..30a4dcc71 100644 --- a/examples/gaussian_shells.py +++ b/examples/gaussian_shells.py @@ -120,7 +120,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells") parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int) diff --git a/examples/gp.py b/examples/gp.py index 8c9955a08..3f26ad2ae 100644 --- a/examples/gp.py +++ b/examples/gp.py @@ -180,7 +180,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/hmm.py b/examples/hmm.py index f6deb3122..f60413271 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -263,7 +263,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Semi-supervised Hidden Markov Model") parser.add_argument("--num-categories", default=3, type=int) parser.add_argument("--num-words", default=10, type=int) diff --git a/examples/holt_winters.py b/examples/holt_winters.py index 8545473a7..8ab1a3aa7 100644 --- a/examples/holt_winters.py +++ b/examples/holt_winters.py @@ -180,7 +180,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Holt-Winters") parser.add_argument("--T", nargs="?", default=6, type=int) parser.add_argument("--future", nargs="?", default=1, type=int) diff --git a/examples/horseshoe_regression.py b/examples/horseshoe_regression.py index b92ef1998..608a6b5ca 100644 --- a/examples/horseshoe_regression.py +++ b/examples/horseshoe_regression.py @@ -162,7 +162,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Horseshoe regression example") parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/minipyro.py b/examples/minipyro.py index 1c969c86a..613aee406 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -58,7 +58,7 @@ def body_fn(i, val): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Mini Pyro demo") parser.add_argument("-f", "--full-pyro", action="store_true", default=False) parser.add_argument("-n", "--num-steps", default=1001, type=int) diff --git a/examples/mortality.py b/examples/mortality.py index 37fdff356..3bb460b31 100644 --- a/examples/mortality.py +++ b/examples/mortality.py @@ -220,7 +220,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Mortality regression model") parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int) diff --git a/examples/neutra.py b/examples/neutra.py index 72f5ea602..6a9557b4a 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -197,7 +197,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="NeuTra HMC") parser.add_argument("-n", "--num-samples", nargs="?", default=4000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/ode.py b/examples/ode.py index 3e6cdadc5..6655c32ed 100644 --- a/examples/ode.py +++ b/examples/ode.py @@ -117,7 +117,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Predator-Prey Model") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/prodlda.py b/examples/prodlda.py index 2195d3fed..aaec56197 100644 --- a/examples/prodlda.py +++ b/examples/prodlda.py @@ -315,7 +315,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser( description="Probabilistic topic modelling with Flax and Haiku" ) diff --git a/examples/proportion_test.py b/examples/proportion_test.py index 8dfd02877..494f0458b 100644 --- a/examples/proportion_test.py +++ b/examples/proportion_test.py @@ -158,7 +158,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Testing whether ") parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int) parser.add_argument("--num-warmup", nargs="?", default=1500, type=int) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index 8fd00a0e3..9cfc40eae 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -384,7 +384,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=500, type=int) diff --git a/examples/stochastic_volatility.py b/examples/stochastic_volatility.py index 288489484..0338e9fea 100644 --- a/examples/stochastic_volatility.py +++ b/examples/stochastic_volatility.py @@ -122,7 +122,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Stochastic Volatility Model") parser.add_argument("-n", "--num-samples", nargs="?", default=600, type=int) parser.add_argument("--num-warmup", nargs="?", default=600, type=int) diff --git a/examples/thompson_sampling.py b/examples/thompson_sampling.py index 65e35f2d8..1d3d8b646 100644 --- a/examples/thompson_sampling.py +++ b/examples/thompson_sampling.py @@ -292,7 +292,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Thompson sampling example") parser.add_argument( "--num-random", nargs="?", default=2, type=int, help="number of random draws" diff --git a/examples/toy_mixture_model_discrete_enumeration.py b/examples/toy_mixture_model_discrete_enumeration.py index e3b0a12d1..7f30662b7 100644 --- a/examples/toy_mixture_model_discrete_enumeration.py +++ b/examples/toy_mixture_model_discrete_enumeration.py @@ -126,7 +126,7 @@ def get_true_pred_CPDs(CPD, posterior_param): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="Toy mixture model") parser.add_argument("-n", "--num-steps", default=4000, type=int) parser.add_argument("-o", "--num-obs", default=10000, type=int) diff --git a/examples/ucbadmit.py b/examples/ucbadmit.py index 8846678af..6f5af22ed 100644 --- a/examples/ucbadmit.py +++ b/examples/ucbadmit.py @@ -151,7 +151,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser( description="UCBadmit gender discrimination using HMC" ) diff --git a/examples/vae.py b/examples/vae.py index b73d26ec7..2f762704b 100644 --- a/examples/vae.py +++ b/examples/vae.py @@ -160,7 +160,7 @@ def reconstruct_img(epoch, rng_key): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.15.3") + assert numpyro.__version__.startswith("0.16.0") parser = argparse.ArgumentParser(description="parse args") parser.add_argument( "-n", "--num-epochs", default=15, type=int, help="number of training epochs" diff --git a/notebooks/source/bad_posterior_geometry.ipynb b/notebooks/source/bad_posterior_geometry.ipynb index 8f3adf799..fd468e2f9 100644 --- a/notebooks/source/bad_posterior_geometry.ipynb +++ b/notebooks/source/bad_posterior_geometry.ipynb @@ -49,7 +49,7 @@ "import numpyro.distributions as dist\n", "from numpyro.infer import MCMC, NUTS\n", "\n", - "assert numpyro.__version__.startswith(\"0.15.3\")\n", + "assert numpyro.__version__.startswith(\"0.16.0\")\n", "\n", "# NB: replace cpu by gpu to run this notebook on gpu\n", "numpyro.set_platform(\"cpu\")" diff --git a/notebooks/source/bayesian_hierarchical_linear_regression.ipynb b/notebooks/source/bayesian_hierarchical_linear_regression.ipynb index 0acf4ae76..3f0d7cf4b 100644 --- a/notebooks/source/bayesian_hierarchical_linear_regression.ipynb +++ b/notebooks/source/bayesian_hierarchical_linear_regression.ipynb @@ -246,7 +246,7 @@ "import numpyro.distributions as dist\n", "from numpyro.infer import MCMC, NUTS, Predictive\n", "\n", - "assert numpyro.__version__.startswith(\"0.15.3\")" + "assert numpyro.__version__.startswith(\"0.16.0\")" ] }, { diff --git a/notebooks/source/bayesian_hierarchical_stacking.ipynb b/notebooks/source/bayesian_hierarchical_stacking.ipynb index 4589ca8d6..5287700b8 100644 --- a/notebooks/source/bayesian_hierarchical_stacking.ipynb +++ b/notebooks/source/bayesian_hierarchical_stacking.ipynb @@ -97,7 +97,7 @@ " set_matplotlib_formats(\"svg\")\n", "\n", "numpyro.set_host_device_count(4)\n", - "assert numpyro.__version__.startswith(\"0.15.3\")" + "assert numpyro.__version__.startswith(\"0.16.0\")" ] }, { diff --git a/notebooks/source/bayesian_imputation.ipynb b/notebooks/source/bayesian_imputation.ipynb index e0a6044cd..e766a1037 100644 --- a/notebooks/source/bayesian_imputation.ipynb +++ b/notebooks/source/bayesian_imputation.ipynb @@ -52,7 +52,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats(\"svg\")\n", "\n", - "assert numpyro.__version__.startswith(\"0.15.3\")" + "assert numpyro.__version__.startswith(\"0.16.0\")" ] }, { diff --git a/notebooks/source/bayesian_regression.ipynb b/notebooks/source/bayesian_regression.ipynb index 2c2f8878c..6e9b9e553 100644 --- a/notebooks/source/bayesian_regression.ipynb +++ b/notebooks/source/bayesian_regression.ipynb @@ -91,7 +91,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats(\"svg\")\n", "\n", - "assert numpyro.__version__.startswith(\"0.15.3\")" + "assert numpyro.__version__.startswith(\"0.16.0\")" ] }, { diff --git a/notebooks/source/censoring.ipynb b/notebooks/source/censoring.ipynb index 5da43483b..41f7c9430 100644 --- a/notebooks/source/censoring.ipynb +++ b/notebooks/source/censoring.ipynb @@ -60,7 +60,7 @@ "\n", "rng_key = random.PRNGKey(seed=0)\n", "\n", - "assert numpyro.__version__.startswith(\"0.15.3\")\n", + "assert numpyro.__version__.startswith(\"0.16.0\")\n", "\n", "%load_ext autoreload\n", "%autoreload 2\n", diff --git a/notebooks/source/gmm.ipynb b/notebooks/source/gmm.ipynb index 28c4b347a..6a1ee5c12 100644 --- a/notebooks/source/gmm.ipynb +++ b/notebooks/source/gmm.ipynb @@ -54,7 +54,7 @@ "%matplotlib inline\n", "\n", "smoke_test = \"CI\" in os.environ\n", - "assert numpyro.__version__.startswith(\"0.15.3\")" + "assert numpyro.__version__.startswith(\"0.16.0\")" ] }, { diff --git a/notebooks/source/hsgp_example.ipynb b/notebooks/source/hsgp_example.ipynb index d40804145..557a5354a 100644 --- a/notebooks/source/hsgp_example.ipynb +++ b/notebooks/source/hsgp_example.ipynb @@ -62,7 +62,7 @@ "\n", "rng_key = random.PRNGKey(seed=42)\n", "\n", - "assert numpyro.__version__.startswith(\"0.15.3\")\n", + "assert numpyro.__version__.startswith(\"0.16.0\")\n", "\n", "%load_ext autoreload\n", "%autoreload 2\n", diff --git a/notebooks/source/logistic_regression.ipynb b/notebooks/source/logistic_regression.ipynb index 66f1508dd..137e471af 100644 --- a/notebooks/source/logistic_regression.ipynb +++ b/notebooks/source/logistic_regression.ipynb @@ -41,7 +41,7 @@ "from numpyro.examples.datasets import COVTYPE, load_dataset\n", "from numpyro.infer import HMC, MCMC, NUTS\n", "\n", - "assert numpyro.__version__.startswith(\"0.15.3\")\n", + "assert numpyro.__version__.startswith(\"0.16.0\")\n", "\n", "# NB: replace gpu by cpu to run this notebook in cpu\n", "numpyro.set_platform(\"gpu\")" diff --git a/notebooks/source/model_rendering.ipynb b/notebooks/source/model_rendering.ipynb index cfba99f3e..454b432c1 100644 --- a/notebooks/source/model_rendering.ipynb +++ b/notebooks/source/model_rendering.ipynb @@ -38,7 +38,7 @@ "import numpyro.distributions as dist\n", "import numpyro.distributions.constraints as constraints\n", "\n", - "assert numpyro.__version__.startswith(\"0.15.3\")" + "assert numpyro.__version__.startswith(\"0.16.0\")" ] }, { diff --git a/notebooks/source/ordinal_regression.ipynb b/notebooks/source/ordinal_regression.ipynb index 7ed20baf9..6d48e2db8 100644 --- a/notebooks/source/ordinal_regression.ipynb +++ b/notebooks/source/ordinal_regression.ipynb @@ -54,7 +54,7 @@ "from numpyro.infer import MCMC, NUTS\n", "from numpyro.infer.reparam import TransformReparam\n", "\n", - "assert numpyro.__version__.startswith(\"0.15.3\")" + "assert numpyro.__version__.startswith(\"0.16.0\")" ] }, { diff --git a/notebooks/source/other_samplers.ipynb b/notebooks/source/other_samplers.ipynb index 03316ccfb..a3e7c5a12 100644 --- a/notebooks/source/other_samplers.ipynb +++ b/notebooks/source/other_samplers.ipynb @@ -67,7 +67,7 @@ "\n", "rng_key = random.PRNGKey(seed=42)\n", "\n", - "assert numpyro.__version__.startswith(\"0.15.3\")\n", + "assert numpyro.__version__.startswith(\"0.16.0\")\n", "\n", "%load_ext autoreload\n", "%autoreload 2\n", diff --git a/notebooks/source/time_series_forecasting.ipynb b/notebooks/source/time_series_forecasting.ipynb index c09d8b269..1b12b09da 100644 --- a/notebooks/source/time_series_forecasting.ipynb +++ b/notebooks/source/time_series_forecasting.ipynb @@ -48,7 +48,7 @@ " set_matplotlib_formats(\"svg\")\n", "\n", "numpyro.set_host_device_count(4)\n", - "assert numpyro.__version__.startswith(\"0.15.3\")" + "assert numpyro.__version__.startswith(\"0.16.0\")" ] }, { diff --git a/numpyro/version.py b/numpyro/version.py index 601b211d8..0799da2cd 100644 --- a/numpyro/version.py +++ b/numpyro/version.py @@ -1,4 +1,4 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -__version__ = "0.15.3" +__version__ = "0.16.0" From 48cb2b13b5453dba517ed7bd21bfde13af68d4c8 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 27 Nov 2024 17:44:09 -0500 Subject: [PATCH 2/2] update lint with ruff 0.8.0 --- numpyro/contrib/control_flow/scan.py | 7 ++++--- numpyro/infer/autoguide.py | 15 ++++++++++----- numpyro/infer/elbo.py | 7 +++++-- numpyro/infer/inspect.py | 11 +++++++---- test/infer/test_hmc_util.py | 5 +++-- test/test_distributions.py | 8 ++++++-- test/test_handlers.py | 6 ++++-- 7 files changed, 39 insertions(+), 20 deletions(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index b5069e6e3..893814f0f 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -198,9 +198,10 @@ def body_fn(wrapped_carry, x, prefix=None): ) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) - with handlers.block( - hide_fn=lambda site: not site["name"].startswith("_PREV_") - ), enum(first_available_dim=first_available_dim): + with ( + handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")), + enum(first_available_dim=first_available_dim), + ): wrapped_carry = (0, rng_key, init) y0s = [] # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan` diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 72b2df3bf..6df31055a 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -1463,8 +1463,10 @@ def _sample_latent(self, *args, **kwargs): if self.global_guide is not None: global_latents = self.global_guide(*args, **kwargs) rng_key = numpyro.prng_key() - with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute( - data=global_latents + with ( + handlers.block(), + handlers.seed(rng_seed=rng_key), + handlers.substitute(data=global_latents), ): global_outputs = self.global_guide.model(*args, **kwargs) local_args = (global_outputs,) @@ -1575,9 +1577,12 @@ def fn(x): if self.local_guide is not None: key = numpyro.prng_key() subsample_guide = partial(_subsample_model, self.local_guide) - with handlers.block(), handlers.trace() as tr, handlers.seed( - rng_seed=key - ), handlers.substitute(data=local_guide_params): + with ( + handlers.block(), + handlers.trace() as tr, + handlers.seed(rng_seed=key), + handlers.substitute(data=local_guide_params), + ): with warnings.catch_warnings(): warnings.simplefilter("ignore") subsample_guide(*local_args, **local_kwargs) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index a255fbb7f..f3d96d91d 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -893,8 +893,11 @@ def get_importance_trace_enum( trace as _trace, ) - with plate_to_enum_plate(), enum( - first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None + with ( + plate_to_enum_plate(), + enum( + first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None + ), ): guide = substitute(guide, data=params) with _without_rsample_stop_gradient(): diff --git a/numpyro/infer/inspect.py b/numpyro/infer/inspect.py index 7b999ede5..1e4b52b75 100644 --- a/numpyro/infer/inspect.py +++ b/numpyro/infer/inspect.py @@ -58,8 +58,10 @@ def get_trace(): def _get_log_probs(model, model_args, model_kwargs, **sample): # Note: We use seed 0 for parameter initialization. - with handlers.trace() as tr, handlers.seed(rng_seed=0), handlers.substitute( - data=sample + with ( + handlers.trace() as tr, + handlers.seed(rng_seed=0), + handlers.substitute(data=sample), ): model(*model_args, **model_kwargs) return { @@ -370,8 +372,9 @@ def process_message(self, msg): # Note: We use seed 0 for parameter initialization. with handlers.trace() as tr, handlers.seed(rng_seed=0): - with handlers.substitute(data=sample), substitute_deterministic( - data=sample + with ( + handlers.substitute(data=sample), + substitute_deterministic(data=sample), ): model(*model_args, **model_kwargs) provenance_arrays = {} diff --git a/test/infer/test_hmc_util.py b/test/infer/test_hmc_util.py index f1a81e436..080c6a6f4 100644 --- a/test/infer/test_hmc_util.py +++ b/test/infer/test_hmc_util.py @@ -57,8 +57,9 @@ def optimize(f): @pytest.mark.parametrize("regularize", [True, False]) @pytest.mark.filterwarnings("ignore:numpy.linalg support is experimental:UserWarning") def test_welford_covariance(jitted, diagonal, regularize): - with optional(jitted, disable_jit()), optional( - jitted, control_flow_prims_disabled() + with ( + optional(jitted, disable_jit()), + optional(jitted, control_flow_prims_disabled()), ): np.random.seed(0) loc = np.random.randn(3) diff --git a/test/test_distributions.py b/test/test_distributions.py index ca6cf6fa1..ff7037a92 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2237,8 +2237,12 @@ def g(x): def test_beta_proportion_invalid_mean(): - with dist.distribution.validation_enabled(), pytest.raises( - ValueError, match=r"^BetaProportion distribution got invalid mean parameter\.$" + with ( + dist.distribution.validation_enabled(), + pytest.raises( + ValueError, + match=r"^BetaProportion distribution got invalid mean parameter\.$", + ), ): dist.BetaProportion(1.0, 1.0) diff --git a/test/test_handlers.py b/test/test_handlers.py index 4ef449237..15121eb46 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -372,8 +372,10 @@ def test_subsample_substitute(): data = jnp.arange(100.0) subsample_size = 7 subsample = jnp.array([13, 3, 30, 4, 1, 68, 5]) - with handlers.trace() as tr, handlers.seed(rng_seed=0), handlers.substitute( - data={"a": subsample} + with ( + handlers.trace() as tr, + handlers.seed(rng_seed=0), + handlers.substitute(data={"a": subsample}), ): with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx: assert data[idx].shape == (subsample_size,)