diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ffcd94b88..753639535 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -83,6 +83,8 @@ jobs: uses: coverallsapp/github-action@v2 with: github-token: ${{ secrets.GITHUB_TOKEN }} + parallel: true + flag-name: test-modeling test-inference: @@ -134,6 +136,8 @@ jobs: uses: coverallsapp/github-action@v2 with: github-token: ${{ secrets.GITHUB_TOKEN }} + parallel: true + flag-name: test-inference examples: @@ -166,3 +170,19 @@ jobs: uses: coverallsapp/github-action@v2 with: github-token: ${{ secrets.GITHUB_TOKEN }} + parallel: true + flag-name: examples + + + finish: + + needs: [test-modeling, test-inference, examples] + runs-on: ubuntu-latest + steps: + - name: Coveralls finished + uses: coverallsapp/github-action@v2 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + parallel-finished: true + carryforward: "test-modeling,test-inference,examples" + diff --git a/numpyro/contrib/module.py b/numpyro/contrib/module.py index 3f40363e6..f87d96d5f 100644 --- a/numpyro/contrib/module.py +++ b/numpyro/contrib/module.py @@ -114,6 +114,7 @@ def flax_module( def apply_with_state(params, *args, **kwargs): params = {"params": params, **nn_state} out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs) + new_state = jax.lax.stop_gradient(new_state) nn_state.update(**new_state) return out @@ -202,6 +203,7 @@ def haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kw def apply_with_state(params, *args, **kwargs): out, new_state = nn_module.apply(params, nn_state, *args, **kwargs) + new_state = jax.lax.stop_gradient(new_state) nn_state.update(**new_state) return out diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index 584cd457b..a1342507f 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -224,7 +224,6 @@ def model(data, labels): ) -@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36") @pytest.mark.parametrize("dropout", [True, False]) @pytest.mark.parametrize("batchnorm", [True, False]) def test_haiku_state_dropout_smoke(dropout, batchnorm): @@ -264,7 +263,6 @@ def model(): svi.run(random.PRNGKey(100), 10) -@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36") @pytest.mark.parametrize("dropout", [True, False]) @pytest.mark.parametrize("batchnorm", [True, False]) def test_flax_state_dropout_smoke(dropout, batchnorm):