-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jaxlib or libtpu not detected on TPU Pod #22070
Comments
How did you install JAX? |
@yashk2810 I ran these commands on all hosts: python3.12 -m venv ~/venv
. ~/venv/bin/activate
pip install -U pip
pip install -U wheel
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
Can you show me your |
$ pip freeze
certifi==2024.6.2
charset-normalizer==3.3.2
idna==3.7
jax==0.4.30
jaxlib==0.4.30
libtpu-nightly==0.1.dev20240617
ml-dtypes==0.4.0
numpy==2.0.0
opt-einsum==3.3.0
requests==2.32.3
scipy==1.14.0
urllib3==2.2.2
wheel==0.43.0 |
Can you try running |
|
From the logs I realised the actual reason is that the TPU is used by another process. It works after the process is killed. |
Ah. This is supposed to be raised as an exception instead of falling back to CPU. That functionality must have regressed. Now to figure out why... |
Description
I have installed the TPU version of JAX (including jaxlib and libtpu) on all hosts of a TPU Pod inside a venv. Then, I run the following command on all hosts:
I got this error:
System info (python version, jaxlib version, accelerator, etc.)
$ ls /dev/accel* /dev/accel0 /dev/accel1 /dev/accel2 /dev/accel3
The text was updated successfully, but these errors were encountered: