diff --git a/numpyro/contrib/hsgp/approximation.py b/numpyro/contrib/hsgp/approximation.py index 11488ad10..1b71361a1 100644 --- a/numpyro/contrib/hsgp/approximation.py +++ b/numpyro/contrib/hsgp/approximation.py @@ -21,18 +21,14 @@ import numpyro.distributions as dist -def _non_centered_approximation( - phi: ArrayLike, spd: ArrayLike, m: int | list[int] -) -> Array: +def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array: with numpyro.plate("basis", m): beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0)) return phi @ (spd * beta) -def _centered_approximation( - phi: ArrayLike, spd: ArrayLike, m: int | list[int] -) -> Array: +def _centered_approximation(phi: Array, spd: Array, m: int) -> Array: with numpyro.plate("basis", m): beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd)) @@ -40,7 +36,7 @@ def _centered_approximation( def linear_approximation( - phi: ArrayLike, spd: ArrayLike, m: int | list[int], non_centered: bool = True + phi: Array, spd: Array, m: int, non_centered: bool = True ) -> Array: """ Linear approximation formula of the Hilbert space Gaussian process. @@ -52,10 +48,10 @@ def linear_approximation( 1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). - :param ArrayLike phi: laplacian eigenfunctions - :param ArrayLike spd: square root of the diagonal of the spectral density evaluated at square + :param Array phi: laplacian eigenfunctions + :param Array spd: square root of the diagonal of the spectral density evaluated at square root of the first `m` eigenvalues. - :param int | list[int] m: number of eigenfunctions in the approximation + :param int m: number of eigenfunctions in the approximation :param bool non_centered: whether to use a non-centered parameterization :return: The low-rank approximation linear model :rtype: Array diff --git a/numpyro/diagnostics.py b/numpyro/diagnostics.py index a68631b6a..746289f2e 100644 --- a/numpyro/diagnostics.py +++ b/numpyro/diagnostics.py @@ -10,6 +10,7 @@ from typing import Union import numpy as np +from numpy.typing import NDArray import jax from jax import device_get @@ -25,7 +26,7 @@ ] -def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]: +def _compute_chain_variance_stats(x: NDArray) -> tuple[NDArray, NDArray]: # compute within-chain variance and variance estimator # input has shape C x N x sample_shape C, N = x.shape[:2] @@ -41,7 +42,7 @@ def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray return var_within, var_estimator -def gelman_rubin(x: np.ndarray) -> np.ndarray: +def gelman_rubin(x: NDArray) -> NDArray: """ Computes R-hat over chains of samples ``x``, where the first dimension of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension. @@ -60,7 +61,7 @@ def gelman_rubin(x: np.ndarray) -> np.ndarray: return rhat -def split_gelman_rubin(x: np.ndarray) -> np.ndarray: +def split_gelman_rubin(x: NDArray) -> NDArray: """ Computes split R-hat over chains of samples ``x``, where the first dimension of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension. @@ -97,7 +98,7 @@ def _fft_next_fast_len(target: int) -> int: target += 1 -def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray: +def autocorrelation(x: NDArray, axis: int = 0, bias: bool = True) -> NDArray: """ Computes the autocorrelation of samples at dimension ``axis``. @@ -137,11 +138,11 @@ def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarr autocorr = autocorr / np.arange(N, 0.0, -1) with np.errstate(invalid="ignore", divide="ignore"): - autocorr = autocorr / autocorr[..., :1] + autocorr = (autocorr / autocorr[..., :1]).astype(np.float64) return np.swapaxes(autocorr, axis, -1) -def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray: +def autocovariance(x: NDArray, axis: int = 0, bias: bool = True) -> NDArray: """ Computes the autocovariance of samples at dimension ``axis``. @@ -154,7 +155,7 @@ def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarra return autocorrelation(x, axis, bias) * x.var(axis=axis, keepdims=True) -def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray: +def effective_sample_size(x: NDArray, bias: bool = True) -> NDArray: """ Computes effective sample size of input ``x``, where the first dimension of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension. @@ -202,7 +203,7 @@ def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray: return n_eff -def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray: +def hpdi(x: NDArray, prob: float = 0.90, axis: int = 0) -> NDArray: """ Computes "highest posterior density interval" (HPDI) which is the narrowest interval with probability mass ``prob``. @@ -285,7 +286,7 @@ def summary( def print_summary( - samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True + samples: Union[dict, NDArray], prob: float = 0.90, group_by_chain: bool = True ) -> None: """ Prints a summary table displaying diagnostics of ``samples`` from the diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index c975f865d..e03f6aca3 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -24,11 +24,11 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. - from collections import OrderedDict from contextlib import contextmanager import functools import inspect +from typing import Any, Protocol, runtime_checkable import warnings import numpy as np @@ -37,6 +37,7 @@ from jax import lax, tree_util import jax.numpy as jnp from jax.scipy.special import logsumexp +from jax.typing import ArrayLike from numpyro.distributions.transforms import AbsTransform, ComposeTransform, Transform from numpyro.distributions.util import ( @@ -270,7 +271,7 @@ def validate_args(self, strict: bool = True) -> None: raise RuntimeError("Cannot validate arguments inside jitted code.") @property - def batch_shape(self): + def batch_shape(self) -> tuple[int, ...]: """ Returns the shape over which the distribution parameters are batched. @@ -280,7 +281,7 @@ def batch_shape(self): return self._batch_shape @property - def event_shape(self): + def event_shape(self) -> tuple[int, ...]: """ Returns the shape of a single sample from the distribution without batching. @@ -291,7 +292,7 @@ def event_shape(self): return self._event_shape @property - def event_dim(self): + def event_dim(self) -> int: """ :return: Number of dimensions of individual events. :rtype: int @@ -299,16 +300,16 @@ def event_dim(self): return len(self.event_shape) @property - def has_rsample(self): + def has_rsample(self) -> bool: return set(self.reparametrized_params) == set(self.arg_constraints) - def rsample(self, key, sample_shape=()): + def rsample(self, key, sample_shape=()) -> ArrayLike: if self.has_rsample: return self.sample(key, sample_shape=sample_shape) raise NotImplementedError - def shape(self, sample_shape=()): + def shape(self, sample_shape=()) -> tuple[int, ...]: """ The tensor shape of samples from this distribution. @@ -323,7 +324,7 @@ def shape(self, sample_shape=()): """ return sample_shape + self.batch_shape + self.event_shape - def sample(self, key, sample_shape=()): + def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> ArrayLike: """ Returns a sample from the distribution having shape given by `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty, @@ -361,14 +362,14 @@ def log_prob(self, value): raise NotImplementedError @property - def mean(self): + def mean(self) -> ArrayLike: """ Mean of the distribution. """ raise NotImplementedError @property - def variance(self): + def variance(self) -> ArrayLike: """ Variance of the distribution. """ @@ -540,7 +541,7 @@ def infer_shapes(cls, *args, **kwargs): event_shape = () return batch_shape, event_shape - def cdf(self, value): + def cdf(self, value: ArrayLike) -> ArrayLike: """ The cumulative distribution function of this distribution. @@ -549,7 +550,7 @@ def cdf(self, value): """ raise NotImplementedError - def icdf(self, q): + def icdf(self, q: ArrayLike) -> ArrayLike: """ The inverse cumulative distribution function of this distribution. @@ -563,6 +564,43 @@ def is_discrete(self): return self.support.is_discrete +@runtime_checkable +class DistributionLike(Protocol): + """A protocol for typing distributions. + + Used to type object of type numpyro.distributions.Distribution, funsor.Funsor + or tensorflow_probability.distributions.Distribution. + """ + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return super().__call__(*args, **kwargs) + + @property + def batch_shape(self) -> tuple[int, ...]: ... + + @property + def event_shape(self) -> tuple[int, ...]: ... + + @property + def event_dim(self) -> int: ... + + def sample( + self, key: ArrayLike, sample_shape: tuple[int, ...] = () + ) -> ArrayLike: ... + + def log_prob(self, value: ArrayLike) -> ArrayLike: ... + + @property + def mean(self) -> ArrayLike: ... + + @property + def variance(self) -> ArrayLike: ... + + def cdf(self, value: ArrayLike) -> ArrayLike: ... + + def icdf(self, q: ArrayLike) -> ArrayLike: ... + + class ExpandedDistribution(Distribution): arg_constraints = {} pytree_data_fields = ("base_dist",) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index cf84b9512..a9f0e5c4d 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -4,21 +4,28 @@ from collections import namedtuple from contextlib import ExitStack, contextmanager import functools +from typing import Any, Callable, Generator, Optional, Union, cast import warnings import jax -from jax import lax, random +from jax import Array, lax, random import jax.numpy as jnp +from jax.typing import ArrayLike import numpyro +from numpyro.distributions.distribution import DistributionLike from numpyro.util import find_stack_level, identity -_PYRO_STACK = [] +# Type aliases +Message = dict[str, Any] + +_PYRO_STACK: list = [] + CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "dim", "size"]) -def default_process_message(msg): +def default_process_message(msg: Message) -> None: if msg["value"] is None: if msg["type"] == "sample": msg["value"], msg["intermediates"] = msg["fn"]( @@ -28,7 +35,7 @@ def default_process_message(msg): msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"]) -def apply_stack(msg): +def apply_stack(msg: Message) -> Message: """ Execute the effect stack at a single site according to the following scheme: @@ -61,19 +68,20 @@ def apply_stack(msg): class Messenger(object): - def __init__(self, fn=None): + def __init__(self, fn: Optional[Callable] = None) -> None: if fn is not None and not callable(fn): raise ValueError( "Expected `fn` to be a Python callable object; " "instead found type(fn) = {}.".format(type(fn)) ) self.fn = fn - functools.update_wrapper(self, fn, updated=[]) + if fn is not None: + functools.update_wrapper(self, fn, updated=[]) - def __enter__(self): + def __enter__(self) -> None: _PYRO_STACK.append(self) - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: if exc_type is None: assert _PYRO_STACK[-1] is self _PYRO_STACK.pop() @@ -86,13 +94,15 @@ def __exit__(self, exc_type, exc_value, traceback): # then remove it and everything below it in the stack. if self in _PYRO_STACK: loc = _PYRO_STACK.index(self) - for i in range(loc, len(_PYRO_STACK)): + for _ in range(loc, len(_PYRO_STACK)): _PYRO_STACK.pop() def process_message(self, msg): + """To be implemented by subclasses.""" pass def postprocess_message(self, msg): + """To be implemented by subclasses.""" pass def __call__(self, *args, **kwargs): @@ -105,7 +115,13 @@ def __call__(self, *args, **kwargs): return self.fn(*args, **kwargs) -def _masked_observe(name, fn, obs, obs_mask, **kwargs): +def _masked_observe( + name: str, + fn: DistributionLike, + obs: Optional[ArrayLike], + obs_mask, + **kwargs, +) -> ArrayLike: # Split into two auxiliary sample sites. with numpyro.handlers.mask(mask=obs_mask): observed = sample(f"{name}_observed", fn, **kwargs, obs=obs) @@ -120,8 +136,14 @@ def _masked_observe(name, fn, obs, obs_mask, **kwargs): def sample( - name, fn, obs=None, rng_key=None, sample_shape=(), infer=None, obs_mask=None -): + name: str, + fn: DistributionLike, + obs: Optional[ArrayLike] = None, + rng_key: Optional[ArrayLike] = None, + sample_shape: tuple[int, ...] = (), + infer: Optional[dict] = None, + obs_mask: Optional[ArrayLike] = None, +) -> ArrayLike: """ Returns a random sample from the stochastic function `fn`. This can have additional side effects when wrapped inside effect handlers like @@ -223,7 +245,9 @@ def sample( return msg["value"] -def param(name, init_value=None, **kwargs): +def param( + name: str, init_value: Optional[Union[ArrayLike, Callable]] = None, **kwargs +) -> Optional[ArrayLike]: """ Annotate the given site as an optimizable parameter for use with :mod:`jax.example_libraries.optimizers`. For an example of how `param` statements @@ -257,11 +281,11 @@ def param(name, init_value=None, **kwargs): if callable(init_value): - def fn(init_fn, *args, **kwargs): + def fn(init_fn: Callable, *args, **kwargs): return init_fn(prng_key()) else: - fn = identity + fn = cast(Callable, identity) # Otherwise, we initialize a message... initial_msg = { @@ -280,7 +304,7 @@ def fn(init_fn, *args, **kwargs): return msg["value"] -def deterministic(name, value): +def deterministic(name: str, value: ArrayLike) -> ArrayLike: """ Used to designate deterministic sites in the model. Note that most effect handlers will not operate on deterministic sites (except @@ -294,7 +318,7 @@ def deterministic(name, value): if not _PYRO_STACK: return value - initial_msg = { + initial_msg: Message = { "type": "deterministic", "name": name, "value": value, @@ -306,7 +330,9 @@ def deterministic(name, value): return msg["value"] -def mutable(name, init_value=None): +def mutable( + name: str, init_value: Optional[ArrayLike] = None +) -> Union[ArrayLike, None]: """ This primitive is used to store a mutable value that can be changed during model execution:: @@ -338,7 +364,7 @@ def mutable(name, init_value=None): return msg["value"] -def _inspect(): +def _inspect() -> dict: """ EXPERIMENTAL Inspect the Pyro stack. @@ -362,7 +388,7 @@ def _inspect(): return msg -def get_mask(): +def get_mask() -> Union[ArrayLike, None]: """ Records the effects of enclosing ``handlers.mask`` handlers. This is useful for avoiding expensive ``numpyro.factor()`` computations during @@ -381,7 +407,7 @@ def model(): return _inspect()["mask"] -def module(name, nn, input_shape=None): +def module(name: str, nn: tuple, input_shape: Optional[tuple] = None) -> Callable: """ Declare a :mod:`~jax.example_libraries.stax` style neural network inside a model so that its parameters are registered for optimization via @@ -408,7 +434,7 @@ def module(name, nn, input_shape=None): return functools.partial(nn_apply, nn_params) -def _subsample_fn(size, subsample_size, rng_key=None): +def _subsample_fn(size: int, subsample_size: int, rng_key: Optional[ArrayLike] = None): if rng_key is None: raise ValueError( "Missing random key to generate subsample indices." @@ -457,7 +483,13 @@ class plate(Messenger): is allocated. """ - def __init__(self, name, size, subsample_size=None, dim=None): + def __init__( + self, + name: str, + size: int, + subsample_size: Optional[int] = None, + dim: Optional[int] = None, + ) -> None: self.name = name assert size > 0, "size of plate should be positive" self.size = size @@ -513,14 +545,16 @@ def __enter__(self): return self._indices @staticmethod - def _get_batch_shape(cond_indep_stack): + def _get_batch_shape( + cond_indep_stack: list[CondIndepStackFrame], + ) -> tuple[int, ...]: n_dims = max(-f.dim for f in cond_indep_stack) batch_shape = [1] * n_dims for f in cond_indep_stack: batch_shape[f.dim] = f.size return tuple(batch_shape) - def process_message(self, msg): + def process_message(self, msg: Message) -> None: if msg["type"] not in ("param", "sample", "plate", "deterministic"): if msg["type"] == "control_flow": raise NotImplementedError( @@ -536,7 +570,7 @@ def process_message(self, msg): ): return - cond_indep_stack = msg["cond_indep_stack"] + cond_indep_stack: list[CondIndepStackFrame] = msg["cond_indep_stack"] frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size) cond_indep_stack.append(frame) if msg["type"] == "deterministic": @@ -560,7 +594,7 @@ def process_message(self, msg): self.size / self.subsample_size if self.subsample_size else 1 ) - def postprocess_message(self, msg): + def postprocess_message(self, msg: Message) -> None: if msg["type"] in ("subsample", "param") and self.dim is not None: event_dim = msg["kwargs"].get("event_dim") if event_dim is not None: @@ -589,7 +623,9 @@ def postprocess_message(self, msg): @contextmanager -def plate_stack(prefix, sizes, rightmost_dim=-1): +def plate_stack( + prefix: str, sizes: list[int], rightmost_dim: int = -1 +) -> Generator[None, None, None]: """ Create a contiguous stack of :class:`plate` s with dimensions:: @@ -607,7 +643,7 @@ def plate_stack(prefix, sizes, rightmost_dim=-1): yield -def factor(name, log_factor): +def factor(name: str, log_factor: ArrayLike) -> None: """ Factor statement to add arbitrary log probability factor to a probabilistic model. @@ -620,7 +656,7 @@ def factor(name, log_factor): sample(name, unit_dist, obs=unit_value, infer={"is_auxiliary": True}) -def prng_key(): +def prng_key() -> Union[Array, None]: """ A statement to draw a pseudo-random number generator key :func:`~jax.random.PRNGKey` under :class:`~numpyro.handlers.seed` handler. @@ -632,7 +668,7 @@ def prng_key(): "Cannot generate JAX PRNG key outside of `seed` handler.", stacklevel=find_stack_level(), ) - return + return None initial_msg = { "type": "prng_key", @@ -646,7 +682,7 @@ def prng_key(): return msg["value"] -def subsample(data, event_dim): +def subsample(data: ArrayLike, event_dim: int) -> ArrayLike: """ EXPERIMENTAL Subsampling statement to subsample data based on enclosing :class:`~numpyro.primitives.plate` s. diff --git a/pyproject.toml b/pyproject.toml index 37d24a27c..8e22e0c60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ module = [ "numpyro.contrib.hsgp.*", "numpyro.contrib.stochastic_support.*", "numpyro.diagnostics.*", + "numpyro.primitives.*", "numpyro.patch.*", "numpyro.util.*", ] diff --git a/test/test_distributions.py b/test/test_distributions.py index 61dd23182..03c5813d6 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -7,7 +7,7 @@ from itertools import product import math import os -from typing import Callable +from typing import Any, Callable import numpy as np from numpy.testing import assert_allclose, assert_array_equal @@ -20,6 +20,7 @@ from jax import grad, lax, vmap import jax.numpy as jnp import jax.random as random +from jax.random import PRNGKey from jax.scipy.special import expit, logsumexp from jax.scipy.stats import norm as jax_norm, truncnorm as jax_truncnorm @@ -32,6 +33,7 @@ ) from numpyro.distributions.batch_util import vmap_over from numpyro.distributions.discrete import _to_probs_bernoulli, _to_probs_multinom +from numpyro.distributions.distribution import DistributionLike from numpyro.distributions.flows import InverseAutoregressiveTransform from numpyro.distributions.gof import InvalidTest, auto_goodness_of_fit from numpyro.distributions.transforms import ( @@ -3534,3 +3536,148 @@ def make_dist(): # Run scan which validates that pytree structures are consistent. jax.lax.scan(lambda *_: (make_dist(), None), init, jnp.arange(7)) + + +# Test class that properly implements DistributionLike +class ValidDistributionLike: + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return () + + @property + def batch_shape(self) -> tuple[int, ...]: + return () + + @property + def event_shape(self) -> tuple[int, ...]: + return () + + @property + def event_dim(self) -> int: + return 0 + + def sample(self, key, sample_shape=()): + return jnp.array(0.0) + + def log_prob(self, value): + return jnp.array(0.0) + + @property + def mean(self): + return jnp.array(0.0) + + @property + def variance(self): + return jnp.array(1.0) + + def cdf(self, value): + return jnp.array(0.5) + + def icdf(self, q): + return jnp.array(0.0) + + +# Test class missing required methods +class InvalidDistributionLike: + @property + def batch_shape(self): + return () + + +def test_valid_distribution_implementations(): + """Test that valid implementations are recognized as DistributionLike""" + # Test standard NumPyro distribution + assert isinstance(dist.Normal(0, 1), DistributionLike) + + # Test custom implementation + assert isinstance(ValidDistributionLike(), DistributionLike) + + +def test_invalid_distribution_implementations(): + """Test that invalid implementations are not recognized as DistributionLike""" + assert not isinstance(InvalidDistributionLike(), DistributionLike) + assert not isinstance(object(), DistributionLike) + + +def test_distribution_like_interface(): + """Test that we can use a custom DistributionLike implementation where a Distribution is expected""" + my_dist = ValidDistributionLike() + + # Test basic properties + assert my_dist() == () + assert my_dist.batch_shape == () + assert my_dist.event_shape == () + assert my_dist.event_dim == 0 + + # Test methods + key = PRNGKey(0) + sample = my_dist.sample(key) + assert isinstance(sample, jnp.ndarray) + + log_prob = my_dist.log_prob(0.0) + assert isinstance(log_prob, jnp.ndarray) + + mean = my_dist.mean + assert isinstance(mean, jnp.ndarray) + + var = my_dist.variance + assert isinstance(var, jnp.ndarray) + + cdf = my_dist.cdf(0.0) + assert isinstance(cdf, jnp.ndarray) + + icdf = my_dist.icdf(0.5) + assert isinstance(icdf, jnp.ndarray) + + +def test_distribution_like_with_shapes(): + """Test a DistributionLike implementation with non-trivial shapes""" + + class ShapedDistributionLike: + @property + def batch_shape(self): + return (2, 3) + + @property + def event_shape(self): + return (4,) + + @property + def event_dim(self): + return 1 + + def sample(self, key, sample_shape=()): + shape = sample_shape + self.batch_shape + self.event_shape + return jnp.zeros(shape) + + def log_prob(self, value): + return jnp.zeros(self.batch_shape) + + @property + def mean(self): + return jnp.zeros(self.batch_shape + self.event_shape) + + @property + def variance(self): + return jnp.ones(self.batch_shape + self.event_shape) + + def cdf(self, value): + return jnp.full(self.batch_shape, 0.5) + + def icdf(self, q): + return jnp.zeros(self.batch_shape + self.event_shape) + + my_dist = ShapedDistributionLike() + + assert my_dist.batch_shape == (2, 3) + assert my_dist.event_shape == (4,) + assert my_dist.event_dim == 1 + + key = PRNGKey(0) + sample = my_dist.sample(key, sample_shape=(5,)) + assert sample.shape == (5, 2, 3, 4) + + log_prob = my_dist.log_prob(jnp.zeros((2, 3, 4))) + assert log_prob.shape == (2, 3) + + assert my_dist.mean.shape == (2, 3, 4) + assert my_dist.variance.shape == (2, 3, 4)