Skip to content

Different Results on GPU and CPU for BiCGSTAB linear solve #174

@kenneth-meyer

Description

@kenneth-meyer

Hello, I'm using lineax as a linear solver backend for some JAX-based Finite Element code I'm developing (to the lineax developers, thank you! lineax, equinox, etc are amazing IMO), and I'm getting some weird behavior when using a GPU vs. a CPU to solve a linear system using lineax.bicgstab. While there may be a bug hiding in my code, and I might need to use a better preconditioner, I am still confused by the difference in behavior.

Context

I'm trying to replicate a structural dynamics demo with some minor modifications (a slightly different time integration scheme, and Newton's method + BiCGSTAB are used). I am able to reproduce the results of the demo on a CPU and also on a GPU only if I am lucky - newton iterations would fail at random time steps. After providing more output from lineax, it's clear that a form of iterative breakdown has occurred in a linear solve, during nearly every incremental linear solve within a newton step, and eventually a NaN appears at a random timestep. Importantly, this lineax throw/exception it is not replicated while using a CPU to solve the same problem - no iterative solve breakdown occurs. Additionally, the solution I'm obtaining seems correct when compared to the analytical eigenmode/frequency on the CPU (and the GPU when I was able to skirt the numerical issues I'm seeing).

A Minimal Example

After iterating through dynamics-related updates (when I first saw the issue occur) on a previously working problem, I was able to reproduce a breakdown in the iterative BiCGSTAB solve by using a buggy material model:

def stress(u_grad):
    I = np.eye(3)
    F = u_grad + I
    C = F.T @ F
    E = (1/2) * (C - I)

    # piola kirchoff stress - BUGGY on GPU only, and only when dynamics are non-negligible
    # P = (lmbda * np.trace(E) * I + 2 * mu * E)

    # correct piola kirchoff stress - NOT 'buggy' on GPU or CPU
    P = F @ (lmbda * np.trace(E) * I + 2 * mu * E)
    return P

While this might imply that there is a bug in my code, the confusing part is that lineax linear solves converge on a CPU (the newton solve does not converge due to the poorly defined material model). In contrast, running the code on a GPU sees the solve break down and occasionally result in a NaN. The attached logs show the difference in solver behavior. This is the same issue that I'm seeing in my larger structural dynamics code (when using a 'correct' material model).

gpu_out.log
cpu_out.log

Questions

Even if there is a small bug in my code somewhere, why is there a difference in CPU and GPU behavior during the linear solve? I have done everything I can to force all computations to use float64 (see 'debugging attempts`)

Other questions:

  1. Any tips for how to continue to debug this? I'm hitting my JAX debugging knowledge limits!
  2. Is this observed difference expected for ill-conditioned systems? Should I just be looking for a better preconditioner? (I'm using a jacobi preconditioner at the moment...)

Debugging Attempts

While trying to diagnose the problem, I attempted or noticed the following (nothing has solved the problem):

  • setting jax_default_matmul_precision to 'highest' to force GPU-based linear algebra backends to use 64bit precision (I am on an A100 fyi; I'm trying to use 'highest' per the discussion in JAX #19444): JAX #22557, tracked by JAX #18934
  • GPU vs. CPU different behavior JAX #22382. There are some in-place updates that are made via .at[] that I haven't looked into changing quite yet, but it seems like this has been fixed in XLA #19716.

Other debugging attempts/important mentions:

  • 64 bit precision is enabled in JAX
  • jax.debug_nans - doesn't catch the issue. un-jitting the linear solve and using a breakpoint to re-run the problematic/NaN-inducing iteration, with the same LHS, RHS, and initial guess, does NOT reproduce the NaN

JAX version

note: I'm using lineax 0.0.8 and equinox 0.13.1

jax:    0.6.2
jaxlib: 0.6.2
numpy:  2.2.2
python: 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0]
device info: NVIDIA A100 80GB PCIe-2, 2 local devices"
process_count: 1
platform: uname_result(system='Linux', node='tralfamadore', release='6.8.0-85-generic', version='#85~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 16:18:59 UTC 2', machine='x86_64')

$ nvidia-smi
Wed Oct 15 14:14:58 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.65.06              Driver Version: 580.65.06      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100 80GB PCIe          Off |   00000000:17:00.0 Off |                    0 |
| N/A   38C    P0             72W /  300W |     445MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off |   00000000:65:00.0 Off |                    0 |
| N/A   37C    P0             72W /  300W |     441MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            2176      G   /usr/lib/xorg/Xorg                        4MiB |
|    0   N/A  N/A          656991      C   python                                  422MiB |
|    1   N/A  N/A            2176      G   /usr/lib/xorg/Xorg                        4MiB |
|    1   N/A  N/A          656991      C   python                                  418MiB |
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions