Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update link to jax repo #1936

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@

# NumPyro

Probabilistic programming powered by [JAX](https://github.com/google/jax) for autograd and JIT compilation to GPU/TPU/CPU.
Probabilistic programming powered by [JAX](https://github.com/jax-ml/jax) for autograd and JIT compilation to GPU/TPU/CPU.

[Docs and Examples](https://num.pyro.ai) | [Forum](https://forum.pyro.ai/)

----------------------------------------------------------------------------------------------------

## What is NumPyro?

NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for [Pyro](https://github.com/pyro-ppl/pyro). We rely on [JAX](https://github.com/google/jax) for automatic differentiation and JIT compilation to GPU / CPU. NumPyro is under active development, so beware of brittleness, bugs, and changes to the API as the design evolves.
NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for [Pyro](https://github.com/pyro-ppl/pyro). We rely on [JAX](https://github.com/jax-ml/jax) for automatic differentiation and JIT compilation to GPU / CPU. NumPyro is under active development, so beware of brittleness, bugs, and changes to the API as the design evolves.

NumPyro is designed to be *lightweight* and focuses on providing a flexible substrate that users can build on:

- **Pyro Primitives:** NumPyro programs can contain regular Python and NumPy code, in addition to [Pyro primitives](https://pyro.ai/examples/intro_part_i.html) like `sample` and `param`. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy's API. See the [example](https://github.com/pyro-ppl/numpyro#a-simple-example---8-schools) below.
- **Inference algorithms:** NumPyro supports a number of inference algorithms, with a particular focus on MCMC algorithms like Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. Additional MCMC algorithms include [MixedHMC](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.mixed_hmc.MixedHMC) (which can accommodate discrete latent variables) as well as [HMCECS](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCECS) (which only computes the likelihood for subsets of the data in each iteration). One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integrator that includes multiple gradient computations. With JAX, we can compose `jit` and `grad` to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using [Iterative NUTS](https://github.com/pyro-ppl/numpyro/wiki/Iterative-NUTS)). There is also a basic Variational Inference implementation together with many flexible (auto)guides for Automatic Differentiation Variational Inference (ADVI). The variational inference implementation supports a number of features, including support for models with discrete latent variables (see [TraceGraph_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceGraph_ELBO) and [TraceEnum_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceEnum_ELBO)).
- **Distributions:** The [numpyro.distributions](https://numpyro.readthedocs.io/en/latest/distributions.html) module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's [functional pseudo-random number generator](https://github.com/google/jax#random-numbers-are-different). The design of the distributions module largely follows from [PyTorch](https://pytorch.org/docs/stable/distributions.html). A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in `torch.distributions`. In addition to distributions, `constraints` and `transforms` are very useful when operating on distribution classes with bounded support. Finally, distributions from TensorFlow Probability ([TFP](https://num.pyro.ai/en/latest/distributions.html?highlight=tfp#numpyro.contrib.tfp.distributions.TFPDistribution)) can directly be used in NumPyro models.
- **Distributions:** The [numpyro.distributions](https://numpyro.readthedocs.io/en/latest/distributions.html) module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's [functional pseudo-random number generator](https://github.com/jax-ml/jax#random-numbers-are-different). The design of the distributions module largely follows from [PyTorch](https://pytorch.org/docs/stable/distributions.html). A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in `torch.distributions`. In addition to distributions, `constraints` and `transforms` are very useful when operating on distribution classes with bounded support. Finally, distributions from TensorFlow Probability ([TFP](https://num.pyro.ai/en/latest/distributions.html?highlight=tfp#numpyro.contrib.tfp.distributions.TFPDistribution)) can directly be used in NumPyro models.
- **Effect handlers:** Like Pyro, primitives like `sample` and `param` can be provided nonstandard interpretations using effect-handlers from the [numpyro.handlers](https://numpyro.readthedocs.io/en/latest/handlers.html) module, and these can be easily extended to implement custom inference algorithms and inference utilities.

## A Simple Example - 8 Schools
Expand Down Expand Up @@ -228,7 +228,7 @@ See the [docs](https://num.pyro.ai/en/latest/contrib.html#stein-variational-infe

## Installation

> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/google/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) and [this forum post](https://forum.pyro.ai/t/numpyro-with-gpu-works-on-windows/2690) if you want to use GPUs on Windows.
> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/jax-ml/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) and [this forum post](https://forum.pyro.ai/t/numpyro-with-gpu-works-on-windows/2690) if you want to use GPUs on Windows.

To install NumPyro with the latest CPU version of JAX, you can use pip:

Expand All @@ -249,9 +249,9 @@ To use **NumPyro on the GPU**, you need to install CUDA first and then use the f
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/google/jax#pip-installation-gpu-cuda).
If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/jax-ml/jax#pip-installation-gpu-cuda).

To run **NumPyro on Cloud TPUs**, you can look at some [JAX on Cloud TPU examples](https://github.com/google/jax/tree/master/cloud_tpu_colabs).
To run **NumPyro on Cloud TPUs**, you can look at some [JAX on Cloud TPU examples](https://github.com/jax-ml/jax/tree/master/cloud_tpu_colabs).

For Cloud TPU VM, you need to setup the TPU backend as detailed in the [Cloud TPU VM JAX Quickstart Guide](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
After you have verified that the TPU backend is properly set up,
Expand Down Expand Up @@ -310,7 +310,7 @@ conda install -c conda-forge numpyro
- Any `torch` operation in your model will need to be written in terms of the corresponding `jax.numpy` operation. Additionally, not all `torch` operations have a `numpy` counterpart (and vice-versa), and sometimes there are minor differences in the API.
- `pyro.sample` statements outside an inference context will need to be wrapped in a `seed` handler, as mentioned above.
- There is no global parameter store, and as such using `numpyro.param` outside an inference context will have no effect. To retrieve the optimized parameter values from SVI, use the [SVI.get_params](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.svi.SVI.get_params) method. Note that you can still use `param` statements inside a model and NumPyro will use the [substitute](https://num.pyro.ai/en/latest/handlers.html#substitute) effect handler internally to substitute values from the optimizer when running the model in SVI.
- PyTorch neural network modules will need to rewritten as [stax](https://github.com/google/jax#neural-net-building-with-stax), [flax](https://flax.readthedocs.io/en/latest/), or [haiku](https://dm-haiku.readthedocs.io/en/latest/) neural networks. See the [VAE](https://num.pyro.ai/en/latest/examples/vae.html) and [ProdLDA](https://num.pyro.ai/en/stable/examples/prodlda.html) examples for differences in syntax between the two backends.
- PyTorch neural network modules will need to rewritten as [stax](https://github.com/jax-ml/jax#neural-net-building-with-stax), [flax](https://flax.readthedocs.io/en/latest/), or [haiku](https://dm-haiku.readthedocs.io/en/latest/) neural networks. See the [VAE](https://num.pyro.ai/en/latest/examples/vae.html) and [ProdLDA](https://num.pyro.ai/en/stable/examples/prodlda.html) examples for differences in syntax between the two backends.
- JAX works best with functional code, particularly if we would like to leverage JIT compilation, which NumPyro does internally for many inference subroutines. As such, if your model has side-effects that are not visible to the JAX tracer, it may need to rewritten in a more functional style.

For most small models, changes required to run inference in NumPyro should be minor. Additionally, we are working on [pyro-api](https://github.com/pyro-ppl/pyro-api) which allows you to write the same code and dispatch it to multiple backends, including NumPyro. This will necessarily be more restrictive, but has the advantage of being backend agnostic. See the [documentation](https://pyro-api.readthedocs.io/en/latest/dispatch.html#module-pyroapi.dispatch) for an example, and let us know your feedback.
Expand Down
2 changes: 1 addition & 1 deletion docker/release/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ ENV PATH=/root/.local/bin:$PATH

# install python packages via pip
RUN pip3 install --user \
# we pull wheels from google's api as per https://github.com/google/jax#installation
# we pull wheels from google's api as per https://github.com/jax-ml/jax#installation
# the pre-compiled wheels that google provides work for now. This may change in the future (and necessitate building from source)
numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2 changes: 1 addition & 1 deletion examples/stein_dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def combiner(x, params):


def gru(xs, lengths, init_hidden, params):
"""RNN with GRU. Based on https://github.com/google/jax/pull/2298"""
"""RNN with GRU. Based on https://github.com/jax-ml/jax/pull/2298"""

def apply_fun_single(state, inputs):
i, x = inputs
Expand Down
8 changes: 4 additions & 4 deletions notebooks/source/bayesian_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1323,7 +1323,7 @@
"\n",
"Note the following:\n",
"\n",
" - JAX uses functional PRNGs. Unlike other languages / frameworks which maintain a global random state, in JAX, every call to a sampler requires an [explicit PRNGKey](https://github.com/google/jax#random-numbers-are-different). We will split our initial random seed for subsequent operations, so that we do not accidentally reuse the same seed.\n",
" - JAX uses functional PRNGs. Unlike other languages / frameworks which maintain a global random state, in JAX, every call to a sampler requires an [explicit PRNGKey](https://github.com/jax-ml/jax#random-numbers-are-different). We will split our initial random seed for subsequent operations, so that we do not accidentally reuse the same seed.\n",
" - We run inference with the `NUTS` sampler. To run vanilla HMC, we can instead use the [HMC](https://numpyro.readthedocs.io/en/latest/mcmc.html#numpyro.mcmc.HMC) class."
]
},
Expand Down Expand Up @@ -1626,7 +1626,7 @@
"source": [
"#### Predictive Utility With Effect Handlers\n",
"\n",
"To remove the magic behind `Predictive`, let us see how we can combine [effect handlers](https://numpyro.readthedocs.io/en/latest/handlers.html) with the [vmap](https://github.com/google/jax#auto-vectorization-with-vmap) JAX primitive to implement our own simplified predictive utility function that can do vectorized predictions."
"To remove the magic behind `Predictive`, let us see how we can combine [effect handlers](https://numpyro.readthedocs.io/en/latest/handlers.html) with the [vmap](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) JAX primitive to implement our own simplified predictive utility function that can do vectorized predictions."
]
},
{
Expand Down Expand Up @@ -1663,7 +1663,7 @@
" - The `condition` effect handler conditions the latent sample sites to certain values. In our case, we are conditioning on values from the posterior distribution returned by MCMC.\n",
" - The `trace` effect handler runs the model and records the execution trace within an `OrderedDict`. This trace object contains execution metadata that is useful for computing quantities such as the log joint density.\n",
" \n",
"It should be clear now that the `predict` function simply runs the model by substituting the latent parameters with samples from the posterior (generated by the `mcmc` function) to generate predictions. Note the use of JAX's auto-vectorization transform called [vmap](https://github.com/google/jax#auto-vectorization-with-vmap) to vectorize predictions. Note that if we didn't use `vmap`, we would have to use a native for loop which for each sample which is much slower. Each draw from the posterior can be used to get predictions over all the 50 states. When we vectorize this over all the samples from the posterior using `vmap`, we will get a `predictions_1` array of shape `(num_samples, 50)`. We can then compute the mean and 90% CI of these samples to plot the posterior predictive distribution. We note that our mean predictions match those obtained from the `Predictive` utility class."
"It should be clear now that the `predict` function simply runs the model by substituting the latent parameters with samples from the posterior (generated by the `mcmc` function) to generate predictions. Note the use of JAX's auto-vectorization transform called [vmap](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to vectorize predictions. Note that if we didn't use `vmap`, we would have to use a native for loop which for each sample which is much slower. Each draw from the posterior can be used to get predictions over all the 50 states. When we vectorize this over all the samples from the posterior using `vmap`, we will get a `predictions_1` array of shape `(num_samples, 50)`. We can then compute the mean and 90% CI of these samples to plot the posterior predictive distribution. We note that our mean predictions match those obtained from the `Predictive` utility class."
]
},
{
Expand Down Expand Up @@ -2626,7 +2626,7 @@
"4. Pyro Development Team. [Poutine: A Guide to Programming with Effect Handlers in Pyro](http://pyro.ai/examples/effect_handlers.html)\n",
"5. Hoffman, M.D., Gelman, A. (2011). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo.\n",
"6. Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo.\n",
"7. JAX Development Team (2018). [Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more](https://github.com/google/jax)\n",
"7. JAX Development Team (2018). [Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more](https://github.com/jax-ml/jax)\n",
"8. Gelman, A., Hwang, J., and Vehtari A. [Understanding predictive information criteria for Bayesian models](https://arxiv.org/pdf/1307.5928.pdf)"
]
}
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/hsgp_nd_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@
"metadata": {},
"outputs": [],
"source": [
"@jax.tree_util.register_pytree_node_class # https://github.com/google/jax/discussions/16020\n",
"@jax.tree_util.register_pytree_node_class # https://github.com/jax-ml/jax/discussions/16020\n",
"class GPModel:\n",
" \"\"\"Exact GP model with a squared exponential kernel.\"\"\"\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that `level` and `s` are updated recursively while we collect the expected value at each time step. NumPyro uses [JAX](https://github.com/google/jax) in the backend to JIT compile many critical parts of the NUTS algorithm, including the verlet integrator and the tree building process. However, doing so using Python's `for` loop in the model will result in a long compilation time for the model, so we use `scan` - which is a wrapper of [lax.scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) with supports for NumPyro primitives and handlers. A detailed explanation for using this utility can be found in [NumPyro documentation](http://num.pyro.ai/en/latest/primitives.html#scan). Here we use it to collect `y` values while the triple `(level, s, moving_sum)` plays the role of carrying state."
"Note that `level` and `s` are updated recursively while we collect the expected value at each time step. NumPyro uses [JAX](https://github.com/jax-ml/jax) in the backend to JIT compile many critical parts of the NUTS algorithm, including the verlet integrator and the tree building process. However, doing so using Python's `for` loop in the model will result in a long compilation time for the model, so we use `scan` - which is a wrapper of [lax.scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) with supports for NumPyro primitives and handlers. A detailed explanation for using this utility can be found in [NumPyro documentation](http://num.pyro.ai/en/latest/primitives.html#scan). Here we use it to collect `y` values while the triple `(level, s, moving_sum)` plays the role of carrying state."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion numpyro/contrib/nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def transform_fn(q):
# NB: icdf is not available yet for Gamma distribution
# so this will raise an NotImplementedError for now.
# We will need scipy.special.gammaincinv, which is not available yet in JAX
# see issue: https://github.com/google/jax/issues/5350
# see issue: https://github.com/jax-ml/jax/issues/5350
# TODO: consider wrap jaxns GammaPrior transform implementation
gammas = uniform_reparam_transform(gamma_dist)(q)
return gammas / gammas.sum(-1, keepdims=True)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class Distribution(metaclass=DistributionMeta):
pytree_aux_fields = ("_batch_shape", "_event_shape")

# register Distribution as a pytree
# ref: https://github.com/google/jax/issues/2916
# ref: https://github.com/jax-ml/jax/issues/2916
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
tree_util.register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def set_host_device_count(n: int) -> None:
`xla_force_host_platform_device_count` flag in XLA is incomplete. If you
observe some strange phenomenon when using this utility, please let us
know through our issue or forum page. More information is available in this
`JAX issue <https://github.com/google/jax/issues/1408>`_.
`JAX issue <https://github.com/jax-ml/jax/issues/1408>`_.

:param int n: number of CPU devices to use.
"""
Expand Down
Loading
Loading