diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 4e8aeeaf8..089ecf256 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -5,6 +5,8 @@ import aiter import torch +import triton +import triton.language as tl from aiter import ( QuantType, layernorm2d_fwd, @@ -288,6 +290,61 @@ def forward( return x, residual +# decode +@triton.jit +def _rmsnorm_gated_contiguous_128_kernel( + x_ptr, + z_ptr, + weight_ptr, + out_ptr, + num_heads: tl.constexpr, + eps: tl.constexpr, +): + token_id = tl.program_id(0) + head_id = tl.program_id(1) + offsets = tl.arange(0, 128) + row_offset = (token_id * num_heads + head_id) * 128 + + x = tl.load(x_ptr + row_offset + offsets, cache_modifier=".ca").to(tl.float32) + z = tl.load(z_ptr + row_offset + offsets, cache_modifier=".ca").to(tl.float32) + weight = tl.load(weight_ptr + offsets, cache_modifier=".ca").to(tl.float32) + + variance = tl.sum(x * x, axis=0) * 0.0078125 + inv_rms = tl.rsqrt(variance + eps) + gate = z * tl.sigmoid(z) + out = x * inv_rms * weight * gate + + tl.store(out_ptr + row_offset + offsets, out) + + +# prefill +@triton.jit +def _rmsnorm_gated_contiguous_128_tiled_rows_kernel( + x_ptr, + z_ptr, + weight_ptr, + out_ptr, + num_rows: tl.constexpr, + eps: tl.constexpr, + block_rows: tl.constexpr, +): + row_offsets = tl.program_id(0) * block_rows + tl.arange(0, block_rows) + dim_offsets = tl.arange(0, 128) + mask_rows = row_offsets < num_rows + offsets = row_offsets[:, None] * 128 + dim_offsets[None, :] + + x = tl.load(x_ptr + offsets, mask=mask_rows[:, None], other=0.0).to(tl.float32) + z = tl.load(z_ptr + offsets, mask=mask_rows[:, None], other=0.0).to(tl.float32) + weight = tl.load(weight_ptr + dim_offsets, cache_modifier=".ca").to(tl.float32) + + variance = tl.sum(x * x, axis=1) * 0.0078125 + inv_rms = tl.rsqrt(variance + eps) + gate = z * tl.sigmoid(z) + out = x * inv_rms[:, None] * weight[None, :] * gate + + tl.store(out_ptr + offsets, out, mask=mask_rows[:, None]) + + class RMSNormGated(nn.Module): """RMS Normalization with optional gating. @@ -360,6 +417,55 @@ def __init__( def reset_parameters(self): torch.nn.init.ones_(self.weight) + def forward_triton(self, x: torch.Tensor, z: torch.Tensor): + if ( + z is None + or x.ndim != 3 + or self.group_size is not None + or not self.norm_before_gate + or x.shape[-1] != 128 + or not x.is_contiguous() + or not z.is_contiguous() + ): + return self.forward_native(x, z) + + num_tokens, num_heads, head_dim = x.shape + out = torch.empty( + (num_tokens, num_heads * head_dim), + dtype=x.dtype, + device=x.device, + ) + + num_rows = num_tokens * num_heads + if num_rows >= 65536: + block_rows = 32 + _rmsnorm_gated_contiguous_128_tiled_rows_kernel[ + (triton.cdiv(num_rows, block_rows),) + ]( + x, + z, + self.weight, + out, + num_rows, + self.eps, + block_rows, + num_warps=4, + num_stages=1, + ) + else: + _rmsnorm_gated_contiguous_128_kernel[(num_tokens, num_heads)]( + x, + z, + self.weight, + out, + num_heads, + self.eps, + num_warps=1, + num_stages=1, + ) + + return (out, None) + def forward_native( self, x: torch.Tensor, z: torch.Tensor ) -> tuple[torch.Tensor, None]: @@ -479,7 +585,7 @@ def forward( if self.use_fused_fp8_quant: return self.forward_fused_fp8(x, z) - return self.forward_native(x, z) + return self.forward_triton(x, z) class GemmaRMSNorm(nn.Module): @@ -547,13 +653,11 @@ def forward_cuda( x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if torch.compiler.is_compiling(): - return self.forward_native(x, residual) + from atom.model_ops.triton_gemma_rmsnorm import gemma_rmsnorm_triton - if not getattr(self, "_is_compiled", False): - self.forward_static = torch.compile(self.forward_static) # type: ignore - self._is_compiled = True - return self.forward_native(x, residual) + return gemma_rmsnorm_triton( + x, self.weight.data, self.variance_epsilon, residual + ) def _forward_fused_fp8(self, x, residual=None): from aiter.ops.fused_qk_rmsnorm_group_quant import fused_qk_rmsnorm_group_quant @@ -605,10 +709,6 @@ def forward( # --------------------------------------------------------------------------- # Fused Q/K RMSNorm Triton kernel # --------------------------------------------------------------------------- -import triton # noqa: E402 -import triton.language as tl # noqa: E402 - - @triton.jit def _fused_qk_norm_single_kernel( q_ptr,