Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tracer leak in contrib.module #1935

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions numpyro/contrib/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def flax_module(
def apply_with_state(params, *args, **kwargs):
params = {"params": params, **nn_state}
out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs)
new_state = jax.lax.stop_gradient(new_state)
nn_state.update(**new_state)
return out

Expand Down Expand Up @@ -202,6 +203,7 @@ def haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kw

def apply_with_state(params, *args, **kwargs):
out, new_state = nn_module.apply(params, nn_state, *args, **kwargs)
new_state = jax.lax.stop_gradient(new_state)
nn_state.update(**new_state)
return out

Expand Down
2 changes: 0 additions & 2 deletions test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def model(data, labels):
)


@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36")
@pytest.mark.parametrize("dropout", [True, False])
@pytest.mark.parametrize("batchnorm", [True, False])
def test_haiku_state_dropout_smoke(dropout, batchnorm):
Expand Down Expand Up @@ -264,7 +263,6 @@ def model():
svi.run(random.PRNGKey(100), 10)


@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36")
@pytest.mark.parametrize("dropout", [True, False])
@pytest.mark.parametrize("batchnorm", [True, False])
def test_flax_state_dropout_smoke(dropout, batchnorm):
Expand Down
Loading