Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

debug_nans error always says the de-optimized function did not produce NaNs #24955

Open
emilyfertig opened this issue Nov 18, 2024 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@emilyfertig
Copy link
Collaborator

Description

I'm working on documentation for debug_nans and I wrote the following function, which for certain input values calls jnp.log on a negative number, producing a nan value.

import jax
import jax.numpy as jnp
jax.config.update("jax_debug_nans", True)

@jax.jit
def f(x, y):
  w = jnp.sin(x) - y**2
  z = jnp.log(w)
  return z*2

print(f(0.5, 0))
print(f(-2., 5))  # ==> FloatingPointError with note that the NaN doesn't appear without jit

print(jnp.log(-5.) ) # ==> same error with note

It fails with this error, indicating that a NaN was returned from the compiled function but not fun.call_wrapped. It's the same if I replace log with sqrt, if I remove the jit decorator, or if I just call jnp.log on a negative value without jit.

The error message is misleading because NaNs are returned from the de-optimized functions as well, since it's taking the log of a negative value. I think something is going wrong in the code path taken in _pjit_call_impl_python but I can't tell what.

cc @yashk2810 since it looks like you've worked on this area of the code a fair amount.

System info (python version, jaxlib version, accelerator, etc.)

Reproducible across a few different environments, but e.g.:

jax: 0.4.36
jaxlib: 0.4.36
numpy: 2.1.3
python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (f58ce1152703ca753794b8cef36da30bd2668d0f)]
device info: Tesla V100-SXM2-16GB-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='b6e5614622812f47-3e7e1adbbf9.borgtask.google.com', release='5.10.0-smp-1104.53.0.0', version='#1 [v5.10.0-1104.53.0.0] SMP @1727505643', machine='x86_64')

$ nvidia-smi
Mon Nov 18 12:42:04 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla V100-SXM2-16GB Off | 00000000:B3:00.0 Off | 0 |
| N/A 41C P0 72W / 300W | 12433MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 829280 C ...fb3717c109/mount/server/ml_notebook 12430MiB |
+---------------------------------------------------------------------------------------+

@emilyfertig emilyfertig added the bug Something isn't working label Nov 18, 2024
@mattjj mattjj self-assigned this Nov 18, 2024
@mattjj
Copy link
Collaborator

mattjj commented Nov 18, 2024

Thanks, @emilyfertig ! I noticed this a few months ago, started a PR to fix it, and then let it languish. This regressed when we did the jit/pjit merge more than a year ago. Let me see if I can revive the PR...

@emilyfertig
Copy link
Collaborator Author

I noticed something else that might be related: the error message with debug_nans used to say which line inside of a jitted function produced a nan, and now it just reports the call site. Here's an example from the NaN Debugging section of The Sharp Bits:

In [4]: from jax import jit

In [5]: @jit
   ...: def f(x, y):
   ...:     a = x * y
   ...:     b = (x + y) / (x - y)
   ...:     c = a + 2
   ...:     return a + b * c
   ...:

In [6]: x = jnp.array([2., 0.])

In [7]: y = jnp.array([3., 0.])

In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)

 ... stack trace ...

<ipython-input-5-619b39acbaac> in f(x, y)
      2 def f(x, y):
      3     a = x * y
----> 4     b = (x + y) / (x - y)
      5     c = a + 2
      6     return a + b * c

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

 ... stack trace ...

And here's the same code, run with 0.4.36:

---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
    [... skipping hidden 1 frame]

[google3/third_party/py/jax/_src/profiler.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in wrapper(*args, **kwargs)
    332     with TraceAnnotation(name, **decorator_kwargs):
--> 333       return func(*args, **kwargs)
    334     return wrapper

4 frames
[google3/third_party/py/jax/_src/interpreters/pxla.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in __call__(self, *args)
   1302         for arrays in out_arrays:
-> 1303           dispatch.check_special(self.name, arrays)
   1304         out = self.out_handler(out_arrays)

[google3/third_party/py/jax/_src/dispatch.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in check_special(name, bufs)
    315     for buf in bufs:
--> 316       _check_special(name, buf.dtype, buf)
    317 

[google3/third_party/py/jax/_src/dispatch.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in _check_special(name, dtype, buf)
    320     if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
--> 321       raise FloatingPointError(f"invalid value (nan) encountered in {name}")
    322     if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):

FloatingPointError: invalid value (nan) encountered in jit(f)

During handling of the above exception, another exception occurred:

FloatingPointError                        Traceback (most recent call last)
[<ipython-input-18-9911e10902e9>](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in <cell line: 0>()
     15 b = jnp.array([3., 9])
     16 
---> 17 print(f(x, y))

    [... skipping hidden 3 frame]

[google3/third_party/py/jax/_src/pjit.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs, *args)
   1692            "If you see this error, consider opening a bug report at "
   1693            "https://github.com/jax-ml/jax.")
-> 1694     raise FloatingPointError(msg)
   1695 
   1696 

FloatingPointError: invalid value (nan) encountered in jit(f). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/jax-ml/jax.

Also, the current version no longer prints "Invalid value encountered in the output of a jit function. Calling the de-optimized version." (sometimes it does, but I haven't figured out how to consistently reproduce it. I tried flushing the log buffer so I don't think it's that).

@mattjj If you have a start at a PR I'd be happy to take it over (especially if you think it'd be a good way to learn about this part of the code and wouldn't be too much to bite off as I'm getting ramped up).

@emilyfertig
Copy link
Collaborator Author

The above behavior (printing the call site only and not the line in the function where the NaN occurred) is more recent. 0.4.35 (released 10/22) still prints the exact line.

@emilyfertig
Copy link
Collaborator Author

For now #24989 comments out parts of the docs/error message that aren't consistent with how the code behaves.

@emilyfertig
Copy link
Collaborator Author

Culprit for the second issue appears to be 32bf19a

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants