Tested and working configuration for JAX with GPU on a cluster with NVIDIA driver 525 (CUDA 12.0 max).
Running JAX on GPU clusters can be challenging due to driver and library compatibility issues. Version mismatches often lead to cryptic errors during import or runtime, or performance regressions where JAX runs but compilation takes an excessive amount of time. This repository provides a verified configuration for CUDA 12.0 and highlights common pitfalls.
The included tests verify JAX functionality by comparing CPU and GPU outputs for complex operations, such as differentiation through loops with eigensolvers. These tests ensure numerical consistency across devices, regardless of whether you used the provided installation scripts.
ssh node07 # must be on GPU node
cd /path/to/jax-cluster-setup
./install.sh # one-time setup
source activate.sh # every session
pytest tests/ # verify everything works
pytest tests/ -m "not stress" # quick verification onlyYou can also just source activate_jax.sh in existing environment and run pytest tests/ or python test_jax_gpu.py to verify everything works.
You might want to also see requirements-jax-cluster-works-gpu.txt for a list of dependencies that work together on cluster (installing qex library for isntance).
I also just put into the bashrc the following to automatically set the correct environment, which will depends where are your modules installed on your cluster:
########################################################
# FIX for JAX with GPU
########################################################
# --- Step 1: Load correct Python ---
module load devel/python/3.11.13 2>/dev/null
# --- Step 2: FORCE CUDA 12.0.1 for XLA (ptxas 12.0 matches driver 525) ---
# We MUST override whatever the cuda-sdk module set
export CUDA_HOME="/softs/nvidia/sdk/12.0.1"
export XLA_FLAGS="--xla_gpu_cuda_data_dir=/softs/nvidia/sdk/12.0.1"
# Put 12.0.1 ptxas FIRST in PATH, before any 12.8.1 from modules
export PATH="/softs/nvidia/sdk/12.0.1/bin:${PATH}"
# --- Step 3: Runtime libs from pip wheels + system cuDNN ---
CUDNN_ROOT="/softs/nvidia/cudnn/9.10.1.4_cuda12"
CUPTI_PATH="/softs/nvidia/sdk/12.0.1/extras/CUPTI/lib64"
SITE_PACKAGES=$(python3 -c 'import site; print(site.getsitepackages()[0])' 2>/dev/null)
NVIDIA_PATH="${SITE_PACKAGES}/nvidia"
NVIDIA_LIBS=""
if [ -d "${NVIDIA_PATH}" ]; then
for pkg in cusparse cusolver cufft cublas cudnn cuda_runtime cuda_nvrtc nvjitlink nccl cuda_cupti curand; do
if [ -d "${NVIDIA_PATH}/${pkg}/lib" ]; then
NVIDIA_LIBS="${NVIDIA_PATH}/${pkg}/lib:${NVIDIA_LIBS}"
fi
done
fi
export LD_LIBRARY_PATH="${NVIDIA_LIBS}${CUDNN_ROOT}/lib:${CUPTI_PATH}:/softs/nvidia/sdk/12.0.1/lib64:${LD_LIBRARY_PATH}"
# --- Verify ---
PTXAS_ACTUAL=$(which ptxas 2>/dev/null)
PTXAS_VER=$("${PTXAS_ACTUAL}" --version 2>/dev/null | grep -oP 'release \K[0-9.]+')
echo "============================================"
echo " JAX GPU Environment Activated"
echo "============================================"
echo " CUDA_HOME: $CUDA_HOME"
echo " XLA_FLAGS: $XLA_FLAGS"
echo " ptxas: ${PTXAS_ACTUAL} (v${PTXAS_VER})"
echo " cuDNN: ${CUDNN_ROOT}"
echo " Python: $(python3 --version 2>&1)"
if [ "$PTXAS_VER" = "12.0" ]; then
echo " GREAT! ptxas 12.0 matches driver CUDA 12.0"
else
echo " TROUBLE! ptxas ${PTXAS_VER} — may cause slow compilation!"
fi
echo "============================================"Driver 525 only supports CUDA 12.0. Using a newer CUDA toolkit (12.4, 12.8) gives XLA a ptxas version newer than what the driver supports. XLA detects this mismatch and disables parallel JIT compilation, making compilation 10-100x slower.
Additionally, pip install jax[cuda12] pulls in nvidia-cuda-nvcc-cu12 which bundles ptxas 12.9 -- same problem even if the system toolkit is correct.
JAX 0.4.30+ switched to a plugin architecture that bundles nvidia-* pip packages with CUDA 12.6-12.9 libraries. These conflict with driver 525. So JAX 0.4.29 is the maximum usable version.
- Use CUDA 12.0.1 toolkit for ptxas and libdevice (matches driver)
- Use JAX 0.4.29 with jaxlib 0.4.29+cuda12.cudnn91 (last standalone wheel)
- Disable the pip-bundled ptxas 12.9 (rename it)
- Runtime libraries (cuBLAS, cuSPARSE, etc.) come from pip nvidia-* wheels -- these are minor-version compatible and work fine
| Metric | CUDA 12.8 ptxas | CUDA 12.0 ptxas |
|---|---|---|
| Simple JIT compile | Minutes | ~1s |
| SCF stress test | Minutes | ~5s |
| Parallel compilation | Disabled | Enabled |
| Component | Version | Path |
|---|---|---|
| NVIDIA Driver | 525.85.12 / 525.147.05 | - |
| Driver's max CUDA | 12.0 | - |
| CUDA Toolkit (for XLA) | 12.0.1 | /softs/nvidia/sdk/12.0.1 |
| cuDNN | 9.10.1.4 | /softs/nvidia/cudnn/9.10.1.4_cuda12 |
| Python | 3.11.13 | module load devel/python/3.11.13 |
| GPU | A100-SXM4-40GB | node07 |
| JAX | 0.4.29 | pip (standalone CUDA wheel) |
| jaxlib | 0.4.29+cuda12.cudnn91 | pip -f jax_cuda_releases |
| File | Purpose |
|---|---|
config.sh |
All cluster-specific paths and versions -- edit this for your cluster |
install.sh |
One-time: creates venv, installs JAX, disables bad ptxas, installs test deps |
activate.sh |
Source each session: sets CUDA_HOME, XLA_FLAGS, PATH, LD_LIBRARY_PATH |
activate_jax.sh |
Source each session: sets CUDA_HOME, XLA_FLAGS, PATH, LD_LIBRARY_PATH in an existing environment |
diagnose.sh |
Troubleshooting: checks GPU, ptxas, env vars, libraries, JAX |
tests/ |
pytest test suite (device, basic ops, stress compilation, CPU/GPU consistency) |
Edit config.sh to adapt to a different cluster:
CUDA_VERSION="12.0.1" # Must match driver's max CUDA
CUDA_SDK_BASE="/softs/nvidia/sdk" # Where CUDA toolkits live
CUDNN_PATH="/softs/nvidia/cudnn/9.10.1.4_cuda12"
PYTHON_MODULE="devel/python/3.11.13" # For `module load`
JAX_VERSION="0.4.29" # Don't change unless driver is upgradedCustom venv location:
JAX_VENV_DIR=/scratch/user/.jax_venv ./install.sh| JAX Version | Status | Reason |
|---|---|---|
| 0.4.29 | Works | Last version with standalone CUDA wheels |
| 0.4.30+ | Fails | Requires nvidia-* pip packages (CUDA 12.6+) that conflict |
| 0.5.x-0.6.x | Fails | Same plugin architecture, same conflict |
| 0.7.x | Fails | Requires CUDA 13 |
To use newer JAX, the cluster needs a driver upgrade to >=535.
The driver (525) reports max CUDA 12.0. When XLA finds ptxas with a higher version:
The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas
CUDA version (12.8.xx). XLA is disabling parallel compilation, which may
slow down compilation.
CUDA 12.0.1 has ptxas 12.0 which matches the driver -> parallel compilation stays enabled -> fast JIT.
The pip-installed runtime libraries (cuBLAS 12.6, cuSPARSE 12.5, etc.) are forward-compatible and work fine with driver 525. Only ptxas needs to match.
Run diagnostics:
source activate.sh
./diagnose.shCause: ptxas version mismatch. Check:
ptxas --version # Should say 12.0Fix: Ensure activate.sh was sourced (puts CUDA 12.0.1 bin first in PATH). Also check that pip-bundled ptxas is disabled:
./diagnose.sh # Check item 2 and 3Cause: XLA can't find CUDA toolkit.
Fix: source activate.sh (sets XLA_FLAGS).
Cause: Missing from LD_LIBRARY_PATH.
Fix: source activate.sh (sets LD_LIBRARY_PATH).
Cause: Conflicting nvidia-* pip packages.
Fix:
pip uninstall -y nvidia-cublas-cu12 nvidia-cuda-cupti-cu12 \
nvidia-cuda-nvcc-cu12 nvidia-cuda-nvrtc-cu12 \
nvidia-cuda-runtime-cu12 nvidia-cudnn-cu12 nvidia-cufft-cu12 \
nvidia-cusolver-cu12 nvidia-cusparse-cu12 nvidia-nccl-cu12 \
nvidia-nvjitlink-cu12
pip install jax==0.4.29 jaxlib==0.4.29+cuda12.cudnn91 \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmlCause: Not on a GPU node.
Fix: ssh node07
rm -rf ~/.cache/jaxIf a pip install pulls in nvidia-cuda-nvcc-cu12 again:
SITE=$(python3 -c 'import site; print(site.getsitepackages()[0])')
mv ${SITE}/nvidia/cuda_nvcc/bin/ptxas ${SITE}/nvidia/cuda_nvcc/bin/ptxas.disabledWhen installing packages that depend on JAX, use --no-deps to avoid overwriting the working JAX version:
pip install some-package --no-depsThen install missing dependencies manually.
| Benchmark | Result |
|---|---|
| Matrix multiply (4096x4096, 10 iters) | ~104,000 GFLOPS |
| SCF stress test (64x64, 15 iters, grad) | ~5s compile |
| NN+SCF stress test (vmap+grad) | ~10s compile |
| JIT / Autodiff / vmap | All working |
- CUDA 12.8.1 toolkit -- ptxas 12.8 causes slow compilation
- CUDA 12.4.1 toolkit -- ptxas 12.4, same problem
pip install jax[cuda12]-- pulls nvidia-* packages with ptxas 12.9- JAX >= 0.4.30 -- requires bundled CUDA 12.6+ libraries
- JAX >= 0.7.x -- requires CUDA 13
