Fix Mamba-3 Triton kernels for AMD ROCm/RDNA4#914
Open
ChrisLundquist wants to merge 5 commits intostate-spaces:mainfrom
Open
Fix Mamba-3 Triton kernels for AMD ROCm/RDNA4#914ChrisLundquist wants to merge 5 commits intostate-spaces:mainfrom
ChrisLundquist wants to merge 5 commits intostate-spaces:mainfrom
Conversation
Replace NVIDIA PTX inline assembly in Mamba-3 SISO math utils with portable Triton builtins (tl.cos, tl.sin, tl.sigmoid) that compile on both NVIDIA and AMD backends. The PTX "=f,f" register constraints in cos_approx, sin_approx, tanh_approx, and sech2_approx caused LLVM/AMDGCN compilation failures: "error: couldn't allocate output register for constraint 'f'" because AMDGCN does not recognize the 'f' register class. Also add num_warps=1 autotune configs to all Mamba-3 kernels (fwd, bwd, angle_dt, step) to reduce register pressure on GPUs with smaller VGPR files like AMD RDNA4. The autotuner selects these only when they benchmark faster, so NVIDIA performance is unaffected. The MIMO rotary step kernel already uses tl.cos/tl.sin and the sigmoid-based tanh pattern — this makes SISO consistent with MIMO. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Use module-level backend detection (_is_hip) to conditionally define the trig/activation helpers: PTX inline asm on NVIDIA (preserving single-cycle SFU cos.approx/sin.approx/tanh.approx instructions), portable Triton builtins on AMD HIP/ROCm. tl.cos/tl.sin compile through libdevice on NVIDIA (__nv_cosf) rather than the SFU approximate path, so unconditionally replacing PTX asm would regress NVIDIA throughput. This conditional approach has zero performance impact on NVIDIA while enabling AMD support. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Triton's HIP backend does not recognize the maxnreg keyword argument (added for NVIDIA in PR state-spaces#905), raising: KeyError: 'Keyword argument maxnreg was specified but unrecognised' Add _maxnreg() helper that returns {} on HIP and {maxnreg: value} on CUDA, and use **_maxnreg(r) in all autotune Config constructors. Verified on AMD RX 9070 XT (gfx1201): - All math utils: PASS (cos/sin err=0, tanh err=1.8e-7) - mamba3_siso_fwd kernel: PASS - angle_dt_fwd kernel: PASS - mamba3_siso_combined pipeline: PASS - Forward throughput: 3.9M tok/s at seqlen=1024 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Check torch.version.hip (set at build time) before querying the Triton runtime, so backend detection works even if the Triton driver is not yet initialized at import time. - Cache the result in _IS_HIP to avoid repeated runtime queries. - Add MAXNREG_VALUES / MAXNREG_VALUES_SMALL constants that collapse to [None] on HIP (where all maxnreg values produce identical configs), eliminating ~3x redundant autotune benchmarks on first run. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The bare import at line 20 crashes on ROCm where the CUDA extension is not built. Mamba-2/3 use Triton kernels and do not need this extension, so a graceful fallback to None is appropriate. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Three fixes to enable Mamba-3 Triton kernels on AMD ROCm / HIP GPUs:
Root cause
The Mamba-3 SISO kernels use `tl.inline_asm_elementwise` with PTX instructions (`cos.approx.f32`, `sin.approx.f32`, `tanh.approx.f32`) and `"=f,f"` register constraints. The `f` constraint specifies an NVIDIA floating-point register class that the AMDGCN backend does not recognize, causing:
on all AMD GPUs. The Mamba-2 SSD kernels and Mamba-3 MIMO rotary step kernel do not use PTX inline assembly and already work on AMD.
Approach
Rather than unconditionally replacing PTX asm with Triton builtins (which would regress NVIDIA — `tl.cos` compiles through libdevice `__nv_cosf` rather than the single-cycle SFU `cos.approx.f32`), we detect the backend at module load time:
```python
if _is_hip():
@triton.jit
def cos_approx(x):
return tl.cos(x) # portable builtin
else:
@triton.jit
def cos_approx(x):
return tl.inline_asm_elementwise(
"cos.approx.f32 $0, $1;", # PTX SFU, single-cycle
constraints="=f,f", ...)
```
NVIDIA: zero changes — same PTX SFU instructions as before.
AMD: portable Triton builtins that compile to AMDGCN instructions.
The AMD path for `tanh_approx` uses `2*tl.sigmoid(2x) - 1` (mathematically equivalent), the same pattern already used by the MIMO rotary step kernel (`mamba3_mimo_rotary_step.py:76`).
The `num_warps=1, num_stages=1` autotune configs are only selected by the autotuner when they benchmark faster on the current GPU, so they cannot regress NVIDIA performance.
Hardware-verified test results
Tested on AMD Radeon RX 9070 XT (gfx1201, RDNA4, 16GB), ROCm 7.2.1, PyTorch 2.12.0.dev+rocm7.2, Triton 3.6.0:
Forward throughput (batch=2, nheads=8, hdim_qk=128, hdim_v=64):
Files changed
Test plan
Relates to #65 (ROCm support), #821 (ROCm performance)
🤖 Generated with Claude Code