From 6f171eb62e50c679d2e3b09ab76615a7fcd5b3c6 Mon Sep 17 00:00:00 2001 From: zovonoir Date: Wed, 6 May 2026 16:22:37 +0800 Subject: [PATCH 1/4] add layernorm triton kernel for qwen3.5 --- atom/model_ops/layernorm.py | 120 ++++++++++++++++++++++++++++++++---- 1 file changed, 109 insertions(+), 11 deletions(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 4e8aeeaf8..78128e3bb 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -5,6 +5,8 @@ import aiter import torch +import triton +import triton.language as tl from aiter import ( QuantType, layernorm2d_fwd, @@ -287,6 +289,59 @@ def forward( ) return x, residual +# decode +@triton.jit +def _rmsnorm_gated_contiguous_128_kernel( + x_ptr, + z_ptr, + weight_ptr, + out_ptr, + num_heads: tl.constexpr, + eps: tl.constexpr, +): + token_id = tl.program_id(0) + head_id = tl.program_id(1) + offsets = tl.arange(0, 128) + row_offset = (token_id * num_heads + head_id) * 128 + + x = tl.load(x_ptr + row_offset + offsets, cache_modifier=".ca").to(tl.float32) + z = tl.load(z_ptr + row_offset + offsets, cache_modifier=".ca").to(tl.float32) + weight = tl.load(weight_ptr + offsets, cache_modifier=".ca").to(tl.float32) + + variance = tl.sum(x * x, axis=0) * 0.0078125 + inv_rms = tl.rsqrt(variance + eps) + gate = z * tl.sigmoid(z) + out = x * inv_rms * weight * gate + + tl.store(out_ptr + row_offset + offsets, out) + +# prefill +@triton.jit +def _rmsnorm_gated_contiguous_128_tiled_rows_kernel( + x_ptr, + z_ptr, + weight_ptr, + out_ptr, + num_rows: tl.constexpr, + eps: tl.constexpr, + block_rows: tl.constexpr, +): + row_offsets = tl.program_id(0) * block_rows + tl.arange(0, block_rows) + dim_offsets = tl.arange(0, 128) + mask_rows = row_offsets < num_rows + offsets = row_offsets[:, None] * 128 + dim_offsets[None, :] + + x = tl.load(x_ptr + offsets, mask=mask_rows[:, None], other=0.0).to(tl.float32) + z = tl.load(z_ptr + offsets, mask=mask_rows[:, None], other=0.0).to(tl.float32) + weight = tl.load(weight_ptr + dim_offsets, cache_modifier=".ca").to(tl.float32) + + variance = tl.sum(x * x, axis=1) * 0.0078125 + inv_rms = tl.rsqrt(variance + eps) + gate = z * tl.sigmoid(z) + out = x * inv_rms[:, None] * weight[None, :] * gate + + tl.store(out_ptr + offsets, out, mask=mask_rows[:, None]) + class RMSNormGated(nn.Module): """RMS Normalization with optional gating. @@ -360,6 +415,55 @@ def __init__( def reset_parameters(self): torch.nn.init.ones_(self.weight) + def forward_triton(self, x: torch.Tensor, z: torch.Tensor): + if ( + z is None + or x.ndim != 3 + or self.group_size is not None + or not self.norm_before_gate + or x.shape[-1] != 128 + or not x.is_contiguous() + or not z.is_contiguous() + ): + return self.forward_native(x, z) + + num_tokens, num_heads, head_dim = x.shape + out = torch.empty( + (num_tokens, num_heads * head_dim), + dtype=x.dtype, + device=x.device, + ) + + num_rows = num_tokens * num_heads + if num_rows >= 65536: + block_rows = 32 + _rmsnorm_gated_contiguous_128_tiled_rows_kernel[ + (triton.cdiv(num_rows, block_rows),) + ]( + x, + z, + self.weight, + out, + num_rows, + self.eps, + block_rows, + num_warps=4, + num_stages=1, + ) + else: + _rmsnorm_gated_contiguous_128_kernel[(num_tokens, num_heads)]( + x, + z, + self.weight, + out, + num_heads, + self.eps, + num_warps=1, + num_stages=1, + ) + + return (out, None) + def forward_native( self, x: torch.Tensor, z: torch.Tensor ) -> tuple[torch.Tensor, None]: @@ -479,7 +583,7 @@ def forward( if self.use_fused_fp8_quant: return self.forward_fused_fp8(x, z) - return self.forward_native(x, z) + return self.forward_triton(x, z) class GemmaRMSNorm(nn.Module): @@ -547,13 +651,11 @@ def forward_cuda( x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if torch.compiler.is_compiling(): - return self.forward_native(x, residual) + from atom.model_ops.triton_gemma_rmsnorm import gemma_rmsnorm_triton - if not getattr(self, "_is_compiled", False): - self.forward_static = torch.compile(self.forward_static) # type: ignore - self._is_compiled = True - return self.forward_native(x, residual) + return gemma_rmsnorm_triton( + x, self.weight.data, self.variance_epsilon, residual + ) def _forward_fused_fp8(self, x, residual=None): from aiter.ops.fused_qk_rmsnorm_group_quant import fused_qk_rmsnorm_group_quant @@ -605,10 +707,6 @@ def forward( # --------------------------------------------------------------------------- # Fused Q/K RMSNorm Triton kernel # --------------------------------------------------------------------------- -import triton # noqa: E402 -import triton.language as tl # noqa: E402 - - @triton.jit def _fused_qk_norm_single_kernel( q_ptr, From 11c129328ceb7eee58af87fa5300a49ab09bae59 Mon Sep 17 00:00:00 2001 From: zovonoir Date: Wed, 6 May 2026 16:29:46 +0800 Subject: [PATCH 2/4] =?UTF-8?q?style:=20fix=20Black=20formatting=20?= =?UTF-8?q?=E2=80=94=20add=20missing=20blank=20lines=20between=20top-level?= =?UTF-8?q?=20definitions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 (1M context) --- atom/model_ops/layernorm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 78128e3bb..089ecf256 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -289,6 +289,7 @@ def forward( ) return x, residual + # decode @triton.jit def _rmsnorm_gated_contiguous_128_kernel( @@ -315,6 +316,7 @@ def _rmsnorm_gated_contiguous_128_kernel( tl.store(out_ptr + row_offset + offsets, out) + # prefill @triton.jit def _rmsnorm_gated_contiguous_128_tiled_rows_kernel( From 637c57fefacc031203bfb5925e152f9bcd99b051 Mon Sep 17 00:00:00 2001 From: zovonoir Date: Wed, 6 May 2026 17:02:39 +0800 Subject: [PATCH 3/4] test: temporarily disable Triton RMSNormGated path to isolate CI failures Keep the two new Triton kernels but route forward() back to forward_native() so we can verify whether CI failures are pre-existing. Co-Authored-By: Claude Opus 4.6 (1M context) --- atom/model_ops/layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 089ecf256..952b838cf 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -585,7 +585,7 @@ def forward( if self.use_fused_fp8_quant: return self.forward_fused_fp8(x, z) - return self.forward_triton(x, z) + return self.forward_native(x, z) class GemmaRMSNorm(nn.Module): From 82a3415fc7f94485006ef7df77352c494b2ded89 Mon Sep 17 00:00:00 2001 From: zovonoir Date: Wed, 6 May 2026 17:16:29 +0800 Subject: [PATCH 4/4] =?UTF-8?q?revert:=20re-enable=20Triton=20RMSNormGated?= =?UTF-8?q?=20path=20=E2=80=94=20CI=20failures=20confirmed=20pre-existing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 (1M context) --- atom/model_ops/layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 952b838cf..089ecf256 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -585,7 +585,7 @@ def forward( if self.use_fused_fp8_quant: return self.forward_fused_fp8(x, z) - return self.forward_native(x, z) + return self.forward_triton(x, z) class GemmaRMSNorm(nn.Module):