Hello, I'm using lineax as a linear solver backend for some JAX-based Finite Element code I'm developing (to the lineax developers, thank you! lineax, equinox, etc are amazing IMO), and I'm getting some weird behavior when using a GPU vs. a CPU to solve a linear system using lineax.bicgstab. While there may be a bug hiding in my code, and I might need to use a better preconditioner, I am still confused by the difference in behavior.
Context
I'm trying to replicate a structural dynamics demo with some minor modifications (a slightly different time integration scheme, and Newton's method + BiCGSTAB are used). I am able to reproduce the results of the demo on a CPU and also on a GPU only if I am lucky - newton iterations would fail at random time steps. After providing more output from lineax, it's clear that a form of iterative breakdown has occurred in a linear solve, during nearly every incremental linear solve within a newton step, and eventually a NaN appears at a random timestep. Importantly, this lineax throw/exception it is not replicated while using a CPU to solve the same problem - no iterative solve breakdown occurs. Additionally, the solution I'm obtaining seems correct when compared to the analytical eigenmode/frequency on the CPU (and the GPU when I was able to skirt the numerical issues I'm seeing).
A Minimal Example
After iterating through dynamics-related updates (when I first saw the issue occur) on a previously working problem, I was able to reproduce a breakdown in the iterative BiCGSTAB solve by using a buggy material model:
def stress(u_grad):
I = np.eye(3)
F = u_grad + I
C = F.T @ F
E = (1/2) * (C - I)
# piola kirchoff stress - BUGGY on GPU only, and only when dynamics are non-negligible
# P = (lmbda * np.trace(E) * I + 2 * mu * E)
# correct piola kirchoff stress - NOT 'buggy' on GPU or CPU
P = F @ (lmbda * np.trace(E) * I + 2 * mu * E)
return P
While this might imply that there is a bug in my code, the confusing part is that lineax linear solves converge on a CPU (the newton solve does not converge due to the poorly defined material model). In contrast, running the code on a GPU sees the solve break down and occasionally result in a NaN. The attached logs show the difference in solver behavior. This is the same issue that I'm seeing in my larger structural dynamics code (when using a 'correct' material model).
gpu_out.log
cpu_out.log
Questions
Even if there is a small bug in my code somewhere, why is there a difference in CPU and GPU behavior during the linear solve? I have done everything I can to force all computations to use float64 (see 'debugging attempts`)
Other questions:
- Any tips for how to continue to debug this? I'm hitting my JAX debugging knowledge limits!
- Is this observed difference expected for ill-conditioned systems? Should I just be looking for a better preconditioner? (I'm using a jacobi preconditioner at the moment...)
Debugging Attempts
While trying to diagnose the problem, I attempted or noticed the following (nothing has solved the problem):
- setting
jax_default_matmul_precision to 'highest' to force GPU-based linear algebra backends to use 64bit precision (I am on an A100 fyi; I'm trying to use 'highest' per the discussion in JAX #19444): JAX #22557, tracked by JAX #18934
- GPU vs. CPU different behavior JAX #22382. There are some in-place updates that are made via
.at[] that I haven't looked into changing quite yet, but it seems like this has been fixed in XLA #19716.
Other debugging attempts/important mentions:
- 64 bit precision is enabled in JAX
jax.debug_nans - doesn't catch the issue. un-jitting the linear solve and using a breakpoint to re-run the problematic/NaN-inducing iteration, with the same LHS, RHS, and initial guess, does NOT reproduce the NaN
JAX version
note: I'm using lineax 0.0.8 and equinox 0.13.1
jax: 0.6.2
jaxlib: 0.6.2
numpy: 2.2.2
python: 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0]
device info: NVIDIA A100 80GB PCIe-2, 2 local devices"
process_count: 1
platform: uname_result(system='Linux', node='tralfamadore', release='6.8.0-85-generic', version='#85~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 16:18:59 UTC 2', machine='x86_64')
$ nvidia-smi
Wed Oct 15 14:14:58 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.65.06 Driver Version: 580.65.06 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100 80GB PCIe Off | 00000000:17:00.0 Off | 0 |
| N/A 38C P0 72W / 300W | 445MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA A100 80GB PCIe Off | 00000000:65:00.0 Off | 0 |
| N/A 37C P0 72W / 300W | 441MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 2176 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 656991 C python 422MiB |
| 1 N/A N/A 2176 G /usr/lib/xorg/Xorg 4MiB |
| 1 N/A N/A 656991 C python 418MiB |
+-----------------------------------------------------------------------------------------+
Hello, I'm using lineax as a linear solver backend for some JAX-based Finite Element code I'm developing (to the lineax developers, thank you! lineax, equinox, etc are amazing IMO), and I'm getting some weird behavior when using a GPU vs. a CPU to solve a linear system using
lineax.bicgstab. While there may be a bug hiding in my code, and I might need to use a better preconditioner, I am still confused by the difference in behavior.Context
I'm trying to replicate a structural dynamics demo with some minor modifications (a slightly different time integration scheme, and Newton's method + BiCGSTAB are used). I am able to reproduce the results of the demo on a CPU and also on a GPU only if I am lucky - newton iterations would fail at random time steps. After providing more output from lineax, it's clear that
a form of iterative breakdown has occurred in a linear solve, during nearly every incremental linear solve within a newton step, and eventually a NaN appears at a random timestep. Importantly, this lineaxthrow/exception it is not replicated while using a CPU to solve the same problem - no iterative solve breakdown occurs. Additionally, the solution I'm obtaining seems correct when compared to the analytical eigenmode/frequency on the CPU (and the GPU when I was able to skirt the numerical issues I'm seeing).A Minimal Example
After iterating through dynamics-related updates (when I first saw the issue occur) on a previously working problem, I was able to reproduce a breakdown in the iterative BiCGSTAB solve by using a buggy material model:
While this might imply that there is a bug in my code, the confusing part is that
lineaxlinear solves converge on a CPU (the newton solve does not converge due to the poorly defined material model). In contrast, running the code on a GPU sees the solve break down and occasionally result in a NaN. The attached logs show the difference in solver behavior. This is the same issue that I'm seeing in my larger structural dynamics code (when using a 'correct' material model).gpu_out.log
cpu_out.log
Questions
Even if there is a small bug in my code somewhere, why is there a difference in CPU and GPU behavior during the linear solve? I have done everything I can to force all computations to use
float64(see 'debugging attempts`)Other questions:
Debugging Attempts
While trying to diagnose the problem, I attempted or noticed the following (nothing has solved the problem):
jax_default_matmul_precisionto'highest'to force GPU-based linear algebra backends to use 64bit precision (I am on an A100 fyi; I'm trying to use 'highest' per the discussion in JAX #19444): JAX #22557, tracked by JAX #18934.at[]that I haven't looked into changing quite yet, but it seems like this has been fixed in XLA #19716.Other debugging attempts/important mentions:
jax.debug_nans- doesn't catch the issue. un-jitting the linear solve and using a breakpoint to re-run the problematic/NaN-inducing iteration, with the same LHS, RHS, and initial guess, does NOT reproduce the NaNJAX version
note: I'm using lineax 0.0.8 and equinox 0.13.1