Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version Conflict Between Torch and JAX for NVIDIA cuDNN-cu12 #858

Open
apivovarov opened this issue Nov 26, 2024 · 1 comment
Open

Version Conflict Between Torch and JAX for NVIDIA cuDNN-cu12 #858

apivovarov opened this issue Nov 26, 2024 · 1 comment

Comments

@apivovarov
Copy link

apivovarov commented Nov 26, 2024

I am trying to run all the pytests on a GPU instance

To set up the environment, I installed the [dev] and [gpu] dependencies, but encountered the following issue:

pip install -e .[dev]
torchvision-0.16.1 requires torch-2.1.1
torch-2.1.1 requires nvidia_cudnn_cu12-8.9.2.26
pip install -e .[gpu]
jax[cuda12]-0.4.33 needs nvidia-cudnn-cu12 9.5.1.17

This leads to the following conflict:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.1.1 requires nvidia-cudnn-cu12==8.9.2.26; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cudnn-cu12 9.5.1.17 which is incompatible.
Successfully installed nvidia-cudnn-cu12-9.5.1.17

I am unable to have both torch and jax installed simultaneously.

When nvidia-cudnn-cu12-9.5.1.17 (the newer version) is installed, torch-2.1.1 crashes with the following error:

import torch

ImportError: libcudnn.so.8: cannot open shared object file: No such file or directory

When nvidia-cudnn-cu12-8.9.2.26 (the older version) is installed, jax crashes with this error:

import jax.numpy as jnp
x = jnp.ones((1000, 1000))

FAILED_PRECONDITION: DNN library initialization failed

Approximately 30 test files use the torch package.

I am confused about how to run all pytests on the GPU instance, as I cannot have both torch/torchvision and jax[cuda12] installed at the same time due to these conflicts.

OS: Ubuntu 22.04

@apivovarov
Copy link
Author

Possible workaround: Install the CPU version of Torch after the Axlern [gpu] dependencies are installed.

pip install -e .[dev]
pip install -e .[gpu]
pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cpu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant