Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 20, 2024
1 parent a810776 commit b4dacc0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
23 changes: 14 additions & 9 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import numpy as np

import jax
from jax import Array, lax, tree_util
from jax import lax, tree_util
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike
Expand Down Expand Up @@ -303,13 +303,13 @@ def event_dim(self) -> int:
def has_rsample(self) -> bool:
return set(self.reparametrized_params) == set(self.arg_constraints)

def rsample(self, key, sample_shape=()):
def rsample(self, key, sample_shape=()) -> ArrayLike:
if self.has_rsample:
return self.sample(key, sample_shape=sample_shape)

raise NotImplementedError

def shape(self, sample_shape=()):
def shape(self, sample_shape=()) -> tuple[int, ...]:
"""
The tensor shape of samples from this distribution.
Expand All @@ -324,7 +324,7 @@ def shape(self, sample_shape=()):
"""
return sample_shape + self.batch_shape + self.event_shape

def sample(self, key, sample_shape=()):
def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
"""
Returns a sample from the distribution having shape given by
`sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
Expand Down Expand Up @@ -362,14 +362,14 @@ def log_prob(self, value):
raise NotImplementedError

@property
def mean(self):
def mean(self) -> ArrayLike:
"""
Mean of the distribution.
"""
raise NotImplementedError

@property
def variance(self):
def variance(self) -> ArrayLike:
"""
Variance of the distribution.
"""
Expand Down Expand Up @@ -541,7 +541,7 @@ def infer_shapes(cls, *args, **kwargs):
event_shape = ()
return batch_shape, event_shape

def cdf(self, value):
def cdf(self, value: ArrayLike) -> ArrayLike:
"""
The cumulative distribution function of this distribution.
Expand All @@ -550,7 +550,7 @@ def cdf(self, value):
"""
raise NotImplementedError

def icdf(self, q):
def icdf(self, q: ArrayLike) -> ArrayLike:
"""
The inverse cumulative distribution function of this distribution.
Expand All @@ -572,6 +572,9 @@ class DistributionLike(Protocol):
or tensorflow_probability.distributions.Distribution.
"""

def __call__(self, *args: functools.Any, **kwds: functools.Any) -> functools.Any:
return super().__call__(*args, **kwds)

@property
def batch_shape(self) -> tuple[int, ...]: ...

Expand All @@ -581,7 +584,9 @@ def event_shape(self) -> tuple[int, ...]: ...
@property
def event_dim(self) -> int: ...

def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> Array: ...
def sample(
self, key: ArrayLike, sample_shape: tuple[int, ...] = ()
) -> ArrayLike: ...

def log_prob(self, value: ArrayLike) -> ArrayLike: ...

Expand Down
8 changes: 4 additions & 4 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def sample(
sample_shape: tuple[int, ...] = (),
infer: Optional[dict] = None,
obs_mask: Optional[ArrayLike] = None,
) -> Array:
) -> ArrayLike:
"""
Returns a random sample from the stochastic function `fn`. This can have
additional side effects when wrapped inside effect handlers like
Expand Down Expand Up @@ -208,7 +208,7 @@ def sample(
if obs is None:
return fn(rng_key=rng_key, sample_shape=sample_shape)
else:
return jnp.asarray(obs)
return obs

if obs_mask is not None:
return _masked_observe(
Expand Down Expand Up @@ -670,7 +670,7 @@ def prng_key() -> Union[Array, None]:
return msg["value"]


def subsample(data: ArrayLike, event_dim: int) -> Array:
def subsample(data: ArrayLike, event_dim: int) -> ArrayLike:
"""
EXPERIMENTAL Subsampling statement to subsample data based on enclosing
:class:`~numpyro.primitives.plate` s.
Expand Down Expand Up @@ -698,7 +698,7 @@ def model(data):
:rtype: ~jnp.ndarray
"""
if not _PYRO_STACK:
return jnp.asarray(data)
return data

assert isinstance(event_dim, int) and event_dim >= 0
initial_msg = {
Expand Down

0 comments on commit b4dacc0

Please sign in to comment.