Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Nov 13, 2024
1 parent d55d209 commit f609a75
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 35 deletions.
90 changes: 68 additions & 22 deletions numpyro/contrib/stochastic_support/dcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from abc import ABC, abstractmethod
from collections import OrderedDict, namedtuple
from typing import Any, Callable, OrderedDict as OrderedDictType

from jaxlib.xla_extension import ArrayImpl

import jax
from jax import random
Expand All @@ -11,11 +14,17 @@
import numpyro.distributions as dist
from numpyro.handlers import condition, seed, trace
from numpyro.infer import MCMC, NUTS
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.autoguide import AutoGuide, AutoNormal
from numpyro.infer.util import init_to_value, log_density

DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"])

SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"])

RunInferenceResult = (
dict[str, Any] | tuple[AutoGuide, dict[str, Any]]
) # for mcmc or sdvi


class StochasticSupportInference(ABC):
"""
Expand Down Expand Up @@ -46,12 +55,14 @@ def model():
on more than `max_slps`.
"""

def __init__(self, model, num_slp_samples, max_slps):
self.model = model
self.num_slp_samples = num_slp_samples
self.max_slps = max_slps
def __init__(self, model: Callable, num_slp_samples: int, max_slps: int) -> None:
self.model: Callable = model
self.num_slp_samples: int = num_slp_samples
self.max_slps: int = max_slps

def _find_slps(self, rng_key, *args, **kwargs):
def _find_slps(
self, rng_key: ArrayImpl, *args: Any, **kwargs: Any
) -> dict[str, OrderedDictType]:
"""
Discover the straight-line programs (SLPs) in the model by sampling from the prior.
This implementation assumes that all branching is done via discrete sampling sites
Expand All @@ -70,13 +81,17 @@ def _find_slps(self, rng_key, *args, **kwargs):

return branching_traces

def _get_branching_trace(self, tr):
def _get_branching_trace(self, tr: dict[str, Any]) -> OrderedDictType:
"""
Extract the sites from the trace that are annotated with `infer={"branching": True}`.
"""
branching_trace = OrderedDict()
for site in tr.values():
if site["type"] == "sample" and site["infer"].get("branching", False):
if (
site["type"] == "sample"
and site["infer"].get("branching", False)
and site["fn"].support is not None
):
if (
not isinstance(site["fn"], dist.Distribution)
or not site["fn"].support.is_discrete
Expand All @@ -92,13 +107,24 @@ def _get_branching_trace(self, tr):
return branching_trace

@abstractmethod
def _run_inference(self, rng_key, branching_trace, *args, **kwargs):
def _run_inference(
self,
rng_key: ArrayImpl,
branching_trace: OrderedDictType,
*args: Any,
**kwargs: Any,
) -> RunInferenceResult:
raise NotImplementedError

@abstractmethod
def _combine_inferences(
self, rng_key, inferences, branching_traces, *args, **kwargs
):
self,
rng_key: ArrayImpl,
inferences: dict[str, Any],
branching_traces: dict[str, OrderedDictType],
*args: Any,
**kwargs: Any,
) -> SDVIResult:
raise NotImplementedError

def run(self, rng_key, *args, **kwargs):
Expand Down Expand Up @@ -165,21 +191,27 @@ def model():

def __init__(
self,
model,
mcmc_kwargs,
kernel_cls=NUTS,
num_slp_samples=1000,
max_slps=124,
proposal_scale=1.0,
):
model: Callable,
mcmc_kwargs: dict[str, Any],
kernel_cls: Callable = NUTS,
num_slp_samples: int = 1000,
max_slps: int = 124,
proposal_scale: float = 1.0,
) -> None:
self.kernel_cls = kernel_cls
self.mcmc_kwargs = mcmc_kwargs

self.proposal_scale = proposal_scale

super().__init__(model, num_slp_samples, max_slps)

def _run_inference(self, rng_key, branching_trace, *args, **kwargs):
def _run_inference(
self,
rng_key: ArrayImpl,
branching_trace: OrderedDictType,
*args: Any,
**kwargs: Any,
) -> RunInferenceResult:
"""
Run MCMC on the model conditioned on the given branching trace.
"""
Expand All @@ -190,7 +222,14 @@ def _run_inference(self, rng_key, branching_trace, *args, **kwargs):

return mcmc.get_samples()

def _combine_inferences(self, rng_key, samples, branching_traces, *args, **kwargs):
def _combine_inferences( # type: ignore[override]
self,
rng_key: ArrayImpl,
samples: dict[str, Any],
branching_traces: dict[str, OrderedDictType],
*args: Any,
**kwargs: Any,
) -> DCCResult:
"""
Weight each SLP proportional to its estimated normalization constant.
The normalization constants are estimated using importance sampling with
Expand All @@ -202,7 +241,12 @@ def _combine_inferences(self, rng_key, samples, branching_traces, *args, **kwarg
Luca Martino, Victor Elvira, David Luengo, and Jukka Corander.
"""

def log_weight(rng_key, i, slp_model, slp_samples):
def log_weight(
rng_key: ArrayImpl,
i: int,
slp_model: Callable,
slp_samples: dict[str, Any],
) -> float:
trace = {k: v[i] for k, v in slp_samples.items()}
guide = AutoNormal(
slp_model,
Expand All @@ -215,7 +259,9 @@ def log_weight(rng_key, i, slp_model, slp_samples):
model_log_density, _ = log_density(slp_model, args, kwargs, guide_trace)
return model_log_density - guide_log_density

log_weights = jax.vmap(log_weight, in_axes=(None, 0, None, None))
log_weights: Callable[
[ArrayImpl, ArrayImpl, Callable, dict[str, Any]], ArrayImpl
] = jax.vmap(log_weight, in_axes=(None, 0, None, None))

log_Zs = {}
for bt, slp_samples in samples.items():
Expand Down
44 changes: 31 additions & 13 deletions numpyro/contrib/stochastic_support/sdvi.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from typing import Any, Callable, OrderedDict as OrderedDictType

from jaxlib.xla_extension import ArrayImpl

import jax
import jax.numpy as jnp

from numpyro.contrib.stochastic_support.dcc import StochasticSupportInference
from numpyro.contrib.stochastic_support.dcc import (
RunInferenceResult,
SDVIResult,
StochasticSupportInference,
)
from numpyro.handlers import condition
from numpyro.infer import (
ELBO,
SVI,
Trace_ELBO,
TraceEnum_ELBO,
TraceGraph_ELBO,
TraceMeanField_ELBO,
)
from numpyro.infer.autoguide import AutoNormal

SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"])
from numpyro.infer.autoguide import AutoGuide, AutoNormal

VALID_ELBOS = (Trace_ELBO, TraceMeanField_ELBO, TraceEnum_ELBO, TraceGraph_ELBO)

Expand Down Expand Up @@ -69,13 +74,13 @@ def model():

def __init__(
self,
model,
model: Callable,
optimizer,
svi_num_steps=1000,
combine_elbo_particles=1000,
guide_init=AutoNormal,
loss=Trace_ELBO(),
svi_progress_bar=False,
svi_num_steps: int = 1000,
combine_elbo_particles: int = 1000,
guide_init: Callable = AutoNormal,
loss: ELBO = Trace_ELBO(),
svi_progress_bar: bool = False,
num_slp_samples=1000,
max_slps=124,
):
Expand All @@ -92,7 +97,13 @@ def __init__(

super().__init__(model, num_slp_samples, max_slps)

def _run_inference(self, rng_key, branching_trace, *args, **kwargs):
def _run_inference(
self,
rng_key: ArrayImpl,
branching_trace: OrderedDictType,
*args: Any,
**kwargs: Any,
) -> RunInferenceResult:
"""
Run SVI on a given SLP defined by its branching trace.
"""
Expand All @@ -108,7 +119,14 @@ def _run_inference(self, rng_key, branching_trace, *args, **kwargs):
)
return guide, svi_result.params

def _combine_inferences(self, rng_key, guides, branching_traces, *args, **kwargs):
def _combine_inferences( # type: ignore[override]
self,
rng_key: ArrayImpl,
guides: dict[str, tuple[AutoGuide, dict[str, Any]]],
branching_traces: dict[str, OrderedDictType],
*args: Any,
**kwargs: Any,
) -> SDVIResult:
"""Weight each SLP proportional to its estimated ELBO."""
elbos = {}
for bt, (guide, param_map) in guides.items():
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,6 @@ module = [
"numpyro.contrib.control_flow.*", # types missing
"numpyro.contrib.funsor.*", # types missing
"numpyro.contrib.hsgp.*",
"numpyro.contrib.stochastic_support.*",
]
ignore_errors = false

0 comments on commit f609a75

Please sign in to comment.