Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AutoGuideList #1644

Merged
merged 9 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 88 additions & 8 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"""

from collections import OrderedDict
from functools import partial
import warnings

import numpy as np
Expand Down Expand Up @@ -224,17 +225,83 @@ def process_message(self, msg):
msg["infer"] = guide_msg["infer"].copy()


def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg):
if msg.get("type") == "sample" and msg.get("is_observed"):
msg_type = "observe"
else:
msg_type = msg.get("type")

is_not_exposed = (msg.get("name") not in expose) and (msg_type not in expose_types)

if (
(msg.get("name") in hide)
or (msg_type in hide_types)
or (is_not_exposed and hide_all)
):
return True
else:
return False


def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose_types):
assert (hide_all is False and expose_all is False) or (
hide_all != expose_all
), "cannot hide and expose a site"

if hide is None:
hide = []
else:
hide_all = False

if expose is None:
expose = []
else:
hide_all = True

assert set(hide).isdisjoint(set(expose)), "cannot hide and expose a site"

if hide_types is None:
hide_types = []
else:
hide_all = False

if expose_types is None:
expose_types = []
else:
hide_all = True

assert set(hide_types).isdisjoint(
set(expose_types)
), "cannot hide and expose a site type"

return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all)


class block(Messenger):
"""
Given a callable `fn`, return another callable that selectively hides
primitive sites where `hide_fn` returns True from other effect handlers
on the stack.
primitive sites.

A site is hidden if at least one of the following holds:

0. ``hide_fn(msg) is True`` or ``(not expose_fn(msg)) is True``
1. ``msg["name"] in hide``
2. ``msg["type"] in hide_types``
3. ``msg["name"] not in expose and msg["type"] not in expose_types``
4. ``hide``, ``hide_types``, and ``expose_types`` are all ``None``

:param callable fn: Python callable with NumPyro primitives.
:param callable hide_fn: function which when given a dictionary containing
site-level metadata returns whether it should be blocked.
:param list hide: list of site names to hide.
:param list expose_types: list of site types to expose, e.g. `['param']`.
:param callable expose_fn: function which when given a dictionary containing
site-level metadata returns whether it should be exposed.
:param bool hide_all: whether to hide all sites.
:param bool expose_all: whether to expose all sites.
:param list expose: list of site names to hide.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you only introduce expose for simplicity? I'm not sure why we need other arguments.

typo: a list of site names to expose.

:param list hide_types: list of site types to hide, e.g. `['param']`.
:returns: Python callable with NumPyro primitives.

**Example:**

Expand All @@ -259,15 +326,28 @@ class block(Messenger):
>>> assert 'b' in trace_block_a
"""

def __init__(self, fn=None, hide_fn=None, hide=None, expose_types=None):
def __init__(
self,
fn=None,
hide_fn=None,
hide=None,
expose_types=None,
expose_fn=None,
hide_all=True,
expose_all=False,
expose=None,
hide_types=None,
):
if not (hide_fn is None or expose_fn is None):
raise ValueError("Only specify one of hide_fn or expose_fn")
if hide_fn is not None:
self.hide_fn = hide_fn
elif hide is not None:
self.hide_fn = lambda msg: msg.get("name") in hide
elif expose_types is not None:
self.hide_fn = lambda msg: msg.get("type") not in expose_types
elif expose_fn is not None:
self.hide_fn = lambda msg: not expose_fn(msg.get("name"))
else:
self.hide_fn = lambda msg: True
self.hide_fn = _make_default_hide_fn(
hide_all, expose_all, hide, expose, hide_types, expose_types
)
super(block, self).__init__(fn)

def process_message(self, msg):
Expand Down
70 changes: 70 additions & 0 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,76 @@ def quantiles(self, params, quantiles):
raise NotImplementedError


class AutoGuideList(AutoGuide):
"""
Container class to combine multiple automatic guides.

Example usage::

rng_key_init = random.PRNGKey(0)
guide = AutoGuideList(my_model)
seeded_model = numpyro.handlers.seed(model, rng_seed=1)
guide.append(AutoNormal(numpyro.handlers.block(seeded_model, hide=["coefs"])))
guide.append(AutoDelta(numpyro.handlers.block(seeded_model, expose=["coefs"])))
svi = SVI(model, guide, optim, Trace_ELBO())
svi_state = svi.init(rng_key_init, data, labels)
params = svi.get_params(svi_state)

:param callable model: a NumPyro model
"""

def __init__(
self, model, *, prefix="auto", init_loc_fn=init_to_uniform, create_plates=None
):
self._guides = []
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates
)

def append(self, part):
"""
Add an automatic or custom guide for part of the model. The guide should
have been created by blocking the model to restrict to a subset of
sample sites. No two parts should operate on any one sample site.

:param part: a partial guide to add
:type part: AutoGuide
"""
self._guides.append(part)

def __call__(self, *args, **kwargs):
if self.prototype_trace is None:
# run model to inspect the model structure
self._setup_prototype(*args, **kwargs)

# create all plates
self._create_plates(*args, **kwargs)

# run slave guides
result = {}
for part in self._guides:
result.update(part(*args, **kwargs))
return result

def sample_posterior(self, rng_key, params, sample_shape=()):
result = {}
for part in self._guides:
result.update(part.sample_posterior(rng_key, params, sample_shape))
return result

def median(self, *args, **kwargs):
result = {}
for part in self._guides:
result.update(part.median(*args, **kwargs))
return result

def quantiles(self, quantiles, *args, **kwargs):
result = {}
for part in self._guides:
result.update(part.quantiles(quantiles, *args, **kwargs))
return result


class AutoNormal(AutoGuide):
"""
This implementation of :class:`AutoGuide` uses Normal distributions
Expand Down