Skip to content

Commit

Permalink
Add method to retrieve distribution arguments. (#1913)
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann authored Nov 19, 2024
1 parent c110053 commit 5c2eafa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
20 changes: 14 additions & 6 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,21 +232,29 @@ def __init__(self, batch_shape=(), event_shape=(), *, validate_args=None):
self.validate_args(strict=False)
super(Distribution, self).__init__()

def get_args(self) -> dict:
"""
Get arguments of the distribution.
"""
return {
param: getattr(self, param)
for param in self.arg_constraints
if param in self.__dict__
or not isinstance(getattr(type(self), param), lazy_property)
}

def validate_args(self, strict: bool = True) -> None:
"""
Validate the arguments of the distribution.
:param strict: Require strict validation, raising an error if the function is
called inside jitted code.
"""
for param, constraint in self.arg_constraints.items():
if param not in self.__dict__ and isinstance(
getattr(type(self), param), lazy_property
):
continue
for param, value in self.get_args().items():
constraint = self.arg_constraints[param]
if constraints.is_dependent(constraint):
continue # skip constraints that cannot be checked
is_valid = constraint(getattr(self, param))
is_valid = constraint(value)
if not_jax_tracer(is_valid):
if not np.all(is_valid):
raise ValueError(
Expand Down
7 changes: 7 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3357,6 +3357,13 @@ def test_explicit_validate_args():
jitted(d, True)


def test_get_args():
# Test that we only pick up parameters that were supplied or derived by the
# constructor.
d = dist.MultivariateNormal(precision_matrix=jnp.eye(3))
assert set(d.get_args()) == {"loc", "precision_matrix", "scale_tril"}


def test_multinomial_abstract_total_count():
probs = jnp.array([0.2, 0.5, 0.3])
key = random.PRNGKey(0)
Expand Down

0 comments on commit 5c2eafa

Please sign in to comment.