Skip to content

Commit

Permalink
add tests for the distributions type
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 20, 2024
1 parent a14adaa commit a810776
Showing 1 changed file with 143 additions and 0 deletions.
143 changes: 143 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from jax import grad, lax, vmap
import jax.numpy as jnp
import jax.random as random
from jax.random import PRNGKey
from jax.scipy.special import expit, logsumexp
from jax.scipy.stats import norm as jax_norm, truncnorm as jax_truncnorm

Expand All @@ -32,6 +33,7 @@
)
from numpyro.distributions.batch_util import vmap_over
from numpyro.distributions.discrete import _to_probs_bernoulli, _to_probs_multinom
from numpyro.distributions.distribution import DistributionLike
from numpyro.distributions.flows import InverseAutoregressiveTransform
from numpyro.distributions.gof import InvalidTest, auto_goodness_of_fit
from numpyro.distributions.transforms import (
Expand Down Expand Up @@ -3534,3 +3536,144 @@ def make_dist():

# Run scan which validates that pytree structures are consistent.
jax.lax.scan(lambda *_: (make_dist(), None), init, jnp.arange(7))


# Test class that properly implements DistributionLike
class ValidDistributionLike:
@property
def batch_shape(self) -> tuple[int, ...]:
return ()

@property
def event_shape(self) -> tuple[int, ...]:
return ()

@property
def event_dim(self) -> int:
return 0

def sample(self, key, sample_shape=()):
return jnp.array(0.0)

def log_prob(self, value):
return jnp.array(0.0)

@property
def mean(self):
return jnp.array(0.0)

@property
def variance(self):
return jnp.array(1.0)

def cdf(self, value):
return jnp.array(0.5)

def icdf(self, q):
return jnp.array(0.0)


# Test class missing required methods
class InvalidDistributionLike:
@property
def batch_shape(self):
return ()


def test_valid_distribution_implementations():
"""Test that valid implementations are recognized as DistributionLike"""
# Test standard NumPyro distribution
assert isinstance(dist.Normal(0, 1), DistributionLike)

# Test custom implementation
assert isinstance(ValidDistributionLike(), DistributionLike)


def test_invalid_distribution_implementations():
"""Test that invalid implementations are not recognized as DistributionLike"""
assert not isinstance(InvalidDistributionLike(), DistributionLike)
assert not isinstance(object(), DistributionLike)


def test_distribution_like_interface():
"""Test that we can use a custom DistributionLike implementation where a Distribution is expected"""
my_dist = ValidDistributionLike()

# Test basic properties
assert my_dist.batch_shape == ()
assert my_dist.event_shape == ()
assert my_dist.event_dim == 0

# Test methods
key = PRNGKey(0)
sample = my_dist.sample(key)
assert isinstance(sample, jnp.ndarray)

log_prob = my_dist.log_prob(0.0)
assert isinstance(log_prob, jnp.ndarray)

mean = my_dist.mean
assert isinstance(mean, jnp.ndarray)

var = my_dist.variance
assert isinstance(var, jnp.ndarray)

cdf = my_dist.cdf(0.0)
assert isinstance(cdf, jnp.ndarray)

icdf = my_dist.icdf(0.5)
assert isinstance(icdf, jnp.ndarray)


def test_distribution_like_with_shapes():
"""Test a DistributionLike implementation with non-trivial shapes"""

class ShapedDistributionLike:
@property
def batch_shape(self):
return (2, 3)

@property
def event_shape(self):
return (4,)

@property
def event_dim(self):
return 1

def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape + self.event_shape
return jnp.zeros(shape)

def log_prob(self, value):
return jnp.zeros(self.batch_shape)

@property
def mean(self):
return jnp.zeros(self.batch_shape + self.event_shape)

@property
def variance(self):
return jnp.ones(self.batch_shape + self.event_shape)

def cdf(self, value):
return jnp.full(self.batch_shape, 0.5)

def icdf(self, q):
return jnp.zeros(self.batch_shape + self.event_shape)

my_dist = ShapedDistributionLike()

assert my_dist.batch_shape == (2, 3)
assert my_dist.event_shape == (4,)
assert my_dist.event_dim == 1

key = PRNGKey(0)
sample = my_dist.sample(key, sample_shape=(5,))
assert sample.shape == (5, 2, 3, 4)

log_prob = my_dist.log_prob(jnp.zeros((2, 3, 4)))
assert log_prob.shape == (2, 3)

assert my_dist.mean.shape == (2, 3, 4)
assert my_dist.variance.shape == (2, 3, 4)

0 comments on commit a810776

Please sign in to comment.