Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tracer leak in contrib.module #1935

Merged
merged 3 commits into from
Dec 12, 2024
Merged

Fix tracer leak in contrib.module #1935

merged 3 commits into from
Dec 12, 2024

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Dec 12, 2024

Fixes #1934

We don't want NN's mutable state to hold gradient values. Under jax.grad, the intermediate values like mutable states of flax/haiku modules can hold intermediate grad tracers. We don't want to store them in the trace so we will stop_gradient after the nn_module.apply update.

@fehiepsi fehiepsi merged commit 4b33db1 into pyro-ppl:master Dec 12, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

flax and haiku contrib tests fail under jax 0.4.36.
2 participants