Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix faulty interaction between jax.vmap and validate_args=True #1686

Merged
merged 2 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ def gather_pytree_data_fields(cls):
return all_pytree_data_fields

@classmethod
def gather_pytree_aux_fields(cls):
def gather_pytree_aux_fields(cls) -> tuple:
bases = inspect.getmro(cls)

all_pytree_aux_fields = ()
all_pytree_aux_fields = ("_validate_args",)
for base in bases:
if issubclass(base, Distribution):
all_pytree_aux_fields += base.__dict__.get("pytree_aux_fields", ())
Expand Down Expand Up @@ -203,11 +203,15 @@ def tree_unflatten(cls, aux_data, params):
for k, v in pytree_aux_fields_dict.items():
setattr(d, k, v)

# disable args validation during `tree_unflatten` it is called by jax with
# placeholder attributes that would make validation fail
d._validate_args = False
Distribution.__init__(
d,
pytree_aux_fields_dict["_batch_shape"],
pytree_aux_fields_dict["_event_shape"],
)
d._validate_args = pytree_aux_fields_dict["_validate_args"]
return d

@staticmethod
Expand Down
15 changes: 15 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3092,6 +3092,21 @@ def sample(d: dist.Distribution):
assert samples_batched_dist.shape == (1, *samples_dist.shape)


def test_vmap_validate_args():
# Test for #1684: vmapping distributions whould work when `validate_args=True`
v_dist = jax.vmap(
lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=True),
in_axes=(0, 0),
)(jnp.zeros((2,)), jnp.zeros((2,)))

# non-regression test
v_dist = jax.vmap(
lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=False),
in_axes=(0, 0),
)(jnp.zeros((2,)), jnp.zeros((2,)))
assert not v_dist._validate_args


def test_multinomial_abstract_total_count():
probs = jnp.array([0.2, 0.5, 0.3])
key = random.PRNGKey(0)
Expand Down
Loading