diff --git a/README.md b/README.md index 62d5e591a..475afde89 100644 --- a/README.md +++ b/README.md @@ -237,22 +237,22 @@ conda install -c conda-forge numpyro - Provide the `rng_key` argument to `numpyro.sample`. e.g. `numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))`. - Wrap the code in a `seed` handler, used either as a context manager or as a function that wraps over the original callable. e.g. - ```python - with handlers.seed(rng_seed=0): # random.PRNGKey(0) is used - x = numpyro.sample('x', dist.Beta(1, 1)) # uses a PRNGKey split from random.PRNGKey(0) - y = numpyro.sample('y', dist.Bernoulli(x)) # uses different PRNGKey split from the last one - ``` + ```python + with handlers.seed(rng_seed=0): # random.PRNGKey(0) is used + x = numpyro.sample('x', dist.Beta(1, 1)) # uses a PRNGKey split from random.PRNGKey(0) + y = numpyro.sample('y', dist.Bernoulli(x)) # uses different PRNGKey split from the last one + ``` , or as a higher order function: - ```python - def fn(): - x = numpyro.sample('x', dist.Beta(1, 1)) - y = numpyro.sample('y', dist.Bernoulli(x)) - return y + ```python + def fn(): + x = numpyro.sample('x', dist.Beta(1, 1)) + y = numpyro.sample('y', dist.Bernoulli(x)) + return y - print(handlers.seed(fn, rng_seed=0)()) - ``` + print(handlers.seed(fn, rng_seed=0)()) + ``` 2. Can I use the same Pyro model for doing inference in NumPyro? diff --git a/docker/dev/Dockerfile b/docker/dev/Dockerfile index 2cf3c62a1..f15611d74 100644 --- a/docker/dev/Dockerfile +++ b/docker/dev/Dockerfile @@ -10,7 +10,7 @@ FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 # note that this image uses Python 3.8 ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \ # declare the cuda version for pulling appropriate jaxlib wheel - JAXLIB_CUDA=112 + JAXLIB_CUDA=111 # install python3 and pip on top of the base Ubuntu image # unlike for release, we need to install git and setuptools too diff --git a/docker/release/Dockerfile b/docker/release/Dockerfile index 14fa65023..2ad632b10 100644 --- a/docker/release/Dockerfile +++ b/docker/release/Dockerfile @@ -8,14 +8,7 @@ FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 # declare the image name # note that this image uses Python 3.8 ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \ - # declare what jaxlib, jax, and numpyro versions to use - # right now this is a manual process - in the future it should be automated - # if a CI/CD system is expected to pass in these arguments - # the dockerfile should be modified accordingly - JAXLIB_CUDA=112 \ - JAXLIB_VERSION=0.1.62 \ - JAX_VERSION=0.2.10 \ - NUMPYRO_VERSION=0.6.0 + JAXLIB_CUDA=111 # install python3 and pip on top of the base Ubuntu image RUN apt update && \ @@ -26,8 +19,6 @@ ENV PATH=/root/.local/bin:$PATH # install python packages via pip RUN pip3 install --user \ - numpyro==${NUMPYRO_VERSION} \ - jax==${JAX_VERSION} \ # we pull wheels from google's api as per https://github.com/google/jax#installation # the pre-compiled wheels that google provides work for now. This may change in the future (and necessitate building from source) - jaxlib==${JAXLIB_VERSION}+cuda${JAXLIB_CUDA} -f https://storage.googleapis.com/jax-releases/jax_releases.html \ No newline at end of file + numpyro[cuda${JAXLIB_CUDA}] -f https://storage.googleapis.com/jax-releases/jax_releases.html diff --git a/docs/requirements.txt b/docs/requirements.txt index 56c361b0f..567c3b922 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,11 +1,11 @@ dm-haiku flax -funsor -jax>=0.1.65 -jaxlib>=0.1.45 -jaxns==0.0.7 -optax==0.0.6 +funsor>=0.4.1 +jax>=0.2.11 +jaxlib>=0.1.62 +jaxns>=0.0.7 +optax>=0.0.6 nbsphinx>=0.8.5 sphinx-gallery -tfp-nightly<=0.14.0.dev20210608 # TODO: change this to tensorflow-probability when it is stable +tensorflow_probability>=0.13 tqdm diff --git a/docs/source/api.rst b/docs/source/api.rst deleted file mode 100644 index 9478f6663..000000000 --- a/docs/source/api.rst +++ /dev/null @@ -1,48 +0,0 @@ -.. currentmodule:: api - -API Reference -============= - -Modeling --------- - -.. toctree:: - :glob: - :maxdepth: 1 - - primitives - handlers - -Distributions -------------- - -.. toctree:: - :glob: - :maxdepth: 1 - - distributions - -Inference ---------- - -.. toctree:: - :glob: - :maxdepth: 1 - - mcmc - svi - autoguide - reparam - funsor - optimizers - diagnostics - utilities - -Contributed Code ----------------- - -.. toctree:: - :glob: - :maxdepth: 1 - - contrib diff --git a/docs/source/contrib.rst b/docs/source/contrib.rst index acb050a94..741417163 100644 --- a/docs/source/contrib.rst +++ b/docs/source/contrib.rst @@ -1,3 +1,6 @@ +Contributed Code +================ + Nested Sampling ~~~~~~~~~~~~~~~ diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 845d22282..3174115f8 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -1,8 +1,11 @@ +Distributions +============= + Base Distribution -================= +----------------- Distribution ------------- +^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.distribution.Distribution :members: :undoc-members: @@ -10,7 +13,7 @@ Distribution :member-order: bysource ExpandedDistribution --------------------- +^^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.distribution.ExpandedDistribution :members: :undoc-members: @@ -18,7 +21,7 @@ ExpandedDistribution :member-order: bysource FoldedDistribution ------------------- +^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.distribution.FoldedDistribution :members: :undoc-members: @@ -26,7 +29,7 @@ FoldedDistribution :member-order: bysource ImproperUniform ---------------- +^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.distribution.ImproperUniform :members: :undoc-members: @@ -34,7 +37,7 @@ ImproperUniform :member-order: bysource Independent ------------ +^^^^^^^^^^^ .. autoclass:: numpyro.distributions.distribution.Independent :members: :undoc-members: @@ -42,7 +45,7 @@ Independent :member-order: bysource MaskedDistribution ------------------- +^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.distribution.MaskedDistribution :members: :undoc-members: @@ -50,7 +53,7 @@ MaskedDistribution :member-order: bysource TransformedDistribution ------------------------ +^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.distribution.TransformedDistribution :members: :undoc-members: @@ -58,7 +61,7 @@ TransformedDistribution :member-order: bysource Delta ------ +^^^^^ .. autoclass:: numpyro.distributions.distribution.Delta :members: :undoc-members: @@ -66,7 +69,7 @@ Delta :member-order: bysource Unit ----- +^^^^ .. autoclass:: numpyro.distributions.distribution.Unit :members: :undoc-members: @@ -75,10 +78,10 @@ Unit Continuous Distributions -======================== +------------------------ Beta ----- +^^^^ .. autoclass:: numpyro.distributions.continuous.Beta :members: :undoc-members: @@ -86,7 +89,7 @@ Beta :member-order: bysource BetaProportion --------------- +^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.BetaProportion :members: :undoc-members: @@ -95,7 +98,7 @@ BetaProportion Cauchy ------- +^^^^^^ .. autoclass:: numpyro.distributions.continuous.Cauchy :members: :undoc-members: @@ -103,7 +106,7 @@ Cauchy :member-order: bysource Chi2 ----- +^^^^ .. autoclass:: numpyro.distributions.continuous.Chi2 :members: :undoc-members: @@ -111,7 +114,7 @@ Chi2 :member-order: bysource Dirichlet ---------- +^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.Dirichlet :members: :undoc-members: @@ -119,7 +122,7 @@ Dirichlet :member-order: bysource Exponential ------------ +^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.Exponential :members: :undoc-members: @@ -127,7 +130,7 @@ Exponential :member-order: bysource Gamma ------ +^^^^^ .. autoclass:: numpyro.distributions.continuous.Gamma :members: :undoc-members: @@ -135,7 +138,7 @@ Gamma :member-order: bysource Gumbel ------- +^^^^^^ .. autoclass:: numpyro.distributions.continuous.Gumbel :members: :undoc-members: @@ -143,7 +146,7 @@ Gumbel :member-order: bysource GaussianRandomWalk ------------------- +^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.GaussianRandomWalk :members: :undoc-members: @@ -151,7 +154,7 @@ GaussianRandomWalk :member-order: bysource HalfCauchy ----------- +^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.HalfCauchy :members: :undoc-members: @@ -159,7 +162,7 @@ HalfCauchy :member-order: bysource HalfNormal ----------- +^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.HalfNormal :members: :undoc-members: @@ -167,7 +170,7 @@ HalfNormal :member-order: bysource InverseGamma ------------- +^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.InverseGamma :members: :undoc-members: @@ -175,7 +178,7 @@ InverseGamma :member-order: bysource Laplace -------- +^^^^^^^ .. autoclass:: numpyro.distributions.continuous.Laplace :members: :undoc-members: @@ -183,7 +186,7 @@ Laplace :member-order: bysource LKJ ---- +^^^ .. autoclass:: numpyro.distributions.continuous.LKJ :members: :undoc-members: @@ -191,7 +194,7 @@ LKJ :member-order: bysource LKJCholesky ------------ +^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.LKJCholesky :members: :undoc-members: @@ -199,7 +202,7 @@ LKJCholesky :member-order: bysource LogNormal ---------- +^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.LogNormal :members: :undoc-members: @@ -207,7 +210,7 @@ LogNormal :member-order: bysource Logistic --------- +^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.Logistic :members: :undoc-members: @@ -215,7 +218,7 @@ Logistic :member-order: bysource MultivariateNormal ------------------- +^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.MultivariateNormal :members: :undoc-members: @@ -223,7 +226,7 @@ MultivariateNormal :member-order: bysource LowRankMultivariateNormal -------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.LowRankMultivariateNormal :members: :undoc-members: @@ -231,7 +234,7 @@ LowRankMultivariateNormal :member-order: bysource Normal ------- +^^^^^^ .. autoclass:: numpyro.distributions.continuous.Normal :members: :undoc-members: @@ -239,7 +242,7 @@ Normal :member-order: bysource Pareto ------- +^^^^^^ .. autoclass:: numpyro.distributions.continuous.Pareto :members: :undoc-members: @@ -247,7 +250,7 @@ Pareto :member-order: bysource SoftLaplace ------------ +^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.SoftLaplace :members: :undoc-members: @@ -255,7 +258,7 @@ SoftLaplace :member-order: bysource StudentT --------- +^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.StudentT :members: :undoc-members: @@ -263,7 +266,7 @@ StudentT :member-order: bysource Uniform -------- +^^^^^^^ .. autoclass:: numpyro.distributions.continuous.Uniform :members: :undoc-members: @@ -271,7 +274,7 @@ Uniform :member-order: bysource Weibull -------- +^^^^^^^ .. autoclass:: numpyro.distributions.continuous.Weibull :members: :undoc-members: @@ -280,14 +283,14 @@ Weibull Discrete Distributions -====================== +---------------------- Bernoulli ---------- +^^^^^^^^^ .. autofunction:: numpyro.distributions.discrete.Bernoulli BernoulliLogits ---------------- +^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.BernoulliLogits :members: :undoc-members: @@ -295,7 +298,7 @@ BernoulliLogits :member-order: bysource BernoulliProbs --------------- +^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.BernoulliProbs :members: :undoc-members: @@ -303,7 +306,7 @@ BernoulliProbs :member-order: bysource BetaBinomial ------------- +^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.conjugate.BetaBinomial :members: :undoc-members: @@ -311,11 +314,11 @@ BetaBinomial :member-order: bysource Binomial ---------- +^^^^^^^^^ .. autofunction:: numpyro.distributions.discrete.Binomial BinomialLogits --------------- +^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.BinomialLogits :members: :undoc-members: @@ -323,7 +326,7 @@ BinomialLogits :member-order: bysource BinomialProbs -------------- +^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.BinomialProbs :members: :undoc-members: @@ -331,11 +334,11 @@ BinomialProbs :member-order: bysource Categorical ------------ +^^^^^^^^^^^ .. autofunction:: numpyro.distributions.discrete.Categorical CategoricalLogits ------------------ +^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.CategoricalLogits :members: :undoc-members: @@ -343,7 +346,7 @@ CategoricalLogits :member-order: bysource CategoricalProbs ----------------- +^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.CategoricalProbs :members: :undoc-members: @@ -351,7 +354,7 @@ CategoricalProbs :member-order: bysource DirichletMultinomial --------------------- +^^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.conjugate.DirichletMultinomial :members: :undoc-members: @@ -359,7 +362,7 @@ DirichletMultinomial :member-order: bysource GammaPoisson ------------- +^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.conjugate.GammaPoisson :members: :undoc-members: @@ -367,11 +370,11 @@ GammaPoisson :member-order: bysource Geometric ---------- +^^^^^^^^^ .. autofunction:: numpyro.distributions.discrete.Geometric GeometricLogits ---------------- +^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.GeometricLogits :members: :undoc-members: @@ -379,7 +382,7 @@ GeometricLogits :member-order: bysource GeometricProbs --------------- +^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.GeometricProbs :members: :undoc-members: @@ -387,11 +390,11 @@ GeometricProbs :member-order: bysource Multinomial ------------ +^^^^^^^^^^^ .. autofunction:: numpyro.distributions.discrete.Multinomial MultinomialLogits ------------------ +^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.MultinomialLogits :members: :undoc-members: @@ -399,7 +402,7 @@ MultinomialLogits :member-order: bysource MultinomialProbs ----------------- +^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.MultinomialProbs :members: :undoc-members: @@ -407,7 +410,7 @@ MultinomialProbs :member-order: bysource OrderedLogistic ---------------- +^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.OrderedLogistic :members: :undoc-members: @@ -415,11 +418,11 @@ OrderedLogistic :member-order: bysource NegativeBinomial ----------------- +^^^^^^^^^^^^^^^^ .. autofunction:: numpyro.distributions.conjugate.NegativeBinomial NegativeBinomialLogits ----------------------- +^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.conjugate.NegativeBinomialLogits :members: :undoc-members: @@ -427,7 +430,7 @@ NegativeBinomialLogits :member-order: bysource NegativeBinomialProbs ---------------------- +^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.conjugate.NegativeBinomialProbs :members: :undoc-members: @@ -435,7 +438,7 @@ NegativeBinomialProbs :member-order: bysource NegativeBinomial2 ------------------ +^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.conjugate.NegativeBinomial2 :members: :undoc-members: @@ -443,7 +446,7 @@ NegativeBinomial2 :member-order: bysource Poisson -------- +^^^^^^^ .. autoclass:: numpyro.distributions.discrete.Poisson :members: :undoc-members: @@ -451,7 +454,7 @@ Poisson :member-order: bysource PRNGIdentity ------------- +^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.PRNGIdentity :members: :undoc-members: @@ -459,11 +462,11 @@ PRNGIdentity :member-order: bysource ZeroInflatedDistribution ------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^ .. autofunction:: numpyro.distributions.discrete.ZeroInflatedDistribution ZeroInflatedPoisson -------------------- +^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.discrete.ZeroInflatedPoisson :members: :undoc-members: @@ -471,15 +474,15 @@ ZeroInflatedPoisson :member-order: bysource ZeroInflatedNegativeBinomial2 ------------------------------ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autofunction:: numpyro.distributions.conjugate.ZeroInflatedNegativeBinomial2 Directional Distributions -========================= +------------------------- ProjectedNormal ---------------- +^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.directional.ProjectedNormal :members: :undoc-members: @@ -487,7 +490,7 @@ ProjectedNormal :member-order: bysource VonMises --------- +^^^^^^^^ .. autoclass:: numpyro.distributions.directional.VonMises :members: :undoc-members: @@ -496,10 +499,10 @@ VonMises Truncated Distributions -======================= +----------------------- LeftTruncatedDistribution -------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.truncated.LeftTruncatedDistribution :members: :undoc-members: @@ -507,7 +510,7 @@ LeftTruncatedDistribution :member-order: bysource RightTruncatedDistribution --------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.truncated.RightTruncatedDistribution :members: :undoc-members: @@ -515,7 +518,7 @@ RightTruncatedDistribution :member-order: bysource TruncatedCauchy ---------------- +^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.truncated.TruncatedCauchy :members: :undoc-members: @@ -523,11 +526,11 @@ TruncatedCauchy :member-order: bysource TruncatedDistribution ---------------------- +^^^^^^^^^^^^^^^^^^^^^ .. autofunction:: numpyro.distributions.truncated.TruncatedDistribution TruncatedNormal ---------------- +^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.truncated.TruncatedNormal :members: :undoc-members: @@ -535,7 +538,7 @@ TruncatedNormal :member-order: bysource TruncatedPolyaGamma -------------------- +^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.truncated.TruncatedPolyaGamma :members: :undoc-members: @@ -543,7 +546,7 @@ TruncatedPolyaGamma :member-order: bysource TwoSidedTruncatedDistribution ------------------------------ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.truncated.TwoSidedTruncatedDistribution :members: :undoc-members: @@ -552,7 +555,7 @@ TwoSidedTruncatedDistribution TensorFlow Distributions -======================== +------------------------ Thin wrappers around TensorFlow Probability (TFP) distributions. For details on the TFP distribution interface, see `its Distribution docs `_. @@ -561,10 +564,10 @@ see `its Distribution docs .. nbgallery:: @@ -39,6 +43,7 @@ NumPyro documentation examples/capture_recapture examples/gaussian_shells tutorials/discrete_imputation + examples/prodlda .. nbgallery:: :maxdepth: 1 @@ -59,7 +64,6 @@ NumPyro documentation examples/neutra examples/covtype examples/thompson_sampling - examples/prodlda Indices and tables diff --git a/docs/source/infer.rst b/docs/source/infer.rst new file mode 100644 index 000000000..4d44d68de --- /dev/null +++ b/docs/source/infer.rst @@ -0,0 +1,15 @@ +Inference +========= + +.. toctree:: + :glob: + :maxdepth: 1 + + mcmc + svi + autoguide + reparam + funsor + optimizers + diagnostics + utilities diff --git a/docs/source/mcmc.rst b/docs/source/mcmc.rst index 416e0a3a6..10b477931 100644 --- a/docs/source/mcmc.rst +++ b/docs/source/mcmc.rst @@ -10,54 +10,72 @@ Markov Chain Monte Carlo (MCMC) MCMC Kernels ------------ +MCMCKernel +^^^^^^^^^^ .. autoclass:: numpyro.infer.mcmc.MCMCKernel :members: :undoc-members: :show-inheritance: :member-order: bysource +BarkerMH +^^^^^^^^ .. autoclass:: numpyro.infer.barker.BarkerMH :members: :undoc-members: :show-inheritance: :member-order: bysource +HMC +^^^ .. autoclass:: numpyro.infer.hmc.HMC :members: :undoc-members: :show-inheritance: :member-order: bysource +NUTS +^^^^ .. autoclass:: numpyro.infer.hmc.NUTS :members: :undoc-members: :show-inheritance: :member-order: bysource +HMCGibbs +^^^^^^^^ .. autoclass:: numpyro.infer.hmc_gibbs.HMCGibbs :members: :undoc-members: :show-inheritance: :member-order: bysource +DiscreteHMCGibbs +^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.infer.hmc_gibbs.DiscreteHMCGibbs :members: :undoc-members: :show-inheritance: :member-order: bysource +MixedHMC +^^^^^^^^ .. autoclass:: numpyro.infer.mixed_hmc.MixedHMC :members: :undoc-members: :show-inheritance: :member-order: bysource +HMCECS +^^^^^^ .. autoclass:: numpyro.infer.hmc_gibbs.HMCECS :members: :undoc-members: :show-inheritance: :member-order: bysource +SA +^^ .. autoclass:: numpyro.infer.sa.SA :members: :undoc-members: diff --git a/examples/annotation.py b/examples/annotation.py index 264d14ac7..a84f896e5 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -327,7 +327,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Bayesian Models of Annotation") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/baseball.py b/examples/baseball.py index 463c54203..549ddae72 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -210,7 +210,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Baseball batting average using MCMC") parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1500, type=int) diff --git a/examples/bnn.py b/examples/bnn.py index 333ef1c6c..393908de5 100644 --- a/examples/bnn.py +++ b/examples/bnn.py @@ -156,7 +156,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Bayesian neural network example") parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/covtype.py b/examples/covtype.py index e62867ea1..f9dd88322 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -206,7 +206,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="parse args") parser.add_argument( "-n", "--num-samples", default=1000, type=int, help="number of samples" diff --git a/examples/funnel.py b/examples/funnel.py index 16897f06f..16cefdf5d 100644 --- a/examples/funnel.py +++ b/examples/funnel.py @@ -108,7 +108,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser( description="Non-centered reparameterization example" ) diff --git a/examples/gaussian_shells.py b/examples/gaussian_shells.py index 262c6dd7d..b3851143b 100644 --- a/examples/gaussian_shells.py +++ b/examples/gaussian_shells.py @@ -120,7 +120,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells") parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/gp.py b/examples/gp.py index 11b7a4d3a..c7952c984 100644 --- a/examples/gp.py +++ b/examples/gp.py @@ -170,7 +170,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/hmm.py b/examples/hmm.py index f57e6dffd..47a3c4025 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -263,7 +263,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Semi-supervised Hidden Markov Model") parser.add_argument("--num-categories", default=3, type=int) parser.add_argument("--num-words", default=10, type=int) diff --git a/examples/minipyro.py b/examples/minipyro.py index b54c5f08e..4d9e57800 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -58,7 +58,7 @@ def body_fn(i, val): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Mini Pyro demo") parser.add_argument("-f", "--full-pyro", action="store_true", default=False) parser.add_argument("-n", "--num-steps", default=1001, type=int) diff --git a/examples/neutra.py b/examples/neutra.py index 4a30c87ab..e8348ac8c 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -197,7 +197,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="NeuTra HMC") parser.add_argument("-n", "--num-samples", nargs="?", default=4000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/ode.py b/examples/ode.py index 6b9f241df..d00dca624 100644 --- a/examples/ode.py +++ b/examples/ode.py @@ -116,7 +116,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Predator-Prey Model") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/prodlda.py b/examples/prodlda.py index d7b2ae5c3..ac6b802d3 100644 --- a/examples/prodlda.py +++ b/examples/prodlda.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 """ -Example: ProdLDA -================ +Example: ProdLDA with Flax and Haiku +==================================== + In this example, we will follow [1] to implement the ProdLDA topic model from Autoencoding Variational Inference For Topic Models by Akash Srivastava and Charles Sutton [2]. This model returns consistently better topics than vanilla LDA and trains @@ -313,7 +314,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser( description="Probabilistic topic modelling with Flax and Haiku" ) diff --git a/examples/proportion_test.py b/examples/proportion_test.py index b02e9aeb5..164e9147d 100644 --- a/examples/proportion_test.py +++ b/examples/proportion_test.py @@ -160,7 +160,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Testing whether ") parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int) parser.add_argument("--num-warmup", nargs="?", default=1500, type=int) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index f94cfac48..338adbc4f 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -401,7 +401,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=500, type=int) diff --git a/examples/stochastic_volatility.py b/examples/stochastic_volatility.py index ad4b1c64c..ec094b306 100644 --- a/examples/stochastic_volatility.py +++ b/examples/stochastic_volatility.py @@ -122,7 +122,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Stochastic Volatility Model") parser.add_argument("-n", "--num-samples", nargs="?", default=600, type=int) parser.add_argument("--num-warmup", nargs="?", default=600, type=int) diff --git a/examples/thompson_sampling.py b/examples/thompson_sampling.py index 0c15f6bf4..d78d58a4a 100644 --- a/examples/thompson_sampling.py +++ b/examples/thompson_sampling.py @@ -294,7 +294,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="Thompson sampling example") parser.add_argument( "--num-random", nargs="?", default=2, type=int, help="number of random draws" diff --git a/examples/ucbadmit.py b/examples/ucbadmit.py index a7523a787..9bf99b9b5 100644 --- a/examples/ucbadmit.py +++ b/examples/ucbadmit.py @@ -151,7 +151,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser( description="UCBadmit gender discrimination using HMC" ) diff --git a/examples/vae.py b/examples/vae.py index c82a7cfdc..20ab6881e 100644 --- a/examples/vae.py +++ b/examples/vae.py @@ -159,7 +159,7 @@ def reconstruct_img(epoch, rng_key): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.6.0") + assert numpyro.__version__.startswith("0.7.0") parser = argparse.ArgumentParser(description="parse args") parser.add_argument( "-n", "--num-epochs", default=15, type=int, help="number of training epochs" diff --git a/notebooks/source/bayesian_hierarchical_linear_regression.ipynb b/notebooks/source/bayesian_hierarchical_linear_regression.ipynb index 49670aef3..3ab89d034 100644 --- a/notebooks/source/bayesian_hierarchical_linear_regression.ipynb +++ b/notebooks/source/bayesian_hierarchical_linear_regression.ipynb @@ -242,7 +242,7 @@ "import numpyro.distributions as dist\n", "from jax import random\n", "\n", - "assert numpyro.__version__.startswith('0.6.0')" + "assert numpyro.__version__.startswith('0.7.0')" ] }, { diff --git a/notebooks/source/bayesian_imputation.ipynb b/notebooks/source/bayesian_imputation.ipynb index e7a03197e..da83bcc74 100644 --- a/notebooks/source/bayesian_imputation.ipynb +++ b/notebooks/source/bayesian_imputation.ipynb @@ -55,7 +55,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats(\"svg\")\n", "\n", - "assert numpyro.__version__.startswith('0.6.0')" + "assert numpyro.__version__.startswith('0.7.0')" ] }, { diff --git a/notebooks/source/bayesian_regression.ipynb b/notebooks/source/bayesian_regression.ipynb index 9a72853c3..40b008d27 100644 --- a/notebooks/source/bayesian_regression.ipynb +++ b/notebooks/source/bayesian_regression.ipynb @@ -95,7 +95,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats('svg')\n", "\n", - "assert numpyro.__version__.startswith('0.6.0')" + "assert numpyro.__version__.startswith('0.7.0')" ], "execution_count": 2, "outputs": [] diff --git a/notebooks/source/logistic_regression.ipynb b/notebooks/source/logistic_regression.ipynb index bea5eecbd..22cfa561c 100644 --- a/notebooks/source/logistic_regression.ipynb +++ b/notebooks/source/logistic_regression.ipynb @@ -40,7 +40,7 @@ "import numpyro.distributions as dist\n", "from numpyro.examples.datasets import COVTYPE, load_dataset\n", "from numpyro.infer import HMC, MCMC, NUTS\n", - "assert numpyro.__version__.startswith('0.6.0')\n", + "assert numpyro.__version__.startswith('0.7.0')\n", "\n", "# NB: replace gpu by cpu to run this notebook in cpu\n", "numpyro.set_platform(\"gpu\")" diff --git a/notebooks/source/model_rendering.ipynb b/notebooks/source/model_rendering.ipynb index 17d46747a..abf29442b 100644 --- a/notebooks/source/model_rendering.ipynb +++ b/notebooks/source/model_rendering.ipynb @@ -33,7 +33,7 @@ "import numpyro\n", "import numpyro.distributions as dist\n", "\n", - "assert numpyro.__version__.startswith('0.6.0')" + "assert numpyro.__version__.startswith('0.7.0')" ] }, { diff --git a/notebooks/source/ordinal_regression.ipynb b/notebooks/source/ordinal_regression.ipynb index 1787786fe..d0f3a76d2 100644 --- a/notebooks/source/ordinal_regression.ipynb +++ b/notebooks/source/ordinal_regression.ipynb @@ -39,7 +39,7 @@ "from numpyro.infer import MCMC, NUTS\n", "import pandas as pd\n", "import seaborn as sns\n", - "assert numpyro.__version__.startswith('0.6.0')" + "assert numpyro.__version__.startswith('0.7.0')" ] }, { diff --git a/numpyro/contrib/nested_sampling.py b/numpyro/contrib/nested_sampling.py index 6ca9ce7cc..c90590534 100644 --- a/numpyro/contrib/nested_sampling.py +++ b/numpyro/contrib/nested_sampling.py @@ -124,7 +124,8 @@ def __call__(self, name, fn, obs): class NestedSampler: """ - (EXPERIMENTAL) A wrapper for `jaxns`, a nested sampling package based on JAX. + (EXPERIMENTAL) A wrapper for `jaxns `_ , + a nested sampling package based on JAX. See reference [1] for details on the meaning of each parameter. Please consider citing this reference if you use the nested sampler in your research. diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index aff4493f5..f424e0b63 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -136,7 +136,7 @@ def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo="NUTS"): .. warning:: Instead of using this interface directly, we would highly recommend you - to use the higher level :class:`numpyro.infer.MCMC` API instead. + to use the higher level :class:`~numpyro.infer.mcmc.MCMC` API instead. **Example** diff --git a/numpyro/version.py b/numpyro/version.py index 3720decc6..0eba22b6c 100644 --- a/numpyro/version.py +++ b/numpyro/version.py @@ -1,4 +1,4 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -__version__ = "0.6.0" +__version__ = "0.7.0" diff --git a/scripts/update_version.py b/scripts/update_version.py index ed9e38fd9..e3e0bbcbf 100644 --- a/scripts/update_version.py +++ b/scripts/update_version.py @@ -23,12 +23,13 @@ # Update version string. pattern1 = re.compile('assert numpyro.__version__.startswith\\("[^"]*"\\)') pattern2 = re.compile("assert numpyro.__version__.startswith\\('[^']*'\\)") -text = f"assert numpyro.__version__.startswith({new_version})" +text1 = f"assert numpyro.__version__.startswith({new_version})" +text2 = text1.replace('"', "'") for filename in filenames: with open(filename) as f: old_text = f.read() - new_text = pattern1.sub(text, old_text) - new_text = pattern2.sub(text, new_text) + new_text = pattern1.sub(text1, old_text) + new_text = pattern2.sub(text2, new_text) if new_text != old_text: print("updating {}".format(filename)) with open(filename, "w") as f: diff --git a/setup.py b/setup.py index e8ccd3a95..5c63a7c02 100644 --- a/setup.py +++ b/setup.py @@ -62,15 +62,11 @@ "dev": [ "dm-haiku", "flax", - # TODO: bump funsor version before the release - "funsor @ git+https://github.com/pyro-ppl/funsor.git@d5574988665dd822ec64e41f2b54b9dc929959dc", + "funsor==0.4.1", "graphviz", "jaxns==0.0.7", - "optax==0.0.6", - # TODO: change this to tensorflow_probability>0.12.1 when the next version - # of tfp is released. The current release is not compatible with jax>=0.2.12. - # TODO: relax this restriction when we revise tfp wrapper - "tfp-nightly<=0.14.0.dev20210608", + "optax>=0.0.6", + "tensorflow_probability>=0.13", ], "examples": [ "arviz",