From 32d5736155eba69c8ac2b4f821dfbde1384c0122 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 12 Dec 2024 11:44:02 -0500 Subject: [PATCH] fix tracer leak in contrib --- numpyro/contrib/module.py | 2 ++ test/contrib/test_module.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/contrib/module.py b/numpyro/contrib/module.py index 3f40363e6..f87d96d5f 100644 --- a/numpyro/contrib/module.py +++ b/numpyro/contrib/module.py @@ -114,6 +114,7 @@ def flax_module( def apply_with_state(params, *args, **kwargs): params = {"params": params, **nn_state} out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs) + new_state = jax.lax.stop_gradient(new_state) nn_state.update(**new_state) return out @@ -202,6 +203,7 @@ def haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kw def apply_with_state(params, *args, **kwargs): out, new_state = nn_module.apply(params, nn_state, *args, **kwargs) + new_state = jax.lax.stop_gradient(new_state) nn_state.update(**new_state) return out diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index 584cd457b..a1342507f 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -224,7 +224,6 @@ def model(data, labels): ) -@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36") @pytest.mark.parametrize("dropout", [True, False]) @pytest.mark.parametrize("batchnorm", [True, False]) def test_haiku_state_dropout_smoke(dropout, batchnorm): @@ -264,7 +263,6 @@ def model(): svi.run(random.PRNGKey(100), 10) -@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36") @pytest.mark.parametrize("dropout", [True, False]) @pytest.mark.parametrize("batchnorm", [True, False]) def test_flax_state_dropout_smoke(dropout, batchnorm):