Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tests/pytorch/triton_kernels/test_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,13 @@ def test_norm_triton(
zero_centered_gamma=zero_centered_gamma,

)
triton_bwd_outs = triton_bwd_func(*args["triton"])
# te_rmsnorm_bwd_triton accepts an `autotune` kwarg; te_layernorm_bwd_triton does not.
# Honor the same NVTE_TEST_TRITON_AUTOTUNE env toggle as the fwd path so
# default test runs avoid the autotune compile/sweep cost.
if norm == "rms":
triton_bwd_outs = triton_bwd_func(*args["triton"], autotune=autotune)
else:
triton_bwd_outs = triton_bwd_func(*args["triton"])

if norm == "layer":
dx_triton, dgamma_triton, dbeta_triton = triton_bwd_outs
Expand Down
132 changes: 118 additions & 14 deletions transformer_engine/pytorch/triton_kernels/norms_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# License for AMD contributions = MIT. See LICENSE for more information

import os
import torch
import triton
import warnings
Expand All @@ -20,8 +21,22 @@
_rmsnorm_fwd_triton,
_rmsnorm_fwd_triton_impl,
_rmsnorm_bwd_triton,
_rmsnorm_bwd_triton_impl,
_rmsnorm_bwd_dg_reduce_triton,
_rmsnorm_bwd_dg_reduce_triton_impl,
_fp8_transpose_2d_triton,
_fp8_transpose_2d_impl,
)

# Use the external LDS-tiled byte transpose instead of the in-kernel strided
# stores. Default on -- the in-kernel path is uncoalesced and bottlenecks
# every fp8_t shape. Set NVTE_RMS_EXTERNAL_TRANSPOSE=0 to fall back.
_USE_EXTERNAL_TRANSPOSE = os.environ.get("NVTE_RMS_EXTERNAL_TRANSPOSE", "1") == "1"

_fp8_transpose_kernels = {
True: _fp8_transpose_2d_triton,
False: _fp8_transpose_2d_impl,
}
from .layernorm import (
_layernorm_fwd_triton,
_layernorm_fwd_triton_impl,
Expand All @@ -41,6 +56,16 @@
False: _layernorm_fwd_triton_impl,
}
}

_rmsnorm_bwd_kernels = {
True: _rmsnorm_bwd_triton,
False: _rmsnorm_bwd_triton_impl,
}

_rmsnorm_bwd_dg_reduce_kernels = {
True: _rmsnorm_bwd_dg_reduce_triton,
False: _rmsnorm_bwd_dg_reduce_triton_impl,
}
# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd
def te_rmsnorm_fwd_triton(
input: torch.Tensor,
Expand Down Expand Up @@ -152,6 +177,10 @@ def _te_norm_fwd_triton(
out_transpose_ptr = None
out_transpose_stride = None
FP8_MAX = None
# When True, skip in-kernel strided transpose stores and dispatch a
# separate LDS-tiled transpose kernel after the main fwd. Only applies
# to the rms path for now.
use_external_transpose = False
if IS_FP8:
MAKE_TRANSPOSE = quantizer.columnwise_usage
amax = (
Expand All @@ -170,8 +199,11 @@ def _te_norm_fwd_triton(
dtype=out._data.dtype, device=device
)
out._transpose_invalid = False
out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype)
out_transpose_stride = out._transpose.stride(0)
use_external_transpose = _USE_EXTERNAL_TRANSPOSE and kernel == 'rms'
if not use_external_transpose:
# In-kernel strided transpose path; main kernel does the writes.
out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype)
out_transpose_stride = out._transpose.stride(0)

grid_fwd = lambda meta: (NUM_PRGMS,)
kernel_func = _norm_kernels[kernel][autotune]
Expand All @@ -195,7 +227,9 @@ def _te_norm_fwd_triton(
BLOCK_SIZE=BLOCK_SIZE,
IS_FP8=IS_FP8,
FP8_MAX=FP8_MAX,
MAKE_TRANSPOSE=MAKE_TRANSPOSE,
# Gate the in-kernel strided transpose stores off when we'll do the
# transpose externally via the LDS-tiled kernel.
MAKE_TRANSPOSE=(MAKE_TRANSPOSE and not use_external_transpose),
)
if kernel == 'layer':
kwargs["APPLY_ATOMIC"]=APPLY_ATOMIC
Expand All @@ -216,6 +250,29 @@ def _te_norm_fwd_triton(

kernel_func[grid_fwd](**kwargs)

if use_external_transpose:
# out._data: (N rows, H cols) row-major uint8; out._transpose: (H, N).
transpose_kernel = _fp8_transpose_kernels[autotune]
if autotune:
grid_t = lambda meta: (
triton.cdiv(N, meta['BLOCK_M']),
triton.cdiv(H, meta['BLOCK_N']),
)
transpose_kernel[grid_t](
out._data, out._transpose,
N, H,
out._data.stride(0), out._transpose.stride(0),
)
else:
BLOCK_M, BLOCK_N = 64, 64
grid_t = (triton.cdiv(N, BLOCK_M), triton.cdiv(H, BLOCK_N))
transpose_kernel[grid_t](
out._data, out._transpose,
N, H,
out._data.stride(0), out._transpose.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
)

# Reduce and find amax if "not APPLY_ATOMIC" is True for layernorm.
if IS_FP8 and not APPLY_ATOMIC:
_layernorm_fwd_reduce_triton[(triton.cdiv(N, ATOMIC_REDUCTION_BLOCK_SIZE),)](
Expand All @@ -234,7 +291,7 @@ def _te_norm_fwd_triton(


# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd
def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma):
def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, autotune: bool = True):
# may take non-contiguous inputs
dz_ = dz.contiguous()
x_ = x.contiguous()
Expand All @@ -248,25 +305,72 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma):
blk_size = block_size(x_)
USE_BLOCKED = use_blocked(x_)
NUM_PRGMS = num_programs(x_, sm_margin)
need_reduction = N > 1
dg_tmp_rows = x_.shape[0] if use_blocked(x_) else num_programs(x_, sm_margin)
dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None
# dg accumulation strategy:
# * Large M (rows_per_program > 1): per-program partial buffer of shape
# (NUM_PRGMS, N) accumulated via HBM RMW. Buffer is small, L2-resident,
# RMW near-free; reduce kernel then sums NUM_PRGMS rows.
# * Small M (rows_per_program == 1, i.e. NUM_PRGMS == M): RMW would just
# be load+add+store of a slot only written once. Fall back to pure
# per-row writes into (M, N) and skip the zero-init.
# * Non-blocked path always writes via in-register accumulator (no RMW).
rows_per_program_gt_1 = NUM_PRGMS < M
DG_RMW = USE_BLOCKED and rows_per_program_gt_1
need_reduction = NUM_PRGMS > 1
if need_reduction:
if DG_RMW:
# RMW requires zero-init.
dg_tmp = torch.zeros(NUM_PRGMS, N, device=x.device, dtype=torch.float32, requires_grad=False)
elif USE_BLOCKED:
# Pure per-row writes; rows are M.
dg_tmp = torch.empty(M, N, device=x.device, dtype=torch.float32, requires_grad=False)
else:
# Non-blocked: each program writes its slot unconditionally.
dg_tmp = torch.empty(NUM_PRGMS, N, device=x.device, dtype=torch.float32, requires_grad=False)
else:
dg_tmp = None

input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) * x_.dtype.itemsize % 16 == 0)
grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(0) * dz_.dtype.itemsize % 16 == 0)
dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(0) * dx.dtype.itemsize % 16 == 0)
dg_target = dg_tmp if need_reduction else dgamma
dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) * dg_target.dtype.itemsize % 16 == 0)

grid_bwd = lambda meta: (NUM_PRGMS, )
_rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma,
x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size,
USE_BLOCKED, NUM_PRGMS, input_aligned_16, grad_output_aligned_16,
dx_aligned_16, dg_aligned_16, num_warps=8)
bwd_kernel = _rmsnorm_bwd_kernels[autotune]
bwd_kwargs = dict(
n_rows=M, n_cols=N,
ZERO_CENTERED_GAMMA=zero_centered_gamma,
BLOCK_SIZE=blk_size,
USE_BLOCKED=USE_BLOCKED, NUM_PRGMS=NUM_PRGMS,
INPUT_ALIGNED_16=input_aligned_16,
GRAD_OUTPUT_ALIGNED_16=grad_output_aligned_16,
DX_ALIGNED_16=dx_aligned_16,
DG_ALIGNED_16=dg_aligned_16,
DG_RMW=DG_RMW,
)
if not autotune:
bwd_kwargs["num_warps"] = 8
bwd_kernel[grid_bwd](
dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma,
x_.stride(0), dz_.stride(0),
**bwd_kwargs,
)

if need_reduction:
grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
_rmsnorm_bwd_dg_reduce_triton[grid_reduce](dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1],
BLOCK_SIZE_M=128, BLOCK_SIZE_N=64)
reduce_kernel = _rmsnorm_bwd_dg_reduce_kernels[autotune]
if autotune:
grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
reduce_kernel[grid_reduce](
dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1],
)
else:
# Match the previously-hardcoded tile when autotune is disabled.
BLOCK_SIZE_M, BLOCK_SIZE_N = 128, 64
grid_reduce = (triton.cdiv(N, BLOCK_SIZE_N),)
reduce_kernel[grid_reduce](
dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1],
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N,
)

return dx, dgamma

Expand Down
Loading
Loading