Disclaimer: This issue was prepared with the help of AI and may be incorrect. If it is not relevant or accurate, please feel free to close it.
Summary
The Mamba3 backward pass crashes with CUDA error: misaligned address in mamba3_siso_bwd_kernel_dqkv when the number of heads is not divisible by 4 and the sequence length is n
ot divisible by 4. The forward pass works fine in all cases.
Environment
mamba-ssm: 2.3.1 (installed from main branch, commit 126bbf2)
torch: 2.10.0+cu130
triton: 3.6.0
- GPU: NVIDIA RTX PRO 6000 Blackwell Workstation Edition (SM120)
- CUDA: 13.0
- OS: Linux (Ubuntu 24.04)
Minimal reproduction
import subprocess, sys, os
configs = [
(128, 64, 50), # nheads=4, seqlen=50 (not %4)
(192, 64, 50), # nheads=6, seqlen=50
(256, 64, 50), # nheads=8, seqlen=50
(320, 64, 50), # nheads=10, seqlen=50
(320, 64, 64), # nheads=10, seqlen=64 (divisible by 4)
(320, 64, 100), # nheads=10, seqlen=100 (divisible by 4)
]
for d_model, headdim, seqlen in configs:
nheads = (d_model * 2) // headdim
label = f"d_model={d_model} nheads={nheads} seqlen={seqlen}"
code = f"""
import torch
from mamba_ssm.modules.mamba3 import Mamba3
m = Mamba3(d_model={d_model}, d_state=64, expand=2, headdim={headdim}).cuda().bfloat16()
x = torch.randn(2, {seqlen}, {d_model}, device="cuda", dtype=torch.bfloat16)
y = m(x)
y.sum().backward()
print("OK")
"""
# Each test in a separate process to avoid GPU error state contamination
result = subprocess.run(
[sys.executable, "-c", code],
capture_output=True, text=True, timeout=120,
env={**os.environ, "CUDA_VISIBLE_DEVICES": "0"},
)
status = "OK" if "OK" in result.stdout else "CRASH"
print(f" {status} {label}")
Expected output (confirmed on our setup):
OK d_model=128 nheads=4 seqlen=50
CRASH d_model=192 nheads=6 seqlen=50
OK d_model=256 nheads=8 seqlen=50
CRASH d_model=320 nheads=10 seqlen=50
OK d_model=320 nheads=10 seqlen=64
OK d_model=320 nheads=10 seqlen=100
Crash pattern
The crash occurs if and only if both conditions are true:
nheads % 4 != 0 (e.g., nheads = 6, 10)
seqlen % 4 != 0 (e.g., seqlen = 10, 25, 30, 49, 50, 51, 63, 65, 95, ...)
If either nheads or seqlen is divisible by 4, backward works fine.
nheads sweep (seqlen=50, headdim=64):
| d_model |
nheads |
nheads % 4 |
Result |
| 128 |
4 |
0 |
OK |
| 192 |
6 |
2 |
CRASH |
| 256 |
8 |
0 |
OK |
| 320 |
10 |
2 |
CRASH |
| 384 |
12 |
0 |
OK |
| 512 |
16 |
0 |
OK |
seqlen sweep (d_model=320, nheads=10):
| seqlen |
seqlen % 4 |
Result |
| 4 |
0 |
OK |
| 8 |
0 |
OK |
| 10 |
2 |
CRASH |
| 16 |
0 |
OK |
| 20 |
0 |
OK |
| 25 |
1 |
CRASH |
| 30 |
2 |
CRASH |
| 32 |
0 |
OK |
| 48 |
0 |
OK |
| 50 |
2 |
CRASH |
| 60 |
0 |
OK |
| 63 |
3 |
CRASH |
| 64 |
0 |
OK |
| 65 |
1 |
CRASH |
| 96 |
0 |
OK |
| 100 |
0 |
OK |
| 128 |
0 |
OK |
Error traceback
File "mamba_ssm/ops/triton/mamba3/mamba3_siso_bwd.py", line 726, in compute_dqkv
mamba3_siso_bwd_kernel_dqkv[grid](
...
File "triton/runtime/autotuner.py", line 164, in _bench
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
...
RuntimeError: Triton Error [CUDA]: misaligned address
The crash happens during Triton autotuner benchmarking of the mamba3_siso_bwd_kernel_dqkv kernel. The error originates in the kernel launch itself, not in the autotuner logic.
Practical impact
This prevents using Mamba3 with d_model values that produce non-multiple-of-4 nheads (e.g., d_model=320 with headdim=64 gives nheads=10). In practice, variable-length inputs
(common in speech/audio) will inevitably produce seqlen % 4 != 0 after padding/subsampling, triggering the crash.
Notes
- The forward pass works for all configurations.
- Restricting autotuner
num_stages from [1, 2, 3] to [1, 2] does not fix the issue.
- Only tested on Blackwell SM120 (RTX PRO 6000). Not confirmed whether datacenter Blackwell (SM100) or older architectures are affected.
Disclaimer: This issue was prepared with the help of AI and may be incorrect. If it is not relevant or accurate, please feel free to close it.
Summary
The Mamba3 backward pass crashes with
CUDA error: misaligned addressinmamba3_siso_bwd_kernel_dqkvwhen the number of heads is not divisible by 4 and the sequence length is not divisible by 4. The forward pass works fine in all cases.
Environment
mamba-ssm: 2.3.1 (installed frommainbranch, commit126bbf2)torch: 2.10.0+cu130triton: 3.6.0Minimal reproduction
Expected output (confirmed on our setup):
Crash pattern
The crash occurs if and only if both conditions are true:
nheads % 4 != 0(e.g., nheads = 6, 10)seqlen % 4 != 0(e.g., seqlen = 10, 25, 30, 49, 50, 51, 63, 65, 95, ...)If either
nheadsorseqlenis divisible by 4, backward works fine.nheads sweep (seqlen=50, headdim=64):
seqlen sweep (d_model=320, nheads=10):
Error traceback
The crash happens during Triton autotuner benchmarking of the
mamba3_siso_bwd_kernel_dqkvkernel. The error originates in the kernel launch itself, not in the autotuner logic.Practical impact
This prevents using Mamba3 with
d_modelvalues that produce non-multiple-of-4nheads(e.g.,d_model=320withheaddim=64givesnheads=10). In practice, variable-length inputs(common in speech/audio) will inevitably produce
seqlen % 4 != 0after padding/subsampling, triggering the crash.Notes
num_stagesfrom[1, 2, 3]to[1, 2]does not fix the issue.