From b74c0e958b207575197cca2febdaaf404e9b9cce Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Thu, 5 Dec 2024 05:04:10 -0500 Subject: [PATCH] Only include `arg_constraints` in `pytree_data_fields` if they are not `lazy_property`s. (#1929) --- numpyro/distributions/distribution.py | 18 +++++++++++------- test/test_distributions.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 0a34991c4..b0e6ae431 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -147,16 +147,20 @@ def __init_subclass__(cls, **kwargs): def gather_pytree_data_fields(cls): bases = inspect.getmro(cls) - all_pytree_data_fields = () + all_pytree_data_fields = set() for base in bases: if issubclass(base, Distribution): - all_pytree_data_fields += base.__dict__.get( - "pytree_data_fields", - tuple(base.__dict__.get("arg_constraints", {}).keys()), + all_pytree_data_fields.update( + base.__dict__.get( + "pytree_data_fields", + tuple( + arg + for arg in base.__dict__.get("arg_constraints", {}) + if not isinstance(getattr(cls, arg, None), lazy_property) + ), + ) ) - # remove duplicates - all_pytree_data_fields = tuple(set(all_pytree_data_fields)) - return all_pytree_data_fields + return tuple(all_pytree_data_fields) @classmethod def gather_pytree_aux_fields(cls) -> tuple: diff --git a/test/test_distributions.py b/test/test_distributions.py index ff7037a92..87b81219c 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3521,3 +3521,16 @@ def test_gaussian_random_walk_state_space_equivalence(): assert jnp.allclose(x1, jnp.squeeze(x2, axis=-1)) assert jnp.allclose(d1.log_prob(x1), d2.log_prob(x2)) + + +def test_consistent_pytree() -> None: + def make_dist(): + return dist.MultivariateNormal(precision_matrix=jnp.eye(2)) + + init = make_dist() + # Access the covariance matrix to evaluate the lazy property. + init.covariance_matrix + assert "covariance_matrix" in init.__dict__ + + # Run scan which validates that pytree structures are consistent. + jax.lax.scan(lambda *_: (make_dist(), None), init, jnp.arange(7))