Skip to content

Commit

Permalink
Add initial AutoGuideList implementation
Browse files Browse the repository at this point in the history
Use Pyro implementation as an example to
- implement numpyro.infer.autoguide.AutoGuideList
- modify numpyro.handlers.block
  • Loading branch information
tare committed Sep 22, 2023
1 parent ca96eca commit 7a16de2
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 8 deletions.
96 changes: 88 additions & 8 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"""

from collections import OrderedDict
from functools import partial
import warnings

import numpy as np
Expand Down Expand Up @@ -224,17 +225,83 @@ def process_message(self, msg):
msg["infer"] = guide_msg["infer"].copy()


def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg):
if msg.get("type") == "sample" and msg.get("is_observed"):
msg_type = "observe"
else:
msg_type = msg.get("type")

is_not_exposed = (msg.get("name") not in expose) and (msg_type not in expose_types)

if (
(msg.get("name") in hide)
or (msg_type in hide_types)
or (is_not_exposed and hide_all)
):
return True
else:
return False


def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose_types):
assert (hide_all is False and expose_all is False) or (
hide_all != expose_all
), "cannot hide and expose a site"

if hide is None:
hide = []
else:
hide_all = False

if expose is None:
expose = []
else:
hide_all = True

assert set(hide).isdisjoint(set(expose)), "cannot hide and expose a site"

if hide_types is None:
hide_types = []
else:
hide_all = False

if expose_types is None:
expose_types = []
else:
hide_all = True

assert set(hide_types).isdisjoint(
set(expose_types)
), "cannot hide and expose a site type"

return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all)


class block(Messenger):
"""
Given a callable `fn`, return another callable that selectively hides
primitive sites where `hide_fn` returns True from other effect handlers
on the stack.
primitive sites.
A site is hidden if at least one of the following holds:
0. ``hide_fn(msg) is True`` or ``(not expose_fn(msg)) is True``
1. ``msg["name"] in hide``
2. ``msg["type"] in hide_types``
3. ``msg["name"] not in expose and msg["type"] not in expose_types``
4. ``hide``, ``hide_types``, and ``expose_types`` are all ``None``
:param callable fn: Python callable with NumPyro primitives.
:param callable hide_fn: function which when given a dictionary containing
site-level metadata returns whether it should be blocked.
:param list hide: list of site names to hide.
:param list expose_types: list of site types to expose, e.g. `['param']`.
:param callable expose_fn: function which when given a dictionary containing
site-level metadata returns whether it should be exposed.
:param bool hide_all: whether to hide all sites.
:param bool expose_all: whether to expose all sites.
:param list expose: list of site names to hide.
:param list hide_types: list of site types to hide, e.g. `['param']`.
:returns: Python callable with NumPyro primitives.
**Example:**
Expand All @@ -259,15 +326,28 @@ class block(Messenger):
>>> assert 'b' in trace_block_a
"""

def __init__(self, fn=None, hide_fn=None, hide=None, expose_types=None):
def __init__(
self,
fn=None,
hide_fn=None,
hide=None,
expose_types=None,
expose_fn=None,
hide_all=True,
expose_all=False,
expose=None,
hide_types=None,
):
if not (hide_fn is None or expose_fn is None):
raise ValueError("Only specify one of hide_fn or expose_fn")
if hide_fn is not None:
self.hide_fn = hide_fn
elif hide is not None:
self.hide_fn = lambda msg: msg.get("name") in hide
elif expose_types is not None:
self.hide_fn = lambda msg: msg.get("type") not in expose_types
elif expose_fn is not None:
self.hide_fn = lambda msg: not expose_fn(msg.get("name"))
else:
self.hide_fn = lambda msg: True
self.hide_fn = _make_default_hide_fn(
hide_all, expose_all, hide, expose, hide_types, expose_types
)
super(block, self).__init__(fn)

def process_message(self, msg):
Expand Down
70 changes: 70 additions & 0 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,76 @@ def quantiles(self, params, quantiles):
raise NotImplementedError


class AutoGuideList(AutoGuide):
"""
Container class to combine multiple automatic guides.
Example usage::
rng_key_init = random.PRNGKey(0)
guide = AutoGuideList(my_model)
seeded_model = numpyro.handlers.seed(model, rng_seed=1)
guide.append(AutoNormal(numpyro.handlers.block(seeded_model, hide=["coefs"])))
guide.append(AutoDelta(numpyro.handlers.block(seeded_model, expose=["coefs"])))
svi = SVI(model, guide, optim, Trace_ELBO())
svi_state = svi.init(rng_key_init, data, labels)
params = svi.get_params(svi_state)
:param callable model: a NumPyro model
"""

def __init__(
self, model, *, prefix="auto", init_loc_fn=init_to_uniform, create_plates=None
):
self._guides = []
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates
)

def append(self, part):
"""
Add an automatic or custom guide for part of the model. The guide should
have been created by blocking the model to restrict to a subset of
sample sites. No two parts should operate on any one sample site.
:param part: a partial guide to add
:type part: AutoGuide
"""
self._guides.append(part)

def __call__(self, *args, **kwargs):
if self.prototype_trace is None:
# run model to inspect the model structure
self._setup_prototype(*args, **kwargs)

# create all plates
self._create_plates(*args, **kwargs)

# run slave guides
result = {}
for part in self._guides:
result.update(part(*args, **kwargs))
return result

def sample_posterior(self, rng_key, params, sample_shape=()):
result = {}
for part in self._guides:
result.update(part.sample_posterior(rng_key, params, sample_shape))
return result

def median(self, *args, **kwargs):
result = {}
for part in self._guides:
result.update(part.median(*args, **kwargs))
return result

def quantiles(self, quantiles, *args, **kwargs):
result = {}
for part in self._guides:
result.update(part.quantiles(quantiles, *args, **kwargs))
return result


class AutoNormal(AutoGuide):
"""
This implementation of :class:`AutoGuide` uses Normal distributions
Expand Down

0 comments on commit 7a16de2

Please sign in to comment.