Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add types for contrib.stochastic_support and improve the HSGP module #1907

Merged
merged 8 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 73 additions & 24 deletions numpyro/contrib/stochastic_support/dcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from abc import ABC, abstractmethod
from collections import OrderedDict, namedtuple
from typing import Any, Callable, OrderedDict as OrderedDictType

from jaxlib.xla_extension import ArrayImpl
Copy link
Contributor

@tillahoffmann tillahoffmann Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what the implications of using ArrayImpl vs jax.numpy.ndarray are?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great question! Thanks for the feedback! So we started trying with type-hints with ArrayImpl in #1819 (comment)

On the other hand in https://jax.readthedocs.io/en/latest/jax.typing.html its suggested to use jax.typing.ArrayLike. However when I use this for x: ArrayLike and have inside the function x.ndim, mypy complains with

error: Item "builtins.bool" of "Array | ndarray[Any, Any] | numpy.bool | number[Any] | builtins.bool | int | float | complex" has no attribute "ndim"  [union-attr]

Hence, this is why I am continuing with ArrayImpl and waiting for feedback :)

Copy link
Member

@fehiepsi fehiepsi Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to avoid importing stuff from jaxlib. I guess it is better to use jax.Array (or jnp.ndarray)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re ndim: you can use jnp.ndim(x), which will accept ArrayLike instances, including python scalars.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great input! I'll try that then 🙂!


import jax
from jax import random
Expand All @@ -11,11 +14,18 @@
import numpyro.distributions as dist
from numpyro.handlers import condition, seed, trace
from numpyro.infer import MCMC, NUTS
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.autoguide import AutoGuide, AutoNormal
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import init_to_value, log_density

DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"])

SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"])

RunInferenceResult = (
dict[str, Any] | tuple[AutoGuide, dict[str, Any]]
) # for mcmc or sdvi


class StochasticSupportInference(ABC):
"""
Expand Down Expand Up @@ -46,12 +56,14 @@ def model():
on more than `max_slps`.
"""

def __init__(self, model, num_slp_samples, max_slps):
self.model = model
self.num_slp_samples = num_slp_samples
self.max_slps = max_slps
def __init__(self, model: Callable, num_slp_samples: int, max_slps: int) -> None:
self.model: Callable = model
self.num_slp_samples: int = num_slp_samples
self.max_slps: int = max_slps

def _find_slps(self, rng_key, *args, **kwargs):
def _find_slps(
self, rng_key: ArrayImpl, *args: Any, **kwargs: Any
) -> dict[str, OrderedDictType]:
"""
Discover the straight-line programs (SLPs) in the model by sampling from the prior.
This implementation assumes that all branching is done via discrete sampling sites
Expand All @@ -70,13 +82,17 @@ def _find_slps(self, rng_key, *args, **kwargs):

return branching_traces

def _get_branching_trace(self, tr):
def _get_branching_trace(self, tr: dict[str, Any]) -> OrderedDictType:
"""
Extract the sites from the trace that are annotated with `infer={"branching": True}`.
"""
branching_trace = OrderedDict()
for site in tr.values():
if site["type"] == "sample" and site["infer"].get("branching", False):
if (
site["type"] == "sample"
and site["infer"].get("branching", False)
and site["fn"].support is not None
):
if (
not isinstance(site["fn"], dist.Distribution)
or not site["fn"].support.is_discrete
Expand All @@ -92,16 +108,29 @@ def _get_branching_trace(self, tr):
return branching_trace

@abstractmethod
def _run_inference(self, rng_key, branching_trace, *args, **kwargs):
def _run_inference(
self,
rng_key: ArrayImpl,
branching_trace: OrderedDictType,
*args: Any,
**kwargs: Any,
) -> RunInferenceResult:
raise NotImplementedError

@abstractmethod
def _combine_inferences(
self, rng_key, inferences, branching_traces, *args, **kwargs
):
self,
rng_key: ArrayImpl,
inferences: dict[str, Any],
branching_traces: dict[str, OrderedDictType],
*args: Any,
**kwargs: Any,
) -> DCCResult | SDVIResult:
raise NotImplementedError

def run(self, rng_key, *args, **kwargs):
def run(
self, rng_key: ArrayImpl, *args: Any, **kwargs: Any
) -> DCCResult | SDVIResult:
"""
Run inference on each SLP separately and combine the results.

Expand Down Expand Up @@ -165,32 +194,45 @@ def model():

def __init__(
self,
model,
mcmc_kwargs,
kernel_cls=NUTS,
num_slp_samples=1000,
max_slps=124,
proposal_scale=1.0,
):
model: Callable,
mcmc_kwargs: dict[str, Any],
kernel_cls: type[MCMCKernel] = NUTS,
num_slp_samples: int = 1_000,
max_slps: int = 124,
proposal_scale: float = 1.0,
) -> None:
self.kernel_cls = kernel_cls
self.mcmc_kwargs = mcmc_kwargs

self.proposal_scale = proposal_scale

super().__init__(model, num_slp_samples, max_slps)

def _run_inference(self, rng_key, branching_trace, *args, **kwargs):
def _run_inference(
self,
rng_key: ArrayImpl,
branching_trace: OrderedDictType,
*args: Any,
**kwargs: Any,
) -> RunInferenceResult:
"""
Run MCMC on the model conditioned on the given branching trace.
"""
slp_model = condition(self.model, data=branching_trace)
kernel = self.kernel_cls(slp_model)
kernel = self.kernel_cls(slp_model) # type: ignore[call-arg]
mcmc = MCMC(kernel, **self.mcmc_kwargs)
mcmc.run(rng_key, *args, **kwargs)

return mcmc.get_samples()

def _combine_inferences(self, rng_key, samples, branching_traces, *args, **kwargs):
def _combine_inferences( # type: ignore[override]
self,
rng_key: ArrayImpl,
samples: dict[str, Any],
branching_traces: dict[str, OrderedDictType],
*args: Any,
**kwargs: Any,
) -> DCCResult:
"""
Weight each SLP proportional to its estimated normalization constant.
The normalization constants are estimated using importance sampling with
Expand All @@ -202,7 +244,12 @@ def _combine_inferences(self, rng_key, samples, branching_traces, *args, **kwarg
Luca Martino, Victor Elvira, David Luengo, and Jukka Corander.
"""

def log_weight(rng_key, i, slp_model, slp_samples):
def log_weight(
rng_key: ArrayImpl,
i: int,
slp_model: Callable,
slp_samples: dict[str, Any],
) -> float:
trace = {k: v[i] for k, v in slp_samples.items()}
guide = AutoNormal(
slp_model,
Expand All @@ -215,7 +262,9 @@ def log_weight(rng_key, i, slp_model, slp_samples):
model_log_density, _ = log_density(slp_model, args, kwargs, guide_trace)
return model_log_density - guide_log_density

log_weights = jax.vmap(log_weight, in_axes=(None, 0, None, None))
log_weights: Callable[
[ArrayImpl, ArrayImpl, Callable, dict[str, Any]], ArrayImpl
] = jax.vmap(log_weight, in_axes=(None, 0, None, None))

log_Zs = {}
for bt, slp_samples in samples.items():
Expand Down
48 changes: 33 additions & 15 deletions numpyro/contrib/stochastic_support/sdvi.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from typing import Any, Callable, OrderedDict as OrderedDictType

from jaxlib.xla_extension import ArrayImpl

import jax
import jax.numpy as jnp

from numpyro.contrib.stochastic_support.dcc import StochasticSupportInference
from numpyro.contrib.stochastic_support.dcc import (
RunInferenceResult,
SDVIResult,
StochasticSupportInference,
)
from numpyro.handlers import condition
from numpyro.infer import (
ELBO,
SVI,
Trace_ELBO,
TraceEnum_ELBO,
TraceGraph_ELBO,
TraceMeanField_ELBO,
)
from numpyro.infer.autoguide import AutoNormal

SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"])
from numpyro.infer.autoguide import AutoGuide, AutoNormal

VALID_ELBOS = (Trace_ELBO, TraceMeanField_ELBO, TraceEnum_ELBO, TraceGraph_ELBO)

Expand Down Expand Up @@ -69,15 +74,15 @@ def model():

def __init__(
self,
model,
model: Callable,
optimizer,
svi_num_steps=1000,
combine_elbo_particles=1000,
guide_init=AutoNormal,
loss=Trace_ELBO(),
svi_progress_bar=False,
num_slp_samples=1000,
max_slps=124,
svi_num_steps: int = 1_000,
combine_elbo_particles: int = 1_000,
guide_init: Callable = AutoNormal,
loss: ELBO = Trace_ELBO(),
svi_progress_bar: bool = False,
num_slp_samples: int = 1_000,
max_slps: int = 124,
):
self.guide_init = guide_init
self.optimizer = optimizer
Expand All @@ -92,7 +97,13 @@ def __init__(

super().__init__(model, num_slp_samples, max_slps)

def _run_inference(self, rng_key, branching_trace, *args, **kwargs):
def _run_inference(
self,
rng_key: ArrayImpl,
branching_trace: OrderedDictType,
*args: Any,
**kwargs: Any,
) -> RunInferenceResult:
"""
Run SVI on a given SLP defined by its branching trace.
"""
Expand All @@ -108,7 +119,14 @@ def _run_inference(self, rng_key, branching_trace, *args, **kwargs):
)
return guide, svi_result.params

def _combine_inferences(self, rng_key, guides, branching_traces, *args, **kwargs):
def _combine_inferences( # type: ignore[override]
self,
rng_key: ArrayImpl,
guides: dict[str, tuple[AutoGuide, dict[str, Any]]],
branching_traces: dict[str, OrderedDictType],
*args: Any,
**kwargs: Any,
) -> SDVIResult:
"""Weight each SLP proportional to its estimated ELBO."""
elbos = {}
for bt, (guide, param_map) in guides.items():
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,6 @@ module = [
"numpyro.contrib.control_flow.*", # types missing
"numpyro.contrib.funsor.*", # types missing
"numpyro.contrib.hsgp.*",
"numpyro.contrib.stochastic_support.*",
]
ignore_errors = false
Loading