diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6cefeb470..b59f3cb9c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: [3.9] steps: - uses: actions/checkout@v2 @@ -51,7 +51,7 @@ jobs: needs: lint strategy: matrix: - python-version: [3.8] + python-version: [3.9] steps: - uses: actions/checkout@v2 @@ -81,7 +81,7 @@ jobs: needs: lint strategy: matrix: - python-version: [3.8] + python-version: [3.9] steps: - uses: actions/checkout@v2 @@ -120,7 +120,7 @@ jobs: needs: lint strategy: matrix: - python-version: [3.8] + python-version: [3.9] steps: - uses: actions/checkout@v2 diff --git a/examples/annotation.py b/examples/annotation.py index a4f064ec2..e2a4fd1ad 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -320,7 +320,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Bayesian Models of Annotation") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/ar2.py b/examples/ar2.py index 38787ea67..93e396f3c 100644 --- a/examples/ar2.py +++ b/examples/ar2.py @@ -114,7 +114,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="AR2 example") parser.add_argument("--num-data", nargs="?", default=142, type=int) parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) diff --git a/examples/baseball.py b/examples/baseball.py index 958665acd..7310741db 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -210,7 +210,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Baseball batting average using MCMC") parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1500, type=int) diff --git a/examples/bnn.py b/examples/bnn.py index 5330ce315..710bb083e 100644 --- a/examples/bnn.py +++ b/examples/bnn.py @@ -160,7 +160,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Bayesian neural network example") parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/covtype.py b/examples/covtype.py index 9ab6a9725..01aba060f 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -206,7 +206,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="parse args") parser.add_argument( "-n", "--num-samples", default=1000, type=int, help="number of samples" diff --git a/examples/funnel.py b/examples/funnel.py index d66298111..d6d5c98a9 100644 --- a/examples/funnel.py +++ b/examples/funnel.py @@ -139,7 +139,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser( description="Non-centered reparameterization example" ) diff --git a/examples/gaussian_shells.py b/examples/gaussian_shells.py index 37bd3fdfc..ef295b391 100644 --- a/examples/gaussian_shells.py +++ b/examples/gaussian_shells.py @@ -120,7 +120,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells") parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/gp.py b/examples/gp.py index d070d92ff..4b69ceeb1 100644 --- a/examples/gp.py +++ b/examples/gp.py @@ -170,7 +170,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/hmm.py b/examples/hmm.py index f068084e5..206504636 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -263,7 +263,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Semi-supervised Hidden Markov Model") parser.add_argument("--num-categories", default=3, type=int) parser.add_argument("--num-words", default=10, type=int) diff --git a/examples/holt_winters.py b/examples/holt_winters.py index a5eff8dfa..334994ab7 100644 --- a/examples/holt_winters.py +++ b/examples/holt_winters.py @@ -180,7 +180,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Holt-Winters") parser.add_argument("--T", nargs="?", default=6, type=int) parser.add_argument("--future", nargs="?", default=1, type=int) diff --git a/examples/horseshoe_regression.py b/examples/horseshoe_regression.py index 60ce31954..e5bf51907 100644 --- a/examples/horseshoe_regression.py +++ b/examples/horseshoe_regression.py @@ -162,7 +162,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Horseshoe regression example") parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/minipyro.py b/examples/minipyro.py index 73d58b2f8..e6d981d6a 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -58,7 +58,7 @@ def body_fn(i, val): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Mini Pyro demo") parser.add_argument("-f", "--full-pyro", action="store_true", default=False) parser.add_argument("-n", "--num-steps", default=1001, type=int) diff --git a/examples/mortality.py b/examples/mortality.py index 4c41961f1..02b2e4ae0 100644 --- a/examples/mortality.py +++ b/examples/mortality.py @@ -220,7 +220,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Mortality regression model") parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int) diff --git a/examples/neutra.py b/examples/neutra.py index f9a65ed32..7afda7086 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -197,7 +197,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="NeuTra HMC") parser.add_argument("-n", "--num-samples", nargs="?", default=4000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/ode.py b/examples/ode.py index 4f79a7afe..a53eaba6a 100644 --- a/examples/ode.py +++ b/examples/ode.py @@ -117,7 +117,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Predator-Prey Model") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/prodlda.py b/examples/prodlda.py index 6f62d4af1..dfddcad8e 100644 --- a/examples/prodlda.py +++ b/examples/prodlda.py @@ -314,7 +314,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser( description="Probabilistic topic modelling with Flax and Haiku" ) diff --git a/examples/proportion_test.py b/examples/proportion_test.py index e93a157ee..b6691db30 100644 --- a/examples/proportion_test.py +++ b/examples/proportion_test.py @@ -19,7 +19,6 @@ import argparse import os -from typing import Tuple from jax import random import jax.numpy as jnp @@ -31,7 +30,7 @@ from numpyro.infer import MCMC, NUTS -def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]: +def make_dataset(rng_key) -> tuple[jnp.ndarray, jnp.ndarray]: """ Make simulated dataset where potential customers who get a sales calls have ~2% higher chance of making another purchase. @@ -160,7 +159,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Testing whether ") parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int) parser.add_argument("--num-warmup", nargs="?", default=1500, type=int) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index 9616a1f41..82c7db5c0 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -384,7 +384,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=500, type=int) diff --git a/examples/stochastic_volatility.py b/examples/stochastic_volatility.py index 6d3ea405b..269d74c91 100644 --- a/examples/stochastic_volatility.py +++ b/examples/stochastic_volatility.py @@ -122,7 +122,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Stochastic Volatility Model") parser.add_argument("-n", "--num-samples", nargs="?", default=600, type=int) parser.add_argument("--num-warmup", nargs="?", default=600, type=int) diff --git a/examples/thompson_sampling.py b/examples/thompson_sampling.py index 8b50e5ac6..ac31392d0 100644 --- a/examples/thompson_sampling.py +++ b/examples/thompson_sampling.py @@ -292,7 +292,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Thompson sampling example") parser.add_argument( "--num-random", nargs="?", default=2, type=int, help="number of random draws" diff --git a/examples/toy_mixture_model_discrete_enumeration.py b/examples/toy_mixture_model_discrete_enumeration.py index feffa4f3c..757711493 100644 --- a/examples/toy_mixture_model_discrete_enumeration.py +++ b/examples/toy_mixture_model_discrete_enumeration.py @@ -126,7 +126,7 @@ def get_true_pred_CPDs(CPD, posterior_param): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="Toy mixture model") parser.add_argument("-n", "--num-steps", default=4000, type=int) parser.add_argument("-o", "--num-obs", default=10000, type=int) diff --git a/examples/ucbadmit.py b/examples/ucbadmit.py index 13d47956a..36ae08d20 100644 --- a/examples/ucbadmit.py +++ b/examples/ucbadmit.py @@ -151,7 +151,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser( description="UCBadmit gender discrimination using HMC" ) diff --git a/examples/vae.py b/examples/vae.py index ddf27c764..0be911b48 100644 --- a/examples/vae.py +++ b/examples/vae.py @@ -160,7 +160,7 @@ def reconstruct_img(epoch, rng_key): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.12.1") + assert numpyro.__version__.startswith("0.13.0") parser = argparse.ArgumentParser(description="parse args") parser.add_argument( "-n", "--num-epochs", default=15, type=int, help="number of training epochs" diff --git a/notebooks/source/bad_posterior_geometry.ipynb b/notebooks/source/bad_posterior_geometry.ipynb index a9ddc2201..fb689a337 100644 --- a/notebooks/source/bad_posterior_geometry.ipynb +++ b/notebooks/source/bad_posterior_geometry.ipynb @@ -50,7 +50,7 @@ "\n", "from numpyro.infer import MCMC, NUTS\n", "\n", - "assert numpyro.__version__.startswith(\"0.12.1\")\n", + "assert numpyro.__version__.startswith(\"0.13.0\")\n", "\n", "# NB: replace cpu by gpu to run this notebook on gpu\n", "numpyro.set_platform(\"cpu\")" diff --git a/notebooks/source/bayesian_hierarchical_linear_regression.ipynb b/notebooks/source/bayesian_hierarchical_linear_regression.ipynb index 11e2241ff..4960d4cc5 100644 --- a/notebooks/source/bayesian_hierarchical_linear_regression.ipynb +++ b/notebooks/source/bayesian_hierarchical_linear_regression.ipynb @@ -245,7 +245,7 @@ "import numpyro.distributions as dist\n", "from jax import random\n", "\n", - "assert numpyro.__version__.startswith(\"0.12.1\")" + "assert numpyro.__version__.startswith(\"0.13.0\")" ] }, { diff --git a/notebooks/source/bayesian_hierarchical_stacking.ipynb b/notebooks/source/bayesian_hierarchical_stacking.ipynb index c38fea40e..165169d11 100644 --- a/notebooks/source/bayesian_hierarchical_stacking.ipynb +++ b/notebooks/source/bayesian_hierarchical_stacking.ipynb @@ -96,7 +96,7 @@ " set_matplotlib_formats(\"svg\")\n", "\n", "numpyro.set_host_device_count(4)\n", - "assert numpyro.__version__.startswith(\"0.12.1\")" + "assert numpyro.__version__.startswith(\"0.13.0\")" ] }, { diff --git a/notebooks/source/bayesian_imputation.ipynb b/notebooks/source/bayesian_imputation.ipynb index c2740754c..52d9a3316 100644 --- a/notebooks/source/bayesian_imputation.ipynb +++ b/notebooks/source/bayesian_imputation.ipynb @@ -55,7 +55,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats(\"svg\")\n", "\n", - "assert numpyro.__version__.startswith(\"0.12.1\")" + "assert numpyro.__version__.startswith(\"0.13.0\")" ] }, { diff --git a/notebooks/source/bayesian_regression.ipynb b/notebooks/source/bayesian_regression.ipynb index 2f0085cc7..8a99c6bb3 100644 --- a/notebooks/source/bayesian_regression.ipynb +++ b/notebooks/source/bayesian_regression.ipynb @@ -95,7 +95,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats(\"svg\")\n", "\n", - "assert numpyro.__version__.startswith(\"0.12.1\")" + "assert numpyro.__version__.startswith(\"0.13.0\")" ], "execution_count": 2, "outputs": [] diff --git a/notebooks/source/gmm.ipynb b/notebooks/source/gmm.ipynb index db1fae94c..4d120663f 100644 --- a/notebooks/source/gmm.ipynb +++ b/notebooks/source/gmm.ipynb @@ -54,7 +54,7 @@ "%matplotlib inline\n", "\n", "smoke_test = \"CI\" in os.environ\n", - "assert numpyro.__version__.startswith(\"0.12.1\")" + "assert numpyro.__version__.startswith(\"0.13.0\")" ] }, { diff --git a/notebooks/source/logistic_regression.ipynb b/notebooks/source/logistic_regression.ipynb index eefbab6f1..3162f46c4 100644 --- a/notebooks/source/logistic_regression.ipynb +++ b/notebooks/source/logistic_regression.ipynb @@ -41,7 +41,7 @@ "from numpyro.examples.datasets import COVTYPE, load_dataset\n", "from numpyro.infer import HMC, MCMC, NUTS\n", "\n", - "assert numpyro.__version__.startswith(\"0.12.1\")\n", + "assert numpyro.__version__.startswith(\"0.13.0\")\n", "\n", "# NB: replace gpu by cpu to run this notebook in cpu\n", "numpyro.set_platform(\"gpu\")" diff --git a/notebooks/source/model_rendering.ipynb b/notebooks/source/model_rendering.ipynb index e08fee41d..18d050937 100644 --- a/notebooks/source/model_rendering.ipynb +++ b/notebooks/source/model_rendering.ipynb @@ -37,7 +37,7 @@ "import numpyro.distributions as dist\n", "import numpyro.distributions.constraints as constraints\n", "\n", - "assert numpyro.__version__.startswith(\"0.12.1\")" + "assert numpyro.__version__.startswith(\"0.13.0\")" ] }, { diff --git a/notebooks/source/ordinal_regression.ipynb b/notebooks/source/ordinal_regression.ipynb index 817f219cd..9dc73c6c0 100644 --- a/notebooks/source/ordinal_regression.ipynb +++ b/notebooks/source/ordinal_regression.ipynb @@ -53,7 +53,7 @@ "import pandas as pd\n", "import seaborn as sns\n", "\n", - "assert numpyro.__version__.startswith(\"0.12.1\")" + "assert numpyro.__version__.startswith(\"0.13.0\")" ] }, { diff --git a/notebooks/source/time_series_forecasting.ipynb b/notebooks/source/time_series_forecasting.ipynb index f876c8b1a..2e214c137 100644 --- a/notebooks/source/time_series_forecasting.ipynb +++ b/notebooks/source/time_series_forecasting.ipynb @@ -48,7 +48,7 @@ " set_matplotlib_formats(\"svg\")\n", "\n", "numpyro.set_host_device_count(4)\n", - "assert numpyro.__version__.startswith(\"0.12.1\")" + "assert numpyro.__version__.startswith(\"0.13.0\")" ] }, { diff --git a/numpyro/contrib/einstein/mixture_guide_predictive.py b/numpyro/contrib/einstein/mixture_guide_predictive.py index 3eca7513b..3d9849224 100644 --- a/numpyro/contrib/einstein/mixture_guide_predictive.py +++ b/numpyro/contrib/einstein/mixture_guide_predictive.py @@ -1,8 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable, Sequence from functools import partial -from typing import Callable, Dict, Optional, Sequence +from typing import Optional from jax import numpy as jnp, random, tree_map, vmap from jax.tree_util import tree_flatten @@ -24,7 +25,7 @@ class MixtureGuidePredictive: :param Callable model: Python callable containing Pyro primitives. :param Callable guide: Python callable containing Pyro primitives to get posterior samples of sites. - :param Dict params: Dictionary of values for param sites of model/guide + :param dict params: Dictionary of values for param sites of model/guide :param Sequence guide_sites: Names of sites that contribute to the Stein mixture. :param Optional[int] num_samples: :param Optional[Sequence[str]] return_sites: Sites to return. By default, only sample sites not present @@ -37,7 +38,7 @@ def __init__( self, model: Callable, guide: Callable, - params: Dict, + params: dict, guide_sites: Sequence, num_samples: Optional[int] = None, return_sites: Optional[Sequence[str]] = None, diff --git a/numpyro/contrib/einstein/stein_kernels.py b/numpyro/contrib/einstein/stein_kernels.py index 983c7d778..d7607e3cb 100644 --- a/numpyro/contrib/einstein/stein_kernels.py +++ b/numpyro/contrib/einstein/stein_kernels.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Tuple +from collections.abc import Callable import numpy as np @@ -30,7 +30,7 @@ def mode(self): def compute( self, particles: jnp.ndarray, - particle_info: Dict[str, Tuple[int, int]], + particle_info: dict[str, tuple[int, int]], loss_fn: Callable[[jnp.ndarray], float], ): """ @@ -279,7 +279,7 @@ class MixtureKernel(SteinKernel): :param kernel_fns: Different kernel functions to mix together """ - def __init__(self, ws: List[float], kernel_fns: List[SteinKernel], mode="norm"): + def __init__(self, ws: list[float], kernel_fns: list[SteinKernel], mode="norm"): assert len(ws) == len(kernel_fns) assert len(kernel_fns) > 1 assert all(kf.mode == mode for kf in kernel_fns) @@ -328,7 +328,7 @@ class GraphicalKernel(SteinKernel): def __init__( self, mode="matrix", - local_kernel_fns: Dict[str, SteinKernel] = None, + local_kernel_fns: dict[str, SteinKernel] = None, default_kernel_fn: SteinKernel = RBFKernel(), ): assert mode == "matrix" @@ -385,7 +385,7 @@ def __init__(self, guide, scale=1.0): def compute( self, particles: jnp.ndarray, - particle_info: Dict[str, Tuple[int, int]], + particle_info: dict[str, tuple[int, int]], loss_fn: Callable[[jnp.ndarray], float], ): loc_idx = jnp.concatenate( diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 6013f2f44..5ff91f03a 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -2,12 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 from collections import namedtuple +from collections.abc import Callable from copy import deepcopy import functools from functools import partial from itertools import chain import operator -from typing import Callable from jax import grad, jacfwd, numpy as jnp, random, vmap from jax.random import KeyArray diff --git a/numpyro/distributions/mixtures.py b/numpyro/distributions/mixtures.py index aee99efcd..f5117d1f3 100644 --- a/numpyro/distributions/mixtures.py +++ b/numpyro/distributions/mixtures.py @@ -350,7 +350,7 @@ def component_distributions(self): """The list of component distributions in the mixture :return: The list of component distributions - :rtype: List[Distribution] + :rtype: list[Distribution] """ return self._component_distributions diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 40fb3f003..eb69244b3 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -11,16 +11,9 @@ import jax from jax import grad, hessian, lax, random -from jax.tree_util import tree_map - -from numpyro.util import _versiontuple, find_stack_level - -if _versiontuple(jax.__version__) >= (0, 2, 25): - from jax.example_libraries import stax -else: - from jax.experimental import stax - +from jax.example_libraries import stax import jax.numpy as jnp +from jax.tree_util import tree_map import numpyro from numpyro import handlers @@ -54,7 +47,7 @@ ) from numpyro.nn.auto_reg_nn import AutoregressiveNN from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN -from numpyro.util import not_jax_tracer +from numpyro.util import find_stack_level, not_jax_tracer __all__ = [ "AutoContinuous", diff --git a/numpyro/infer/inspect.py b/numpyro/infer/inspect.py index cdccecc39..2b97f4021 100644 --- a/numpyro/infer/inspect.py +++ b/numpyro/infer/inspect.py @@ -1,10 +1,11 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable from functools import partial import itertools from pathlib import Path -from typing import Callable, Dict, Optional +from typing import Optional import jax @@ -72,7 +73,7 @@ def get_dependencies( model: Callable, model_args: Optional[tuple] = None, model_kwargs: Optional[dict] = None, -) -> Dict[str, object]: +) -> dict[str, object]: r""" Infers dependency structure about a conditioned model. diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index e20bfabdb..b79d9da56 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -1,21 +1,14 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from functools import namedtuple, partial +from collections import namedtuple +from functools import partial import warnings import tqdm -import jax - -from numpyro.util import _versiontuple, find_stack_level - -if _versiontuple(jax.__version__) >= (0, 2, 25): - from jax.example_libraries import optimizers -else: - from jax.experimental import optimizers # pytype: disable=import-error - from jax import jit, lax, random +from jax.example_libraries import optimizers import jax.numpy as jnp from jax.tree_util import tree_map @@ -24,6 +17,7 @@ from numpyro.handlers import replay, seed, substitute, trace from numpyro.infer.util import helpful_support_errors, transform_fn from numpyro.optim import _NumPyroOptim, optax_to_numpyro +from numpyro.util import find_stack_level SVIState = namedtuple("SVIState", ["optim_state", "mutable_state", "rng_key"]) """ diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 92c19034b..e3d269a6e 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -4,7 +4,7 @@ from collections import namedtuple from contextlib import contextmanager from functools import partial -from typing import Callable, Dict, List, Optional +from typing import Callable, Optional import warnings import numpy as np @@ -896,12 +896,12 @@ def model(X, y=None): def __init__( self, model: Callable, - posterior_samples: Optional[Dict] = None, + posterior_samples: Optional[dict] = None, *, guide: Optional[Callable] = None, - params: Optional[Dict] = None, + params: Optional[dict] = None, num_samples: Optional[int] = None, - return_sites: Optional[List[str]] = None, + return_sites: Optional[list[str]] = None, infer_discrete: bool = False, parallel: bool = False, batch_ndims: Optional[int] = None, diff --git a/numpyro/nn/auto_reg_nn.py b/numpyro/nn/auto_reg_nn.py index c00ebbd3a..727a05802 100644 --- a/numpyro/nn/auto_reg_nn.py +++ b/numpyro/nn/auto_reg_nn.py @@ -5,15 +5,7 @@ import numpy as np -import jax - -from numpyro.util import _versiontuple - -if _versiontuple(jax.__version__) >= (0, 2, 25): - from jax.example_libraries import stax -else: - from jax.experimental import stax - +from jax.example_libraries import stax import jax.numpy as jnp from numpyro.nn.masked_dense import MaskedDense diff --git a/numpyro/nn/block_neural_arn.py b/numpyro/nn/block_neural_arn.py index e893e5f6f..452ae55f1 100644 --- a/numpyro/nn/block_neural_arn.py +++ b/numpyro/nn/block_neural_arn.py @@ -3,16 +3,8 @@ import numpy as np -import jax from jax import random - -from numpyro.util import _versiontuple - -if _versiontuple(jax.__version__) >= (0, 2, 25): - from jax.example_libraries import stax -else: - from jax.experimental import stax - +from jax.example_libraries import stax from jax.nn import sigmoid, softplus from jax.nn.initializers import glorot_uniform, normal, uniform import jax.numpy as jnp diff --git a/numpyro/optim.py b/numpyro/optim.py index 00d08906f..8a3b78149 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -8,18 +8,11 @@ """ from collections import namedtuple -from typing import Any, Callable, Tuple, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar -import jax from jax import lax, value_and_grad - -from numpyro.util import _versiontuple - -if _versiontuple(jax.__version__) >= (0, 2, 25): - from jax.example_libraries import optimizers -else: - from jax.experimental import optimizers # pytype: disable=import-error - +from jax.example_libraries import optimizers from jax.flatten_util import ravel_pytree import jax.numpy as jnp from jax.scipy.optimize import minimize @@ -39,7 +32,7 @@ _Params = TypeVar("_Params") _OptState = TypeVar("_OptState") -_IterOptState = Tuple[int, _OptState] +_IterOptState = tuple[int, _OptState] class _NumPyroOptim(object): @@ -68,7 +61,7 @@ def update(self, g: _Params, state: _IterOptState) -> _IterOptState: opt_state = self.update_fn(i, g, opt_state) return i + 1, opt_state - def eval_and_update(self, fn: Callable[[Any], Tuple], state: _IterOptState): + def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState): """ Performs an optimization step for the objective function `fn`. For most optimizers, the update is performed based on the gradient @@ -87,7 +80,7 @@ def eval_and_update(self, fn: Callable[[Any], Tuple], state: _IterOptState): (out, aux), grads = value_and_grad(fn, has_aux=True)(params) return (out, aux), self.update(grads, state) - def eval_and_stable_update(self, fn: Callable[[Any], Tuple], state: _IterOptState): + def eval_and_stable_update(self, fn: Callable[[Any], tuple], state: _IterOptState): """ Like :meth:`eval_and_update` but when the value of the objective function or the gradients are not finite, we will not update the input `state` @@ -273,7 +266,7 @@ def __init__(self, method="BFGS", **kwargs): self._method = method self._kwargs = kwargs - def eval_and_update(self, fn: Callable[[Any], Tuple], state: _IterOptState): + def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState): i, (flat_params, unravel_fn) = state def loss_fn(x): diff --git a/numpyro/util.py b/numpyro/util.py index 1b402e48b..09a338588 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -717,19 +717,6 @@ def _format_table(rows): ) -def _versiontuple(version): - """ - :param str version: Version, in string format. - Parse version string into tuple of ints. - - Only to be used for the standard 'major.minor.patch' format, - such as ``'0.2.13'``. - - Source: https://stackoverflow.com/a/11887825/4451315 - """ - return tuple([int(number) for number in version.split(".")]) - - def find_stack_level() -> int: """ Find the first place in the stack that is not inside numpyro diff --git a/numpyro/version.py b/numpyro/version.py index b0e0d6069..19e884256 100644 --- a/numpyro/version.py +++ b/numpyro/version.py @@ -1,4 +1,4 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -__version__ = "0.12.1" +__version__ = "0.13.0" diff --git a/setup.py b/setup.py index f840a029b..dd9304fcf 100644 --- a/setup.py +++ b/setup.py @@ -9,8 +9,8 @@ from setuptools import find_packages, setup PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) -_jax_version_constraints = ">=0.4.7" -_jaxlib_version_constraints = ">=0.4.7" +_jax_version_constraints = ">=0.4.14" +_jaxlib_version_constraints = ">=0.4.14" # Find version for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")): @@ -97,7 +97,6 @@ "License :: OSI Approved :: Apache Software License", "Operating System :: POSIX :: Linux", "Operating System :: MacOS :: MacOS X", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 3b641f65f..fe5d77170 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -7,18 +7,10 @@ from numpy.testing import assert_allclose import pytest -import jax from jax import jacobian, jit, lax, random -from jax.tree_util import tree_all, tree_map - -from numpyro.util import _versiontuple - -if _versiontuple(jax.__version__) >= (0, 2, 25): - from jax.example_libraries.stax import Dense -else: - from jax.experimental.stax import Dense - +from jax.example_libraries.stax import Dense import jax.numpy as jnp +from jax.tree_util import tree_all, tree_map import optax from optax import piecewise_constant_schedule diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index e834a9758..f6f196ffe 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -9,16 +9,10 @@ import jax from jax import jit, random, value_and_grad +from jax.example_libraries import optimizers import jax.numpy as jnp from jax.tree_util import tree_all, tree_map -from numpyro.util import _versiontuple - -if _versiontuple(jax.__version__) >= (0, 2, 25): - from jax.example_libraries import optimizers -else: - from jax.experimental import optimizers # pytype: disable=import-error - import numpyro from numpyro import optim import numpyro.distributions as dist diff --git a/test/test_distributions.py b/test/test_distributions.py index 0362adf28..390880cac 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2112,7 +2112,7 @@ def test_beta_proportion_invalid_mean(): (constraints.real, -1, True), ( constraints.real, - np.array([np.inf, np.NINF, np.nan, np.pi]), + np.array([np.inf, -np.inf, np.nan, np.pi]), np.array([False, False, False, True]), ), (constraints.simplex, np.array([0.1, 0.3, 0.6]), True), diff --git a/test/test_flows.py b/test/test_flows.py index f169fefb0..9e65cffbc 100644 --- a/test/test_flows.py +++ b/test/test_flows.py @@ -7,15 +7,8 @@ from numpy.testing import assert_allclose import pytest -import jax from jax import jacfwd, random - -from numpyro.util import _versiontuple - -if _versiontuple(jax.__version__) >= (0, 2, 25): - from jax.example_libraries import stax -else: - from jax.experimental import stax +from jax.example_libraries import stax from numpyro.distributions.flows import ( BlockNeuralAutoregressiveTransform, diff --git a/test/test_nn.py b/test/test_nn.py index e60ee0d87..b6b819128 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7,16 +7,8 @@ from numpy.testing import assert_allclose, assert_array_equal import pytest -import jax from jax import jacfwd, random, vmap - -from numpyro.util import _versiontuple - -if _versiontuple(jax.__version__) >= (0, 2, 25): - from jax.example_libraries.stax import serial -else: - from jax.experimental.stax import serial - +from jax.example_libraries.stax import serial import jax.numpy as jnp from numpyro.distributions.util import matrix_to_tril_vec