diff --git a/examples/annotation.py b/examples/annotation.py index f820dd8a4..7c9cd6ef4 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -266,7 +266,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.4.1") + assert numpyro.__version__.startswith("0.5.0") parser = argparse.ArgumentParser(description="Bayesian Models of Annotation") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/baseball.py b/examples/baseball.py index de6a52293..9be4206c6 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -196,7 +196,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="Baseball batting average using MCMC") parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1500, type=int) diff --git a/examples/bnn.py b/examples/bnn.py index dd2b806dd..08b68407c 100644 --- a/examples/bnn.py +++ b/examples/bnn.py @@ -138,7 +138,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="Bayesian neural network example") parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1000, type=int) diff --git a/examples/covtype.py b/examples/covtype.py index fdac66a04..1dcff36e7 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -58,7 +58,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-samples', default=100, type=int, help='number of samples') parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') diff --git a/examples/funnel.py b/examples/funnel.py index 860950f34..16305aa7c 100644 --- a/examples/funnel.py +++ b/examples/funnel.py @@ -87,7 +87,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="Non-centered reparameterization example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1000, type=int) diff --git a/examples/gp.py b/examples/gp.py index f7c897bcb..bccf22a3d 100644 --- a/examples/gp.py +++ b/examples/gp.py @@ -142,7 +142,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1000, type=int) diff --git a/examples/hmm.py b/examples/hmm.py index b51708f22..c07e15fcf 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -191,7 +191,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description='Semi-supervised Hidden Markov Model') parser.add_argument('--num-categories', default=3, type=int) parser.add_argument('--num-words', default=10, type=int) diff --git a/examples/minipyro.py b/examples/minipyro.py index d6652ce06..f59abd2c8 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -58,7 +58,7 @@ def body_fn(i, val): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="Mini Pyro demo") parser.add_argument("-f", "--full-pyro", action="store_true", default=False) parser.add_argument("-n", "--num-steps", default=1001, type=int) diff --git a/examples/neutra.py b/examples/neutra.py index fc90c0a82..54d76d644 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -146,7 +146,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="NeuTra HMC") parser.add_argument('-n', '--num-samples', nargs='?', default=4000, type=int) parser.add_argument('--num-warmup', nargs='?', default=1000, type=int) diff --git a/examples/ode.py b/examples/ode.py index 8f7c35477..df06a84a7 100644 --- a/examples/ode.py +++ b/examples/ode.py @@ -103,7 +103,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description='Predator-Prey Model') parser.add_argument('-n', '--num-samples', nargs='?', default=1000, type=int) parser.add_argument('--num-warmup', nargs='?', default=1000, type=int) diff --git a/examples/proportion_test.py b/examples/proportion_test.py index 6964bc1b2..3e6e0f5a7 100644 --- a/examples/proportion_test.py +++ b/examples/proportion_test.py @@ -128,7 +128,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description='Testing whether ') parser.add_argument('-n', '--num-samples', nargs='?', default=500, type=int) parser.add_argument('--num-warmup', nargs='?', default=1500, type=int) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index a7bd8563d..a91c27297 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -320,7 +320,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs='?', default=500, type=int) diff --git a/examples/stochastic_volatility.py b/examples/stochastic_volatility.py index 95315d273..fb480c325 100644 --- a/examples/stochastic_volatility.py +++ b/examples/stochastic_volatility.py @@ -112,7 +112,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="Stochastic Volatility Model") parser.add_argument('-n', '--num-samples', nargs='?', default=600, type=int) parser.add_argument('--num-warmup', nargs='?', default=600, type=int) diff --git a/examples/ucbadmit.py b/examples/ucbadmit.py index 3daa72da8..a7fdff6ec 100644 --- a/examples/ucbadmit.py +++ b/examples/ucbadmit.py @@ -131,7 +131,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description='UCBadmit gender discrimination using HMC') parser.add_argument('-n', '--num-samples', nargs='?', default=2000, type=int) parser.add_argument('--num-warmup', nargs='?', default=500, type=int) diff --git a/examples/vae.py b/examples/vae.py index 37d5d3e29..3935fbf3e 100644 --- a/examples/vae.py +++ b/examples/vae.py @@ -131,7 +131,7 @@ def reconstruct_img(epoch, rng_key): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.4.1') + assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-epochs', default=15, type=int, help='number of training epochs') parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate') diff --git a/notebooks/source/bayesian_imputation.ipynb b/notebooks/source/bayesian_imputation.ipynb index 498af5b80..3498618ec 100644 --- a/notebooks/source/bayesian_imputation.ipynb +++ b/notebooks/source/bayesian_imputation.ipynb @@ -46,7 +46,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats(\"svg\")\n", "\n", - "assert numpyro.__version__.startswith(\"0.4.1\")" + "assert numpyro.__version__.startswith(\"0.5.0\")" ] }, { diff --git a/notebooks/source/bayesian_regression.ipynb b/notebooks/source/bayesian_regression.ipynb index 2c1af08f8..329fa6801 100644 --- a/notebooks/source/bayesian_regression.ipynb +++ b/notebooks/source/bayesian_regression.ipynb @@ -66,7 +66,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats('svg')\n", "\n", - "assert numpyro.__version__.startswith('0.4.1')" + "assert numpyro.__version__.startswith('0.5.0')" ] }, { diff --git a/notebooks/source/logistic_regression.ipynb b/notebooks/source/logistic_regression.ipynb index 042034459..901a0da06 100644 --- a/notebooks/source/logistic_regression.ipynb +++ b/notebooks/source/logistic_regression.ipynb @@ -31,7 +31,7 @@ "import numpyro.distributions as dist\n", "from numpyro.examples.datasets import COVTYPE, load_dataset\n", "from numpyro.infer import HMC, MCMC, NUTS\n", - "assert numpyro.__version__.startswith('0.4.1')\n", + "assert numpyro.__version__.startswith('0.5.0')\n", "\n", "# NB: replace gpu by cpu to run this notebook in cpu\n", "numpyro.set_platform(\"gpu\")" diff --git a/notebooks/source/ordinal_regression.ipynb b/notebooks/source/ordinal_regression.ipynb index 9aacc7049..8b1ac5b74 100644 --- a/notebooks/source/ordinal_regression.ipynb +++ b/notebooks/source/ordinal_regression.ipynb @@ -30,7 +30,7 @@ "from numpyro.infer import MCMC, NUTS\n", "import pandas as pd\n", "import seaborn as sns\n", - "assert numpyro.__version__.startswith('0.4.1')" + "assert numpyro.__version__.startswith('0.5.0')" ] }, { diff --git a/notebooks/source/time_series_forecasting.ipynb b/notebooks/source/time_series_forecasting.ipynb index 589daf4be..5fd18d541 100644 --- a/notebooks/source/time_series_forecasting.ipynb +++ b/notebooks/source/time_series_forecasting.ipynb @@ -39,7 +39,7 @@ " set_matplotlib_formats(\"svg\")\n", "\n", "numpyro.set_host_device_count(4)\n", - "assert numpyro.__version__.startswith(\"0.4.1\")" + "assert numpyro.__version__.startswith(\"0.5.0\")" ] }, { diff --git a/numpyro/contrib/einstein/kernels.py b/numpyro/contrib/einstein/kernels.py index a0b002aa1..d14261fe2 100644 --- a/numpyro/contrib/einstein/kernels.py +++ b/numpyro/contrib/einstein/kernels.py @@ -11,8 +11,8 @@ import jax.scipy.linalg import jax.scipy.stats -import numpyro.distributions as dist from numpyro.contrib.einstein.utils import posdef, safe_norm, sqrth +import numpyro.distributions as dist class PrecondMatrix(ABC): diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 69e2b917f..89f1c89d3 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -5,7 +5,7 @@ import copy from functools import partial -from jax import device_put, jacfwd, grad, ops, random, value_and_grad +from jax import device_put, grad, jacfwd, ops, random, value_and_grad import jax.numpy as jnp from jax.scipy.special import expit diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index 315cf3f65..a97b33fc6 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -3,7 +3,7 @@ from collections import namedtuple -from jax import jacfwd, grad, random, value_and_grad, vmap +from jax import grad, jacfwd, random, value_and_grad, vmap from jax.flatten_util import ravel_pytree import jax.numpy as jnp from jax.ops import index_update diff --git a/numpyro/infer/reparam.py b/numpyro/infer/reparam.py index c5623696b..05c5573d8 100644 --- a/numpyro/infer/reparam.py +++ b/numpyro/infer/reparam.py @@ -3,8 +3,8 @@ from abc import ABC, abstractmethod -import jax.numpy as jnp from jax import lax +import jax.numpy as jnp import numpyro import numpyro.distributions as dist diff --git a/numpyro/version.py b/numpyro/version.py index c8e4064aa..f86a79a48 100644 --- a/numpyro/version.py +++ b/numpyro/version.py @@ -1,4 +1,4 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -__version__ = '0.4.1' +__version__ = '0.5.0' diff --git a/setup.py b/setup.py index 4f19b3179..bd20fd1b0 100644 --- a/setup.py +++ b/setup.py @@ -33,9 +33,9 @@ author='Uber AI Labs', install_requires=[ # TODO: pin to a specific version for the release (until JAX's API becomes stable) - 'jax>=0.2.7', + 'jax==0.2.8', # check min version here: https://github.com/google/jax/blob/master/jax/lib/__init__.py#L26 - 'jaxlib>=0.1.56', + 'jaxlib==0.1.59', 'tqdm', ], extras_require={