Skip to content

Fix Mamba-3 Triton kernels for AMD ROCm/RDNA4#914

Open
ChrisLundquist wants to merge 5 commits intostate-spaces:mainfrom
ChrisLundquist:rdna4-compat
Open

Fix Mamba-3 Triton kernels for AMD ROCm/RDNA4#914
ChrisLundquist wants to merge 5 commits intostate-spaces:mainfrom
ChrisLundquist:rdna4-compat

Conversation

@ChrisLundquist
Copy link
Copy Markdown

@ChrisLundquist ChrisLundquist commented Apr 12, 2026

Summary

Three fixes to enable Mamba-3 Triton kernels on AMD ROCm / HIP GPUs:

  1. Backend-conditional PTX/builtin dispatch in `utils.py`: detect AMD HIP at import time, keep PTX SFU inline asm on NVIDIA (zero perf regression), use portable Triton builtins (`tl.cos`, `tl.sin`, `tl.sigmoid`) on AMD
  2. Add `num_warps=1` autotune configs to all Mamba-3 SISO kernels (fwd, bwd, angle_dt, step) to reduce register pressure on GPUs with smaller VGPR files
  3. Strip `maxnreg` from autotune configs on HIP — Triton's HIP backend does not recognize the `maxnreg` keyword (added in Add maxnreg autotuning to Mamba-3 Triton kernels #905), raising `KeyError`. A `_maxnreg()` helper returns `{}` on HIP and `{maxnreg: value}` on CUDA.

Root cause

The Mamba-3 SISO kernels use `tl.inline_asm_elementwise` with PTX instructions (`cos.approx.f32`, `sin.approx.f32`, `tanh.approx.f32`) and `"=f,f"` register constraints. The `f` constraint specifies an NVIDIA floating-point register class that the AMDGCN backend does not recognize, causing:

error: couldn't allocate output register for constraint 'f'

on all AMD GPUs. The Mamba-2 SSD kernels and Mamba-3 MIMO rotary step kernel do not use PTX inline assembly and already work on AMD.

Approach

Rather than unconditionally replacing PTX asm with Triton builtins (which would regress NVIDIA — `tl.cos` compiles through libdevice `__nv_cosf` rather than the single-cycle SFU `cos.approx.f32`), we detect the backend at module load time:

```python
if _is_hip():
@triton.jit
def cos_approx(x):
return tl.cos(x) # portable builtin
else:
@triton.jit
def cos_approx(x):
return tl.inline_asm_elementwise(
"cos.approx.f32 $0, $1;", # PTX SFU, single-cycle
constraints="=f,f", ...)
```

NVIDIA: zero changes — same PTX SFU instructions as before.
AMD: portable Triton builtins that compile to AMDGCN instructions.

The AMD path for `tanh_approx` uses `2*tl.sigmoid(2x) - 1` (mathematically equivalent), the same pattern already used by the MIMO rotary step kernel (`mamba3_mimo_rotary_step.py:76`).

The `num_warps=1, num_stages=1` autotune configs are only selected by the autotuner when they benchmark faster on the current GPU, so they cannot regress NVIDIA performance.

Hardware-verified test results

Tested on AMD Radeon RX 9070 XT (gfx1201, RDNA4, 16GB), ROCm 7.2.1, PyTorch 2.12.0.dev+rocm7.2, Triton 3.6.0:

Test Result
`cos_approx` PASS — max error 0.00e+00 vs PyTorch
`sin_approx` PASS — max error 0.00e+00 vs PyTorch
`tanh_approx` PASS — max error 1.79e-07 vs PyTorch
`sech2_approx` PASS — max error 3.59e-07 vs PyTorch
`mamba3_siso_fwd` kernel PASS — compiles, autotunes, no NaN/Inf
`angle_dt_fwd` kernel PASS
`mamba3_siso_combined` pipeline PASS — end-to-end forward

Forward throughput (batch=2, nheads=8, hdim_qk=128, hdim_v=64):

seqlen tok/s
128 1,296k
256 2,654k
512 3,504k
1024 3,913k

Files changed

File Change
`utils.py` Backend detection, conditional trig/activation defs, `_maxnreg()` helper
`mamba3_siso_fwd.py` `num_warps=1` configs, `**_maxnreg(r)`
`mamba3_siso_bwd.py` `num_warps=1` configs (4 kernels), `**_maxnreg(r)`
`angle_dt.py` `num_warps=1` in warp sweep
`mamba3_siso_step.py` `num_warps=1` in warp sweep
`test_mamba3_siso.py` `test_mamba3_portable_math_utils`

Test plan

  • `test_mamba3_portable_math_utils` — cos/sin/tanh/sech2 compile and produce correct results
  • `mamba3_siso_fwd` kernel compiles and runs on RDNA4
  • `angle_dt_fwd` kernel compiles and runs on RDNA4
  • `mamba3_siso_combined` end-to-end forward pipeline on RDNA4
  • Existing `test_mamba3_siso_combined_batched` on NVIDIA (verify no regression)
  • Existing `test_mamba3_siso_combined_varlen` on NVIDIA (verify no regression)

Relates to #65 (ROCm support), #821 (ROCm performance)

🤖 Generated with Claude Code

ChrisLundquist and others added 5 commits April 12, 2026 11:18
Replace NVIDIA PTX inline assembly in Mamba-3 SISO math utils with
portable Triton builtins (tl.cos, tl.sin, tl.sigmoid) that compile on
both NVIDIA and AMD backends.

The PTX "=f,f" register constraints in cos_approx, sin_approx,
tanh_approx, and sech2_approx caused LLVM/AMDGCN compilation failures:
"error: couldn't allocate output register for constraint 'f'"
because AMDGCN does not recognize the 'f' register class.

Also add num_warps=1 autotune configs to all Mamba-3 kernels (fwd, bwd,
angle_dt, step) to reduce register pressure on GPUs with smaller VGPR
files like AMD RDNA4. The autotuner selects these only when they
benchmark faster, so NVIDIA performance is unaffected.

The MIMO rotary step kernel already uses tl.cos/tl.sin and the
sigmoid-based tanh pattern — this makes SISO consistent with MIMO.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Use module-level backend detection (_is_hip) to conditionally define
the trig/activation helpers: PTX inline asm on NVIDIA (preserving
single-cycle SFU cos.approx/sin.approx/tanh.approx instructions),
portable Triton builtins on AMD HIP/ROCm.

tl.cos/tl.sin compile through libdevice on NVIDIA (__nv_cosf) rather
than the SFU approximate path, so unconditionally replacing PTX asm
would regress NVIDIA throughput. This conditional approach has zero
performance impact on NVIDIA while enabling AMD support.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Triton's HIP backend does not recognize the maxnreg keyword argument
(added for NVIDIA in PR state-spaces#905), raising:
  KeyError: 'Keyword argument maxnreg was specified but unrecognised'

Add _maxnreg() helper that returns {} on HIP and {maxnreg: value} on
CUDA, and use **_maxnreg(r) in all autotune Config constructors.

Verified on AMD RX 9070 XT (gfx1201):
  - All math utils: PASS (cos/sin err=0, tanh err=1.8e-7)
  - mamba3_siso_fwd kernel: PASS
  - angle_dt_fwd kernel: PASS
  - mamba3_siso_combined pipeline: PASS
  - Forward throughput: 3.9M tok/s at seqlen=1024

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Check torch.version.hip (set at build time) before querying the
  Triton runtime, so backend detection works even if the Triton driver
  is not yet initialized at import time.
- Cache the result in _IS_HIP to avoid repeated runtime queries.
- Add MAXNREG_VALUES / MAXNREG_VALUES_SMALL constants that collapse to
  [None] on HIP (where all maxnreg values produce identical configs),
  eliminating ~3x redundant autotune benchmarks on first run.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The bare import at line 20 crashes on ROCm where the CUDA extension
is not built. Mamba-2/3 use Triton kernels and do not need this
extension, so a graceful fallback to None is appropriate.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant