Skip to content

Commit

Permalink
Bump to 0.16.0 (#1917)
Browse files Browse the repository at this point in the history
* Bump to 0.16.0
  • Loading branch information
fehiepsi authored Nov 28, 2024
1 parent f87f40e commit 07e4c9b
Show file tree
Hide file tree
Showing 44 changed files with 76 additions and 57 deletions.
2 changes: 1 addition & 1 deletion examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ar2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="AR2 example")
parser.add_argument("--num-data", nargs="?", default=142, type=int)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Baseball batting average using MCMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Bayesian neural network example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
"-n", "--num-samples", default=1000, type=int, help="number of samples"
Expand Down
2 changes: 1 addition & 1 deletion examples/funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(
description="Non-centered reparameterization example"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/gaussian_shells.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")

parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells")
parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Semi-supervised Hidden Markov Model")
parser.add_argument("--num-categories", default=3, type=int)
parser.add_argument("--num-words", default=10, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Holt-Winters")
parser.add_argument("--T", nargs="?", default=6, type=int)
parser.add_argument("--future", nargs="?", default=1, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/horseshoe_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Horseshoe regression example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def body_fn(i, val):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-f", "--full-pyro", action="store_true", default=False)
parser.add_argument("-n", "--num-steps", default=1001, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/mortality.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")

parser = argparse.ArgumentParser(description="Mortality regression model")
parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="NeuTra HMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=4000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Predator-Prey Model")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/prodlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(
description="Probabilistic topic modelling with Flax and Haiku"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/proportion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Testing whether ")
parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/stochastic_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
parser.add_argument("-n", "--num-samples", nargs="?", default=600, type=int)
parser.add_argument("--num-warmup", nargs="?", default=600, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/thompson_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Thompson sampling example")
parser.add_argument(
"--num-random", nargs="?", default=2, type=int, help="number of random draws"
Expand Down
2 changes: 1 addition & 1 deletion examples/toy_mixture_model_discrete_enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_true_pred_CPDs(CPD, posterior_param):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Toy mixture model")
parser.add_argument("-n", "--num-steps", default=4000, type=int)
parser.add_argument("-o", "--num-obs", default=10000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(
description="UCBadmit gender discrimination using HMC"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def reconstruct_img(epoch, rng_key):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
"-n", "--num-epochs", default=15, type=int, help="number of training epochs"
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bad_posterior_geometry.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"import numpyro.distributions as dist\n",
"from numpyro.infer import MCMC, NUTS\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
"\n",
"# NB: replace cpu by gpu to run this notebook on gpu\n",
"numpyro.set_platform(\"cpu\")"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@
"import numpyro.distributions as dist\n",
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bayesian_hierarchical_stacking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
" set_matplotlib_formats(\"svg\")\n",
"\n",
"numpyro.set_host_device_count(4)\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bayesian_imputation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
" set_matplotlib_formats(\"svg\")\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bayesian_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
" set_matplotlib_formats(\"svg\")\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/censoring.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"\n",
"rng_key = random.PRNGKey(seed=0)\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/gmm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"%matplotlib inline\n",
"\n",
"smoke_test = \"CI\" in os.environ\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/hsgp_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"\n",
"rng_key = random.PRNGKey(seed=42)\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/logistic_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"from numpyro.examples.datasets import COVTYPE, load_dataset\n",
"from numpyro.infer import HMC, MCMC, NUTS\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
"\n",
"# NB: replace gpu by cpu to run this notebook in cpu\n",
"numpyro.set_platform(\"gpu\")"
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/model_rendering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"import numpyro.distributions as dist\n",
"import numpyro.distributions.constraints as constraints\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/ordinal_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"from numpyro.infer import MCMC, NUTS\n",
"from numpyro.infer.reparam import TransformReparam\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/other_samplers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"\n",
"rng_key = random.PRNGKey(seed=42)\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
" set_matplotlib_formats(\"svg\")\n",
"\n",
"numpyro.set_host_device_count(4)\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
7 changes: 4 additions & 3 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ def body_fn(wrapped_carry, x, prefix=None):
)
return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)

with handlers.block(
hide_fn=lambda site: not site["name"].startswith("_PREV_")
), enum(first_available_dim=first_available_dim):
with (
handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")),
enum(first_available_dim=first_available_dim),
):
wrapped_carry = (0, rng_key, init)
y0s = []
# We run unroll_steps + 1 where the last step is used for rolling with `lax.scan`
Expand Down
15 changes: 10 additions & 5 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,8 +1463,10 @@ def _sample_latent(self, *args, **kwargs):
if self.global_guide is not None:
global_latents = self.global_guide(*args, **kwargs)
rng_key = numpyro.prng_key()
with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute(
data=global_latents
with (
handlers.block(),
handlers.seed(rng_seed=rng_key),
handlers.substitute(data=global_latents),
):
global_outputs = self.global_guide.model(*args, **kwargs)
local_args = (global_outputs,)
Expand Down Expand Up @@ -1575,9 +1577,12 @@ def fn(x):
if self.local_guide is not None:
key = numpyro.prng_key()
subsample_guide = partial(_subsample_model, self.local_guide)
with handlers.block(), handlers.trace() as tr, handlers.seed(
rng_seed=key
), handlers.substitute(data=local_guide_params):
with (
handlers.block(),
handlers.trace() as tr,
handlers.seed(rng_seed=key),
handlers.substitute(data=local_guide_params),
):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
subsample_guide(*local_args, **local_kwargs)
Expand Down
7 changes: 5 additions & 2 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,8 +893,11 @@ def get_importance_trace_enum(
trace as _trace,
)

with plate_to_enum_plate(), enum(
first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None
with (
plate_to_enum_plate(),
enum(
first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None
),
):
guide = substitute(guide, data=params)
with _without_rsample_stop_gradient():
Expand Down
Loading

0 comments on commit 07e4c9b

Please sign in to comment.