Skip to content

Commit

Permalink
distribution like type
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 20, 2024
1 parent ebc510c commit ba9a07a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
45 changes: 39 additions & 6 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,20 @@
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from collections import OrderedDict
from contextlib import contextmanager
import functools
import inspect
from typing import Protocol, runtime_checkable
import warnings

import numpy as np

import jax
from jax import lax, tree_util
from jax import Array, lax, tree_util
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike

from numpyro.distributions.transforms import AbsTransform, ComposeTransform, Transform
from numpyro.distributions.util import (
Expand Down Expand Up @@ -270,7 +271,7 @@ def validate_args(self, strict: bool = True) -> None:
raise RuntimeError("Cannot validate arguments inside jitted code.")

@property
def batch_shape(self):
def batch_shape(self) -> tuple[int, ...]:
"""
Returns the shape over which the distribution parameters are batched.
Expand All @@ -280,7 +281,7 @@ def batch_shape(self):
return self._batch_shape

@property
def event_shape(self):
def event_shape(self) -> tuple[int, ...]:
"""
Returns the shape of a single sample from the distribution without
batching.
Expand All @@ -291,15 +292,15 @@ def event_shape(self):
return self._event_shape

@property
def event_dim(self):
def event_dim(self) -> int:
"""
:return: Number of dimensions of individual events.
:rtype: int
"""
return len(self.event_shape)

@property
def has_rsample(self):
def has_rsample(self) -> bool:
return set(self.reparametrized_params) == set(self.arg_constraints)

def rsample(self, key, sample_shape=()):
Expand Down Expand Up @@ -563,6 +564,38 @@ def is_discrete(self):
return self.support.is_discrete


@runtime_checkable
class DistributionLike(Protocol):
"""A protocol for typing distributions.
Used to type object of type numpyro.distributions.Distribution, funsor.Funsor
or tensorflow_probability.distributions.Distribution.
"""

@property
def batch_shape(self) -> tuple[int, ...]: ...

@property
def event_shape(self) -> tuple[int, ...]: ...

@property
def event_dim(self) -> int: ...

def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> Array: ...

def log_prob(self, value: ArrayLike) -> ArrayLike: ...

@property
def mean(self) -> ArrayLike: ...

@property
def variance(self) -> ArrayLike: ...

def cdf(self, value: ArrayLike) -> ArrayLike: ...

def icdf(self, q: ArrayLike) -> ArrayLike: ...


class ExpandedDistribution(Distribution):
arg_constraints = {}
pytree_data_fields = ("base_dist",)
Expand Down
12 changes: 7 additions & 5 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Callable, Generator, Optional, Union, cast
import warnings

from distributions.distribution import DistributionLike

import jax
from jax import Array, lax, random
import jax.numpy as jnp
Expand Down Expand Up @@ -126,7 +128,7 @@ def _masked_observe(name, fn, obs, obs_mask, **kwargs) -> Array:

def sample(
name: str,
fn,
fn: DistributionLike,
obs: Optional[ArrayLike] = None,
rng_key: Optional[ArrayLike] = None,
sample_shape: tuple[int, ...] = (),
Expand Down Expand Up @@ -207,7 +209,7 @@ def sample(
if obs is None:
return fn(rng_key=rng_key, sample_shape=sample_shape)
else:
return cast(Array, obs)
return jnp.asarray(obs)

if obs_mask is not None:
return _masked_observe(
Expand Down Expand Up @@ -266,7 +268,7 @@ def param(
assert not callable(
init_value
), "A callable init_value needs to be put inside a numpyro.handlers.seed handler."
return cast(Array, init_value)
return jnp.asarray(init_value)

if callable(init_value):

Expand Down Expand Up @@ -305,7 +307,7 @@ def deterministic(name: str, value: ArrayLike) -> Array:
:param jnp.ndarray value: deterministic value to record in the trace.
"""
if not _PYRO_STACK:
return cast(Array, value)
return jnp.asarray(value)

initial_msg: dict = {
"type": "deterministic",
Expand Down Expand Up @@ -697,7 +699,7 @@ def model(data):
:rtype: ~jnp.ndarray
"""
if not _PYRO_STACK:
return cast(Array, data)
return jnp.asarray(data)

assert isinstance(event_dim, int) and event_dim >= 0
initial_msg = {
Expand Down

0 comments on commit ba9a07a

Please sign in to comment.