-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consolidate material on debugging NaNs. #24989
base: main
Are you sure you want to change the base?
Conversation
messages that are inconsistent with how the code behaves.
"If you see this error, consider opening a bug report at " | ||
"https://github.com/jax-ml/jax.") | ||
raise FloatingPointError(msg) | ||
# TODO(emilyaf): Re-enable the below when https://github.com/jax-ml/jax/issues/24955 is fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not comment this out. We should do a proper fix instead of this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree this isn't a substitute for a proper fix. @mattjj and I are working on one. In the meantime, I think commenting this out temporarily is better than the status quo of giving users wrong information (i.e. that their code doesn't produce NaNs without JIT, when it in fact does).
I thought about instead adding a note to flags.md explaining that this error message is broken/misleading, but that seems worse than just omitting it for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can hold off on this until the actual fix lands? Or maybe your doc change PR can be put on hold for some time?
I feel that this still provides some value rather than raising no error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FloatingPointError is still raised with the changes in this PR. This PR just removes the addendum to the error message about failing to produce NaNs with JIT is removed, which currently appears regardless of whether the non-jitted code produces NaNs or not and is therefore often wrong.
The docs are out of sync with how the code behaves, since there have been a couple regressions in debug_nans
(the error message addendum always appearing, and the stack trace stopping at the call site of a jitted function instead of the line in the function that produced the NaN). This PR makes the docs/code consistent, so I think it'd be good to merge this as we continue to work on a fix, and update the docs again when the fix is in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanna find out what it takes to fix this. If it's a couple of days or hours, then I would wait. If it's weeks worth of work, then it sounds fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably unlikely to be fixed before Thanksgiving (I'm out next week and have other stuff to focus on the next couple days). @mattjj please correct me if I'm wrong and you see a quicker fix.
Updating the code examples uncovered a couple regressions in
jax_debug_nans
, described in #24955. This PR comments out parts of docs/error messages that are inconsistent with how the code currently behaves.