Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated SteinVI docs #1898

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ Nested Sampling is a non-MCMC approach that works for arbitrary probability mode

Stein Variational Inference
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Stein Variational Inference (SteinVI) is a family of VI techniques for approximate Bayesian inference based on
Stein variational inference (SteinVI) is a family of VI techniques for approximate Bayesian inference based on
Stein’s method (see [1] for an overview). It is gaining popularity as it combines
the scalability of traditional VI with the flexibility of non-parametric particle-based methods.

Stein variational gradient descent (SVGD) [2] is a recent SteinVI technique which uses iteratively moves a set of
particles :math:`\{z_i\}_{i=1}^N` to approximate a distribution p(z).
SVGD is well suited for capturing correlations between latent variables as a particle-based method.
The technique preserves the scalability of traditional VI approaches while offering the flexibility and modeling scope
of methods such as Markov chain Monte Carlo (MCMC). SVGD is good at capturing multi-modality [3][4].
of methods such as Markov chain Monte Carlo (MCMC). SVGD is good at capturing multi-modality [3].

``numpyro.contrib.einstein`` is a framework for particle-based inference using the Stein mixture algorithm.
``numpyro.contrib.einstein`` is a framework for particle-based inference using the Stein mixture inference algorithm [4].
The framework works on Stein mixtures, a restricted mixture of guide programs parameterized by Stein particles.
Similarly to how SVGD works, Stein mixtures can approximate model posteriors by moving the Stein particles according
to the Stein forces. Because the Stein particles parameterize a guide, they capture a neighborhood rather than a
Expand All @@ -44,27 +44,33 @@ The framework currently supports several kernels, including:
- `MixtureKernel`
- `GraphicalKernel`

For example, usage see:
SteinVI based examples include:

- The `Bayesian neural network example <https://num.pyro.ai/en/latest/examples/stein_bnn.html>`_.
- The `deep Markov example <https://num.pyro.ai/en/latest/examples/stein_dmm.html>`_.

**References**

1. *Stein's Method Meets Statistics: A Review of Some Recent Developments* (2021)
Andreas Anastasiou, Alessandro Barp, François-Xavier Briol, Bruno Ebner,
Robert E. Gaunt, Fatemeh Ghaderinezhad, Jackson Gorham, Arthur Gretton,
Christophe Ley, Qiang Liu, Lester Mackey, Chris. J. Oates, Gesine Reinert,
Yvik Swan.
1. *Stein's Method Meets Statistics: A Review of Some Recent Developments.* 2021.
Andreas Anastasiou, Alessandro Barp, François-Xavier Briol, Bruno Ebner,
Robert E. Gaunt, Fatemeh Ghaderinezhad, Jackson Gorham, Arthur Gretton,
Christophe Ley, Qiang Liu, Lester Mackey, Chris. J. Oates, Gesine Reinert,
Yvik Swan.

2. *Stein Variational Gradient Descent: A General-Purpose Bayesian Inference Algorithm* (2016)
Qiang Liu, Dilin Wang. NeurIPS
2. *Stein Variational Gradient Descent: A General-Purpose Bayesian Inference Algorithm.* 2016.
Qiang Liu, Dilin Wang. NeurIPS

3. *Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models* (2019)
Dilin Wang, Qiang Liu. PMLR
3. *Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models.* 2019.
Dilin Wang, Qiang Liu. PMLR

4. *ELBOing Stein: Variational Bayes with Stein Mixture Inference.* 2024.
Ola Rønning, Eric Nalisnick, Christophe Ley, Padhraic Smyth, and Thomas Hamelryck. arXiv:2410.22948.

SteinVI Interface
-----------------
.. autoclass:: numpyro.contrib.einstein.steinvi.SteinVI
.. autoclass:: numpyro.contrib.einstein.steinvi.SVGD
.. autoclass:: numpyro.contrib.einstein.steinvi.ASVGD

SteinVI Kernels
---------------
Expand Down
20 changes: 11 additions & 9 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ class SVGD(SteinVI):
>>> svgd_result = svgd.run(random.PRNGKey(0), 200, data)

>>> params = svgd_result.params
>>> predictive = Predictive(model, svgd.guide, params, num_samples=10, batch_ndims=1)
>>> predictive = Predictive(model, guide=svgd.guide, params=params, num_samples=10, batch_ndims=1)
>>> samples = predictive(random.PRNGKey(1), data=None)

:param Callable model: Python callable with NumPyro primitives for the model.
Expand All @@ -581,8 +581,9 @@ class SVGD(SteinVI):
kernel selection are not well understood yet.
:param num_stein_particles: Number of particles (i.e., mixture components) in the mixture approximation.
Default is 10.
:param Dict guide_kwargs: Keyword arguments for `~numpyro.infer.autoguide.AutoDelta`.
Default behaviour is the same as the default for `~numpyro.infer.autoguide.AutoDelta`.
:param Dict guide_kwargs: Keyword arguments for :class:`~numpyro.infer.autoguide.AutoDelta`.
Default behaviour is the same as the default for :class:`~numpyro.infer.autoguide.AutoDelta`.

Usage::

opt = Adagrad(step_size=0.05)
Expand Down Expand Up @@ -639,7 +640,7 @@ class ASVGD(SVGD):
>>> from numpyro.distributions.constraints import positive

>>> from numpyro.optim import Adagrad
>>> from numpyro.contrib.einstein import SVGD, RBFKernel
>>> from numpyro.contrib.einstein import ASVGD, RBFKernel
>>> from numpyro.infer import Predictive

>>> def model(data):
Expand All @@ -657,7 +658,7 @@ class ASVGD(SVGD):
>>> asvgd_result = asvgd.run(random.PRNGKey(0), 200, data)

>>> params = asvgd_result.params
>>> predictive = Predictive(model, asvgd.guide, params, num_samples=10, batch_ndims=1)
>>> predictive = Predictive(model, guide=asvgd.guide, params=params, num_samples=10, batch_ndims=1)
>>> samples = predictive(random.PRNGKey(1), data=None)

:param Callable model: Python callable with NumPyro primitives for the model.
Expand All @@ -669,12 +670,13 @@ class ASVGD(SVGD):
This may change as criteria for kernel selection are not well understood yet.
:param num_stein_particles: Number of particles (i.e., mixture components) in the mixture approximation.
Default is `10`.
:param num_cycles: The total number of cycles during inference. This corresponds to $C$ in eq. 4 of [1].
:param num_cycles: The total number of cycles during inference. This corresponds to :math:`C` in eq. 4 of [1].
Default is `10`.
:param trans_speed: Speed of transition between two phases during inference. This corresponds to $p$ in eq. 4
:param trans_speed: Speed of transition between two phases during inference. This corresponds to :math:`p` in eq. 4
of [1]. Default is `10`.
:param Dict guide_kwargs: Keyword arguments for `~numpyro.infer.autoguide.AutoDelta`.
Default behaviour is the same as the default for `~numpyro.infer.autoguide.AutoDelta`.
:param Dict guide_kwargs: Keyword arguments for :class:`~numpyro.infer.autoguide.AutoDelta`.
Default behaviour is the same as the default for :class:`~numpyro.infer.autoguide.AutoDelta`.

Usage::

opt = Adagrad(step_size=0.05)
Expand Down
Loading