Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error #21950

Closed
seishiroono opened this issue Jun 18, 2024 · 7 comments
Assignees
Labels
bug Something isn't working

Comments

@seishiroono
Copy link

Description

Description

I am trying to use JAX version 0.4.29 with CUDA 12.4. When I computed a simple linear algebraic calculation, I got an error RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error.

Error

When I did the following, I found the above error

>>> import jax
>>> import jax.numpy as jnp
>>> c = jnp.array([[ 0.,  0.,  0.,  1.],[-0.,  0., -1.,  0.],[ 0.,  1.,  0.,  0.],[-1.,  0., -0.,  0.]])
>>> jnp.linalg.det(c)

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/s.ono/.conda/envs/jax_cuda12/lib/python3.11/site-packages/jax/_src/numpy/linalg.py", line 692, in det
    sign, logdet = slogdet(a)
  File "/home/s.ono/.conda/envs/jax_cuda12/lib/python3.11/site-packages/jax/_src/numpy/linalg.py", line 534, in slogdet
    return SlogdetResult(*_slogdet_lu(a))
  File "/home/s.ono/.conda/envs/jax_cuda12/lib/python3.11/site-packages/jax/_src/numpy/linalg.py", line 455, in _slogdet_lu
    lu, pivot, _ = lax_linalg.lu(a)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/s.ono/.conda/envs/jax_cuda12/lib/python3.11/site-packages/jaxlib/gpu_solver.py", line 102, in _getrf_hlo
    lwork, opaque = gpu_solver.build_getrf_descriptor(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

On the other hand, when I tried the following command, it works well.

>>> a = jnp.array([[1., 2.],[3., 4.]])
>>> jnp.linalg.det(a)
Array(-2., dtype=float32)

System info (python version, jaxlib version, accelerator, etc.)

System info (python version, jaxlib version, accelerator, etc.)

>>> jax.print_environment_info()
jax:    0.4.29
jaxlib: 0.4.29
numpy:  1.26.4
python: 3.11.7 (main, Dec 15 2023, 18:12:31) [GCC 11.2.0]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='gnode05', release='3.10.0-1160.53.1.el7.x86_64', version='#1 SMP Fri Jan 14 13:59:45 UTC 2022', machine='x86_64')


$ nvidia-smi
Tue Jun 18 21:26:33 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-PCIE-40GB          Off |   00000000:2B:00.0 Off |                    0 |
| N/A   29C    P0             36W /  250W |     575MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-PCIE-40GB          Off |   00000000:A2:00.0 Off |                    0 |
| N/A   27C    P0             37W /  250W |     425MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      4978      C   python                                        416MiB |
|    0   N/A  N/A     15535      G   /usr/bin/X                                    108MiB |
|    0   N/A  N/A     15637      G   /usr/bin/gnome-shell                           23MiB |
|    1   N/A  N/A      4978      C   python                                        416MiB |
+-----------------------------------------------------------------------------------------+
@seishiroono seishiroono added the bug Something isn't working label Jun 18, 2024
@dfm
Copy link
Member

dfm commented Jun 18, 2024

Can you describe exactly how you installed jax (and any relevant drivers, etc.) so that we can try to reproduce?

@seishiroono
Copy link
Author

@dfm Thanks for your message. I am using miniconda for the virtual environment. The following is what I did in a virtual enviornment.

>>> conda create -n jax_cuda12_test python=3.11.7
>>> conda activate jax_cuda12_test
>>> python3 -m pip install nvidia-cuda-cccl-cu12==12.4.127 nvidia-cuda-cupti-cu12==12.4.127 nvidia-cuda-nvcc-cu12==12.4.131 nvidia-cuda-opencl-cu12==12.4.127 nvidia-cuda-nvrtc-cu12==12.4.127 nvidia-cublas-cu12==12.4.5.8 nvidia-cuda-sanitizer-api-cu12==12.4.127 nvidia-cufft-cu12 nvidia-curand-cu12 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.4.1.24 nvidia-npp-cu12 nvidia-nvfatbin-cu12==12.4.127 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvjpeg-cu12 nvidia-nvml-dev-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 nvidia-cuda-runtime-cu12==12.4.127
>>> pip install --upgrade pip
>>> pip install --upgrade "jax[cuda12]"

The resulting conda list is as follows.

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
bzip2                     1.0.8                h5eee18b_6  
ca-certificates           2024.3.11            h06a4308_0  
jax                       0.4.29                   pypi_0    pypi
jax-cuda12-pjrt           0.4.29                   pypi_0    pypi
jax-cuda12-plugin         0.4.29                   pypi_0    pypi
jaxlib                    0.4.29                   pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1  
libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
ml-dtypes                 0.4.0                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
numpy                     1.26.4                   pypi_0    pypi
nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
nvidia-cuda-cccl-cu12     12.4.127                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-nvcc-cu12     12.4.131                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-opencl-cu12   12.4.127                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
nvidia-cuda-sanitizer-api-cu12 12.4.127                 pypi_0    pypi
nvidia-cudnn-cu12         9.1.1.17                 pypi_0    pypi
nvidia-cufft-cu12         11.2.3.18                pypi_0    pypi
nvidia-curand-cu12        10.3.6.39                pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.4.1.24                pypi_0    pypi
nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
nvidia-npp-cu12           12.3.0.116               pypi_0    pypi
nvidia-nvfatbin-cu12      12.4.127                 pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
nvidia-nvjpeg-cu12        12.3.2.38                pypi_0    pypi
nvidia-nvml-dev-cu12      12.4.127                 pypi_0    pypi
nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
openssl                   3.0.13               h7f8727e_2  
opt-einsum                3.3.0                    pypi_0    pypi
pip                       24.0                     pypi_0    pypi
python                    3.11.7               h955ad1f_0  
readline                  8.2                  h5eee18b_0  
scipy                     1.13.1                   pypi_0    pypi
setuptools                69.5.1                   pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0  
tk                        8.6.14               h39e8969_0  
tzdata                    2024a                h04d1e81_0  
wheel                     0.43.0                   pypi_0    pypi
xz                        5.4.6                h5eee18b_1  
zlib                      1.2.13               h5eee18b_1

@dfm
Copy link
Member

dfm commented Jun 18, 2024

I'll look into this a little later, but you shouldn't need to install all those nvidia pip packages manually. What happens if you just pip install jax[cuda12] in a fresh environment?

@seishiroono
Copy link
Author

seishiroono commented Jun 19, 2024

@dfm Thanks for your message. In a new environment, I did pip install --upgrade "jax[cuda12]". It looks like pip install --upgrade "jax[cuda12]" cannot install collect versions of nvidia packages. Actually, the reason why I installed nvidia myself is that I met the same warning.

>>> import jax.numpy as jnp
>>> a = jnp.array(1.)
2024-06-19 10:47:17.191520: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

On the other hand, the original error did not appear.

>>> import jax.numpy as jnp
>>> c = jnp.array([[ 0.,  0.,  0.,  1.],[-0.,  0., -1.,  0.],[ 0.,  1.,  0.,  0.],[-1.,  0., -0.,  0.]])
>>> jnp.linalg.det(c)
Array(1., dtype=float32)

Just in case, I also show my conda list.

_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
bzip2                     1.0.8                h5eee18b_6  
ca-certificates           2024.3.11            h06a4308_0  
jax                       0.4.30                   pypi_0    pypi
jax-cuda12-pjrt           0.4.30                   pypi_0    pypi
jax-cuda12-plugin         0.4.30                   pypi_0    pypi
jaxlib                    0.4.30                   pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1  
libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
ml-dtypes                 0.4.0                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
numpy                     2.0.0                    pypi_0    pypi
nvidia-cublas-cu12        12.5.2.13                pypi_0    pypi
nvidia-cuda-cupti-cu12    12.5.39                  pypi_0    pypi
nvidia-cuda-nvcc-cu12     12.5.40                  pypi_0    pypi
nvidia-cuda-runtime-cu12  12.5.39                  pypi_0    pypi
nvidia-cudnn-cu12         9.1.1.17                 pypi_0    pypi
nvidia-cufft-cu12         11.2.3.18                pypi_0    pypi
nvidia-cusolver-cu12      11.6.2.40                pypi_0    pypi
nvidia-cusparse-cu12      12.4.1.24                pypi_0    pypi
nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.5.40                  pypi_0    pypi
openssl                   3.0.14               h5eee18b_0  
opt-einsum                3.3.0                    pypi_0    pypi
pip                       24.0                     pypi_0    pypi
python                    3.11.7               h955ad1f_0  
readline                  8.2                  h5eee18b_0  
scipy                     1.13.1                   pypi_0    pypi
setuptools                69.5.1                   pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0  
tk                        8.6.14               h39e8969_0  
tzdata                    2024a                h04d1e81_0  
wheel                     0.43.0                   pypi_0    pypi
xz                        5.4.6                h5eee18b_1  
zlib                      1.2.13               h5eee18b_1

P.S. Yesterday, JAX was updated, so the result of jax.print_environment_info() was also changed.

>>> jax.print_environment_info()
jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.0
python: 3.11.7 (main, Dec 15 2023, 18:12:31) [GCC 11.2.0]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='gnode05', release='3.10.0-1160.53.1.el7.x86_64', version='#1 SMP Fri Jan 14 13:59:45 UTC 2022', machine='x86_64')


$ nvidia-smi
Wed Jun 19 10:45:13 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-PCIE-40GB          Off |   00000000:2B:00.0 Off |                    0 |
| N/A   28C    P0             36W /  250W |     574MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-PCIE-40GB          Off |   00000000:A2:00.0 Off |                    0 |
| N/A   27C    P0             36W /  250W |     425MiB /  40960MiB |      3%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     14770      C   python                                        416MiB |
|    0   N/A  N/A     15535      G   /usr/bin/X                                    108MiB |
|    0   N/A  N/A     15637      G   /usr/bin/gnome-shell                           22MiB |
|    1   N/A  N/A     14770      C   python                                        416MiB |
+-----------------------------------------------------------------------------------------+

@dfm
Copy link
Member

dfm commented Jun 19, 2024

Thanks for the info. I can reproduce the warning that you're seeing about the ptxas version. I was able to work around this by simply downgrading the nvidia-cuda-nvcc-cu12 pip package. So, from a fresh virtual environment, I was able to get a working installation with:

pip install "jax[cuda12]" "nvidia-cuda-nvcc-cu12<12.5"

Want to see if that fixes the issue for you?

Edited to add: I also wouldn't worry too much about the warning. It may make jit compilation a little bit slower, but I wouldn't expect it to be a major issue!

@seishiroono
Copy link
Author

@dfm Thank you for your reply. The command you provided works for my environment. My program seems to run correctly. Let me check further if nothing happens.

@dfm dfm self-assigned this Jun 19, 2024
@dfm
Copy link
Member

dfm commented Jun 28, 2024

I'm going to close this as completed - please feel free to comment if there are other issues.

@dfm dfm closed this as completed Jun 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants