We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Installing a fresh copy of jax 0.4.35 with
pip install "jax[cuda12]==0.4.35" "nvidia-cudnn-cu12<9.4"
(this installs nvidia-cudnn-cu12==9.3.0.75) leads to a broken installation due to the following error
nvidia-cudnn-cu12==9.3.0.75
the same error also appears if you remove the constraint nvidia-cudnn-cu12<9.4 and install nvidia-cudnn-cu12==9.5.1.17
nvidia-cudnn-cu12<9.4
nvidia-cudnn-cu12==9.5.1.17
the last version of jax working correctly is 0.4.33
>>> jax.numpy.ones((3,4)) E1123 19:13:25.034265 814362 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found E1123 19:13:25.071507 814362 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 5949, in ones return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1615, in full fill_value = _convert_element_type(fill_value, dtype, weak_type) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 587, in _convert_element_type return convert_element_type_p.bind( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 2981, in _convert_element_type_bind operand = core.Primitive.bind(convert_element_type_p, operand, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 438, in bind return self.bind_with_trace(find_top_trace(args), args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 955, in process_primitive return primitive.impl(*tracers, **params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive outs = fun(*args) ^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. >>>
>>> import jax; jax.print_environment_info() jax: 0.4.35 jaxlib: 0.4.34 numpy: 2.0.2 python: 3.12.7 (main, Oct 16 2024, 04:37:19) [Clang 18.1.8 ] device info: NVIDIA A100-SXM-64GB-4, 4 local devices" process_count: 1 platform: uname_result(system='Linux', node='lrdn3434.leonardo.local', release='4.18.0-425.19.2.el8_7.x86_64', version='#1 SMP Fri Mar 17 01:52:38 EDT 2023', machine='x86_64') $ nvidia-smi Sat Nov 23 19:13:17 2024 +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA A100-SXM-64GB On | 00000000:1D:00.0 Off | 0 | | N/A 44C P0 80W / 475W| 477MiB / 65536MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 1 NVIDIA A100-SXM-64GB On | 00000000:56:00.0 Off | 0 | | N/A 44C P0 76W / 473W| 477MiB / 65536MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 2 NVIDIA A100-SXM-64GB On | 00000000:8F:00.0 Off | 0 | | N/A 44C P0 73W / 453W| 477MiB / 65536MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 3 NVIDIA A100-SXM-64GB On | 00000000:C8:00.0 Off | 0 | | N/A 43C P0 74W / 453W| 477MiB / 65536MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ +---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| | 0 N/A N/A 814362 C ...al/fvicenti/test1/.venv/bin/python3 474MiB | | 1 N/A N/A 814362 C ...al/fvicenti/test1/.venv/bin/python3 474MiB | | 2 N/A N/A 814362 C ...al/fvicenti/test1/.venv/bin/python3 474MiB | | 3 N/A N/A 814362 C ...al/fvicenti/test1/.venv/bin/python3 474MiB | +---------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Description
Installing a fresh copy of jax 0.4.35 with
(this installs
nvidia-cudnn-cu12==9.3.0.75
) leads to a broken installation due to the following errorthe same error also appears if you remove the constraint
nvidia-cudnn-cu12<9.4
and installnvidia-cudnn-cu12==9.5.1.17
the last version of jax working correctly is 0.4.33
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: