Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump to 0.16.0 #1917

Merged
merged 2 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ar2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="AR2 example")
parser.add_argument("--num-data", nargs="?", default=142, type=int)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Baseball batting average using MCMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Bayesian neural network example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
"-n", "--num-samples", default=1000, type=int, help="number of samples"
Expand Down
2 changes: 1 addition & 1 deletion examples/funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(
description="Non-centered reparameterization example"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/gaussian_shells.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")

parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells")
parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Semi-supervised Hidden Markov Model")
parser.add_argument("--num-categories", default=3, type=int)
parser.add_argument("--num-words", default=10, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Holt-Winters")
parser.add_argument("--T", nargs="?", default=6, type=int)
parser.add_argument("--future", nargs="?", default=1, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/horseshoe_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Horseshoe regression example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def body_fn(i, val):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-f", "--full-pyro", action="store_true", default=False)
parser.add_argument("-n", "--num-steps", default=1001, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/mortality.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")

parser = argparse.ArgumentParser(description="Mortality regression model")
parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="NeuTra HMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=4000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Predator-Prey Model")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/prodlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(
description="Probabilistic topic modelling with Flax and Haiku"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/proportion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Testing whether ")
parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/stochastic_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
parser.add_argument("-n", "--num-samples", nargs="?", default=600, type=int)
parser.add_argument("--num-warmup", nargs="?", default=600, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/thompson_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Thompson sampling example")
parser.add_argument(
"--num-random", nargs="?", default=2, type=int, help="number of random draws"
Expand Down
2 changes: 1 addition & 1 deletion examples/toy_mixture_model_discrete_enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_true_pred_CPDs(CPD, posterior_param):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="Toy mixture model")
parser.add_argument("-n", "--num-steps", default=4000, type=int)
parser.add_argument("-o", "--num-obs", default=10000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(
description="UCBadmit gender discrimination using HMC"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def reconstruct_img(epoch, rng_key):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
assert numpyro.__version__.startswith("0.16.0")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
"-n", "--num-epochs", default=15, type=int, help="number of training epochs"
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bad_posterior_geometry.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"import numpyro.distributions as dist\n",
"from numpyro.infer import MCMC, NUTS\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
"\n",
"# NB: replace cpu by gpu to run this notebook on gpu\n",
"numpyro.set_platform(\"cpu\")"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@
"import numpyro.distributions as dist\n",
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bayesian_hierarchical_stacking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
" set_matplotlib_formats(\"svg\")\n",
"\n",
"numpyro.set_host_device_count(4)\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bayesian_imputation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
" set_matplotlib_formats(\"svg\")\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bayesian_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
" set_matplotlib_formats(\"svg\")\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/censoring.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"\n",
"rng_key = random.PRNGKey(seed=0)\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/gmm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"%matplotlib inline\n",
"\n",
"smoke_test = \"CI\" in os.environ\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/hsgp_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"\n",
"rng_key = random.PRNGKey(seed=42)\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/logistic_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"from numpyro.examples.datasets import COVTYPE, load_dataset\n",
"from numpyro.infer import HMC, MCMC, NUTS\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
"\n",
"# NB: replace gpu by cpu to run this notebook in cpu\n",
"numpyro.set_platform(\"gpu\")"
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/model_rendering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"import numpyro.distributions as dist\n",
"import numpyro.distributions.constraints as constraints\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/ordinal_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"from numpyro.infer import MCMC, NUTS\n",
"from numpyro.infer.reparam import TransformReparam\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/other_samplers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"\n",
"rng_key = random.PRNGKey(seed=42)\n",
"\n",
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
" set_matplotlib_formats(\"svg\")\n",
"\n",
"numpyro.set_host_device_count(4)\n",
"assert numpyro.__version__.startswith(\"0.15.3\")"
"assert numpyro.__version__.startswith(\"0.16.0\")"
]
},
{
Expand Down
7 changes: 4 additions & 3 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ def body_fn(wrapped_carry, x, prefix=None):
)
return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)

with handlers.block(
hide_fn=lambda site: not site["name"].startswith("_PREV_")
), enum(first_available_dim=first_available_dim):
with (
handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")),
enum(first_available_dim=first_available_dim),
):
wrapped_carry = (0, rng_key, init)
y0s = []
# We run unroll_steps + 1 where the last step is used for rolling with `lax.scan`
Expand Down
15 changes: 10 additions & 5 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,8 +1463,10 @@ def _sample_latent(self, *args, **kwargs):
if self.global_guide is not None:
global_latents = self.global_guide(*args, **kwargs)
rng_key = numpyro.prng_key()
with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute(
data=global_latents
with (
handlers.block(),
handlers.seed(rng_seed=rng_key),
handlers.substitute(data=global_latents),
):
global_outputs = self.global_guide.model(*args, **kwargs)
local_args = (global_outputs,)
Expand Down Expand Up @@ -1575,9 +1577,12 @@ def fn(x):
if self.local_guide is not None:
key = numpyro.prng_key()
subsample_guide = partial(_subsample_model, self.local_guide)
with handlers.block(), handlers.trace() as tr, handlers.seed(
rng_seed=key
), handlers.substitute(data=local_guide_params):
with (
handlers.block(),
handlers.trace() as tr,
handlers.seed(rng_seed=key),
handlers.substitute(data=local_guide_params),
):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
subsample_guide(*local_args, **local_kwargs)
Expand Down
7 changes: 5 additions & 2 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,8 +893,11 @@ def get_importance_trace_enum(
trace as _trace,
)

with plate_to_enum_plate(), enum(
first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None
with (
plate_to_enum_plate(),
enum(
first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None
),
):
guide = substitute(guide, data=params)
with _without_rsample_stop_gradient():
Expand Down
Loading