From 6274eee010be5a0098aec264e24f1af3a7958c3b Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 20 Dec 2024 19:28:49 +0100 Subject: [PATCH] revert --- numpyro/distributions/distribution.py | 6 ++++-- numpyro/primitives.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 93698390f..52c620fe1 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -34,7 +34,7 @@ import numpy as np import jax -from jax import Array, lax, tree_util +from jax import lax, tree_util import jax.numpy as jnp from jax.scipy.special import logsumexp from jax.typing import ArrayLike @@ -581,7 +581,9 @@ def event_shape(self) -> tuple[int, ...]: ... @property def event_dim(self) -> int: ... - def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> Array: ... + def sample( + self, key: ArrayLike, sample_shape: tuple[int, ...] = () + ) -> ArrayLike: ... def log_prob(self, value: ArrayLike) -> ArrayLike: ... diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 2bc715347..b1e653484 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -133,7 +133,7 @@ def sample( sample_shape: tuple[int, ...] = (), infer: Optional[dict] = None, obs_mask: Optional[ArrayLike] = None, -) -> Array: +) -> ArrayLike: """ Returns a random sample from the stochastic function `fn`. This can have additional side effects when wrapped inside effect handlers like @@ -208,7 +208,7 @@ def sample( if obs is None: return fn(rng_key=rng_key, sample_shape=sample_shape) else: - return jnp.asarray(obs) + return obs if obs_mask is not None: return _masked_observe( @@ -670,7 +670,7 @@ def prng_key() -> Union[Array, None]: return msg["value"] -def subsample(data: ArrayLike, event_dim: int) -> Array: +def subsample(data: ArrayLike, event_dim: int) -> ArrayLike: """ EXPERIMENTAL Subsampling statement to subsample data based on enclosing :class:`~numpyro.primitives.plate` s. @@ -698,7 +698,7 @@ def model(data): :rtype: ~jnp.ndarray """ if not _PYRO_STACK: - return jnp.asarray(data) + return data assert isinstance(event_dim, int) and event_dim >= 0 initial_msg = {