add BF16 output path to fused Gated RMSNorm HIP kernel#3051
Conversation
Add a `gated_rmsnorm_bf16` variant that reuses the existing `gated_rmsnorm_fp8_group_quant` kernel with a compile-time `FUSE_QUANT=false` flag to skip quantization and output BF16/FP16 directly. Motivation: ATOM's Qwen3.5 GDN layers call RMSNormGated on [num_tokens, num_heads, head_dim=128] tensors. The native PyTorch path launches multiple small kernels. This fused HIP kernel eliminates intermediate allocations and reduces launch overhead. Benchmark (MI308X, 24 heads, head_dim=128, bf16): tokens=128: HIP 4.0 us vs Triton 4.0 us (1.0x) tokens=1024: HIP 9.9 us vs Triton 17.0 us (1.7x) tokens=4096: HIP 29.3 us vs Triton 52.3 us (1.8x) tokens=16384: HIP 122 us vs Triton 196 us (1.6x) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
This PR adds a BF16/FP16 output variant for the fused HIP Gated RMSNorm kernel by reusing the existing FP8 group-quant kernel with a compile-time FUSE_QUANT=false path, exposing it through the C++/pybind layer and a Python wrapper, plus a new validation/benchmark script.
Changes:
- Extend the existing fused Gated RMSNorm + FP8 group quant kernel to optionally skip quantization and write BF16/FP16 outputs directly.
- Add C++ API + pybind export and a Python
compile_opsbinding forgated_rmsnorm_bf16. - Add an op_test script for correctness + performance comparison (HIP vs reference vs Triton).
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
csrc/kernels/gated_rmsnorm_quant_kernels.cu |
Adds FUSE_QUANT compile-time switch and implements a BF16/FP16 output launcher using BLOCK_SIZE=64. |
csrc/include/gated_rmsnorm_quant.h |
Declares the new gated_rmsnorm_bf16 C++ entry point. |
csrc/include/rocm_ops.hpp |
Exposes gated_rmsnorm_bf16 via the existing pybind macro for the gated RMSNorm quant module. |
aiter/ops/gated_rmsnorm_fp8_group_quant.py |
Adds a Python JIT binding for gated_rmsnorm_bf16 in the same compiled module. |
op_tests/test_gated_rmsnorm_bf16.py |
Adds a correctness/perf validation script for the BF16/FP16 output path, with Triton comparison. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| const int head_dim = x.size(2); | ||
|
|
||
| TORCH_CHECK(z.size(0) == num_tokens && z.size(1) == num_heads && z.size(2) == head_dim, | ||
| "Gating tensor z must have same shape as x"); |
| TORCH_CHECK(x.is_cuda(), "Input x must be on CUDA device"); | ||
| TORCH_CHECK(z.is_cuda(), "Input z must be on CUDA device"); | ||
| TORCH_CHECK(weight.is_cuda(), "Weight must be on CUDA device"); | ||
| TORCH_CHECK(out.is_cuda(), "Output must be on CUDA device"); | ||
|
|
||
| if (x.scalar_type() == at::ScalarType::BFloat16) { | ||
| TORCH_CHECK(out.scalar_type() == at::ScalarType::BFloat16, | ||
| "Output must be BFloat16 when input is BFloat16"); | ||
| gated_rmsnorm_bf16_launcher<opus::bf16_t>(out, x, z, weight, epsilon); |
|
@ganyi1996ppo This PR implements the approach you suggested in ROCm/ATOM#697 — adding a BF16 output path to the existing The results look promising: 1.6x–1.9x faster than the Triton kernel for 128+ tokens on MI308X with Qwen3.5 GDN shapes (24 heads, head_dim=128). Would appreciate your review when you get a chance. — Claude (AI assistant) |
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary
gated_rmsnorm_bf16variant that reuses the existinggated_rmsnorm_fp8_group_quantkernel with a compile-timeFUSE_QUANT=falsetemplate parameter to skip quantization and output BF16/FP16 directlyBLOCK_SIZE=64(single warp) for optimal GPU occupancy across all workload sizesif constexprensures zero overheadMotivation
ATOM's Qwen3.5 GDN layers call
RMSNormGatedon[num_tokens, num_heads, head_dim=128]tensors. The native PyTorch path launches multiple small GPU kernels (variance reduction, rsqrt, silu, element-wise multiply). This fused HIP kernel eliminates intermediate tensor allocations and reduces kernel launch overhead.Ref: ROCm/ATOM#697
Benchmark
MI308X, Qwen3.5 GDN layer shapes (24 heads, head_dim=128, bf16):
HIP kernel is 1.6x–1.9x faster than Triton for 128+ tokens (the dominant prefill workload). For tiny decode batches (1–64 tokens), Triton is slightly faster but absolute times are <1 us difference.
Files Changed
csrc/kernels/gated_rmsnorm_quant_kernels.cu— kernel template + bf16 launchercsrc/include/gated_rmsnorm_quant.h— declarationcsrc/include/rocm_ops.hpp— pybind macroaiter/ops/gated_rmsnorm_fp8_group_quant.py— Python interfaceop_tests/test_gated_rmsnorm_bf16.py— correctness + perf testTest plan
test_gated_rmsnorm_fp8_group_quant.py)🤖 Generated with Claude Code