Skip to content

Commit

Permalink
fix tracer leak in contrib
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Dec 12, 2024
1 parent bf9c715 commit 32d5736
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 2 additions & 0 deletions numpyro/contrib/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def flax_module(
def apply_with_state(params, *args, **kwargs):
params = {"params": params, **nn_state}
out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs)
new_state = jax.lax.stop_gradient(new_state)
nn_state.update(**new_state)
return out

Expand Down Expand Up @@ -202,6 +203,7 @@ def haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kw

def apply_with_state(params, *args, **kwargs):
out, new_state = nn_module.apply(params, nn_state, *args, **kwargs)
new_state = jax.lax.stop_gradient(new_state)
nn_state.update(**new_state)
return out

Expand Down
2 changes: 0 additions & 2 deletions test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def model(data, labels):
)


@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36")
@pytest.mark.parametrize("dropout", [True, False])
@pytest.mark.parametrize("batchnorm", [True, False])
def test_haiku_state_dropout_smoke(dropout, batchnorm):
Expand Down Expand Up @@ -264,7 +263,6 @@ def model():
svi.run(random.PRNGKey(100), 10)


@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36")
@pytest.mark.parametrize("dropout", [True, False])
@pytest.mark.parametrize("batchnorm", [True, False])
def test_flax_state_dropout_smoke(dropout, batchnorm):
Expand Down

0 comments on commit 32d5736

Please sign in to comment.