Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AutoGuideList #1644

Merged
merged 9 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,18 @@ def process_message(self, msg):
class block(Messenger):
"""
Given a callable `fn`, return another callable that selectively hides
primitive sites where `hide_fn` returns True from other effect handlers
on the stack.
primitive sites from other effect handlers on the stack. In the absence
of parameters, all primitive sites are blocked. `hide_fn` takes precedence
over `hide`, which has higher priority than `expose_types` followed by `expose`.
Only the parameter with the precedence is considered.

:param callable fn: Python callable with NumPyro primitives.
:param callable hide_fn: function which when given a dictionary containing
site-level metadata returns whether it should be blocked.
:param list hide: list of site names to hide.
:param list expose_types: list of site types to expose, e.g. `['param']`.
:param list expose: list of site names to expose.
:returns: Python callable with NumPyro primitives.

**Example:**

Expand All @@ -259,13 +263,22 @@ class block(Messenger):
>>> assert 'b' in trace_block_a
"""

def __init__(self, fn=None, hide_fn=None, hide=None, expose_types=None):
def __init__(
self,
fn=None,
hide_fn=None,
hide=None,
expose_types=None,
expose=None,
):
if hide_fn is not None:
self.hide_fn = hide_fn
elif hide is not None:
self.hide_fn = lambda msg: msg.get("name") in hide
elif expose_types is not None:
self.hide_fn = lambda msg: msg.get("type") not in expose_types
elif expose is not None:
self.hide_fn = lambda msg: msg.get("name") not in expose
else:
self.hide_fn = lambda msg: True
super(block, self).__init__(fn)
Expand Down
97 changes: 97 additions & 0 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,103 @@ def quantiles(self, params, quantiles):
raise NotImplementedError


class AutoGuideList(AutoGuide):
"""
Container class to combine multiple automatic guides.

Example usage::

rng_key_init = random.PRNGKey(0)
guide = AutoGuideList(my_model)
guide.append(
AutoNormal(
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=0), hide=["coefs"])
)
)
guide.append(
AutoDelta(
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), expose=["coefs"])
)
)
svi = SVI(model, guide, optim, Trace_ELBO())
svi_state = svi.init(rng_key_init, data, labels)
params = svi.get_params(svi_state)

:param callable model: a NumPyro model
"""

def __init__(
self, model, *, prefix="auto", init_loc_fn=init_to_uniform, create_plates=None
):
self._guides = []
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates
)

def append(self, part):
"""
Add an automatic or custom guide for part of the model. The guide should
have been created by blocking the model to restrict to a subset of
sample sites. No two parts should operate on any one sample site.

:param part: a partial guide to add
:type part: AutoGuide
"""
self._guides.append(part)

def __call__(self, *args, **kwargs):
if self.prototype_trace is None:
# run model to inspect the model structure
self._setup_prototype(*args, **kwargs)

# create all plates
self._create_plates(*args, **kwargs)

# run slave guides
result = {}
for part in self._guides:
result.update(part(*args, **kwargs))
return result

def __getitem__(self, key):
return self._guides[key]

def __len__(self):
return len(self._guides)

def __iter__(self):
yield from self._guides

def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
result = {}
for part in self._guides:
if isinstance(part, numpyro.infer.autoguide.AutoDelta) or isinstance(
part, numpyro.infer.autoguide.AutoSemiDAIS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably you can raise ValueError in the constructor if the guide is either AutoDAIS or AutoSemiDAIS. I'm not sure if the algorithm gives reasonable results under AutoGuideList.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also AutoSurrogateLikelihoodDAIS

):
result.update(
part.sample_posterior(
rng_key, params, *args, sample_shape=sample_shape, **kwargs
)
)
else:
result.update(
part.sample_posterior(rng_key, params, sample_shape=sample_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you usage signature of other algorithm to be the same as AutoDelta? That signature is more general I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce the chance of misunderstanding, can you please clarify what you mean.

The reason for the if-else clause is that AutoDelta and AutoSemiDAIS differ from the other autoguides when it comes to the signature of sample_posterior():

  • AutoDelta and AutoSemiDAIS have sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs)
  • other autoguides have sample_posterior(self, rng_key, params, sample_shape=())

I suppose we could break things if we change the signature of sample_posterior() since sample_shape might have been passed as a positional argument.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since sample_shape might have been passed as a positional argument

Good point! Maybe we can keep your current implementation here as-is and adding *, before sample_shape in other classes. After the next release, we'll make the signature consistent. Users of the next release will keep informed of the change because they are no longer able to use sample_shape as a positional argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable to me! I'm happy to create another PR to address this if you want.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thank you!

)
return result

def median(self, params):
result = {}
for part in self._guides:
result.update(part.median(params))
return result

def quantiles(self, params, quantiles):
result = {}
for part in self._guides:
result.update(part.quantiles(params, quantiles))
return result


class AutoNormal(AutoGuide):
"""
This implementation of :class:`AutoGuide` uses Normal distributions
Expand Down
170 changes: 163 additions & 7 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
AutoDAIS,
AutoDelta,
AutoDiagonalNormal,
AutoGuideList,
AutoIAFNormal,
AutoLaplaceApproximation,
AutoLowRankMultivariateNormal,
Expand Down Expand Up @@ -62,6 +63,7 @@
AutoLowRankMultivariateNormal,
AutoNormal,
AutoDelta,
AutoGuideList,
],
)
def test_beta_bernoulli(auto_class):
Expand All @@ -76,6 +78,9 @@ def model(data):
adam = optim.Adam(0.01)
if auto_class == AutoDAIS:
guide = auto_class(model, init_loc_fn=init_strategy, base_dist="cholesky")
elif auto_class == AutoGuideList:
guide = AutoGuideList(model)
guide.append(AutoNormal(handlers.block(model, hide=[])))
else:
guide = auto_class(model, init_loc_fn=init_strategy)
svi = SVI(model, guide, adam, Trace_ELBO())
Expand All @@ -96,9 +101,14 @@ def body_fn(i, val):
posterior_mean = jnp.mean(posterior_samples["beta"], 0)
assert_allclose(posterior_mean, true_coefs, atol=0.05)

if auto_class not in [AutoDAIS, AutoDelta, AutoIAFNormal, AutoBNAFNormal]:
quantiles = guide.quantiles(params, [0.2, 0.5, 0.8])
assert quantiles["beta"].shape == (3, 2)
# assume AutoGuideList does not have AutoDAIS, AutoDelta, AutoIAFNormal, or AutoBNAFNormal
if auto_class not in (AutoDAIS, AutoIAFNormal, AutoBNAFNormal):
median = guide.median(params)
assert median["beta"].shape == (2,)
# test .quantile method
if auto_class is not AutoDelta:
quantiles = guide.quantiles(params, [0.2, 0.5, 0.8])
assert quantiles["beta"].shape == (3, 2)

# Predictive can be instantiated from posterior samples...
predictive = Predictive(model, posterior_samples=posterior_samples)
Expand All @@ -123,6 +133,7 @@ def body_fn(i, val):
AutoLowRankMultivariateNormal,
AutoNormal,
AutoDelta,
AutoGuideList,
],
)
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceMeanField_ELBO])
Expand All @@ -142,12 +153,19 @@ def model(data=None, labels=None):

adam = optim.Adam(0.01)
rng_key_init = random.PRNGKey(1)
guide = auto_class(model, init_loc_fn=init_strategy)
if auto_class == AutoGuideList:
guide = AutoGuideList(model)
guide.append(AutoNormal(handlers.block(model, hide=[])))
else:
guide = auto_class(model, init_loc_fn=init_strategy)
svi = SVI(model, guide, adam, Elbo())
svi_state = svi.init(rng_key_init, data, labels)

# smoke test if analytic KL is used
if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO:
# assume that AutoGuideList has AutoNormal
if (
auto_class is AutoNormal or auto_class is AutoGuideList
) and Elbo is TraceMeanField_ELBO:
_, mean_field_loss = svi.update(svi_state, data, labels)
svi.loss = Trace_ELBO()
_, elbo_loss = svi.update(svi_state, data, labels)
Expand All @@ -160,13 +178,14 @@ def body_fn(i, val):

svi_state = fori_loop(0, 2000, body_fn, svi_state)
params = svi.get_params(svi_state)
# assume that AutoGuideList does not have AutoDAIS, AutoIAFNormal, or AutoBNAFNormal
if auto_class not in (AutoDAIS, AutoIAFNormal, AutoBNAFNormal):
median = guide.median(params)
assert_allclose(median["coefs"], true_coefs, rtol=0.1)
# test .quantile method
if auto_class is not AutoDelta:
median = guide.quantiles(params, [0.2, 0.5])
assert_allclose(median["coefs"][1], true_coefs, rtol=0.1)
quantiles = guide.quantiles(params, [0.2, 0.5])
assert_allclose(quantiles["coefs"][1], true_coefs, rtol=0.1)
# test .sample_posterior method
posterior_samples = guide.sample_posterior(
random.PRNGKey(1), params, sample_shape=(1000,)
Expand Down Expand Up @@ -997,3 +1016,140 @@ def model():
)
assert guide_samples["x"].shape == sample_shape + shape
assert guide_samples["x2"].shape == sample_shape + shape


@pytest.mark.parametrize(
"auto_classes",
[
(AutoNormal, AutoDiagonalNormal),
(AutoNormal, AutoIAFNormal),
(AutoNormal, AutoBNAFNormal),
(AutoNormal, AutoMultivariateNormal),
(AutoNormal, AutoLaplaceApproximation),
(AutoNormal, AutoLowRankMultivariateNormal),
(AutoNormal, AutoNormal),
(AutoNormal, AutoDelta),
],
)
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceMeanField_ELBO])
def test_autoguidelist(auto_classes, Elbo):
sigma = 0.05

def model(x, y=None):
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", dist.Normal(jnp.ones((2, 1)), 1))
mu = a + x @ b
with numpyro.plate("N", len(x), dim=-2):
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

N = 500
a = 1
b = jnp.asarray([[-0.5], [-1]])
x = random.normal(random.PRNGKey(0), (N, 2))
y = a + x @ b + sigma * random.normal(random.PRNGKey(1), (N, 1))

guide = AutoGuideList(model)
guide.append(
auto_classes[0](
numpyro.handlers.block(
numpyro.handlers.seed(model, rng_seed=0), expose=["a"]
)
)
)
guide.append(
auto_classes[1](
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), hide=["a"])
)
)

optimiser = numpyro.optim.Adam(step_size=0.1)
svi = SVI(model, guide, optimiser, Elbo())

svi_result = svi.run(random.PRNGKey(0), num_steps=500, x=x, y=y)
params = svi_result.params
posterior_samples = guide.sample_posterior(
random.PRNGKey(0), params, x=x, sample_shape=(1_000,)
)

assert posterior_samples["a"].shape == (1_000,)
assert posterior_samples["b"].shape == (1_000, 2, 1)

assert_allclose(jnp.mean(posterior_samples["a"], 0), a, atol=0.05)
assert_allclose(jnp.mean(posterior_samples["b"], 0), b, atol=0.05)

# Predictive can be instantiated from posterior samples...
predictive = Predictive(model, posterior_samples=posterior_samples)
predictive_samples = predictive(random.PRNGKey(1), x)
assert predictive_samples["y"].shape == (1_000, N, 1)

# ... or from the guide + params
predictive = Predictive(model, guide=guide, params=params, num_samples=1_000)
predictive_samples = predictive(random.PRNGKey(1), x)
assert predictive_samples["y"].shape == (1_000, N, 1)

# median and quantiles from guide
if any(
auto_class in [AutoDAIS, AutoDelta, AutoIAFNormal, AutoBNAFNormal]
for auto_class in auto_classes
):
with pytest.raises(NotImplementedError):
quantiles = guide.quantiles(params=params, quantiles=[0.2, 0.5, 0.8])
else:
quantiles = guide.quantiles(params=params, quantiles=[0.2, 0.5, 0.8])
assert quantiles["a"].shape == (3,)
assert quantiles["b"].shape == (3, 2, 1)

# median and quantiles from partial guides
for auto_class, part in zip(auto_classes, guide):
if auto_class not in (AutoDAIS, AutoIAFNormal, AutoBNAFNormal):
part.median(params=params)
if auto_class not in [AutoDAIS, AutoDelta, AutoIAFNormal, AutoBNAFNormal]:
part.quantiles(params=params, quantiles=[0.2, 0.5, 0.8])


@pytest.mark.parametrize(
"auto_class",
[
AutoDiagonalNormal,
AutoDAIS,
AutoIAFNormal,
AutoBNAFNormal,
AutoMultivariateNormal,
AutoLaplaceApproximation,
AutoLowRankMultivariateNormal,
AutoNormal,
AutoDelta,
],
)
@pytest.mark.parametrize("shape", [(), (1,), (2, 3)])
@pytest.mark.parametrize("sample_shape", [(), (1,), (2, 3)])
def test_autoguidelist_sample_posterior_with_sample_shape(
auto_class, shape, sample_shape
):
def model():
x = numpyro.sample("x", dist.Normal().expand(shape))
numpyro.deterministic("x2", x**2)

guide = AutoGuideList(model)
guide.append(auto_class(model))
svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
if auto_class in (AutoIAFNormal, AutoBNAFNormal) and max(shape, default=0) <= 1:
with pytest.raises(
ValueError,
match="latent dim = 1. Consider using AutoDiagonalNormal instead",
):
svi_result = svi.run(random.PRNGKey(0), num_steps=1_000)
guide_samples = guide.sample_posterior(
rng_key=random.PRNGKey(1),
params=svi_result.params,
sample_shape=sample_shape,
)
else:
svi_result = svi.run(random.PRNGKey(0), num_steps=1_000)
guide_samples = guide.sample_posterior(
rng_key=random.PRNGKey(1),
params=svi_result.params,
sample_shape=sample_shape,
)
assert guide_samples["x"].shape == sample_shape + shape
assert guide_samples["x2"].shape == sample_shape + shape
Loading
Loading