diff --git a/docs/source/contrib.rst b/docs/source/contrib.rst index d1eb42836..88a751816 100644 --- a/docs/source/contrib.rst +++ b/docs/source/contrib.rst @@ -43,11 +43,10 @@ The framework currently supports several kernels, including: - `RandomFeatureKernel` - `MixtureKernel` - `GraphicalKernel` -- `ProbabilityProductKernel` For example, usage see: -- The `Bayesian neural network example `_ +- The `Bayesian neural network example `_. **References** @@ -55,9 +54,9 @@ For example, usage see: Andreas Anastasiou, Alessandro Barp, François-Xavier Briol, Bruno Ebner, Robert E. Gaunt, Fatemeh Ghaderinezhad, Jackson Gorham, Arthur Gretton, Christophe Ley, Qiang Liu, Lester Mackey, Chris. J. Oates, Gesine Reinert, -Yvik Swan. https://arxiv.org/abs/2105.03481 +Yvik Swan. -2. *Stein variational gradient descent: A general-purpose Bayesian inference algorithm* (2016) +2. *Stein Variational Gradient Descent: A General-Purpose Bayesian Inference Algorithm* (2016) Qiang Liu, Dilin Wang. NeurIPS 3. *Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models* (2019) diff --git a/examples/stein_bnn.py b/examples/stein_bnn.py index 533291a45..952b0be49 100644 --- a/examples/stein_bnn.py +++ b/examples/stein_bnn.py @@ -19,14 +19,10 @@ import numpy as np from sklearn.model_selection import train_test_split -import jax -from jax import random -import jax.numpy as jnp - -import numpyro -from numpyro import deterministic -from numpyro.contrib.einstein import IMQKernel, SteinVI -from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive +from jax import config, nn, numpy as jnp, random + +from numpyro import deterministic, plate, sample, set_platform, subsample +from numpyro.contrib.einstein import MixtureGuidePredictive, RBFKernel, SteinVI from numpyro.distributions import Gamma, Normal from numpyro.examples.datasets import BOSTON_HOUSING, load_dataset from numpyro.infer import init_to_uniform @@ -54,23 +50,23 @@ def normalize(val, mean=None, std=None): return (val - mean) / std, mean, std -def model(x, y=None, hidden_dim=50, subsample_size=100): +def model(x, y=None, hidden_dim=50, sub_size=100): """BNN described in section 5 of [1]. **References:** - 1. *Stein variational gradient descent: A general purpose bayesian inference algorithm* + 1. *Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm* Qiang Liu and Dilin Wang (2016). """ - prec_nn = numpyro.sample( + prec_nn = sample( "prec_nn", Gamma(1.0, 0.1) ) # hyper prior for precision of nn weights and biases n, m = x.shape - with numpyro.plate("l1_hidden", hidden_dim, dim=-1): + with plate("l1_hidden", hidden_dim, dim=-1): # prior l1 bias term - b1 = numpyro.sample( + b1 = sample( "nn_b1", Normal( 0.0, @@ -79,38 +75,33 @@ def model(x, y=None, hidden_dim=50, subsample_size=100): ) assert b1.shape == (hidden_dim,) - with numpyro.plate("l1_feat", m, dim=-2): - w1 = numpyro.sample( + with plate("l1_feat", m, dim=-2): + w1 = sample( "nn_w1", Normal(0.0, 1.0 / jnp.sqrt(prec_nn)) ) # prior on l1 weights assert w1.shape == (m, hidden_dim) - with numpyro.plate("l2_hidden", hidden_dim, dim=-1): - w2 = numpyro.sample( + with plate("l2_hidden", hidden_dim, dim=-1): + w2 = sample( "nn_w2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn)) ) # prior on output weights - b2 = numpyro.sample( + b2 = sample( "nn_b2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn)) ) # prior on output bias term # precision prior on observations - prec_obs = numpyro.sample("prec_obs", Gamma(1.0, 0.1)) - with numpyro.plate( - "data", - x.shape[0], - subsample_size=subsample_size, - dim=-1, - ): - batch_x = numpyro.subsample(x, event_dim=1) + prec_obs = sample("prec_obs", Gamma(1.0, 0.1)) + with plate("data", x.shape[0], subsample_size=sub_size, dim=-1): + batch_x = subsample(x, event_dim=1) if y is not None: - batch_y = numpyro.subsample(y, event_dim=0) + batch_y = subsample(y, event_dim=0) else: batch_y = y - loc_y = deterministic("y_pred", jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2) + loc_y = deterministic("y_pred", nn.relu(batch_x @ w1 + b1) @ w2 + b2) - numpyro.sample( + sample( "y", Normal( loc_y, 1.0 / jnp.sqrt(prec_obs) @@ -123,34 +114,33 @@ def main(args): data = load_data() inf_key, pred_key, data_key = random.split(random.PRNGKey(args.rng_key), 3) - # normalize data and labels to zero mean unit variance! + # Normalize features to zero mean unit variance. x, xtr_mean, xtr_std = normalize(data.xtr) - y, ytr_mean, ytr_std = normalize(data.ytr) rng_key, inf_key = random.split(inf_key) + # We find that SteinVI benefits from a small radius when inferring BNNs. guide = AutoNormal(model, init_loc_fn=partial(init_to_uniform, radius=0.1)) stein = SteinVI( model, guide, - Adagrad(0.05), - IMQKernel(), - # ProbabilityProductKernel(guide=guide, scale=1.), + Adagrad(0.5), + RBFKernel(), repulsion_temperature=args.repulsion, num_stein_particles=args.num_stein_particles, num_elbo_particles=args.num_elbo_particles, ) start = time() - # use keyword params for static (shape etc.)! + # Use keyword params for static (shape etc.) result = stein.run( rng_key, args.max_iter, x, - y, + data.ytr, hidden_dim=args.hidden_dim, - subsample_size=args.subsample_size, + sub_size=args.subsample_size, progress_bar=args.progress_bar, ) time_taken = time() - start @@ -164,39 +154,36 @@ def main(args): ) xte, _, _ = normalize( data.xte, xtr_mean, xtr_std - ) # use train data statistics when accessing generalization - preds = pred( - pred_key, xte, subsample_size=xte.shape[0], hidden_dim=args.hidden_dim - )["y_pred"] + ) # Use train data statistics when accessing generalization. + n = xte.shape[0] + y_preds = pred(pred_key, xte, sub_size=n, hidden_dim=args.hidden_dim)["y_pred"] - y_pred = preds * ytr_std + ytr_mean - rmse = jnp.sqrt(jnp.mean((y_pred.mean(0) - data.yte) ** 2)) + mean_pred = y_preds.mean(0) + rmse = jnp.sqrt(jnp.mean((mean_pred - data.yte) ** 2)) print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}") print(rf"RMSE: {rmse:.2f}") # compute mean prediction and confidence interval around median - mean_prediction = y_pred.mean(0) - - ran = np.arange(mean_prediction.shape[0]) - percentiles = np.percentile(preds * ytr_std + ytr_mean, [5.0, 95.0], axis=0) + percentiles = jnp.percentile(y_preds, jnp.array([5.0, 95.0]), axis=0) # make plots fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) + ran = np.arange(mean_pred.shape[0]) ax.add_collection( LineCollection( zip(zip(ran, percentiles[0]), zip(ran, percentiles[1])), colors="lightblue" ) ) ax.plot(data.yte, "kx", label="y true") - ax.plot(mean_prediction, "ko", label="y pred") + ax.plot(mean_pred, "ko", label="y pred") ax.set(xlabel="example", ylabel="y", title="Mean predictions with 90% CI") ax.legend() fig.savefig("stein_bnn.pdf") if __name__ == "__main__": - jax.config.update("jax_debug_nans", True) + config.update("jax_debug_nans", True) parser = argparse.ArgumentParser() parser.add_argument("--subsample-size", type=int, default=100) @@ -212,6 +199,6 @@ def main(args): args = parser.parse_args() - numpyro.set_platform(args.device) + set_platform(args.device) main(args) diff --git a/numpyro/contrib/einstein/__init__.py b/numpyro/contrib/einstein/__init__.py index 57990a488..1b0c7b9ca 100644 --- a/numpyro/contrib/einstein/__init__.py +++ b/numpyro/contrib/einstein/__init__.py @@ -12,17 +12,19 @@ RBFKernel, ) from numpyro.contrib.einstein.stein_loss import SteinLoss -from numpyro.contrib.einstein.steinvi import SteinVI +from numpyro.contrib.einstein.steinvi import ASVGD, SVGD, SteinVI __all__ = [ - "SteinVI", - "SteinLoss", - "RBFKernel", + "ASVGD", + "GraphicalKernel", "IMQKernel", "LinearKernel", - "RandomFeatureKernel", - "GraphicalKernel", + "MixtureGuidePredictive", "MixtureKernel", + "RandomFeatureKernel", + "RBFKernel", "ProbabilityProductKernel", - "MixtureGuidePredictive", + "SVGD", + "SteinVI", + "SteinLoss", ] diff --git a/numpyro/contrib/einstein/mixture_guide_predictive.py b/numpyro/contrib/einstein/mixture_guide_predictive.py index 2a2a8ed51..af67a9331 100644 --- a/numpyro/contrib/einstein/mixture_guide_predictive.py +++ b/numpyro/contrib/einstein/mixture_guide_predictive.py @@ -4,21 +4,26 @@ from collections.abc import Callable, Sequence from functools import partial from typing import Optional +import warnings -import jax -from jax import numpy as jnp, random, vmap +from jax import numpy as jnp, random, tree, vmap from numpyro.handlers import substitute from numpyro.infer import Predictive +from numpyro.infer.autoguide import AutoGuide from numpyro.infer.util import _predictive +from numpyro.util import find_stack_level class MixtureGuidePredictive: """(EXPERIMENTAL INTERFACE) This class constructs the predictive distribution for - :class:`numpyro.contrib.einstein.steinvi.SteinVi` + :class:`numpyro.contrib.einstein.steinvi.SteinVi`. .. Note:: For single mixture component use numpyro.infer.Predictive. + .. Note:: For :class:`numpyro.contrib.einstein.steinvi.SVGD` and :class:`numpyro.contrib.einstein.steinvi.ASVGD` use + :class:`numpyro.infer.util.Predictive`. + .. warning:: The `MixtureGuidePredictive` is experimental and will likely be replaced by :class:`numpyro.infer.util.Predictive` in the future. @@ -44,6 +49,14 @@ def __init__( return_sites: Optional[Sequence[str]] = None, mixture_assignment_sitename="mixture_assignments", ): + if isinstance(guide, AutoGuide): + guide_name = guide.__class__.__name__ + if guide_name == "AutoDelta": + warnings.warn( + "Use numpyro.inter.Predictive with `batch_ndims=1` for ASVGD and SVGD.", + stacklevel=find_stack_level(), + ) + self.model_predictive = partial( Predictive, model=model, @@ -63,7 +76,7 @@ def __init__( self.guide = guide self.return_sites = return_sites - self.num_mixture_components = jnp.shape(jax.tree.flatten(params)[0][0])[0] + self.num_mixture_components = jnp.shape(tree.flatten(params)[0][0])[0] self.mixture_assignment_sitename = mixture_assignment_sitename def _call_with_params(self, rng_key, params, args, kwargs): @@ -99,7 +112,7 @@ def __call__(self, rng_key, *args, **kwargs): minval=0, maxval=self.num_mixture_components, ) - predictive_assign = jax.tree.map( + predictive_assign = tree.map( lambda arr: vmap(lambda i, assign: arr[i, assign])( jnp.arange(self._batch_shape[0]), assigns ), diff --git a/numpyro/contrib/einstein/stein_kernels.py b/numpyro/contrib/einstein/stein_kernels.py index d7607e3cb..330d901cd 100644 --- a/numpyro/contrib/einstein/stein_kernels.py +++ b/numpyro/contrib/einstein/stein_kernels.py @@ -29,12 +29,12 @@ def mode(self): @abstractmethod def compute( self, + rng_key, particles: jnp.ndarray, particle_info: dict[str, tuple[int, int]], loss_fn: Callable[[jnp.ndarray], float], ): - """ - Computes the kernel function given the input Stein particles + """Computes the kernel function given the input Stein particles :param particles: The Stein particles to compute the kernel from :param particle_info: A mapping from parameter names to the position in the @@ -56,15 +56,15 @@ def init(self, rng_key, particles_shape): class RBFKernel(SteinKernel): - """ - Calculates the Gaussian RBF kernel function, from [1], - :math:`k(x,y) = \\exp(\\frac{1}{h} \\|x-y\\|^2)`, - where the bandwidth h is computed using the median heuristic - :math:`h = \\frac{1}{\\log(n)} \\text{med}(\\|x-y\\|)`. + """Calculates the Gaussian RBF kernel function used in [1]. The kernel is given by - **References:** + :math:`k(x,y) = \\exp(\\frac{-1}{h} \\|x-y\\|^2)`, + + where the bandwidth :math:`h` is computed using the median heuristic + + :math:`h = \\frac{1}{\\log(m)} \\text{med}(\\|x-y\\|)`. - 1. *Stein Variational Gradient Descent* by Liu and Wang + In the above :math:`m` is the number of particles. :param str mode: Either 'norm' (default) specifying to take the norm of each particle, 'vector' to return a component-wise kernel or 'matrix' to return a @@ -73,6 +73,11 @@ class RBFKernel(SteinKernel): norm kernel or 'vector_diag' for diagonal of vector-valued kernel :param bandwidth_factor: A multiplier to the bandwidth based on data size n (default 1/log(n)) + + **References:** + + 1. Liu, Qiang, and Dilin Wang. "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm." + Advances in neural information processing systems 29 (2016). """ def __init__( @@ -92,7 +97,7 @@ def _normed(self): self.mode == "matrix" and self.matrix_mode == "norm_diag" ) - def compute(self, particles, particle_info, loss_fn): + def compute(self, rng_key, particles, particle_info, loss_fn): bandwidth = median_bandwidth(particles, self.bandwidth_factor) def kernel(x, y): @@ -114,19 +119,22 @@ def mode(self): class IMQKernel(SteinKernel): - """ - Calculates the IMQ kernel - :math:`k(x,y) = (c^2 + \\|x+y\\|^2_2)^{\\beta},` - from [1]. + """Calculates the IMQ kernel from Theorem 8 of [1]. The kernel is given by - **References:** + :math:`k(x,y) = (c^2 + \\|x-y\\|^2_2)^{\\beta},` + + where :math:`c\\in \\mathcal\\{R\\}` and :math:`\\beta \\in (-1,0)`. - 1. *Measuring Sample Quality with Kernels* by Gorham and Mackey :param str mode: Either 'norm' (default) specifying to take the norm of each particle, or 'vector' to return a component-wise kernel :param float const: Positive multi-quadratic constant (c) :param float expon: Inverse exponent (beta) between (-1, 0) + + **References:** + + 1. Gorham, Jackson, and Lester Mackey. "Measuring Sample Quality with Kernels." + International Conference on Machine Learning. PMLR, 2017. """ def __init__(self, mode="norm", const=1.0, expon=-0.5): @@ -144,7 +152,7 @@ def mode(self): def _normed(self): return self._mode == "norm" - def compute(self, particles, particle_info, loss_fn): + def compute(self, rng_key, particles, particle_info, loss_fn): def kernel(x, y): reduce = jnp.sum if self._normed() else lambda x: x return (self.const**2 + reduce((x - y) ** 2)) ** self.expon @@ -154,13 +162,14 @@ def kernel(x, y): class LinearKernel(SteinKernel): """ - Calculates the linear kernel - :math:`k(x,y) = x \\cdot y + 1` - from [1]. + Calculates the linear kernel from Theorem 3.3 in [1]. The kernel is given by + + :math:`k(x,y) = x^T y + 1`. **References:** - 1. *Stein Variational Gradient Descent as Moment Matching* by Liu and Wang + 1. Liu, Qiang, and Dilin Wang. "Stein Variational Gradient Descent as Moment Matching." + Advances in Neural Information Processing Systems 31 (2018). """ def __init__(self, mode="norm"): @@ -171,7 +180,7 @@ def __init__(self, mode="norm"): def mode(self): return self._mode - def compute(self, particles: jnp.ndarray, particle_info, loss_fn): + def compute(self, rng_key, particles: jnp.ndarray, particle_info, loss_fn): def kernel(x, y): if x.ndim == 1: return x @ y + 1 @@ -182,20 +191,27 @@ def kernel(x, y): class RandomFeatureKernel(SteinKernel): - """ - Calculates the random kernel - :math:`k(x,y)= 1/m\\sum_{l=1}^{m}\\phi(x,w_l)\\phi(y,w_l)` - from [1]. + """Calculates the Gaussian variate of random kernel in eq. 5 and 6 of [1]. The kernel is given by - **References:** + :math:`k(x,y)= \\frac{1}{m}\\sum_{l=1}^{m}\\phi(x,w_l)\\phi(y,w_l)`, + + where :math:`\\phi(\\cdot, w)` are the Gaussian random feature maps in eq. 6. The maps are given by - 1. *Stein Variational Gradient Descent as Moment Matching* by Liu and Wang + :math:`\\phi(z, w) = \\sqrt{2}\\left(h^{-1}w_1^Tz + w_0\\right)`, + + where :math:`h` is the bandwidth computed using the median trick from :class:`~numpyro.constrib.einstein.RBFKernel`, + :math:`w_0\\sim\\text{Uni}([0,2\\pi])` and :math:`w_1\\sim\\mathcal{N}(0,I)`. :param bandwidth_subset: How many particles should be used to calculate the bandwidth? (default None, meaning all particles) :param random_indices: The set of indices which to do random feature expansion on. (default None, meaning all indices) :param bandwidth_factor: A multiplier to the bandwidth based on data size n (default 1/log(n)) + + **References:** + + 1. Liu, Qiang, and Dilin Wang. "Stein Variational Gradient Descent as Moment Matching." + Advances in Neural Information Processing Systems 31 (2018). """ def __init__( @@ -229,7 +245,7 @@ def init(self, rng_key, particles_shape): rng_key, particles_shape[0], (self.bandwidth_subset,) ) - def compute(self, particles, particle_info, loss_fn): + def compute(self, rng_key, particles, particle_info, loss_fn): if self._random_weights is None: raise RuntimeError( "The `.init` method should be called first to initialize the" @@ -267,16 +283,19 @@ def kernel(x, y): class MixtureKernel(SteinKernel): - """ - Calculates a mixture of multiple kernels - :math:`k(x,y) = \\sum_i w_ik_i(x,y)` + """Calculates a mixture of multiple kernels from eq. 1 of [1]. The kernel is given by - **References:** + :math:`k(x,y) = \\sum_i w_ik_i(x,y)`, + + where :math:`k_i` is a reproducing kernel and :math:`w_i \\in (0,\\infty)`. - 1. *Stein Variational Gradient Descent as Moment Matching* by Liu and Wang + :param ws: Weight of each kernel in the mixture. + :param kernel_fns: Different kernel functions to mix together. - :param ws: Weight of each kernel in the mixture - :param kernel_fns: Different kernel functions to mix together + **References:** + + 1. Ai, Qingzhong, et al. "Stein variational gradient descent with multiple kernels." + Cognitive Computation 15.2 (2023): 672-682. """ def __init__(self, ws: list[float], kernel_fns: list[SteinKernel], mode="norm"): @@ -290,9 +309,10 @@ def __init__(self, ws: list[float], kernel_fns: list[SteinKernel], mode="norm"): def mode(self): return self.kernel_fns[0].mode - def compute(self, particles, particle_info, loss_fn): + def compute(self, rng_key, particles, particle_info, loss_fn): kernels = [ - kf.compute(particles, particle_info, loss_fn) for kf in self.kernel_fns + kf.compute(rng_key, particles, particle_info, loss_fn) + for kf in self.kernel_fns ] def kernel(x, y): @@ -310,19 +330,22 @@ def init(self, rng_key, particles_shape): class GraphicalKernel(SteinKernel): - """ - Calculates graphical kernel :math:`k(x,y) = diag({K_l(x_l,y_l)})` for local kernels - :math:`K_l` from [1][2]. + """Calculates the graphical kernel, also called the coordinate-wise kernel, from Theorem 1 in [1]. + The kernel is given by - **References:** + :math:`k(x,y) = diag({k_l(x_l,y_l)})`, - 1. *Stein Variational Message Passing for Continuous Graphical Models* by Wang, Zheng, and Liu - 2. *Stein Variational Gradient Descent with Matrix-Valued Kernels* by Wang, Tang, Bajaj, and Liu + for coordinate-wise kernels :math:`k_l`. :param local_kernel_fns: A mapping between parameters and a choice of kernel function for that parameter (default to default_kernel_fn for each parameter) :param default_kernel_fn: The default choice of kernel function when none is specified for a particular parameter + + **References:** + + 1. Wang, Dilin, Zhe Zeng, and Qiang Liu. "Stein variational message passing for continuous graphical models." + International Conference on Machine Learning. PMLR, 2018. """ def __init__( @@ -340,7 +363,7 @@ def __init__( def mode(self): return "matrix" - def compute(self, particles, particle_info, loss_fn): + def compute(self, rng_key, particles, particle_info, loss_fn): def pk_loss_fn(start, end): def fn(ps): return loss_fn( @@ -352,9 +375,11 @@ def fn(ps): return fn local_kernels = [] - for pk, (start_idx, end_idx) in particle_info.items(): + keys = random.split(rng_key, len(particle_info)) + for key, (pk, (start_idx, end_idx)) in zip(keys, particle_info.items()): pk_kernel_fn = self.local_kernel_fns.get(pk, self.default_kernel_fn) pk_kernel = pk_kernel_fn.compute( + key, particles[:, start_idx:end_idx], {pk: (0, end_idx - start_idx)}, pk_loss_fn(start_idx, end_idx), @@ -376,7 +401,16 @@ def kernel(x, y): class ProbabilityProductKernel(SteinKernel): - def __init__(self, guide, scale=1.0): + """**EXPERIMENTAL** Compute the unormalized probability product kernel for Gaussians given by eq. 5 in [1]. + + **References**: + + 1. Jebara, Tony, Risi Kondor, and Andrew Howard. "Probability product kernels." + The Journal of Machine Learning Research 5 (2004): 819-844. + """ + + def __init__(self, guide, scale=1.0, mode="norm"): + assert mode == "norm" self._mode = "norm" self.guide = guide self.scale = scale @@ -384,6 +418,7 @@ def __init__(self, guide, scale=1.0): def compute( self, + rng_key, particles: jnp.ndarray, particle_info: dict[str, tuple[int, int]], loss_fn: Callable[[jnp.ndarray], float], @@ -407,27 +442,19 @@ def kernel(x, y): biject = biject_to(self.guide.scale_constraint) x_loc = x[loc_idx] x_scale = biject(x[scale_idx]) - x_quad = (x_loc / x_scale) ** 2 + x_quad = ((x_loc / x_scale) ** 2).sum() y_loc = y[loc_idx] y_scale = biject(y[scale_idx]) - y_quad = (y_loc / y_scale) ** 2 + y_quad = ((y_loc / y_scale) ** 2).sum() cross_loc = x_loc * x_scale**-2 + y_loc * y_scale**-2 cross_var = 1 / (y_scale**-2 + x_scale**-2) - cross_quad = cross_loc**2 * cross_var + cross_quad = (cross_loc**2 * cross_var).sum() quad = jnp.exp(-self.scale / 2 * (x_quad + y_quad - cross_quad)) - norm = ( - (2 * jnp.pi) ** ((1 - 2 * self.scale) * 1 / 2) - * self.scale ** (-1 / 2) - * cross_var ** (1 / 2) - * x_scale ** (-self.scale) - * y_scale ** (-self.scale) - ) - - return jnp.linalg.norm(norm * quad) + return quad return kernel diff --git a/numpyro/contrib/einstein/stein_loss.py b/numpyro/contrib/einstein/stein_loss.py index 31f69424e..ddc1db26a 100644 --- a/numpyro/contrib/einstein/stein_loss.py +++ b/numpyro/contrib/einstein/stein_loss.py @@ -15,7 +15,7 @@ def __init__(self, elbo_num_particles=1, stein_num_particles=1): self.elbo_num_particles = elbo_num_particles self.stein_num_particles = stein_num_particles - def single_particle_loss( + def particle_loss( self, rng_key, model, @@ -28,46 +28,51 @@ def single_particle_loss( model_kwargs, param_map, ): - guide_key, model_key = random.split(rng_key, 2) + def single_draw_elbo(rng_key): + guide_key, model_key = random.split(rng_key, 2) - # 2. Draw from selected mixture component - guide_keys = random.split(guide_key, self.stein_num_particles) + # 2. Draw from selected mixture component + guide_keys = random.split(guide_key, self.stein_num_particles) + _, tri = log_density( + seed(guide, guide_keys[select_index]), + model_args, + model_kwargs, + {**param_map, **selected_particle}, + ) - seeded_chosen = seed(guide, guide_keys[select_index]) - log_chosen_density, chosen_trace = log_density( - seeded_chosen, model_args, model_kwargs, {**param_map, **selected_particle} - ) + # 3. Score mixture guide + def ldj(pj): + ld, trj = log_density( + replay(guide, tri), + model_args, + model_kwargs, + {**param_map, **unravel_pytree(pj)}, + ) + # Validate + check_model_guide_match(trj, tri) + return ld - # 3. Score mixture guide - def log_component_density(i): - log_cdensity, component_trace = log_density( - replay(seed(guide, guide_key[i]), chosen_trace), + ldg = logsumexp(vmap(ldj)(flat_particles)) - jnp.log( + self.stein_num_particles + ) + + # 4. Score model + ldm, mtr = log_density( + replay(seed(model, model_key), tri), model_args, model_kwargs, - {**param_map, **unravel_pytree(flat_particles[i])}, + {**param_map, **selected_particle}, ) - # Validate - check_model_guide_match(component_trace, chosen_trace) - return log_cdensity - - log_guide_density = logsumexp( - vmap(log_component_density)(jnp.arange(self.stein_num_particles)) - ) - # 4. Score model - seeded_model = seed(model, model_key) - log_model_density, model_trace = log_density( - replay(seeded_model, chosen_trace), - model_args, - model_kwargs, - {**param_map, **selected_particle}, - ) + # Validation + check_model_guide_match(mtr, tri) + _validate_model(mtr, plate_warning="loose") + elbo = ldm - ldg + return elbo - # Validation - check_model_guide_match(model_trace, chosen_trace) - _validate_model(model_trace, plate_warning="loose") - elbo = log_model_density - log_guide_density - return elbo + return vmap(single_draw_elbo)( + random.split(rng_key, self.elbo_num_particles) + ).mean() def loss(self, rng_key, param_map, model, guide, particles, *args, **kwargs): if not particles: @@ -84,7 +89,7 @@ def loss(self, rng_key, param_map, model, guide, particles, *args, **kwargs): ) score_keys = random.split(score_key, self.elbo_num_particles) elbos = vmap( - lambda key, assign: self.single_particle_loss( + lambda key, assign: self.particle_loss( rng_key=key, model=model, guide=guide, diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 98c055db1..5683be727 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -2,30 +2,24 @@ # SPDX-License-Identifier: Apache-2.0 from collections import namedtuple -from collections.abc import Callable from copy import deepcopy import functools from functools import partial from itertools import chain import operator -import jax -from jax import grad, jacfwd, numpy as jnp, random, vmap +from jax import grad, numpy as jnp, random, tree, vmap from jax.flatten_util import ravel_pytree from numpyro import handlers -from numpyro.contrib.einstein.stein_kernels import SteinKernel from numpyro.contrib.einstein.stein_loss import SteinLoss from numpyro.contrib.einstein.stein_util import ( batch_ravel_pytree, get_parameter_transform, ) -from numpyro.contrib.funsor import config_enumerate, enum from numpyro.distributions import Distribution -from numpyro.distributions.transforms import IdentityTransform -from numpyro.infer.autoguide import AutoGuide -from numpyro.infer.util import _guess_max_plate_nesting, transform_fn -from numpyro.optim import _NumPyroOptim +from numpyro.infer.autoguide import AutoDelta, AutoGuide +from numpyro.infer.util import transform_fn from numpyro.util import fori_collect SteinVIState = namedtuple("SteinVIState", ["optim_state", "rng_key"]) @@ -37,69 +31,89 @@ def _numel(shape): class SteinVI: - """Variational inference with Stein mixtures. + """Variational inference with Stein mixtures inference. **Example:** .. doctest:: - >>> from jax import random - >>> import jax.numpy as jnp - >>> import numpyro - >>> import numpyro.distributions as dist - >>> from numpyro.distributions import constraints + >>> from jax import random, numpy as jnp + + >>> from numpyro import sample, param, plate + >>> from numpyro.distributions import Beta, Bernoulli + >>> from numpyro.distributions.constraints import positive + + >>> from numpyro.optim import Adagrad >>> from numpyro.contrib.einstein import MixtureGuidePredictive, SteinVI, RBFKernel >>> def model(data): - ... f = numpyro.sample("latent_fairness", dist.Beta(10, 10)) - ... with numpyro.plate("N", data.shape[0] if data is not None else 10): - ... numpyro.sample("obs", dist.Bernoulli(f), obs=data) + ... f = sample("fairness", Beta(10, 10)) + ... n = data.shape[0] if data is not None else 1 + ... with plate("N", n): + ... sample("obs", Bernoulli(f), obs=data) >>> def guide(data): - ... alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive) - ... beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key), - ... constraint=constraints.positive) - ... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) + ... # Initialize all particles in the same point. + ... alpha_q = param("alpha_q", 15., constraint=positive) + ... # Initialize particles by sampling an Exponential distribution. + ... beta_q = param("beta_q", + ... lambda rng_key: random.exponential(rng_key), + ... constraint=positive) + ... sample("fairness", Beta(alpha_q, beta_q)) >>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)]) - >>> optimizer = numpyro.optim.Adam(step_size=0.0005) - >>> stein = SteinVI(model, guide, optimizer, kernel_fn=RBFKernel()) - >>> stein_result = stein.run(random.PRNGKey(0), 2000, data) + + >>> opt = Adagrad(step_size=0.05) + >>> k = RBFKernel() + >>> stein = SteinVI(model, guide, opt, k, num_stein_particles=2) + + >>> stein_result = stein.run(random.PRNGKey(0), 200, data) >>> params = stein_result.params - >>> # use guide to make predictive - >>> predictive = MixtureGuidePredictive(model, guide, params, num_samples=1000, guide_sites=stein.guide_sites) + + >>> # Use guide to make predictions. + >>> predictive = MixtureGuidePredictive(model, guide, params, num_samples=10, guide_sites=stein.guide_sites) >>> samples = predictive(random.PRNGKey(1), data=None) - :param Callable model: Python callable with Pyro primitives for the model. - :param guide: Python callable with Pyro primitives for the guide - (recognition network). + :param Callable model: Python callable with NumPyro primitives for the model. + :param Callable guide: Python callable with NumPyro primitives for the guide. :param _NumPyroOptim optim: An instance of :class:`~numpyro.optim._NumpyroOptim`. - :param SteinKernel kernel_fn: Function that produces a logarithm of the statistical kernel to use with Stein mixture - inference. - :param num_stein_particles: Number of particles (i.e., mixture components) in the Stein mixture. + Adagrad should be preferred over Adam [1]. + :param SteinKernel kernel_fn: Function that computes the reproducing kernel to use with Stein mixture + inference. We currently recommend :class:`~numpyro.contrib.einstein.RBFKernel`. + This may change as criteria for kernel selection are not well understood yet. + :param num_stein_particles: Number of particles (i.e., mixture components) in the mixture approximation. + Default is `10`. :param num_elbo_particles: Number of Monte Carlo draws used to approximate the attractive force gradient. - (More particles give better gradient approximations) - :param Float loss_temperature: Scaling factor of the attractive force. - :param Float repulsion_temperature: Scaling factor of the repulsive force (Non-linear Stein) - :param Callable non_mixture_guide_param_fn: predicate on names of parameters in guide which should be optimized - classically without Stein (E.g. parameters for large normal networks or other transformation) - :param static_kwargs: Static keyword arguments for the model / guide, i.e. arguments that remain constant + More particles give better gradient approximations. Default is `10`. + :param Float loss_temperature: Scaling factor of the attractive force. Default is `1`. + :param Float repulsion_temperature: Scaling factor of the repulsive force [2]. + We recommend not scaling the repulsion. Default is `1`. + :param Callable non_mixture_guide_param_fn: Predicate on names of parameters in the guide which should be optimized + using one particle. This could be parameters for large normal networks or other transformation. + Default excludes all parameters from this option. + :param static_kwargs: Static keyword arguments for the model and guide. These arguments cannot change during inference. - """ + + **References:** (MLA style) + + 1. Liu, Chang, et al. "Understanding and Accelerating Particle-Based Variational Inference." + International Conference on Machine Learning. PMLR, 2019. + 2. Wang, Dilin, and Qiang Liu. "Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models." + International Conference on Machine Learning. PMLR, 2019. + """ # noqa: E501 def __init__( self, - model: Callable, - guide: Callable, - optim: _NumPyroOptim, - kernel_fn: SteinKernel, - num_stein_particles: int = 10, - num_elbo_particles: int = 10, - loss_temperature: float = 1.0, - repulsion_temperature: float = 1.0, - non_mixture_guide_params_fn: Callable[[str], bool] = lambda name: False, - enum=True, + model, + guide, + optim, + kernel_fn, + num_stein_particles=10, + num_elbo_particles=10, + loss_temperature=1.0, + repulsion_temperature=1.0, + non_mixture_guide_params_fn=lambda name: False, **static_kwargs, ): if isinstance(guide, AutoGuide): @@ -125,7 +139,7 @@ def __init__( init_fn_name = guide.init_loc_fn.func.__name__ if init_fn_name == "init_to_uniform": assert ( - guide.init_loc_fn.keywords.get("radius", None) != 0 + guide.init_loc_fn.keywords.get("radius", None) != 0.0 ), init_loc_error_message else: init_fn_name = guide.init_loc_fn.__name__ @@ -148,7 +162,6 @@ def __init__( self.num_stein_particles = num_stein_particles self.loss_temperature = loss_temperature self.repulsion_temperature = repulsion_temperature - self.enum = enum self.non_mixture_params_fn = non_mixture_guide_params_fn self.guide_sites = None self.constrain_fn = None @@ -237,58 +250,43 @@ def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs): ) attractive_key, classic_key = random.split(rng_key) + def particle_transform_fn(particle): + params = unravel_pytree(particle) + ctparams = self.constrain_fn(self.particle_transform_fn(params)) + ctparticle, _ = ravel_pytree(ctparams) + return ctparticle + # 2. Calculate gradients for each particle - def kernel_particles_loss_fn( - rng_key, particles - ): # TODO: rewrite using def to utilize jax caching + def kernel_particles_loss_fn(rng_key, particles): particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles) grads = vmap( lambda i: grad( - lambda particle: ( - vmap( - lambda elbo_key: self.stein_loss.single_particle_loss( - rng_key=elbo_key, - model=handlers.scale( - self._inference_model, self.loss_temperature - ), - guide=self.guide, - selected_particle=unravel_pytree(particle), - unravel_pytree=unravel_pytree, - flat_particles=particles, - select_index=i, - model_args=args, - model_kwargs=kwargs, - param_map=self.constrain_fn(non_mixture_uparams), - ) - )( - random.split( - particle_keys[i], self.stein_loss.elbo_num_particles - ) - ) - ).mean() + lambda particle: self.stein_loss.particle_loss( + rng_key=particle_keys[i], + model=handlers.scale( + self._inference_model, self.loss_temperature + ), + guide=self.guide, + selected_particle=self.constrain_fn(unravel_pytree(particle)), + unravel_pytree=unravel_pytree, + flat_particles=vmap(particle_transform_fn)(particles), + select_index=i, + model_args=args, + model_kwargs=kwargs, + param_map=self.constrain_fn(non_mixture_uparams), + ) )(particles[i]) )(jnp.arange(self.stein_loss.stein_num_particles)) return grads - def particle_transform_fn(particle): - params = unravel_pytree(particle) - - tparams = self.particle_transform_fn(params) - ctparams = self.constrain_fn(tparams) - tparticle, _ = ravel_pytree(tparams) - ctparticle, _ = ravel_pytree(ctparams) - return tparticle, ctparticle - - # 2.1 Lift particles to constraint space - tstein_particles, ctstein_particles = vmap(particle_transform_fn)( - stein_particles - ) + # 2.1 Compute particle gradients (for attractive force) + particle_ljp_grads = kernel_particles_loss_fn(attractive_key, stein_particles) - # 2.2 Compute particle gradients (for attractive force) - particle_ljp_grads = kernel_particles_loss_fn(attractive_key, ctstein_particles) + # 2.3 Lift particles to constraint space + ctstein_particles = vmap(particle_transform_fn)(stein_particles) - # 2.2 Compute non-mixture parameter gradients + # 2.4 Compute non-mixture parameter gradients non_mixture_param_grads = grad( lambda cps: -self.stein_loss.loss( classic_key, @@ -302,8 +300,22 @@ def particle_transform_fn(particle): )(non_mixture_uparams) # 3. Calculate kernel of particles + def loss_fn(particle, i): + return self.stein_loss.particle_loss( + rng_key=rng_key, + model=handlers.scale(self._inference_model, self.loss_temperature), + guide=self.guide, + selected_particle=self.constrain_fn(unravel_pytree(particle)), + unravel_pytree=unravel_pytree, + flat_particles=ctstein_particles, + select_index=i, + model_args=args, + model_kwargs=kwargs, + param_map=self.constrain_fn(non_mixture_uparams), + ) + kernel = self.kernel_fn.compute( - stein_particles, particle_info, kernel_particles_loss_fn + rng_key, stein_particles, particle_info, loss_fn ) # 4. Calculate the attractive force and repulsive force on the particles @@ -311,59 +323,29 @@ def particle_transform_fn(particle): lambda y: jnp.sum( vmap( lambda x, x_ljp_grad: self._apply_kernel(kernel, x, y, x_ljp_grad) - )(tstein_particles, particle_ljp_grads), + )(stein_particles, particle_ljp_grads), axis=0, ) - )(tstein_particles) + )(stein_particles) + repulsive_force = vmap( - lambda y: jnp.sum( + lambda y: jnp.mean( vmap( lambda x: self.repulsion_temperature * self._kernel_grad(kernel, x, y) - )(tstein_particles), + )(stein_particles), axis=0, ) - )(tstein_particles) - - def single_particle_grad(particle, attr_forces, rep_forces): - def _nontrivial_jac(var_name, var): - if isinstance(self.particle_transforms[var_name], IdentityTransform): - return None - return jacfwd(self.particle_transforms[var_name].inv)(var) - - def _update_force(attr_force, rep_force, jac): - force = attr_force.reshape(-1) + rep_force.reshape(-1) - if jac is not None: - force = force @ jac.reshape( - (_numel(jac.shape[: len(jac.shape) // 2]), -1) - ) - return force.reshape(attr_force.shape) + )(stein_particles) - reparam_jac = { - name: jax.tree.map(lambda var: _nontrivial_jac(name, var), variables) - for name, variables in unravel_pytree(particle).items() - } - jac_params = jax.tree.map( - _update_force, - unravel_pytree(attr_forces), - unravel_pytree(rep_forces), - reparam_jac, - ) - jac_particle, _ = ravel_pytree(jac_params) - return jac_particle + # 6. Compute the stein force + particle_grads = attractive_force + repulsive_force - particle_grads = ( - vmap(single_particle_grad)( - stein_particles, attractive_force, repulsive_force - ) - / self.num_stein_particles - ) - - # 5. Decompose the monolithic particle forces back to concrete parameter values + # 7. Decompose the monolithic particle forces back to concrete parameter values stein_param_grads = unravel_pytree_batched(particle_grads) - # 6. Return loss and gradients (based on parameter forces) - res_grads = jax.tree.map( + # 8. Return loss and gradients (based on parameter forces) + res_grads = tree.map( lambda x: -x, {**non_mixture_param_grads, **stein_param_grads} ) return jnp.linalg.norm(particle_grads), res_grads @@ -371,10 +353,10 @@ def _update_force(attr_force, rep_force, jac): def init(self, rng_key, *args, **kwargs): """Register random variable transformations, constraints and determine initialize positions of the particles. - :param rng_key: Random number generator seed. - :param args: Arguments to the model / guide. - :param kwargs: Keyword arguments to the model / guide. - :return: initial :data:`SteinVIState` + :param jax.random.PRNGKey rng_key: Random number generator seed. + :param args: Positional arguments to the model and guide. + :param kwargs: Keyword arguments to the model and guide. + :return: Initial :data:`SteinVIState`. """ rng_key, kernel_seed, model_seed, guide_seed, particle_seed = random.split( @@ -400,7 +382,6 @@ def init(self, rng_key, *args, **kwargs): inv_transforms = {} particle_transforms = {} guide_param_names = set() - should_enum = False for site in model_trace.values(): if ( "fn" in site @@ -409,9 +390,7 @@ def init(self, rng_key, *args, **kwargs): and isinstance(site["fn"], Distribution) and site["fn"].is_discrete ): - if site["fn"].has_enumerate_support and self.enum: - should_enum = True - else: + if site["fn"].has_enumerate_support: raise Exception( "Cannot enumerate model with discrete variables without enumerate support" ) @@ -421,22 +400,17 @@ def init(self, rng_key, *args, **kwargs): transform = get_parameter_transform(site) inv_transforms[site["name"]] = transform transforms[site["name"]] = transform.inv - particle_transforms[site["name"]] = site.get( - "particle_transform", IdentityTransform() - ) + particle_transforms[site["name"]] = transform if site["name"] in guide_init_params: pval = guide_init_params[site["name"]] if self.non_mixture_params_fn(site["name"]): - pval = jax.tree.map(lambda x: x[0], pval) + pval = tree.map(lambda x: x[0], pval) else: pval = site["value"] params[site["name"]] = transform.inv(pval) if site["name"] in guide_trace: guide_param_names.add(site["name"]) - if should_enum: - mpn = _guess_max_plate_nesting(model_trace) - self._inference_model = enum(config_enumerate(self.model), -mpn - 1) self.guide_sites = guide_param_names self.constrain_fn = partial(transform_fn, inv_transforms) self.uconstrain_fn = partial(transform_fn, transforms) @@ -455,23 +429,22 @@ def init(self, rng_key, *args, **kwargs): return SteinVIState(self.optim.init(params), rng_key) def get_params(self, state: SteinVIState): - """ - Gets values at `param` sites of the `model` and `guide`. - :param state: current state of the optimizer. + """Gets values at `param` sites of the `model` and `guide`. + + :param SteinVIState state: Current state of optimization. + :return: Constraint parameters (i.e., particles). """ params = self.constrain_fn(self.optim.get_params(state.optim_state)) return params - def update(self, state: SteinVIState, *args, **kwargs): - """ - Take a single step of Stein (possibly on a batch / minibatch of data), - using the optimizer. - :param state: current state of Stein. - :param args: arguments to the model / guide (these can possibly vary during - the course of fitting). - :param kwargs: keyword arguments to the model / guide (these can possibly vary - during the course of fitting). - :return: tuple of `(state, loss)`. + def update(self, state: SteinVIState, *args, **kwargs) -> SteinVIState: + """Take a single step of SteinVI using the optimizer. We recommend using + the run method instead of update. + + :param SteinVIState state: Current state of inference. + :param args: Position arguments to the model and guide. + :param kwargs: Keyword arguments to the model and guide. + :return: next :data:`SteinVIState` """ rng_key, rng_key_mcmc, rng_key_step = random.split(state.rng_key, num=3) params = self.optim.get_params(state.optim_state) @@ -482,6 +455,33 @@ def update(self, state: SteinVIState, *args, **kwargs): optim_state = self.optim.update(grads, optim_state) return SteinVIState(optim_state, rng_key), loss_val + def setup_run(self, rng_key, num_steps, args, init_state, kwargs): + if init_state is None: + state = self.init(rng_key, *args, **kwargs) + else: + state = init_state + loss = self.evaluate(state, *args, **kwargs) + + info_init = (state, loss) + + def step(info): + state, loss = info + return self.update(state, *args, **kwargs) # uses closure! + + def collect(info): + _, loss = info + return loss + + def extract(info): + state, _ = info + return state + + def diagnostic(info): + _, loss = info + return f"Stein force {loss:.2f}." + + return step, diagnostic, collect, extract, info_init + def run( self, rng_key, @@ -489,41 +489,44 @@ def run( *args, progress_bar=True, init_state=None, - collect_fn=lambda val: val[1], # TODO: refactor **kwargs, ): - def bodyfn(_i, info): - body_state = info[0] - return (*self.update(body_state, *info[2:], **kwargs), *info[2:]) + """Run SteinVI inference. + + :param jax.random.PRNGKey rng_key: Random number generator seed. + :param int num_steps: Number of steps to optimize. + :param *args: Positional arguments to the model and guide. + :param bool progress_bar: Use a progress bar. Default is `True`. + Inference is faster with `False`. + :param SteinVIState init_state: Initial state of inference. + Default is ``None``, which will initialize using init before running inference. + :param **kwargs: Keyword arguments to the model and guide. + """ + step, diagnostic, collect, extract, init_info = self.setup_run( + rng_key, num_steps, args, init_state, kwargs + ) - if init_state is None: - state = self.init(rng_key, *args, **kwargs) - else: - state = init_state - loss = self.evaluate(state, *args, **kwargs) auxiliaries, last_res = fori_collect( 0, num_steps, - lambda info: bodyfn(0, info), - (state, loss, *args), + step, + init_info, progbar=progress_bar, - transform=collect_fn, + transform=collect, return_last_val=True, - diagnostics_fn=lambda state: f"norm Stein force: {state[1]:.3f}" - if progress_bar - else None, + diagnostics_fn=diagnostic if progress_bar else None, ) - state = last_res[0] + + state = extract(last_res) return SteinVIRunResult(self.get_params(state), state, auxiliaries) - def evaluate(self, state, *args, **kwargs): - """ - Take a single step of Stein (possibly on a batch / minibatch of data). - :param state: current state of Stein. - :param args: arguments to the model / guide (these can possibly vary during - the course of fitting). - :param kwargs: keyword arguments to the model / guide. - :return: normed stein force given the current parameter values (held within `state.optim_state`). + def evaluate(self, state: SteinVIState, *args, **kwargs): + """Take a single step of Stein (possibly on a batch / minibatch of data). + + :param SteinVIState state: Current state of inference. + :param args: Positional arguments to the model and guide. + :param kwargs: Keyword arguments to the model and guide. + :return: Normed Stein force given by :data:`SteinVIState`. """ # we split to have the same seed as `update_fn` given a state _, _, rng_key_eval = random.split(state.rng_key, num=3) @@ -532,3 +535,241 @@ def evaluate(self, state, *args, **kwargs): rng_key_eval, params, *args, **kwargs, **self.static_kwargs ) return normed_stein_force + + +class SVGD(SteinVI): + """Stein variational gradient descent [1]. + + **Example:** + + .. doctest:: + + >>> from jax import random, numpy as jnp + + >>> from numpyro import sample, param, plate + >>> from numpyro.distributions import Beta, Bernoulli + >>> from numpyro.distributions.constraints import positive + + >>> from numpyro.optim import Adagrad + >>> from numpyro.contrib.einstein import SVGD, RBFKernel + >>> from numpyro.infer import Predictive + + >>> def model(data): + ... f = sample("fairness", Beta(10, 10)) + ... n = data.shape[0] if data is not None else 1 + ... with plate("N", n): + ... sample("obs", Bernoulli(f), obs=data) + + >>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)]) + + >>> opt = Adagrad(step_size=0.05) + >>> k = RBFKernel() + >>> svgd = SVGD(model, opt, k, num_stein_particles=2) + + >>> svgd_result = svgd.run(random.PRNGKey(0), 200, data) + + >>> params = svgd_result.params + >>> predictive = Predictive(model, svgd.guide, params, num_samples=10, batch_ndims=1) + >>> samples = predictive(random.PRNGKey(1), data=None) + + :param Callable model: Python callable with NumPyro primitives for the model. + :param Callable guide: Python callable with NumPyro primitives for the guide. + :param _NumPyroOptim optim: An instance of :class:`~numpyro.optim._NumpyroOptim`. + Adagrad should be preferred over Adam [1]. + :param SteinKernel kernel_fn: Function that computes the reproducing kernel to use with SVGD. + We currently recommend :class:`~numpyro.contrib.einstein.RBFKernel`. This may change as criteria for + kernel selection are not well understood yet. + :param num_stein_particles: Number of particles (i.e., mixture components) in the mixture approximation. + Default is 10. + :param Dict guide_kwargs: Keyword arguments for `~numpyro.infer.autoguide.AutoDelta`. + Default behaviour is the same as the default for `~numpyro.infer.autoguide.AutoDelta`. + Usage:: + + opt = Adagrad(step_size=0.05) + k = RBFKernel() + svgd = SVGD(model, opt, k, guide_kwargs={'init_loc_fn': partial(init_to_uniform, radius=0.1)}) + + :param Dict static_kwargs: Static keyword arguments for the model and guide. These arguments cannot + change during inference. + + **References:** (MLA style) + + 1. Liu, Qiang, and Dilin Wang. "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm." + Advances in neural information processing systems 29 (2016). + """ + + def __init__( + self, + model, + optim, + kernel_fn, + num_stein_particles=10, + guide_kwargs={}, + **static_kwargs, + ): + super().__init__( + model=model, + guide=AutoDelta(model, **guide_kwargs), + optim=optim, + kernel_fn=kernel_fn, + num_stein_particles=num_stein_particles, + # With a Delta guide we only need one draw + # per particle to get its contribution to the expectation. + num_elbo_particles=1, + loss_temperature=1.0 / float(num_stein_particles), + # For SVGD repulsion temperature != 1 changes the + # target posterior so we keep it fixed at 1. + repulsion_temperature=1.0, + non_mixture_guide_params_fn=lambda name: False, + **static_kwargs, + ) + + +class ASVGD(SVGD): + """Annealing Stein variational gradient descent [1]. + + **Example:** + + .. doctest:: + + >>> from jax import random, numpy as jnp + + >>> from numpyro import sample, param, plate + >>> from numpyro.distributions import Beta, Bernoulli + >>> from numpyro.distributions.constraints import positive + + >>> from numpyro.optim import Adagrad + >>> from numpyro.contrib.einstein import SVGD, RBFKernel + >>> from numpyro.infer import Predictive + + >>> def model(data): + ... f = sample("fairness", Beta(10, 10)) + ... n = data.shape[0] if data is not None else 1 + ... with plate("N", n): + ... sample("obs", Bernoulli(f), obs=data) + + >>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)]) + + >>> opt = Adagrad(step_size=0.05) + >>> k = RBFKernel() + >>> asvgd = ASVGD(model, opt, k, num_stein_particles=2) + + >>> asvgd_result = asvgd.run(random.PRNGKey(0), 200, data) + + >>> params = asvgd_result.params + >>> predictive = Predictive(model, asvgd.guide, params, num_samples=10, batch_ndims=1) + >>> samples = predictive(random.PRNGKey(1), data=None) + + :param Callable model: Python callable with NumPyro primitives for the model. + :param Callable guide: Python callable with NumPyro primitives for the guide. + :param _NumPyroOptim optim: An instance of :class:`~numpyro.optim._NumpyroOptim`. + Adagrad should be preferred over Adam [1]. + :param SteinKernel kernel_fn: Function that computes the reproducing kernel to use with ASVGD. + We currently recommend :class:`~numpyro.contrib.einstein.RBFKernel`. + This may change as criteria for kernel selection are not well understood yet. + :param num_stein_particles: Number of particles (i.e., mixture components) in the mixture approximation. + Default is `10`. + :param num_cycles: The total number of cycles during inference. This corresponds to $C$ in eq. 4 of [1]. + Default is `10`. + :param trans_speed: Speed of transition between two phases during inference. This corresponds to $p$ in eq. 4 + of [1]. Default is `10`. + :param Dict guide_kwargs: Keyword arguments for `~numpyro.infer.autoguide.AutoDelta`. + Default behaviour is the same as the default for `~numpyro.infer.autoguide.AutoDelta`. + Usage:: + + opt = Adagrad(step_size=0.05) + k = RBFKernel() + asvgd = ASVGD(model, opt, k, guide_kwargs={'init_loc_fn': partial(init_to_uniform, radius=0.1)}) + + :param Dict static_kwargs: Static keyword arguments for the model and guide. These arguments cannot + change during inference. + + **References:** (MLA style) + + 1. D'Angelo, Francesco, and Vincent Fortuin. "Annealed Stein Variational Gradient Descent." + Third Symposium on Advances in Approximate Bayesian Inference, 2021. + """ + + def __init__( + self, + model, + optim, + kernel_fn, + num_stein_particles=10, + num_cycles=10, + trans_speed=10, + guide_kwargs={}, + **static_kwargs, + ): + self.num_cycles = num_cycles + self.trans_speed = trans_speed + + super().__init__( + model, + optim, + kernel_fn, + num_stein_particles, + guide_kwargs, + **static_kwargs, + ) + + @staticmethod + def _cyclical_annealing(num_steps: int, num_cycles: int, trans_speed: int): + """Cyclical annealing schedule as in eq. 4 of [1]. + + **References** (MLA) + 1. D'Angelo, Francesco, and Vincent Fortuin. "Annealed Stein Variational Gradient Descent." + Third Symposium on Advances in Approximate Bayesian Inference, 2021. + + :param num_steps: The total number of steps. Corresponds to $T$ in eq. 4 of [1]. + :param num_cycles: The total number of cycles. Corresponds to $C$ in eq. 4 of [1]. + :param trans_speed: Speed of transition between two phases. Corresponds to $p$ in eq. 4 of [1]. + """ + norm = float(num_steps + 1) / float(num_cycles) + cycle_len = num_steps // num_cycles + last_start = (num_cycles - 1) * cycle_len + + def cycle_fn(t): + last_cycle = t // last_start + return (1 - last_cycle) * ( + ((t % cycle_len) + 1) / norm + ) ** trans_speed + last_cycle + + return cycle_fn + + def setup_run(self, rng_key, num_steps, args, init_state, kwargs): + cyc_fn = ASVGD._cyclical_annealing(num_steps, self.num_cycles, self.trans_speed) + + ( + istep, + idiag, + icol, + iext, + iinit, + ) = super().setup_run( + rng_key, + num_steps, + args, + init_state, + kwargs, + ) + + def step(info): + t, iinfo = info[0], info[-1] + self.loss_temperature = cyc_fn(t) / float(self.num_stein_particles) + return (t + 1, istep(iinfo)) + + def diagnostic(info): + _, iinfo = info + return idiag(iinfo) + + def collect(info): + _, iinfo = info + return icol(iinfo) + + def extract_state(info): + _, iinfo = info + return iext(iinfo) + + info_init = (0, iinit) + return step, diagnostic, collect, extract_state, info_init diff --git a/test/contrib/einstein/test_stein_kernels.py b/test/contrib/einstein/test_stein_kernels.py index d87af2e5b..57060eea4 100644 --- a/test/contrib/einstein/test_stein_kernels.py +++ b/test/contrib/einstein/test_stein_kernels.py @@ -10,41 +10,91 @@ from jax import numpy as jnp, random +from numpyro import sample from numpyro.contrib.einstein import SteinVI from numpyro.contrib.einstein.stein_kernels import ( GraphicalKernel, IMQKernel, LinearKernel, MixtureKernel, + ProbabilityProductKernel, RandomFeatureKernel, RBFKernel, ) +from numpyro.distributions import Normal +from numpyro.infer.autoguide import AutoNormal from numpyro.optim import Adam T = namedtuple("TestSteinKernel", ["kernel", "particle_info", "loss_fn", "kval"]) -PARTICLES_2D = np.array([[1.0, 2.0], [-10.0, 10.0], [7.0, 3.0], [2.0, -1]]) +PARTICLES = np.array([[1.0, 2.0], [10.0, 5.0], [7.0, 3.0], [2.0, -1]]) + + +def MOCK_MODEL(): + sample("x", Normal()) + -TPARTICLES_2D = (np.array([1.0, 2.0]), np.array([10.0, 5.0])) # transformed particles TEST_CASES = [ T( RBFKernel, lambda d: {}, lambda x: x, { - "norm": 0.040711474, - "vector": np.array([0.056071877, 0.7260586]), - "matrix": np.array([[0.040711474, 0.0], [0.0, 0.040711474]]), + # let + # m = 4 + # median trick (x_i in PARTICLES) + # h = med( [||x_i-x_j||_2]_{i,j=(0,0)}^{(m,m)} )^2 / log(m) = 16.92703711264772 + # x = (1, 2); y=(10,5) + # in + # k(x,y) = exp(-.5 * ||x-y||_2^2 / h) = 0.00490776 + "norm": 0.00490776, + # let + # h = 16.92703711264772 (from norm case) + # x = (1, 2); y=(10,5) + # in + # k(x,y) = exp(-.5 * (x-y)^2 / h) = (0.00835209, 0.5876088) + "vector": np.array([0.00835209, 0.5876088]), + # I(n) is n by n identity matrix + # let + # k_norm = 0.00490776 (from norm case) + # x = (1, 2); y=(10,5) + # in + # k(x,y) = k_norm * I + "matrix": np.array([[0.00490776, 0.0], [0.0, 0.00490776]]), }, ), - T(RandomFeatureKernel, lambda d: {}, lambda x: x, {"norm": 15.173317}), + T(RandomFeatureKernel, lambda d: {}, lambda x: x, {"norm": 13.805723}), T( IMQKernel, lambda d: {}, lambda x: x, - {"norm": 0.104828484, "vector": np.array([0.11043153, 0.31622776])}, + { + # let + # x = (1,2); y=(10,5) + # b = -.5; c=1 + # in + # k(x,y) = (c**2 + ||x-y||^2)^b = (1 + 90)^(-.5) = 0.10482848367219183 + "norm": 0.104828484, + # let + # x = (1,2); y=(10,5) + # b = -.5; c=1 + # in + # k(x,y) = (c**2 + (x-y)^2)^b = (1 + [81,9])^(-.5) = [0.11043153, 0.31622777] + "vector": np.array([0.11043153, 0.31622776]), + }, + ), + T( + LinearKernel, + lambda d: {}, + lambda x: x, + { + # let + # x = (1,2); y=(10,5) + # in + # k(x,y) = (x^Ty + 1) = 20 + 1 = 21 + "norm": 21.0 + }, ), - T(LinearKernel, lambda d: {}, lambda x: x, {"norm": 21.0}), T( lambda mode: MixtureKernel( mode=mode, @@ -53,7 +103,8 @@ ), lambda d: {}, lambda x: x, - {"matrix": np.array([[0.040711474, 0.0], [0.0, 0.040711474]])}, + # simply .2rbf_matrix + .8 rbf_matrix = rbf_matrix + {"matrix": np.array([[0.00490776, 0.0], [0.0, 0.00490776]])}, ), T( lambda mode: GraphicalKernel( @@ -61,11 +112,35 @@ ), lambda d: {"p1": (0, d)}, lambda x: x, - {"matrix": np.array([[0.040711474, 0.0], [0.0, 0.040711474]])}, + { + # let + # d = 2 => l = [0,1] + # x = (1,2); y=(10,5) + # x_0 = x_1 = x; y_0=y_1=y + # k_0(x_0,y_0) = k_1(x_1,y_1) = RBFKernel(norm)(x,y) = 0.00490776 + # in + # k(x,y) = diag({k_l(x_l,y_l)}) = [[0.00490776, 0.0], [0.0, 0.00490776]] + "matrix": np.array([[0.00490776, 0.0], [0.0, 0.00490776]]) + }, + ), + T( + lambda mode: ProbabilityProductKernel(mode=mode, guide=AutoNormal(MOCK_MODEL)), + lambda d: {"x_auto_loc": (0, 1), "x_auto_scale": (1, 2)}, + lambda x: x, + # eq. 5 Probability Product Kernels + # x := (loc_x, softplus-inv(std_x)); y =: (loc_y, softplus-inv(std_y)) + # let + # s+(z) = softplus(z) = log(exp(z)+1); + # x =(1,2); y=(10,5) + # in + # k(x,y) = exp(-.5((1/s+(2))^2 + + # (10/s+(5))^2 - + # (1/(s+(2)^2 + (10/s+(5))^2)) ** 2 / (1/s+(2)^2 + 1/s+(5)^2))) + # = 0.2544481 + {"norm": 0.2544481}, ), ] -PARTICLES = [(PARTICLES_2D, TPARTICLES_2D)] TEST_IDS = [t[0].__class__.__name__ for t in TEST_CASES] @@ -73,18 +148,16 @@ @pytest.mark.parametrize( "kernel, particle_info, loss_fn, kval", TEST_CASES, ids=TEST_IDS ) -@pytest.mark.parametrize("particles, tparticles", PARTICLES) +@pytest.mark.parametrize("particles", [PARTICLES]) @pytest.mark.parametrize("mode", ["norm", "vector", "matrix"]) -def test_kernel_forward( - kernel, particles, particle_info, loss_fn, tparticles, mode, kval -): +def test_kernel_forward(kernel, particles, particle_info, loss_fn, mode, kval): if mode not in kval: - return - (d,) = tparticles[0].shape + pytest.skip() + (d,) = particles[0].shape kernel = kernel(mode=mode) kernel.init(random.PRNGKey(0), particles.shape) - kernel_fn = kernel.compute(particles, particle_info(d), loss_fn) - value = kernel_fn(*tparticles) + kernel_fn = kernel.compute(random.PRNGKey(0), particles, particle_info(d), loss_fn) + value = kernel_fn(particles[0], particles[1]) assert_allclose(value, jnp.array(kval[mode]), atol=1e-6) @@ -92,19 +165,19 @@ def test_kernel_forward( "kernel, particle_info, loss_fn, kval", TEST_CASES, ids=TEST_IDS ) @pytest.mark.parametrize("mode", ["norm", "vector", "matrix"]) -@pytest.mark.parametrize("particles, tparticles", PARTICLES) -def test_apply_kernel( - kernel, particles, particle_info, loss_fn, tparticles, mode, kval -): +@pytest.mark.parametrize("particles", [PARTICLES]) +def test_apply_kernel(kernel, particles, particle_info, loss_fn, mode, kval): if mode not in kval: pytest.skip() - (d,) = tparticles[0].shape + (d,) = particles[0].shape kernel_fn = kernel(mode=mode) kernel_fn.init(random.PRNGKey(0), particles.shape) - kernel_fn = kernel_fn.compute(particles, particle_info(d), loss_fn) + kernel_fn = kernel_fn.compute( + random.PRNGKey(0), particles, particle_info(d), loss_fn + ) v = np.ones_like(kval[mode]) stein = SteinVI(id, id, Adam(1.0), kernel(mode)) - value = stein._apply_kernel(kernel_fn, *tparticles, v) + value = stein._apply_kernel(kernel_fn, particles[0], particles[1], v) kval_ = copy(kval) if mode == "matrix": kval_[mode] = np.dot(kval_[mode], v) diff --git a/test/contrib/einstein/test_stein_loss.py b/test/contrib/einstein/test_stein_loss.py index 2f220c752..c8b21082d 100644 --- a/test/contrib/einstein/test_stein_loss.py +++ b/test/contrib/einstein/test_stein_loss.py @@ -5,6 +5,7 @@ from pytest import fail from jax import numpy as jnp, random, value_and_grad +from jax.scipy.special import logsumexp import numpyro from numpyro.contrib.einstein.stein_loss import SteinLoss @@ -54,17 +55,16 @@ def stein_loss_fn(x, particles): def test_stein_particle_loss(): - def model(x): - numpyro.sample("x", dist.Normal(0, 1)) - numpyro.sample("obs", dist.Normal(0, 1), obs=x) + def model(obs): + z = numpyro.sample("z", dist.Normal(0, 1)) + numpyro.sample("obs", dist.Normal(z, 1), obs=obs) def guide(x): - numpyro.sample("x", dist.Normal(0, 1)) + x = numpyro.param("x", 0.0) + numpyro.sample("z", dist.Normal(x, 1)) - def stein_loss_fn(x, particles, chosen_particle, assign): - return SteinLoss( - elbo_num_particles=1, stein_num_particles=3 - ).single_particle_loss( + def stein_loss_fn(chosen_particle, obs, particles, assign): + return SteinLoss(elbo_num_particles=1, stein_num_particles=3).particle_loss( random.PRNGKey(0), model, guide, @@ -72,7 +72,7 @@ def stein_loss_fn(x, particles, chosen_particle, assign): unravel_pytree, particles, assign, - (x,), + (obs,), {}, {}, ) @@ -80,18 +80,16 @@ def stein_loss_fn(x, particles, chosen_particle, assign): xs = jnp.array([-1, 0.5, 3.0]) num_particles = xs.shape[0] particles = {"x": xs} + zs = jnp.array([-0.1241799, -0.65357316, -0.96147573]) # from inspect flat_particles, unravel_pytree, _ = batch_ravel_pytree(particles, nbatch_dims=1) - losses, grads = [], [] - for i in range(num_particles): - chosen_particle = unravel_pytree(flat_particles[i]) - loss, grad = value_and_grad(stein_loss_fn)( - 2.0, flat_particles, chosen_particle, i - ) - losses.append(loss) - grads.append(grad) - assert jnp.abs(losses[0] - losses[1]) > 0.1 - assert jnp.abs(losses[1] - losses[2]) > 0.1 - assert_allclose(grads[0], grads[1]) - assert_allclose(grads[1], grads[2]) + for i in range(num_particles): + chosen_particle = {"x": jnp.array([-1.0])} + act_loss = stein_loss_fn(chosen_particle, 2.0, flat_particles, i) + + z = zs[i] + lp_m = dist.Normal().log_prob(z) + dist.Normal(z).log_prob(2.0) + lp_g = logsumexp(dist.Normal(xs).log_prob(z)) - jnp.log(3) + exp_loss = lp_m - lp_g + assert_allclose(act_loss, exp_loss) diff --git a/test/contrib/einstein/test_steinvi.py b/test/contrib/einstein/test_steinvi.py index fd297c420..1f096f56f 100644 --- a/test/contrib/einstein/test_steinvi.py +++ b/test/contrib/einstein/test_steinvi.py @@ -15,6 +15,8 @@ import numpyro from numpyro import handlers from numpyro.contrib.einstein import ( + ASVGD, + SVGD, GraphicalKernel, IMQKernel, LinearKernel, @@ -116,14 +118,18 @@ def model(features, labels): @pytest.mark.parametrize("kernel", KERNELS) @pytest.mark.parametrize("problem", (uniform_normal, regression)) -def test_kernel_smoke(kernel, problem): +@pytest.mark.parametrize("method", ("ASVGD", "SVGD", "SteinVI")) +def test_run_smoke(kernel, problem, method): true_coefs, data, model = problem() - stein = SteinVI( - model, - AutoNormal(model), - Adam(1e-1), - kernel, - ) + if method == "ASVGD": + stein = ASVGD(model, Adam(1e-1), kernel, num_stein_particles=1) + if method == "SVGD": + stein = SVGD(model, Adam(1e-1), kernel, num_stein_particles=1) + if method == "SteinVI": + stein = SteinVI( + model, AutoNormal(model), Adam(1e-1), kernel, num_stein_particles=1 + ) + stein.run(random.PRNGKey(0), 1, *data) diff --git a/test/contrib/einstein/test_steinvi_util.py b/test/contrib/einstein/test_steinvi_util.py index 38fbd0603..30d00b731 100644 --- a/test/contrib/einstein/test_steinvi_util.py +++ b/test/contrib/einstein/test_steinvi_util.py @@ -8,8 +8,7 @@ import pytest import scipy -import jax -from jax import numpy as jnp +from jax import numpy as jnp, tree from numpyro.contrib.einstein.stein_util import batch_ravel_pytree, posdef, sqrth @@ -82,10 +81,10 @@ def test_sqrth_shape(batch_shape): def test_ravel_pytree_batched(pytree, nbatch_dims): flat, _, unravel_fn = batch_ravel_pytree(pytree, nbatch_dims) unravel = unravel_fn(flat) - jax.tree.flatten(jax.tree.map(lambda x, y: assert_allclose(x, y), unravel, pytree)) + tree.flatten(tree.map(lambda x, y: assert_allclose(x, y), unravel, pytree)) assert all( - jax.tree.flatten( - jax.tree.map( + tree.flatten( + tree.map( lambda x, y: jnp.result_type(x) == jnp.result_type(y), unravel, pytree ) )[0]