Skip to content

Commit

Permalink
revert
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 20, 2024
1 parent a810776 commit 6274eee
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 4 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import numpy as np

import jax
from jax import Array, lax, tree_util
from jax import lax, tree_util
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike
Expand Down Expand Up @@ -581,7 +581,9 @@ def event_shape(self) -> tuple[int, ...]: ...
@property
def event_dim(self) -> int: ...

def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> Array: ...
def sample(
self, key: ArrayLike, sample_shape: tuple[int, ...] = ()
) -> ArrayLike: ...

def log_prob(self, value: ArrayLike) -> ArrayLike: ...

Expand Down
8 changes: 4 additions & 4 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def sample(
sample_shape: tuple[int, ...] = (),
infer: Optional[dict] = None,
obs_mask: Optional[ArrayLike] = None,
) -> Array:
) -> ArrayLike:
"""
Returns a random sample from the stochastic function `fn`. This can have
additional side effects when wrapped inside effect handlers like
Expand Down Expand Up @@ -208,7 +208,7 @@ def sample(
if obs is None:
return fn(rng_key=rng_key, sample_shape=sample_shape)
else:
return jnp.asarray(obs)
return obs

if obs_mask is not None:
return _masked_observe(
Expand Down Expand Up @@ -670,7 +670,7 @@ def prng_key() -> Union[Array, None]:
return msg["value"]


def subsample(data: ArrayLike, event_dim: int) -> Array:
def subsample(data: ArrayLike, event_dim: int) -> ArrayLike:
"""
EXPERIMENTAL Subsampling statement to subsample data based on enclosing
:class:`~numpyro.primitives.plate` s.
Expand Down Expand Up @@ -698,7 +698,7 @@ def model(data):
:rtype: ~jnp.ndarray
"""
if not _PYRO_STACK:
return jnp.asarray(data)
return data

assert isinstance(event_dim, int) and event_dim >= 0
initial_msg = {
Expand Down

0 comments on commit 6274eee

Please sign in to comment.