Skip to content
Discussion options

You must be logged in to vote

This combination worked for me:

module load python
mamba create -n jaxtest python=3.13 -y
mamba activate jaxtest
module load cuda/12.9.1-fasrc01
module load cudnn/9.10.2.21_cuda12-fasrc01
pip install -U "jax[cuda12]"

Test it

python -c "import jax; print(jax.devices())"
[CudaDevice(id=0)]

Replies: 2 comments

Comment options

Naeemkh
Jul 18, 2025
Maintainer Author

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by Naeemkh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants