Skip to content

Commit

Permalink
Fix tracer leak in contrib.module (#1935)
Browse files Browse the repository at this point in the history
* fix tracer leak in contrib

* fix coverage

* fix typo in ci
  • Loading branch information
fehiepsi authored Dec 12, 2024
1 parent 788b1cc commit 4b33db1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ jobs:
uses: coverallsapp/github-action@v2
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
parallel: true
flag-name: test-modeling


test-inference:
Expand Down Expand Up @@ -134,6 +136,8 @@ jobs:
uses: coverallsapp/github-action@v2
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
parallel: true
flag-name: test-inference


examples:
Expand Down Expand Up @@ -166,3 +170,19 @@ jobs:
uses: coverallsapp/github-action@v2
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
parallel: true
flag-name: examples


finish:

needs: [test-modeling, test-inference, examples]
runs-on: ubuntu-latest
steps:
- name: Coveralls finished
uses: coverallsapp/github-action@v2
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
parallel-finished: true
carryforward: "test-modeling,test-inference,examples"

2 changes: 2 additions & 0 deletions numpyro/contrib/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def flax_module(
def apply_with_state(params, *args, **kwargs):
params = {"params": params, **nn_state}
out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs)
new_state = jax.lax.stop_gradient(new_state)
nn_state.update(**new_state)
return out

Expand Down Expand Up @@ -202,6 +203,7 @@ def haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kw

def apply_with_state(params, *args, **kwargs):
out, new_state = nn_module.apply(params, nn_state, *args, **kwargs)
new_state = jax.lax.stop_gradient(new_state)
nn_state.update(**new_state)
return out

Expand Down
2 changes: 0 additions & 2 deletions test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def model(data, labels):
)


@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36")
@pytest.mark.parametrize("dropout", [True, False])
@pytest.mark.parametrize("batchnorm", [True, False])
def test_haiku_state_dropout_smoke(dropout, batchnorm):
Expand Down Expand Up @@ -264,7 +263,6 @@ def model():
svi.run(random.PRNGKey(100), 10)


@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36")
@pytest.mark.parametrize("dropout", [True, False])
@pytest.mark.parametrize("batchnorm", [True, False])
def test_flax_state_dropout_smoke(dropout, batchnorm):
Expand Down

0 comments on commit 4b33db1

Please sign in to comment.