From 31c175bae70984754bcabdd3d75e657a4c7adff1 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Sat, 25 Nov 2023 03:37:51 +0000 Subject: [PATCH] Fix faulty interaction between `jax.vmap` and `validate_args=True` (#1686) * add initial bug reproducer * disable arg validation during `tree_unflatten` --- numpyro/distributions/distribution.py | 8 ++++++-- test/test_distributions.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 4f27f0022..ade4e9910 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -158,10 +158,10 @@ def gather_pytree_data_fields(cls): return all_pytree_data_fields @classmethod - def gather_pytree_aux_fields(cls): + def gather_pytree_aux_fields(cls) -> tuple: bases = inspect.getmro(cls) - all_pytree_aux_fields = () + all_pytree_aux_fields = ("_validate_args",) for base in bases: if issubclass(base, Distribution): all_pytree_aux_fields += base.__dict__.get("pytree_aux_fields", ()) @@ -203,11 +203,15 @@ def tree_unflatten(cls, aux_data, params): for k, v in pytree_aux_fields_dict.items(): setattr(d, k, v) + # disable args validation during `tree_unflatten` it is called by jax with + # placeholder attributes that would make validation fail + d._validate_args = False Distribution.__init__( d, pytree_aux_fields_dict["_batch_shape"], pytree_aux_fields_dict["_event_shape"], ) + d._validate_args = pytree_aux_fields_dict["_validate_args"] return d @staticmethod diff --git a/test/test_distributions.py b/test/test_distributions.py index 4e278646b..714fe2871 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3092,6 +3092,21 @@ def sample(d: dist.Distribution): assert samples_batched_dist.shape == (1, *samples_dist.shape) +def test_vmap_validate_args(): + # Test for #1684: vmapping distributions whould work when `validate_args=True` + v_dist = jax.vmap( + lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=True), + in_axes=(0, 0), + )(jnp.zeros((2,)), jnp.zeros((2,))) + + # non-regression test + v_dist = jax.vmap( + lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=False), + in_axes=(0, 0), + )(jnp.zeros((2,)), jnp.zeros((2,))) + assert not v_dist._validate_args + + def test_multinomial_abstract_total_count(): probs = jnp.array([0.2, 0.5, 0.3]) key = random.PRNGKey(0)