From 66921bdde3902dc0acecbc04db4d7c1129da8fcd Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 15 Nov 2024 14:27:09 +0100 Subject: [PATCH] Add types for contrib.stochastic_support and improve the HSGP module (#1907) * init * better return hints * type improvements * better hints for jax arrays * initializaze ell_ * undo change * fix condition * OMG stupid typo #facepalm! --- numpyro/contrib/hsgp/approximation.py | 50 ++++++------ numpyro/contrib/hsgp/laplacian.py | 65 +++++++-------- numpyro/contrib/hsgp/spectral_densities.py | 29 ++++--- numpyro/contrib/hsgp/util.py | 10 --- numpyro/contrib/stochastic_support/dcc.py | 94 ++++++++++++++++------ numpyro/contrib/stochastic_support/sdvi.py | 49 +++++++---- pyproject.toml | 1 + test/contrib/hsgp/test_approximation.py | 32 ++++---- test/contrib/hsgp/test_laplacian.py | 12 ++- 9 files changed, 200 insertions(+), 142 deletions(-) delete mode 100644 numpyro/contrib/hsgp/util.py diff --git a/numpyro/contrib/hsgp/approximation.py b/numpyro/contrib/hsgp/approximation.py index 955c6181a..11488ad10 100644 --- a/numpyro/contrib/hsgp/approximation.py +++ b/numpyro/contrib/hsgp/approximation.py @@ -7,9 +7,9 @@ from __future__ import annotations -from jaxlib.xla_extension import ArrayImpl - +from jax import Array import jax.numpy as jnp +from jax.typing import ArrayLike import numpyro from numpyro.contrib.hsgp.laplacian import eigenfunctions, eigenfunctions_periodic @@ -22,8 +22,8 @@ def _non_centered_approximation( - phi: ArrayImpl, spd: ArrayImpl, m: int | list[int] -) -> ArrayImpl: + phi: ArrayLike, spd: ArrayLike, m: int | list[int] +) -> Array: with numpyro.plate("basis", m): beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0)) @@ -31,8 +31,8 @@ def _non_centered_approximation( def _centered_approximation( - phi: ArrayImpl, spd: ArrayImpl, m: int | list[int] -) -> ArrayImpl: + phi: ArrayLike, spd: ArrayLike, m: int | list[int] +) -> Array: with numpyro.plate("basis", m): beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd)) @@ -40,8 +40,8 @@ def _centered_approximation( def linear_approximation( - phi: ArrayImpl, spd: ArrayImpl, m: int | list[int], non_centered: bool = True -) -> ArrayImpl: + phi: ArrayLike, spd: ArrayLike, m: int | list[int], non_centered: bool = True +) -> Array: """ Linear approximation formula of the Hilbert space Gaussian process. @@ -52,13 +52,13 @@ def linear_approximation( 1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). - :param ArrayImpl phi: laplacian eigenfunctions - :param ArrayImpl spd: square root of the diagonal of the spectral density evaluated at square + :param ArrayLike phi: laplacian eigenfunctions + :param ArrayLike spd: square root of the diagonal of the spectral density evaluated at square root of the first `m` eigenvalues. :param int | list[int] m: number of eigenfunctions in the approximation :param bool non_centered: whether to use a non-centered parameterization :return: The low-rank approximation linear model - :rtype: ArrayImpl + :rtype: Array """ if non_centered: return _non_centered_approximation(phi, spd, m) @@ -66,13 +66,13 @@ def linear_approximation( def hsgp_squared_exponential( - x: ArrayImpl, + x: ArrayLike, alpha: float, length: float, ell: float | int | list[float | int], m: int | list[int], non_centered: bool = True, -) -> ArrayImpl: +) -> Array: """ Hilbert space Gaussian process approximation using the squared exponential kernel. @@ -88,7 +88,7 @@ def hsgp_squared_exponential( 2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). - :param ArrayImpl x: input data + :param ArrayLike x: input data :param float alpha: amplitude of the squared exponential kernel :param float length: length scale of the squared exponential kernel :param float | int | list[float | int] ell: positive value that parametrizes the length of the D-dimensional box so @@ -99,9 +99,9 @@ def hsgp_squared_exponential( If an integer, the same number of eigenvalues is computed in each dimension. :param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True :return: the low-rank approximation linear model - :rtype: ArrayImpl + :rtype: Array """ - dim = x.shape[-1] if x.ndim > 1 else 1 + dim = jnp.shape(x)[-1] if jnp.ndim(x) > 1 else 1 phi = eigenfunctions(x=x, ell=ell, m=m) spd = jnp.sqrt( diag_spectral_density_squared_exponential( @@ -114,14 +114,14 @@ def hsgp_squared_exponential( def hsgp_matern( - x: ArrayImpl, + x: ArrayLike, nu: float, alpha: float, length: float, ell: float | int | list[float | int], m: int | list[int], non_centered: bool = True, -): +) -> Array: """ Hilbert space Gaussian process approximation using the Matérn kernel. @@ -137,7 +137,7 @@ def hsgp_matern( 2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). - :param ArrayImpl x: input data + :param ArrayLike x: input data :param float nu: smoothness parameter :param float alpha: amplitude of the squared exponential kernel :param float length: length scale of the squared exponential kernel @@ -149,9 +149,9 @@ def hsgp_matern( If an integer, the same number of eigenvalues is computed in each dimension. :param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True. :return: the low-rank approximation linear model - :rtype: ArrayImpl + :rtype: Array """ - dim = x.shape[-1] if x.ndim > 1 else 1 + dim = jnp.shape(x)[-1] if jnp.ndim(x) > 1 else 1 phi = eigenfunctions(x=x, ell=ell, m=m) spd = jnp.sqrt( diag_spectral_density_matern( @@ -164,8 +164,8 @@ def hsgp_matern( def hsgp_periodic_non_centered( - x: ArrayImpl, alpha: float, length: float, w0: float, m: int -) -> ArrayImpl: + x: ArrayLike, alpha: float, length: float, w0: float, m: int +) -> Array: """ Low rank approximation for the periodic squared exponential kernel in the non-centered parametrization. @@ -176,13 +176,13 @@ def hsgp_periodic_non_centered( 1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). - :param ArrayImpl x: input data + :param ArrayLike x: input data :param float alpha: amplitude :param float length: length scale :param float w0: frequency of the periodic kernel :param int m: number of eigenvalues to compute and include in the approximation :return: the low-rank approximation linear model - :rtype: ArrayImpl + :rtype: Array """ q2 = diag_spectral_density_periodic(alpha=alpha, length=length, m=m) cosines, sines = eigenfunctions_periodic(x=x, w0=w0, m=m) diff --git a/numpyro/contrib/hsgp/laplacian.py b/numpyro/contrib/hsgp/laplacian.py index 16c1d6c31..ae0cbbd6b 100644 --- a/numpyro/contrib/hsgp/laplacian.py +++ b/numpyro/contrib/hsgp/laplacian.py @@ -7,16 +7,14 @@ from __future__ import annotations -from typing import get_args - -from jaxlib.xla_extension import ArrayImpl +import numpy as np +from jax import Array import jax.numpy as jnp - -from numpyro.contrib.hsgp.util import ARRAY_TYPE +from jax.typing import ArrayLike -def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl: +def eigenindices(m: list[int] | int, dim: int) -> Array: """Returns the indices of the first :math:`D \\times m^\\star` eigenvalues of the laplacian operator. .. math:: @@ -35,7 +33,7 @@ def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl: :param int dim: The dimension of the space. :returns: An array of the indices of the first :math:`D \\times m^\\star` eigenvalues. - :rtype: ArrayImpl + :rtype: Array **Examples:** @@ -78,8 +76,8 @@ def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl: def sqrt_eigenvalues( - ell: int | float | list[int | float], m: list[int] | int, dim: int -) -> ArrayImpl: + ell: ArrayLike | list[int | float], m: list[int] | int, dim: int +) -> Array: """ The first :math:`m^\\star \\times D` square root of eigenvalues of the laplacian operator in :math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]`. See Eq. (56) in [1]. @@ -96,16 +94,14 @@ def sqrt_eigenvalues( :param int dim: The dimension of the space. :returns: An array of the first :math:`m^\\star \\times D` square root of eigenvalues. - :rtype: ArrayImpl + :rtype: Array """ ell_ = _convert_ell(ell, dim) S = eigenindices(m, dim) return S * jnp.pi / 2 / ell_ # dim x prod(m) array of eigenvalues -def eigenfunctions( - x: ArrayImpl, ell: float | list[float], m: int | list[int] -) -> ArrayImpl: +def eigenfunctions(x: ArrayLike, ell: float | list[float], m: int | list[int]) -> Array: """ The first :math:`m^\\star` eigenfunctions of the laplacian operator in :math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]` @@ -141,7 +137,7 @@ def eigenfunctions( 1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020) - :param ArrayImpl x: The points at which to evaluate the eigenfunctions. + :param ArrayLike x: The points at which to evaluate the eigenfunctions. If `x` is 1D the problem is assumed unidimensional. Otherwise, the dimension of the input space is inferred as the last dimension of `x`. Other dimensions are treated as batch dimensions. @@ -150,25 +146,27 @@ def eigenfunctions( :param int | list[int] m: The number of eigenvalues to compute in each dimension. If an integer, the same number of eigenvalues is computed in each dimension. :returns: An array of the first :math:`m^\\star \\times D` eigenfunctions evaluated at `x`. - :rtype: ArrayImpl + :rtype: Array """ - if x.ndim == 1: - x_ = x[..., None] + if jnp.ndim(x) == 1: + x_ = jnp.expand_dims(x, axis=-1) else: - x_ = x - dim = x_.shape[-1] # others assumed batch dims - n_batch_dims = x_.ndim - 1 + x_ = jnp.array(x) + dim = jnp.shape(x_)[-1] # others assumed batch dims + n_batch_dims = jnp.ndim(x_) - 1 ell_ = _convert_ell(ell, dim) a = jnp.expand_dims(ell_, tuple(range(n_batch_dims))) b = jnp.expand_dims(sqrt_eigenvalues(ell_, m, dim), tuple(range(n_batch_dims))) - return jnp.prod(jnp.sqrt(1 / a) * jnp.sin(b * (x_[..., None] + a)), axis=-2) + return jnp.prod( + jnp.sqrt(1 / a) * jnp.sin(b * (jnp.expand_dims(x_, axis=-1) + a)), axis=-2 + ) -def eigenfunctions_periodic(x: ArrayImpl, w0: float, m: int): +def eigenfunctions_periodic(x: ArrayLike, w0: float, m: int) -> tuple[Array, Array]: """ Basis functions for the approximation of the periodic kernel. - :param ArrayImpl x: The points at which to evaluate the eigenfunctions. + :param ArrayLike x: The points at which to evaluate the eigenfunctions. :param float w0: The frequency of the periodic kernel. :param int m: The number of eigenfunctions to compute. @@ -178,11 +176,11 @@ def eigenfunctions_periodic(x: ArrayImpl, w0: float, m: int): .. warning:: Multidimensional inputs are not supported. """ - if x.ndim > 1: + if jnp.ndim(x) > 1: raise ValueError( "Multidimensional inputs are not supported by the periodic kernel." ) - m1 = jnp.tile(w0 * x[:, None], m) + m1 = jnp.tile(w0 * jnp.expand_dims(x, axis=-1), m) m2 = jnp.diag(jnp.arange(m, dtype=jnp.float32)) mw0x = m1 @ m2 cosines = jnp.cos(mw0x) @@ -190,31 +188,30 @@ def eigenfunctions_periodic(x: ArrayImpl, w0: float, m: int): return cosines, sines -def _convert_ell( - ell: float | int | list[float | int] | ArrayImpl, dim: int -) -> ArrayImpl: +def _convert_ell(ell: float | int | list[float | int] | ArrayLike, dim: int) -> Array: """ Process the half-length of the approximation interval and return a `D \\times 1` array. If `ell` is a scalar, it is converted to a list of length dim, then transformed into an Array. - :param float | int | list[float | int] | ArrayImpl ell: The length of the interval in each dimension divided by 2. + :param float | int | list[float | int] | ArrayLike ell: The length of the interval in each dimension divided by 2. If a float or int, the same length is used in each dimension. :param int dim: The dimension of the space. :returns: A `D \\times 1` array of the half-lengths of the approximation interval. - :rtype: ArrayImpl + :rtype: Array """ + ell_ = jnp.empty((dim, 1)) if isinstance(ell, float) | isinstance(ell, int): - ell = [ell] * dim + ell = jnp.array([ell] * dim)[..., None] if isinstance(ell, list): if len(ell) != dim: raise ValueError( "The length of ell must be equal to the dimension of the space." ) ell_ = jnp.array(ell)[..., None] # dim x 1 array - elif isinstance(ell, get_args(ARRAY_TYPE)): - ell_ = ell - if ell_.shape != (dim, 1): + elif isinstance(ell, Array) | isinstance(ell, np.ndarray): + ell_ = jnp.array(ell) + if jnp.shape(ell_) != (dim, 1): raise ValueError("ell must be a scalar or a list of length `dim`.") return ell_ diff --git a/numpyro/contrib/hsgp/spectral_densities.py b/numpyro/contrib/hsgp/spectral_densities.py index 118ff2401..383a57a61 100644 --- a/numpyro/contrib/hsgp/spectral_densities.py +++ b/numpyro/contrib/hsgp/spectral_densities.py @@ -7,11 +7,10 @@ from __future__ import annotations -from jaxlib.xla_extension import ArrayImpl - -from jax import vmap +from jax import Array, vmap import jax.numpy as jnp from jax.scipy import special +from jax.typing import ArrayLike from numpyro.contrib.hsgp.laplacian import sqrt_eigenvalues @@ -21,8 +20,8 @@ def align_param(dim, param): def spectral_density_squared_exponential( - dim: int, w: ArrayImpl, alpha: float, length: float | ArrayImpl -) -> ArrayImpl: + dim: int, w: ArrayLike, alpha: float, length: float | ArrayLike +) -> Array: """ Spectral density of the squared exponential kernel. @@ -42,11 +41,11 @@ def spectral_density_squared_exponential( approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). :param int dim: dimension - :param ArrayImpl w: frequency + :param ArrayLike w: frequency :param float alpha: amplitude :param float length: length scale :return: spectral density value - :rtype: ArrayImpl + :rtype: Array """ length = align_param(dim, length) c = alpha * jnp.prod(jnp.sqrt(2 * jnp.pi) * length, axis=-1) @@ -55,7 +54,7 @@ def spectral_density_squared_exponential( def spectral_density_matern( - dim: int, nu: float, w: ArrayImpl, alpha: float, length: float | ArrayImpl + dim: int, nu: float, w: ArrayLike, alpha: float, length: float | ArrayLike ) -> float: """ Spectral density of the Matérn kernel. @@ -78,7 +77,7 @@ def spectral_density_matern( :param int dim: dimension :param float nu: smoothness - :param ArrayImpl w: frequency + :param ArrayLike w: frequency :param float alpha: amplitude :param float length: length scale :return: spectral density value @@ -104,7 +103,7 @@ def diag_spectral_density_squared_exponential( ell: float | int | list[float | int], m: int | list[int], dim: int, -) -> ArrayImpl: +) -> Array: """ Evaluates the spectral density of the squared exponential kernel at the first :math:`D \\times m^\\star` square root eigenvalues of the laplacian operator in :math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]`. @@ -118,7 +117,7 @@ def diag_spectral_density_squared_exponential( :param int dim: The dimension of the space :return: spectral density vector evaluated at the first :math:`D \\times m^\\star` square root eigenvalues - :rtype: ArrayImpl + :rtype: Array """ def _spectral_density(w): @@ -138,7 +137,7 @@ def diag_spectral_density_matern( ell: float | int | list[float | int], m: int | list[int], dim: int, -) -> ArrayImpl: +) -> Array: """ Evaluates the spectral density of the Matérn kernel at the first :math:`D \\times m^\\star` square root eigenvalues of the laplacian operator in :math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]`. @@ -153,7 +152,7 @@ def diag_spectral_density_matern( :param int dim: The dimension of the space :return: spectral density vector evaluated at the first :math:`D \\times m^\\star` square root eigenvalues - :rtype: ArrayImpl + :rtype: Array """ def _spectral_density(w): @@ -176,7 +175,7 @@ def modified_bessel_first_kind(v, z): return jnp.exp(jnp.abs(z)) * tfp.math.bessel_ive(v, z) -def diag_spectral_density_periodic(alpha: float, length: float, m: int) -> ArrayImpl: +def diag_spectral_density_periodic(alpha: float, length: float, m: int) -> Array: """ Not actually a spectral density but these are used in the same way. These are simply the first `m` coefficients of the low rank @@ -191,7 +190,7 @@ def diag_spectral_density_periodic(alpha: float, length: float, m: int) -> Array :param float length: length scale :param int m: number of eigenvalues :return: "spectral density" vector - :rtype: ArrayImpl + :rtype: Array """ a = length ** (-2) j = jnp.arange(0, m) diff --git a/numpyro/contrib/hsgp/util.py b/numpyro/contrib/hsgp/util.py deleted file mode 100644 index 5afbcfb83..000000000 --- a/numpyro/contrib/hsgp/util.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -import numpy as np - -import jax - -ARRAY_TYPE = Union[jax.Array, np.ndarray] # jax.Array covers tracers diff --git a/numpyro/contrib/stochastic_support/dcc.py b/numpyro/contrib/stochastic_support/dcc.py index 13a4d5ce6..912ed46f2 100644 --- a/numpyro/contrib/stochastic_support/dcc.py +++ b/numpyro/contrib/stochastic_support/dcc.py @@ -3,19 +3,28 @@ from abc import ABC, abstractmethod from collections import OrderedDict, namedtuple +from typing import Any, Callable, OrderedDict as OrderedDictType import jax from jax import random import jax.numpy as jnp +from jax.typing import ArrayLike import numpyro.distributions as dist from numpyro.handlers import condition, seed, trace from numpyro.infer import MCMC, NUTS -from numpyro.infer.autoguide import AutoNormal +from numpyro.infer.autoguide import AutoGuide, AutoNormal +from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import init_to_value, log_density DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"]) +SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"]) + +RunInferenceResult = ( + dict[str, Any] | tuple[AutoGuide, dict[str, Any]] +) # for mcmc or sdvi + class StochasticSupportInference(ABC): """ @@ -46,12 +55,14 @@ def model(): on more than `max_slps`. """ - def __init__(self, model, num_slp_samples, max_slps): - self.model = model - self.num_slp_samples = num_slp_samples - self.max_slps = max_slps + def __init__(self, model: Callable, num_slp_samples: int, max_slps: int) -> None: + self.model: Callable = model + self.num_slp_samples: int = num_slp_samples + self.max_slps: int = max_slps - def _find_slps(self, rng_key, *args, **kwargs): + def _find_slps( + self, rng_key: ArrayLike, *args: Any, **kwargs: Any + ) -> dict[str, OrderedDictType]: """ Discover the straight-line programs (SLPs) in the model by sampling from the prior. This implementation assumes that all branching is done via discrete sampling sites @@ -70,13 +81,17 @@ def _find_slps(self, rng_key, *args, **kwargs): return branching_traces - def _get_branching_trace(self, tr): + def _get_branching_trace(self, tr: dict[str, Any]) -> OrderedDictType: """ Extract the sites from the trace that are annotated with `infer={"branching": True}`. """ branching_trace = OrderedDict() for site in tr.values(): - if site["type"] == "sample" and site["infer"].get("branching", False): + if ( + site["type"] == "sample" + and site["infer"].get("branching", False) + and site["fn"].support is not None + ): if ( not isinstance(site["fn"], dist.Distribution) or not site["fn"].support.is_discrete @@ -92,16 +107,29 @@ def _get_branching_trace(self, tr): return branching_trace @abstractmethod - def _run_inference(self, rng_key, branching_trace, *args, **kwargs): + def _run_inference( + self, + rng_key: ArrayLike, + branching_trace: OrderedDictType, + *args: Any, + **kwargs: Any, + ) -> RunInferenceResult: raise NotImplementedError @abstractmethod def _combine_inferences( - self, rng_key, inferences, branching_traces, *args, **kwargs - ): + self, + rng_key: ArrayLike, + inferences: dict[str, Any], + branching_traces: dict[str, OrderedDictType], + *args: Any, + **kwargs: Any, + ) -> DCCResult | SDVIResult: raise NotImplementedError - def run(self, rng_key, *args, **kwargs): + def run( + self, rng_key: ArrayLike, *args: Any, **kwargs: Any + ) -> DCCResult | SDVIResult: """ Run inference on each SLP separately and combine the results. @@ -165,13 +193,13 @@ def model(): def __init__( self, - model, - mcmc_kwargs, - kernel_cls=NUTS, - num_slp_samples=1000, - max_slps=124, - proposal_scale=1.0, - ): + model: Callable, + mcmc_kwargs: dict[str, Any], + kernel_cls: type[MCMCKernel] = NUTS, + num_slp_samples: int = 1_000, + max_slps: int = 124, + proposal_scale: float = 1.0, + ) -> None: self.kernel_cls = kernel_cls self.mcmc_kwargs = mcmc_kwargs @@ -179,18 +207,31 @@ def __init__( super().__init__(model, num_slp_samples, max_slps) - def _run_inference(self, rng_key, branching_trace, *args, **kwargs): + def _run_inference( + self, + rng_key: ArrayLike, + branching_trace: OrderedDictType, + *args: Any, + **kwargs: Any, + ) -> RunInferenceResult: """ Run MCMC on the model conditioned on the given branching trace. """ slp_model = condition(self.model, data=branching_trace) - kernel = self.kernel_cls(slp_model) + kernel = self.kernel_cls(slp_model) # type: ignore[call-arg] mcmc = MCMC(kernel, **self.mcmc_kwargs) mcmc.run(rng_key, *args, **kwargs) return mcmc.get_samples() - def _combine_inferences(self, rng_key, samples, branching_traces, *args, **kwargs): + def _combine_inferences( # type: ignore[override] + self, + rng_key: ArrayLike, + samples: dict[str, Any], + branching_traces: dict[str, OrderedDictType], + *args: Any, + **kwargs: Any, + ) -> DCCResult: """ Weight each SLP proportional to its estimated normalization constant. The normalization constants are estimated using importance sampling with @@ -202,7 +243,12 @@ def _combine_inferences(self, rng_key, samples, branching_traces, *args, **kwarg Luca Martino, Victor Elvira, David Luengo, and Jukka Corander. """ - def log_weight(rng_key, i, slp_model, slp_samples): + def log_weight( + rng_key: ArrayLike, + i: int, + slp_model: Callable, + slp_samples: dict[str, Any], + ) -> float: trace = {k: v[i] for k, v in slp_samples.items()} guide = AutoNormal( slp_model, @@ -215,7 +261,7 @@ def log_weight(rng_key, i, slp_model, slp_samples): model_log_density, _ = log_density(slp_model, args, kwargs, guide_trace) return model_log_density - guide_log_density - log_weights = jax.vmap(log_weight, in_axes=(None, 0, None, None)) + log_weights: Callable = jax.vmap(log_weight, in_axes=(None, 0, None, None)) log_Zs = {} for bt, slp_samples in samples.items(): diff --git a/numpyro/contrib/stochastic_support/sdvi.py b/numpyro/contrib/stochastic_support/sdvi.py index 82c0f2d01..37abbed86 100644 --- a/numpyro/contrib/stochastic_support/sdvi.py +++ b/numpyro/contrib/stochastic_support/sdvi.py @@ -1,23 +1,27 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import namedtuple +from typing import Any, Callable, OrderedDict as OrderedDictType import jax import jax.numpy as jnp +from jax.typing import ArrayLike -from numpyro.contrib.stochastic_support.dcc import StochasticSupportInference +from numpyro.contrib.stochastic_support.dcc import ( + RunInferenceResult, + SDVIResult, + StochasticSupportInference, +) from numpyro.handlers import condition from numpyro.infer import ( + ELBO, SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO, ) -from numpyro.infer.autoguide import AutoNormal - -SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"]) +from numpyro.infer.autoguide import AutoGuide, AutoNormal VALID_ELBOS = (Trace_ELBO, TraceMeanField_ELBO, TraceEnum_ELBO, TraceGraph_ELBO) @@ -69,16 +73,16 @@ def model(): def __init__( self, - model, + model: Callable, optimizer, - svi_num_steps=1000, - combine_elbo_particles=1000, - guide_init=AutoNormal, - loss=Trace_ELBO(), - svi_progress_bar=False, - num_slp_samples=1000, - max_slps=124, - ): + svi_num_steps: int = 1_000, + combine_elbo_particles: int = 1_000, + guide_init: Callable = AutoNormal, + loss: ELBO = Trace_ELBO(), + svi_progress_bar: bool = False, + num_slp_samples: int = 1_000, + max_slps: int = 124, + ) -> None: self.guide_init = guide_init self.optimizer = optimizer self.svi_num_steps = svi_num_steps @@ -92,7 +96,13 @@ def __init__( super().__init__(model, num_slp_samples, max_slps) - def _run_inference(self, rng_key, branching_trace, *args, **kwargs): + def _run_inference( + self, + rng_key: ArrayLike, + branching_trace: OrderedDictType, + *args: Any, + **kwargs: Any, + ) -> RunInferenceResult: """ Run SVI on a given SLP defined by its branching trace. """ @@ -108,7 +118,14 @@ def _run_inference(self, rng_key, branching_trace, *args, **kwargs): ) return guide, svi_result.params - def _combine_inferences(self, rng_key, guides, branching_traces, *args, **kwargs): + def _combine_inferences( # type: ignore[override] + self, + rng_key: ArrayLike, + guides: dict[str, tuple[AutoGuide, dict[str, Any]]], + branching_traces: dict[str, OrderedDictType], + *args: Any, + **kwargs: Any, + ) -> SDVIResult: """Weight each SLP proportional to its estimated ELBO.""" elbos = {} for bt, (guide, param_map) in guides.items(): diff --git a/pyproject.toml b/pyproject.toml index 847758329..e6859c122 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,5 +112,6 @@ module = [ "numpyro.contrib.control_flow.*", # types missing "numpyro.contrib.funsor.*", # types missing "numpyro.contrib.hsgp.*", + "numpyro.contrib.stochastic_support.*", ] ignore_errors = false diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index 79ec1dd88..dea2c9383 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -11,9 +11,9 @@ import pytest from sklearn.gaussian_process.kernels import RBF, ExpSineSquared, Matern -from jax import random -from jax._src.array import ArrayImpl +from jax import Array, random import jax.numpy as jnp +from jax.typing import ArrayLike import numpyro from numpyro.contrib.hsgp.approximation import ( @@ -32,8 +32,8 @@ def generate_synthetic_one_dim_data( - rng_key, start, stop, num, scale -) -> tuple[ArrayImpl, ArrayImpl]: + rng_key: ArrayLike, start: float, stop: float, num: int, scale: float +) -> tuple[Array, Array]: x = jnp.linspace(start=start, stop=stop, num=num) y = jnp.sin(4 * jnp.pi * x) + jnp.sin(7 * jnp.pi * x) y_obs = y + scale * random.normal(rng_key, shape=(num,)) @@ -41,7 +41,7 @@ def generate_synthetic_one_dim_data( @pytest.fixture -def synthetic_one_dim_data() -> tuple[ArrayImpl, ArrayImpl]: +def synthetic_one_dim_data() -> tuple[Array, Array]: kwargs = { "rng_key": random.PRNGKey(0), "start": -0.2, @@ -53,8 +53,8 @@ def synthetic_one_dim_data() -> tuple[ArrayImpl, ArrayImpl]: def generate_synthetic_two_dim_data( - rng_key, start, stop, num, scale -) -> tuple[ArrayImpl, ArrayImpl]: + rng_key: ArrayLike, start: float, stop: float, num: int, scale: float +) -> tuple[Array, Array]: x = random.uniform(rng_key, shape=(num, 2), minval=start, maxval=stop) y = jnp.sin(4 * jnp.pi * x[:, 0]) + jnp.sin(7 * jnp.pi * x[:, 1]) y_obs = y + scale * random.normal(rng_key, shape=(num, num)) @@ -62,7 +62,7 @@ def generate_synthetic_two_dim_data( @pytest.fixture -def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]: +def synthetic_two_dim_data() -> tuple[Array, Array]: kwargs = { "rng_key": random.PRNGKey(0), "start": -0.2, @@ -117,9 +117,9 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]: ], ) def test_kernel_approx_squared_exponential( - x1: ArrayImpl, - x2: ArrayImpl, - length: Union[float, ArrayImpl], + x1: ArrayLike, + x2: ArrayLike, + length: Union[float, ArrayLike], ell: float, xfail: bool, ): @@ -201,7 +201,7 @@ def _exact_rbf(length): ], ) def test_kernel_approx_squared_matern( - x1: ArrayImpl, x2: ArrayImpl, nu: float, length: ArrayImpl, ell: float + x1: ArrayLike, x2: ArrayLike, nu: float, length: ArrayLike, ell: float ): """ensure that the approximation of the matern kernel is accurate, matching the exact kernel implementation from sklearn. @@ -243,8 +243,8 @@ def _exact_matern(length): ], ) def test_kernel_approx_periodic( - x1: ArrayImpl, - x2: ArrayImpl, + x1: ArrayLike, + x2: ArrayLike, w0: float, length: float, ): @@ -281,7 +281,7 @@ def test_kernel_approx_periodic( ids=["non_centered", "centered", "non_centered-large-domain", "non_centered-2d"], ) def test_approximation_squared_exponential( - x: ArrayImpl, + x: ArrayLike, alpha: float, length: float, ell: Union[int, float, list[Union[int, float]]], @@ -332,7 +332,7 @@ def model(x, alpha, length, ell, m, non_centered): ids=["non_centered", "centered", "non_centered-large-domain", "non_centered-2d"], ) def test_approximation_matern( - x: ArrayImpl, + x: ArrayLike, nu: float, alpha: float, length: float, diff --git a/test/contrib/hsgp/test_laplacian.py b/test/contrib/hsgp/test_laplacian.py index 2749a2c7d..978947a43 100644 --- a/test/contrib/hsgp/test_laplacian.py +++ b/test/contrib/hsgp/test_laplacian.py @@ -9,8 +9,8 @@ import numpy as np import pytest -from jax._src.array import ArrayImpl import jax.numpy as jnp +from jax.typing import ArrayLike from numpyro.contrib.hsgp.laplacian import ( _convert_ell, @@ -110,7 +110,7 @@ def test_sqrt_eigenvalues(ell: float | int, m: int | list[int], dim: int): ], ids=["x_pos", "x_contains_zero", "x_neg2", "x_pos2-large", "x_2d", "x_batch"], ) -def test_eigenfunctions(x: ArrayImpl, ell: float | int, m: int | list[int]): +def test_eigenfunctions(x: ArrayLike, ell: float | int, m: int | list[int]): phi = eigenfunctions(x=x, ell=ell, m=m) if isinstance(m, int): m = [m] @@ -131,8 +131,12 @@ def test_eigenfunctions(x: ArrayImpl, ell: float | int, m: int | list[int]): (1, 2, False), ([1, 1], 2, False), (np.array([1, 1])[..., None], 2, False), + (jnp.array([1, 1])[..., None], 2, False), (np.array([1, 1]), 2, True), + (jnp.array([1, 1]), 2, True), ([1, 1], 1, True), + (np.array([1, 1]), 1, True), + (jnp.array([1, 1]), 1, True), ], ids=[ "ell-float", @@ -140,8 +144,12 @@ def test_eigenfunctions(x: ArrayImpl, ell: float | int, m: int | list[int]): "ell-int-multdim", "ell-list", "ell-array", + "ell-jax-array", "ell-array-fail", + "ell-jax-array-fail", "dim-fail", + "dim-fail-array", + "dim-fail-jax", ], ) def test_convert_ell(ell, dim, xfail):