From f609a75181912b1ab025e5f267318b4b1cfda67a Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Wed, 13 Nov 2024 23:18:29 +0100 Subject: [PATCH] init --- numpyro/contrib/stochastic_support/dcc.py | 90 ++++++++++++++++------ numpyro/contrib/stochastic_support/sdvi.py | 44 +++++++---- pyproject.toml | 1 + 3 files changed, 100 insertions(+), 35 deletions(-) diff --git a/numpyro/contrib/stochastic_support/dcc.py b/numpyro/contrib/stochastic_support/dcc.py index 13a4d5ce6..5ed7c9338 100644 --- a/numpyro/contrib/stochastic_support/dcc.py +++ b/numpyro/contrib/stochastic_support/dcc.py @@ -3,6 +3,9 @@ from abc import ABC, abstractmethod from collections import OrderedDict, namedtuple +from typing import Any, Callable, OrderedDict as OrderedDictType + +from jaxlib.xla_extension import ArrayImpl import jax from jax import random @@ -11,11 +14,17 @@ import numpyro.distributions as dist from numpyro.handlers import condition, seed, trace from numpyro.infer import MCMC, NUTS -from numpyro.infer.autoguide import AutoNormal +from numpyro.infer.autoguide import AutoGuide, AutoNormal from numpyro.infer.util import init_to_value, log_density DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"]) +SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"]) + +RunInferenceResult = ( + dict[str, Any] | tuple[AutoGuide, dict[str, Any]] +) # for mcmc or sdvi + class StochasticSupportInference(ABC): """ @@ -46,12 +55,14 @@ def model(): on more than `max_slps`. """ - def __init__(self, model, num_slp_samples, max_slps): - self.model = model - self.num_slp_samples = num_slp_samples - self.max_slps = max_slps + def __init__(self, model: Callable, num_slp_samples: int, max_slps: int) -> None: + self.model: Callable = model + self.num_slp_samples: int = num_slp_samples + self.max_slps: int = max_slps - def _find_slps(self, rng_key, *args, **kwargs): + def _find_slps( + self, rng_key: ArrayImpl, *args: Any, **kwargs: Any + ) -> dict[str, OrderedDictType]: """ Discover the straight-line programs (SLPs) in the model by sampling from the prior. This implementation assumes that all branching is done via discrete sampling sites @@ -70,13 +81,17 @@ def _find_slps(self, rng_key, *args, **kwargs): return branching_traces - def _get_branching_trace(self, tr): + def _get_branching_trace(self, tr: dict[str, Any]) -> OrderedDictType: """ Extract the sites from the trace that are annotated with `infer={"branching": True}`. """ branching_trace = OrderedDict() for site in tr.values(): - if site["type"] == "sample" and site["infer"].get("branching", False): + if ( + site["type"] == "sample" + and site["infer"].get("branching", False) + and site["fn"].support is not None + ): if ( not isinstance(site["fn"], dist.Distribution) or not site["fn"].support.is_discrete @@ -92,13 +107,24 @@ def _get_branching_trace(self, tr): return branching_trace @abstractmethod - def _run_inference(self, rng_key, branching_trace, *args, **kwargs): + def _run_inference( + self, + rng_key: ArrayImpl, + branching_trace: OrderedDictType, + *args: Any, + **kwargs: Any, + ) -> RunInferenceResult: raise NotImplementedError @abstractmethod def _combine_inferences( - self, rng_key, inferences, branching_traces, *args, **kwargs - ): + self, + rng_key: ArrayImpl, + inferences: dict[str, Any], + branching_traces: dict[str, OrderedDictType], + *args: Any, + **kwargs: Any, + ) -> SDVIResult: raise NotImplementedError def run(self, rng_key, *args, **kwargs): @@ -165,13 +191,13 @@ def model(): def __init__( self, - model, - mcmc_kwargs, - kernel_cls=NUTS, - num_slp_samples=1000, - max_slps=124, - proposal_scale=1.0, - ): + model: Callable, + mcmc_kwargs: dict[str, Any], + kernel_cls: Callable = NUTS, + num_slp_samples: int = 1000, + max_slps: int = 124, + proposal_scale: float = 1.0, + ) -> None: self.kernel_cls = kernel_cls self.mcmc_kwargs = mcmc_kwargs @@ -179,7 +205,13 @@ def __init__( super().__init__(model, num_slp_samples, max_slps) - def _run_inference(self, rng_key, branching_trace, *args, **kwargs): + def _run_inference( + self, + rng_key: ArrayImpl, + branching_trace: OrderedDictType, + *args: Any, + **kwargs: Any, + ) -> RunInferenceResult: """ Run MCMC on the model conditioned on the given branching trace. """ @@ -190,7 +222,14 @@ def _run_inference(self, rng_key, branching_trace, *args, **kwargs): return mcmc.get_samples() - def _combine_inferences(self, rng_key, samples, branching_traces, *args, **kwargs): + def _combine_inferences( # type: ignore[override] + self, + rng_key: ArrayImpl, + samples: dict[str, Any], + branching_traces: dict[str, OrderedDictType], + *args: Any, + **kwargs: Any, + ) -> DCCResult: """ Weight each SLP proportional to its estimated normalization constant. The normalization constants are estimated using importance sampling with @@ -202,7 +241,12 @@ def _combine_inferences(self, rng_key, samples, branching_traces, *args, **kwarg Luca Martino, Victor Elvira, David Luengo, and Jukka Corander. """ - def log_weight(rng_key, i, slp_model, slp_samples): + def log_weight( + rng_key: ArrayImpl, + i: int, + slp_model: Callable, + slp_samples: dict[str, Any], + ) -> float: trace = {k: v[i] for k, v in slp_samples.items()} guide = AutoNormal( slp_model, @@ -215,7 +259,9 @@ def log_weight(rng_key, i, slp_model, slp_samples): model_log_density, _ = log_density(slp_model, args, kwargs, guide_trace) return model_log_density - guide_log_density - log_weights = jax.vmap(log_weight, in_axes=(None, 0, None, None)) + log_weights: Callable[ + [ArrayImpl, ArrayImpl, Callable, dict[str, Any]], ArrayImpl + ] = jax.vmap(log_weight, in_axes=(None, 0, None, None)) log_Zs = {} for bt, slp_samples in samples.items(): diff --git a/numpyro/contrib/stochastic_support/sdvi.py b/numpyro/contrib/stochastic_support/sdvi.py index 82c0f2d01..2a390964e 100644 --- a/numpyro/contrib/stochastic_support/sdvi.py +++ b/numpyro/contrib/stochastic_support/sdvi.py @@ -1,23 +1,28 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import namedtuple +from typing import Any, Callable, OrderedDict as OrderedDictType + +from jaxlib.xla_extension import ArrayImpl import jax import jax.numpy as jnp -from numpyro.contrib.stochastic_support.dcc import StochasticSupportInference +from numpyro.contrib.stochastic_support.dcc import ( + RunInferenceResult, + SDVIResult, + StochasticSupportInference, +) from numpyro.handlers import condition from numpyro.infer import ( + ELBO, SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO, ) -from numpyro.infer.autoguide import AutoNormal - -SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"]) +from numpyro.infer.autoguide import AutoGuide, AutoNormal VALID_ELBOS = (Trace_ELBO, TraceMeanField_ELBO, TraceEnum_ELBO, TraceGraph_ELBO) @@ -69,13 +74,13 @@ def model(): def __init__( self, - model, + model: Callable, optimizer, - svi_num_steps=1000, - combine_elbo_particles=1000, - guide_init=AutoNormal, - loss=Trace_ELBO(), - svi_progress_bar=False, + svi_num_steps: int = 1000, + combine_elbo_particles: int = 1000, + guide_init: Callable = AutoNormal, + loss: ELBO = Trace_ELBO(), + svi_progress_bar: bool = False, num_slp_samples=1000, max_slps=124, ): @@ -92,7 +97,13 @@ def __init__( super().__init__(model, num_slp_samples, max_slps) - def _run_inference(self, rng_key, branching_trace, *args, **kwargs): + def _run_inference( + self, + rng_key: ArrayImpl, + branching_trace: OrderedDictType, + *args: Any, + **kwargs: Any, + ) -> RunInferenceResult: """ Run SVI on a given SLP defined by its branching trace. """ @@ -108,7 +119,14 @@ def _run_inference(self, rng_key, branching_trace, *args, **kwargs): ) return guide, svi_result.params - def _combine_inferences(self, rng_key, guides, branching_traces, *args, **kwargs): + def _combine_inferences( # type: ignore[override] + self, + rng_key: ArrayImpl, + guides: dict[str, tuple[AutoGuide, dict[str, Any]]], + branching_traces: dict[str, OrderedDictType], + *args: Any, + **kwargs: Any, + ) -> SDVIResult: """Weight each SLP proportional to its estimated ELBO.""" elbos = {} for bt, (guide, param_map) in guides.items(): diff --git a/pyproject.toml b/pyproject.toml index 847758329..e6859c122 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,5 +112,6 @@ module = [ "numpyro.contrib.control_flow.*", # types missing "numpyro.contrib.funsor.*", # types missing "numpyro.contrib.hsgp.*", + "numpyro.contrib.stochastic_support.*", ] ignore_errors = false