diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 8540b6e5f..0a34991c4 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -232,6 +232,17 @@ def __init__(self, batch_shape=(), event_shape=(), *, validate_args=None): self.validate_args(strict=False) super(Distribution, self).__init__() + def get_args(self) -> dict: + """ + Get arguments of the distribution. + """ + return { + param: getattr(self, param) + for param in self.arg_constraints + if param in self.__dict__ + or not isinstance(getattr(type(self), param), lazy_property) + } + def validate_args(self, strict: bool = True) -> None: """ Validate the arguments of the distribution. @@ -239,14 +250,11 @@ def validate_args(self, strict: bool = True) -> None: :param strict: Require strict validation, raising an error if the function is called inside jitted code. """ - for param, constraint in self.arg_constraints.items(): - if param not in self.__dict__ and isinstance( - getattr(type(self), param), lazy_property - ): - continue + for param, value in self.get_args().items(): + constraint = self.arg_constraints[param] if constraints.is_dependent(constraint): continue # skip constraints that cannot be checked - is_valid = constraint(getattr(self, param)) + is_valid = constraint(value) if not_jax_tracer(is_valid): if not np.all(is_valid): raise ValueError( diff --git a/test/test_distributions.py b/test/test_distributions.py index 437913328..ca6cf6fa1 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3357,6 +3357,13 @@ def test_explicit_validate_args(): jitted(d, True) +def test_get_args(): + # Test that we only pick up parameters that were supplied or derived by the + # constructor. + d = dist.MultivariateNormal(precision_matrix=jnp.eye(3)) + assert set(d.get_args()) == {"loc", "precision_matrix", "scale_tril"} + + def test_multinomial_abstract_total_count(): probs = jnp.array([0.2, 0.5, 0.3]) key = random.PRNGKey(0)