Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring SteinVI #1883

Merged
merged 11 commits into from
Oct 19, 2024
7 changes: 3 additions & 4 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,20 @@ The framework currently supports several kernels, including:
- `RandomFeatureKernel`
- `MixtureKernel`
- `GraphicalKernel`
- `ProbabilityProductKernel`

For example, usage see:

- The `Bayesian neural network example <https://num.pyro.ai/en/latest/examples/stein_bnn.html>`_
- The `Bayesian neural network example <https://num.pyro.ai/en/latest/examples/stein_bnn.html>`_.

**References**

1. *Stein's Method Meets Statistics: A Review of Some Recent Developments* (2021)
Andreas Anastasiou, Alessandro Barp, François-Xavier Briol, Bruno Ebner,
Robert E. Gaunt, Fatemeh Ghaderinezhad, Jackson Gorham, Arthur Gretton,
Christophe Ley, Qiang Liu, Lester Mackey, Chris. J. Oates, Gesine Reinert,
Yvik Swan. https://arxiv.org/abs/2105.03481
Yvik Swan.

2. *Stein variational gradient descent: A general-purpose Bayesian inference algorithm* (2016)
2. *Stein Variational Gradient Descent: A General-Purpose Bayesian Inference Algorithm* (2016)
Qiang Liu, Dilin Wang. NeurIPS

3. *Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models* (2019)
Expand Down
87 changes: 37 additions & 50 deletions examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@
import numpy as np
from sklearn.model_selection import train_test_split

import jax
from jax import random
import jax.numpy as jnp

import numpyro
from numpyro import deterministic
from numpyro.contrib.einstein import IMQKernel, SteinVI
from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive
from jax import config, nn, numpy as jnp, random

from numpyro import deterministic, plate, sample, set_platform, subsample
from numpyro.contrib.einstein import MixtureGuidePredictive, RBFKernel, SteinVI
from numpyro.distributions import Gamma, Normal
from numpyro.examples.datasets import BOSTON_HOUSING, load_dataset
from numpyro.infer import init_to_uniform
Expand Down Expand Up @@ -54,23 +50,23 @@ def normalize(val, mean=None, std=None):
return (val - mean) / std, mean, std


def model(x, y=None, hidden_dim=50, subsample_size=100):
def model(x, y=None, hidden_dim=50, sub_size=100):
"""BNN described in section 5 of [1].

**References:**
1. *Stein variational gradient descent: A general purpose bayesian inference algorithm*
1. *Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm*
Qiang Liu and Dilin Wang (2016).
"""

prec_nn = numpyro.sample(
prec_nn = sample(
"prec_nn", Gamma(1.0, 0.1)
) # hyper prior for precision of nn weights and biases

n, m = x.shape

with numpyro.plate("l1_hidden", hidden_dim, dim=-1):
with plate("l1_hidden", hidden_dim, dim=-1):
# prior l1 bias term
b1 = numpyro.sample(
b1 = sample(
"nn_b1",
Normal(
0.0,
Expand All @@ -79,38 +75,33 @@ def model(x, y=None, hidden_dim=50, subsample_size=100):
)
assert b1.shape == (hidden_dim,)

with numpyro.plate("l1_feat", m, dim=-2):
w1 = numpyro.sample(
with plate("l1_feat", m, dim=-2):
w1 = sample(
"nn_w1", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
) # prior on l1 weights
assert w1.shape == (m, hidden_dim)

with numpyro.plate("l2_hidden", hidden_dim, dim=-1):
w2 = numpyro.sample(
with plate("l2_hidden", hidden_dim, dim=-1):
w2 = sample(
"nn_w2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
) # prior on output weights

b2 = numpyro.sample(
b2 = sample(
"nn_b2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
) # prior on output bias term

# precision prior on observations
prec_obs = numpyro.sample("prec_obs", Gamma(1.0, 0.1))
with numpyro.plate(
"data",
x.shape[0],
subsample_size=subsample_size,
dim=-1,
):
batch_x = numpyro.subsample(x, event_dim=1)
prec_obs = sample("prec_obs", Gamma(1.0, 0.1))
with plate("data", x.shape[0], subsample_size=sub_size, dim=-1):
batch_x = subsample(x, event_dim=1)
if y is not None:
batch_y = numpyro.subsample(y, event_dim=0)
batch_y = subsample(y, event_dim=0)
else:
batch_y = y

loc_y = deterministic("y_pred", jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2)
loc_y = deterministic("y_pred", nn.relu(batch_x @ w1 + b1) @ w2 + b2)

numpyro.sample(
sample(
"y",
Normal(
loc_y, 1.0 / jnp.sqrt(prec_obs)
Expand All @@ -123,34 +114,33 @@ def main(args):
data = load_data()

inf_key, pred_key, data_key = random.split(random.PRNGKey(args.rng_key), 3)
# normalize data and labels to zero mean unit variance!
# Normalize features to zero mean unit variance.
x, xtr_mean, xtr_std = normalize(data.xtr)
y, ytr_mean, ytr_std = normalize(data.ytr)

rng_key, inf_key = random.split(inf_key)

# We find that SteinVI benefits from a small radius when inferring BNNs.
guide = AutoNormal(model, init_loc_fn=partial(init_to_uniform, radius=0.1))

stein = SteinVI(
model,
guide,
Adagrad(0.05),
IMQKernel(),
# ProbabilityProductKernel(guide=guide, scale=1.),
Adagrad(0.5),
RBFKernel(),
repulsion_temperature=args.repulsion,
num_stein_particles=args.num_stein_particles,
num_elbo_particles=args.num_elbo_particles,
)
start = time()

# use keyword params for static (shape etc.)!
# Use keyword params for static (shape etc.)
result = stein.run(
rng_key,
args.max_iter,
x,
y,
data.ytr,
hidden_dim=args.hidden_dim,
subsample_size=args.subsample_size,
sub_size=args.subsample_size,
progress_bar=args.progress_bar,
)
time_taken = time() - start
Expand All @@ -164,39 +154,36 @@ def main(args):
)
xte, _, _ = normalize(
data.xte, xtr_mean, xtr_std
) # use train data statistics when accessing generalization
preds = pred(
pred_key, xte, subsample_size=xte.shape[0], hidden_dim=args.hidden_dim
)["y_pred"]
) # Use train data statistics when accessing generalization.
n = xte.shape[0]
y_preds = pred(pred_key, xte, sub_size=n, hidden_dim=args.hidden_dim)["y_pred"]

y_pred = preds * ytr_std + ytr_mean
rmse = jnp.sqrt(jnp.mean((y_pred.mean(0) - data.yte) ** 2))
mean_pred = y_preds.mean(0)
rmse = jnp.sqrt(jnp.mean((mean_pred - data.yte) ** 2))

print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}")
print(rf"RMSE: {rmse:.2f}")

# compute mean prediction and confidence interval around median
mean_prediction = y_pred.mean(0)

ran = np.arange(mean_prediction.shape[0])
percentiles = np.percentile(preds * ytr_std + ytr_mean, [5.0, 95.0], axis=0)
percentiles = jnp.percentile(y_preds, jnp.array([5.0, 95.0]), axis=0)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
ran = np.arange(mean_pred.shape[0])
ax.add_collection(
LineCollection(
zip(zip(ran, percentiles[0]), zip(ran, percentiles[1])), colors="lightblue"
)
)
ax.plot(data.yte, "kx", label="y true")
ax.plot(mean_prediction, "ko", label="y pred")
ax.plot(mean_pred, "ko", label="y pred")
ax.set(xlabel="example", ylabel="y", title="Mean predictions with 90% CI")
ax.legend()
fig.savefig("stein_bnn.pdf")


if __name__ == "__main__":
jax.config.update("jax_debug_nans", True)
config.update("jax_debug_nans", True)

parser = argparse.ArgumentParser()
parser.add_argument("--subsample-size", type=int, default=100)
Expand All @@ -212,6 +199,6 @@ def main(args):

args = parser.parse_args()

numpyro.set_platform(args.device)
set_platform(args.device)

main(args)
16 changes: 9 additions & 7 deletions numpyro/contrib/einstein/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
RBFKernel,
)
from numpyro.contrib.einstein.stein_loss import SteinLoss
from numpyro.contrib.einstein.steinvi import SteinVI
from numpyro.contrib.einstein.steinvi import ASVGD, SVGD, SteinVI

__all__ = [
"SteinVI",
"SteinLoss",
"RBFKernel",
"ASVGD",
"GraphicalKernel",
"IMQKernel",
"LinearKernel",
"RandomFeatureKernel",
"GraphicalKernel",
"MixtureGuidePredictive",
"MixtureKernel",
"RandomFeatureKernel",
"RBFKernel",
"ProbabilityProductKernel",
"MixtureGuidePredictive",
"SVGD",
"SteinVI",
"SteinLoss",
]
22 changes: 18 additions & 4 deletions numpyro/contrib/einstein/mixture_guide_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,27 @@
from collections.abc import Callable, Sequence
from functools import partial
from typing import Optional
import warnings

import jax
from jax import numpy as jnp, random, vmap
from jax.tree_util import tree_flatten, tree_map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is deprecated. You can use jax.tree.flatten and jax.tree.map now.


from numpyro.handlers import substitute
from numpyro.infer import Predictive
from numpyro.infer.autoguide import AutoGuide
from numpyro.infer.util import _predictive
from numpyro.util import find_stack_level


class MixtureGuidePredictive:
"""(EXPERIMENTAL INTERFACE) This class constructs the predictive distribution for
:class:`numpyro.contrib.einstein.steinvi.SteinVi`
:class:`numpyro.contrib.einstein.steinvi.SteinVi`.

.. Note:: For single mixture component use numpyro.infer.Predictive.

.. Note:: For :class:`numpyro.contrib.einstein.steinvi.SVGD` and :class:`numpyro.contrib.einstein.steinvi.ASVGD` use
:class:`numpyro.infer.util.Predictive`.

.. warning::
The `MixtureGuidePredictive` is experimental and will likely be replaced by
:class:`numpyro.infer.util.Predictive` in the future.
Expand All @@ -44,6 +50,14 @@ def __init__(
return_sites: Optional[Sequence[str]] = None,
mixture_assignment_sitename="mixture_assignments",
):
if isinstance(guide, AutoGuide):
guide_name = guide.__class__.__name__
if guide_name == "AutoDelta":
warnings.warn(
"Use numpyro.inter.Predictive with `batch_ndims=1` for ASVGD and SVGD.",
stacklevel=find_stack_level(),
)

self.model_predictive = partial(
Predictive,
model=model,
Expand All @@ -63,7 +77,7 @@ def __init__(

self.guide = guide
self.return_sites = return_sites
self.num_mixture_components = jnp.shape(jax.tree.flatten(params)[0][0])[0]
self.num_mixture_components = jnp.shape(tree_flatten(params)[0][0])[0]
self.mixture_assignment_sitename = mixture_assignment_sitename

def _call_with_params(self, rng_key, params, args, kwargs):
Expand Down Expand Up @@ -99,7 +113,7 @@ def __call__(self, rng_key, *args, **kwargs):
minval=0,
maxval=self.num_mixture_components,
)
predictive_assign = jax.tree.map(
predictive_assign = tree_map(
lambda arr: vmap(lambda i, assign: arr[i, assign])(
jnp.arange(self._batch_shape[0]), assigns
),
Expand Down
Loading
Loading