From 788b1cccfbf78e6a3a6820c20cffada4fe5ad9ce Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 12 Dec 2024 14:01:38 -0500 Subject: [PATCH] Update link to jax repo (#1936) --- README.md | 14 +++++++------- docker/release/Dockerfile | 2 +- examples/stein_dmm.py | 2 +- notebooks/source/bayesian_regression.ipynb | 8 ++++---- notebooks/source/hsgp_nd_example.ipynb | 2 +- notebooks/source/time_series_forecasting.ipynb | 2 +- numpyro/contrib/nested_sampling.py | 2 +- numpyro/distributions/distribution.py | 2 +- numpyro/util.py | 2 +- test/test_distributions.py | 2 +- 10 files changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 8673f1556..713636107 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ # NumPyro -Probabilistic programming powered by [JAX](https://github.com/google/jax) for autograd and JIT compilation to GPU/TPU/CPU. +Probabilistic programming powered by [JAX](https://github.com/jax-ml/jax) for autograd and JIT compilation to GPU/TPU/CPU. [Docs and Examples](https://num.pyro.ai) | [Forum](https://forum.pyro.ai/) @@ -14,13 +14,13 @@ Probabilistic programming powered by [JAX](https://github.com/google/jax) for au ## What is NumPyro? -NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for [Pyro](https://github.com/pyro-ppl/pyro). We rely on [JAX](https://github.com/google/jax) for automatic differentiation and JIT compilation to GPU / CPU. NumPyro is under active development, so beware of brittleness, bugs, and changes to the API as the design evolves. +NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for [Pyro](https://github.com/pyro-ppl/pyro). We rely on [JAX](https://github.com/jax-ml/jax) for automatic differentiation and JIT compilation to GPU / CPU. NumPyro is under active development, so beware of brittleness, bugs, and changes to the API as the design evolves. NumPyro is designed to be *lightweight* and focuses on providing a flexible substrate that users can build on: - **Pyro Primitives:** NumPyro programs can contain regular Python and NumPy code, in addition to [Pyro primitives](https://pyro.ai/examples/intro_part_i.html) like `sample` and `param`. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy's API. See the [example](https://github.com/pyro-ppl/numpyro#a-simple-example---8-schools) below. - **Inference algorithms:** NumPyro supports a number of inference algorithms, with a particular focus on MCMC algorithms like Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. Additional MCMC algorithms include [MixedHMC](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.mixed_hmc.MixedHMC) (which can accommodate discrete latent variables) as well as [HMCECS](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCECS) (which only computes the likelihood for subsets of the data in each iteration). One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integrator that includes multiple gradient computations. With JAX, we can compose `jit` and `grad` to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using [Iterative NUTS](https://github.com/pyro-ppl/numpyro/wiki/Iterative-NUTS)). There is also a basic Variational Inference implementation together with many flexible (auto)guides for Automatic Differentiation Variational Inference (ADVI). The variational inference implementation supports a number of features, including support for models with discrete latent variables (see [TraceGraph_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceGraph_ELBO) and [TraceEnum_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceEnum_ELBO)). -- **Distributions:** The [numpyro.distributions](https://numpyro.readthedocs.io/en/latest/distributions.html) module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's [functional pseudo-random number generator](https://github.com/google/jax#random-numbers-are-different). The design of the distributions module largely follows from [PyTorch](https://pytorch.org/docs/stable/distributions.html). A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in `torch.distributions`. In addition to distributions, `constraints` and `transforms` are very useful when operating on distribution classes with bounded support. Finally, distributions from TensorFlow Probability ([TFP](https://num.pyro.ai/en/latest/distributions.html?highlight=tfp#numpyro.contrib.tfp.distributions.TFPDistribution)) can directly be used in NumPyro models. +- **Distributions:** The [numpyro.distributions](https://numpyro.readthedocs.io/en/latest/distributions.html) module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's [functional pseudo-random number generator](https://github.com/jax-ml/jax#random-numbers-are-different). The design of the distributions module largely follows from [PyTorch](https://pytorch.org/docs/stable/distributions.html). A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in `torch.distributions`. In addition to distributions, `constraints` and `transforms` are very useful when operating on distribution classes with bounded support. Finally, distributions from TensorFlow Probability ([TFP](https://num.pyro.ai/en/latest/distributions.html?highlight=tfp#numpyro.contrib.tfp.distributions.TFPDistribution)) can directly be used in NumPyro models. - **Effect handlers:** Like Pyro, primitives like `sample` and `param` can be provided nonstandard interpretations using effect-handlers from the [numpyro.handlers](https://numpyro.readthedocs.io/en/latest/handlers.html) module, and these can be easily extended to implement custom inference algorithms and inference utilities. ## A Simple Example - 8 Schools @@ -228,7 +228,7 @@ See the [docs](https://num.pyro.ai/en/latest/contrib.html#stein-variational-infe ## Installation -> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/google/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) and [this forum post](https://forum.pyro.ai/t/numpyro-with-gpu-works-on-windows/2690) if you want to use GPUs on Windows. +> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/jax-ml/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) and [this forum post](https://forum.pyro.ai/t/numpyro-with-gpu-works-on-windows/2690) if you want to use GPUs on Windows. To install NumPyro with the latest CPU version of JAX, you can use pip: @@ -249,9 +249,9 @@ To use **NumPyro on the GPU**, you need to install CUDA first and then use the f pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` -If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/google/jax#pip-installation-gpu-cuda). +If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/jax-ml/jax#pip-installation-gpu-cuda). -To run **NumPyro on Cloud TPUs**, you can look at some [JAX on Cloud TPU examples](https://github.com/google/jax/tree/master/cloud_tpu_colabs). +To run **NumPyro on Cloud TPUs**, you can look at some [JAX on Cloud TPU examples](https://github.com/jax-ml/jax/tree/master/cloud_tpu_colabs). For Cloud TPU VM, you need to setup the TPU backend as detailed in the [Cloud TPU VM JAX Quickstart Guide](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). After you have verified that the TPU backend is properly set up, @@ -310,7 +310,7 @@ conda install -c conda-forge numpyro - Any `torch` operation in your model will need to be written in terms of the corresponding `jax.numpy` operation. Additionally, not all `torch` operations have a `numpy` counterpart (and vice-versa), and sometimes there are minor differences in the API. - `pyro.sample` statements outside an inference context will need to be wrapped in a `seed` handler, as mentioned above. - There is no global parameter store, and as such using `numpyro.param` outside an inference context will have no effect. To retrieve the optimized parameter values from SVI, use the [SVI.get_params](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.svi.SVI.get_params) method. Note that you can still use `param` statements inside a model and NumPyro will use the [substitute](https://num.pyro.ai/en/latest/handlers.html#substitute) effect handler internally to substitute values from the optimizer when running the model in SVI. - - PyTorch neural network modules will need to rewritten as [stax](https://github.com/google/jax#neural-net-building-with-stax), [flax](https://flax.readthedocs.io/en/latest/), or [haiku](https://dm-haiku.readthedocs.io/en/latest/) neural networks. See the [VAE](https://num.pyro.ai/en/latest/examples/vae.html) and [ProdLDA](https://num.pyro.ai/en/stable/examples/prodlda.html) examples for differences in syntax between the two backends. + - PyTorch neural network modules will need to rewritten as [stax](https://github.com/jax-ml/jax#neural-net-building-with-stax), [flax](https://flax.readthedocs.io/en/latest/), or [haiku](https://dm-haiku.readthedocs.io/en/latest/) neural networks. See the [VAE](https://num.pyro.ai/en/latest/examples/vae.html) and [ProdLDA](https://num.pyro.ai/en/stable/examples/prodlda.html) examples for differences in syntax between the two backends. - JAX works best with functional code, particularly if we would like to leverage JIT compilation, which NumPyro does internally for many inference subroutines. As such, if your model has side-effects that are not visible to the JAX tracer, it may need to rewritten in a more functional style. For most small models, changes required to run inference in NumPyro should be minor. Additionally, we are working on [pyro-api](https://github.com/pyro-ppl/pyro-api) which allows you to write the same code and dispatch it to multiple backends, including NumPyro. This will necessarily be more restrictive, but has the advantage of being backend agnostic. See the [documentation](https://pyro-api.readthedocs.io/en/latest/dispatch.html#module-pyroapi.dispatch) for an example, and let us know your feedback. diff --git a/docker/release/Dockerfile b/docker/release/Dockerfile index 63d53bf84..73755e96c 100644 --- a/docker/release/Dockerfile +++ b/docker/release/Dockerfile @@ -18,6 +18,6 @@ ENV PATH=/root/.local/bin:$PATH # install python packages via pip RUN pip3 install --user \ - # we pull wheels from google's api as per https://github.com/google/jax#installation + # we pull wheels from google's api as per https://github.com/jax-ml/jax#installation # the pre-compiled wheels that google provides work for now. This may change in the future (and necessitate building from source) numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html diff --git a/examples/stein_dmm.py b/examples/stein_dmm.py index 2daeb055b..716b677e2 100644 --- a/examples/stein_dmm.py +++ b/examples/stein_dmm.py @@ -96,7 +96,7 @@ def combiner(x, params): def gru(xs, lengths, init_hidden, params): - """RNN with GRU. Based on https://github.com/google/jax/pull/2298""" + """RNN with GRU. Based on https://github.com/jax-ml/jax/pull/2298""" def apply_fun_single(state, inputs): i, x = inputs diff --git a/notebooks/source/bayesian_regression.ipynb b/notebooks/source/bayesian_regression.ipynb index 07e4bb552..4f634a9a9 100644 --- a/notebooks/source/bayesian_regression.ipynb +++ b/notebooks/source/bayesian_regression.ipynb @@ -1323,7 +1323,7 @@ "\n", "Note the following:\n", "\n", - " - JAX uses functional PRNGs. Unlike other languages / frameworks which maintain a global random state, in JAX, every call to a sampler requires an [explicit PRNGKey](https://github.com/google/jax#random-numbers-are-different). We will split our initial random seed for subsequent operations, so that we do not accidentally reuse the same seed.\n", + " - JAX uses functional PRNGs. Unlike other languages / frameworks which maintain a global random state, in JAX, every call to a sampler requires an [explicit PRNGKey](https://github.com/jax-ml/jax#random-numbers-are-different). We will split our initial random seed for subsequent operations, so that we do not accidentally reuse the same seed.\n", " - We run inference with the `NUTS` sampler. To run vanilla HMC, we can instead use the [HMC](https://numpyro.readthedocs.io/en/latest/mcmc.html#numpyro.mcmc.HMC) class." ] }, @@ -1626,7 +1626,7 @@ "source": [ "#### Predictive Utility With Effect Handlers\n", "\n", - "To remove the magic behind `Predictive`, let us see how we can combine [effect handlers](https://numpyro.readthedocs.io/en/latest/handlers.html) with the [vmap](https://github.com/google/jax#auto-vectorization-with-vmap) JAX primitive to implement our own simplified predictive utility function that can do vectorized predictions." + "To remove the magic behind `Predictive`, let us see how we can combine [effect handlers](https://numpyro.readthedocs.io/en/latest/handlers.html) with the [vmap](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) JAX primitive to implement our own simplified predictive utility function that can do vectorized predictions." ] }, { @@ -1663,7 +1663,7 @@ " - The `condition` effect handler conditions the latent sample sites to certain values. In our case, we are conditioning on values from the posterior distribution returned by MCMC.\n", " - The `trace` effect handler runs the model and records the execution trace within an `OrderedDict`. This trace object contains execution metadata that is useful for computing quantities such as the log joint density.\n", " \n", - "It should be clear now that the `predict` function simply runs the model by substituting the latent parameters with samples from the posterior (generated by the `mcmc` function) to generate predictions. Note the use of JAX's auto-vectorization transform called [vmap](https://github.com/google/jax#auto-vectorization-with-vmap) to vectorize predictions. Note that if we didn't use `vmap`, we would have to use a native for loop which for each sample which is much slower. Each draw from the posterior can be used to get predictions over all the 50 states. When we vectorize this over all the samples from the posterior using `vmap`, we will get a `predictions_1` array of shape `(num_samples, 50)`. We can then compute the mean and 90% CI of these samples to plot the posterior predictive distribution. We note that our mean predictions match those obtained from the `Predictive` utility class." + "It should be clear now that the `predict` function simply runs the model by substituting the latent parameters with samples from the posterior (generated by the `mcmc` function) to generate predictions. Note the use of JAX's auto-vectorization transform called [vmap](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to vectorize predictions. Note that if we didn't use `vmap`, we would have to use a native for loop which for each sample which is much slower. Each draw from the posterior can be used to get predictions over all the 50 states. When we vectorize this over all the samples from the posterior using `vmap`, we will get a `predictions_1` array of shape `(num_samples, 50)`. We can then compute the mean and 90% CI of these samples to plot the posterior predictive distribution. We note that our mean predictions match those obtained from the `Predictive` utility class." ] }, { @@ -2626,7 +2626,7 @@ "4. Pyro Development Team. [Poutine: A Guide to Programming with Effect Handlers in Pyro](http://pyro.ai/examples/effect_handlers.html)\n", "5. Hoffman, M.D., Gelman, A. (2011). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo.\n", "6. Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo.\n", - "7. JAX Development Team (2018). [Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more](https://github.com/google/jax)\n", + "7. JAX Development Team (2018). [Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more](https://github.com/jax-ml/jax)\n", "8. Gelman, A., Hwang, J., and Vehtari A. [Understanding predictive information criteria for Bayesian models](https://arxiv.org/pdf/1307.5928.pdf)" ] } diff --git a/notebooks/source/hsgp_nd_example.ipynb b/notebooks/source/hsgp_nd_example.ipynb index a122ae00d..2464e5f3b 100644 --- a/notebooks/source/hsgp_nd_example.ipynb +++ b/notebooks/source/hsgp_nd_example.ipynb @@ -441,7 +441,7 @@ "metadata": {}, "outputs": [], "source": [ - "@jax.tree_util.register_pytree_node_class # https://github.com/google/jax/discussions/16020\n", + "@jax.tree_util.register_pytree_node_class # https://github.com/jax-ml/jax/discussions/16020\n", "class GPModel:\n", " \"\"\"Exact GP model with a squared exponential kernel.\"\"\"\n", "\n", diff --git a/notebooks/source/time_series_forecasting.ipynb b/notebooks/source/time_series_forecasting.ipynb index cacf8112f..90cee1ed7 100644 --- a/notebooks/source/time_series_forecasting.ipynb +++ b/notebooks/source/time_series_forecasting.ipynb @@ -243,7 +243,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note that `level` and `s` are updated recursively while we collect the expected value at each time step. NumPyro uses [JAX](https://github.com/google/jax) in the backend to JIT compile many critical parts of the NUTS algorithm, including the verlet integrator and the tree building process. However, doing so using Python's `for` loop in the model will result in a long compilation time for the model, so we use `scan` - which is a wrapper of [lax.scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) with supports for NumPyro primitives and handlers. A detailed explanation for using this utility can be found in [NumPyro documentation](http://num.pyro.ai/en/latest/primitives.html#scan). Here we use it to collect `y` values while the triple `(level, s, moving_sum)` plays the role of carrying state." + "Note that `level` and `s` are updated recursively while we collect the expected value at each time step. NumPyro uses [JAX](https://github.com/jax-ml/jax) in the backend to JIT compile many critical parts of the NUTS algorithm, including the verlet integrator and the tree building process. However, doing so using Python's `for` loop in the model will result in a long compilation time for the model, so we use `scan` - which is a wrapper of [lax.scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) with supports for NumPyro primitives and handlers. A detailed explanation for using this utility can be found in [NumPyro documentation](http://num.pyro.ai/en/latest/primitives.html#scan). Here we use it to collect `y` values while the triple `(level, s, moving_sum)` plays the role of carrying state." ] }, { diff --git a/numpyro/contrib/nested_sampling.py b/numpyro/contrib/nested_sampling.py index be42f75b2..de953aa9a 100644 --- a/numpyro/contrib/nested_sampling.py +++ b/numpyro/contrib/nested_sampling.py @@ -90,7 +90,7 @@ def transform_fn(q): # NB: icdf is not available yet for Gamma distribution # so this will raise an NotImplementedError for now. # We will need scipy.special.gammaincinv, which is not available yet in JAX - # see issue: https://github.com/google/jax/issues/5350 + # see issue: https://github.com/jax-ml/jax/issues/5350 # TODO: consider wrap jaxns GammaPrior transform implementation gammas = uniform_reparam_transform(gamma_dist)(q) return gammas / gammas.sum(-1, keepdims=True) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index b0e6ae431..c975f865d 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -138,7 +138,7 @@ class Distribution(metaclass=DistributionMeta): pytree_aux_fields = ("_batch_shape", "_event_shape") # register Distribution as a pytree - # ref: https://github.com/google/jax/issues/2916 + # ref: https://github.com/jax-ml/jax/issues/2916 def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) tree_util.register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) diff --git a/numpyro/util.py b/numpyro/util.py index 3c710f41a..ad09cfc16 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -76,7 +76,7 @@ def set_host_device_count(n: int) -> None: `xla_force_host_platform_device_count` flag in XLA is incomplete. If you observe some strange phenomenon when using this utility, please let us know through our issue or forum page. More information is available in this - `JAX issue `_. + `JAX issue `_. :param int n: number of CPU devices to use. """ diff --git a/test/test_distributions.py b/test/test_distributions.py index 87b81219c..61dd23182 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3436,7 +3436,7 @@ def _assert_not_jax_issue_19885( capfd: pytest.CaptureFixture, func: Callable, *args, **kwargs ) -> None: # jit-ing identity plus matrix multiplication leads to performance degradation as - # discussed in https://github.com/google/jax/issues/19885. This assertion verifies + # discussed in https://github.com/jax-ml/jax/issues/19885. This assertion verifies # that the issue does not affect performance in numpyro. for jit in [True, False]: result = jax.jit(func)(*args, **kwargs)