Skip to content

Mamba3 backward pass crashes with misaligned address when nheads % 4 != 0 and seqlen % 4 != 0 #886

@Dolfik1

Description

@Dolfik1

Disclaimer: This issue was prepared with the help of AI and may be incorrect. If it is not relevant or accurate, please feel free to close it.

Summary

The Mamba3 backward pass crashes with CUDA error: misaligned address in mamba3_siso_bwd_kernel_dqkv when the number of heads is not divisible by 4 and the sequence length is n
ot divisible by 4. The forward pass works fine in all cases.

Environment

  • mamba-ssm: 2.3.1 (installed from main branch, commit 126bbf2)
  • torch: 2.10.0+cu130
  • triton: 3.6.0
  • GPU: NVIDIA RTX PRO 6000 Blackwell Workstation Edition (SM120)
  • CUDA: 13.0
  • OS: Linux (Ubuntu 24.04)

Minimal reproduction

import subprocess, sys, os

configs = [
    (128, 64, 50),   # nheads=4,  seqlen=50 (not %4)
    (192, 64, 50),   # nheads=6,  seqlen=50
    (256, 64, 50),   # nheads=8,  seqlen=50
    (320, 64, 50),   # nheads=10, seqlen=50
    (320, 64, 64),   # nheads=10, seqlen=64 (divisible by 4)
    (320, 64, 100),  # nheads=10, seqlen=100 (divisible by 4)
]

for d_model, headdim, seqlen in configs:
    nheads = (d_model * 2) // headdim
    label = f"d_model={d_model} nheads={nheads} seqlen={seqlen}"
    code = f"""
import torch
from mamba_ssm.modules.mamba3 import Mamba3
m = Mamba3(d_model={d_model}, d_state=64, expand=2, headdim={headdim}).cuda().bfloat16()
x = torch.randn(2, {seqlen}, {d_model}, device="cuda", dtype=torch.bfloat16)
y = m(x)
y.sum().backward()
print("OK")
"""
    # Each test in a separate process to avoid GPU error state contamination
    result = subprocess.run(
        [sys.executable, "-c", code],
        capture_output=True, text=True, timeout=120,
        env={**os.environ, "CUDA_VISIBLE_DEVICES": "0"},
    )
    status = "OK" if "OK" in result.stdout else "CRASH"
    print(f"  {status}  {label}")

Expected output (confirmed on our setup):

  OK     d_model=128 nheads=4 seqlen=50
  CRASH  d_model=192 nheads=6 seqlen=50
  OK     d_model=256 nheads=8 seqlen=50
  CRASH  d_model=320 nheads=10 seqlen=50
  OK     d_model=320 nheads=10 seqlen=64
  OK     d_model=320 nheads=10 seqlen=100

Crash pattern

The crash occurs if and only if both conditions are true:

  1. nheads % 4 != 0 (e.g., nheads = 6, 10)
  2. seqlen % 4 != 0 (e.g., seqlen = 10, 25, 30, 49, 50, 51, 63, 65, 95, ...)

If either nheads or seqlen is divisible by 4, backward works fine.

nheads sweep (seqlen=50, headdim=64):

d_model nheads nheads % 4 Result
128 4 0 OK
192 6 2 CRASH
256 8 0 OK
320 10 2 CRASH
384 12 0 OK
512 16 0 OK

seqlen sweep (d_model=320, nheads=10):

seqlen seqlen % 4 Result
4 0 OK
8 0 OK
10 2 CRASH
16 0 OK
20 0 OK
25 1 CRASH
30 2 CRASH
32 0 OK
48 0 OK
50 2 CRASH
60 0 OK
63 3 CRASH
64 0 OK
65 1 CRASH
96 0 OK
100 0 OK
128 0 OK

Error traceback

File "mamba_ssm/ops/triton/mamba3/mamba3_siso_bwd.py", line 726, in compute_dqkv
    mamba3_siso_bwd_kernel_dqkv[grid](
...
File "triton/runtime/autotuner.py", line 164, in _bench
    return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
...
RuntimeError: Triton Error [CUDA]: misaligned address

The crash happens during Triton autotuner benchmarking of the mamba3_siso_bwd_kernel_dqkv kernel. The error originates in the kernel launch itself, not in the autotuner logic.

Practical impact

This prevents using Mamba3 with d_model values that produce non-multiple-of-4 nheads (e.g., d_model=320 with headdim=64 gives nheads=10). In practice, variable-length inputs
(common in speech/audio) will inevitably produce seqlen % 4 != 0 after padding/subsampling, triggering the crash.

Notes

  • The forward pass works for all configurations.
  • Restricting autotuner num_stages from [1, 2, 3] to [1, 2] does not fix the issue.
  • Only tested on Blackwell SM120 (RTX PRO 6000). Not confirmed whether datacenter Blackwell (SM100) or older architectures are affected.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions