diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index 6c301332..16308672 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -18,9 +18,7 @@ from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith, const_expr, gpu, range_constexpr from flydsl.expr import math as fmath -from flydsl.expr.arith import ArithValue -from flydsl.expr.numeric import Float32, Numeric, Uint32 -from flydsl.expr.typing import Int32, T +from flydsl.expr.typing import Vector as Vec from flydsl.expr.vector import ReductionOp, full from flydsl.runtime.device import get_rocm_arch as get_hip_arch from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr @@ -35,7 +33,80 @@ VEC_WIDTH = 8 +def _make_reduction_allocator(arch: str, red_slots: int): + allocator = SmemAllocator(None, arch=arch) + f32_bytes = 4 + red_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = red_offset + red_slots * f32_bytes + red2_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = red2_offset + red_slots * f32_bytes + return allocator, red_offset, red2_offset + + +def _load_scalar(copy_atom, elem_dtype, divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.make_rmem_tensor(1, elem_dtype) + fx.copy_atom_call(copy_atom, view, r) + return fx.memref_load_vec(r)[0] + + +def _store_scalar(copy_atom, elem_dtype, store_dtype, divided_tensor, index, val): + r = fx.make_rmem_tensor(1, elem_dtype) + ts = full(1, store_dtype(val), store_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom, r, view) + + +def _load_vec(copy_atom, vec_width, elem_dtype, div_tensor, idx): + r = fx.make_rmem_tensor(vec_width, elem_dtype) + fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + +def _store_vec(copy_atom, vec_width, elem_dtype, val, div_tensor, idx): + r = fx.make_rmem_tensor(vec_width, elem_dtype) + fx.memref_store_vec(val, r) + fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) + + +def _to_elem_scalar(dtype_str: str, elem_dtype, y): + if const_expr(dtype_str == "f32"): + return y + return y.to(elem_dtype) + + +def _to_elem_vec(dtype_str: str, elem_dtype, use_hw_cvt_bf16: bool, y): + if const_expr(dtype_str == "bf16"): + if const_expr(use_hw_cvt_bf16): + return y.to(elem_dtype) + u = y.bitcast(fx.Uint32) + upper = u >> 16 + lsb = upper & 1 + bias = lsb + 0x7FFF + u_round = y.bitcast(fx.Uint32) + bias + bf16_bits = u_round >> 16 + even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) + odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) + odd_sh = odd << 16 + packed = even | odd_sh + return packed.bitcast(elem_dtype) + if const_expr(dtype_str == "f32"): + return y + return y.to(elem_dtype) + + +def _store_yscale(scale_copy_atom, yscale_div, index, val): + r = fx.make_rmem_tensor(1, fx.Float32) + ts = full(1, fx.Float32(val), fx.Float32) + fx.memref_store_vec(ts, r) + fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) + + def build_rmsnorm_module(M: int, N: int, dtype_str: str): + if M > 8192 and N <= 2048: + return _build_rmsnorm_large_m_small_n_module(M, N, dtype_str) + arch = get_hip_arch() USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") @@ -43,12 +114,7 @@ def build_rmsnorm_module(M: int, N: int, dtype_str: str): RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) elem_bits = 32 if dtype_str == "f32" else 16 - allocator = SmemAllocator(None, arch=arch) - f32_bytes = 4 - red_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red_offset + RED_SLOTS * f32_bytes - red2_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red2_offset + RED_SLOTS * f32_bytes + allocator, red_offset, red2_offset = _make_reduction_allocator(arch, RED_SLOTS) @flyc.kernel def rmsnorm_kernel( @@ -135,16 +201,6 @@ def block_reduce_add2(val0, val1): copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) - def _load_vec(div_tensor, idx): - r = fx.make_rmem_tensor(VEC_WIDTH, elem_dtype) - fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return fx.memref_load_vec(r) - - def _store_vec(val, div_tensor, idx): - r = fx.make_rmem_tensor(VEC_WIDTH, elem_dtype) - fx.memref_store_vec(val, r) - fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) - c_zero_f = fx.Float32(0.0) thread_sumsq = c_zero_f thread_dummy = c_zero_f @@ -153,7 +209,7 @@ def _store_vec(val, div_tensor, idx): # Pass 1: load + cache + sumsq for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - vec = _load_vec(in_div, idx) + vec = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx) in_local.append(vec) x = vec.to(fx.Float32) @@ -170,34 +226,14 @@ def _store_vec(val, div_tensor, idx): for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - g = _load_vec(gamma_div, idx).to(fx.Float32) + g = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, gamma_div, idx).to(fx.Float32) x = in_local[tile_i].to(fx.Float32) y = (x * rrms) * g - - out_e = y.to(elem_dtype) - if const_expr(dtype_str == "bf16"): - if const_expr(USE_HW_CVT_PK_BF16_F32): - out_e = y.to(elem_dtype) - else: - u = y.bitcast(fx.Uint32) - upper = u >> 16 - lsb = upper & 1 - bias = lsb + 0x7FFF - u_round = y.bitcast(fx.Uint32) + bias - bf16_bits = u_round >> 16 - even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) - odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) - odd_sh = odd << 16 - packed = even | odd_sh - out_e = packed.bitcast(elem_dtype) - elif const_expr(dtype_str == "f32"): - out_e = y - else: - out_e = y.to(elem_dtype) + out_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, y) out_idx = tid + tile_i * BLOCK_THREADS - _store_vec(out_e, out_div, out_idx) + _store_vec(copy_atom, VEC_WIDTH, elem_dtype, out_e, out_div, out_idx) else: # ============================================================== @@ -219,19 +255,6 @@ def _store_vec(val, div_tensor, idx): gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) - def _load_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.make_rmem_tensor(1, elem_dtype) - fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0] - - def _store_scalar(divided_tensor, index, val): - r = fx.make_rmem_tensor(1, elem_dtype) - ts = full(1, elem_dtype(val), elem_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_s, r, view) - c_zero_f = fx.Float32(0.0) thread_sumsq = c_zero_f @@ -239,7 +262,7 @@ def _store_scalar(divided_tensor, index, val): idx = tid + base_idx_int is_valid = idx < N idx_safe = is_valid.select(idx, 0) - x_e = _load_scalar(row_div, idx_safe) + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div, idx_safe) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) x2 = x * x x2_safe = is_valid.select(x2, c_zero_f) @@ -253,19 +276,14 @@ def _store_scalar(divided_tensor, index, val): for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int if idx < N: - x_e = _load_scalar(row_div, idx) - g_e = _load_scalar(gamma_div, idx) + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div, idx) + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div, idx) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) norm = x * rrms y = norm * g - if const_expr(dtype_str == "f32"): - y_e = y - elif const_expr(dtype_str == "bf16"): - y_e = y.to(elem_dtype) - else: - y_e = y.to(elem_dtype) - _store_scalar(out_div, idx, y_e) + y_e = _to_elem_scalar(dtype_str, elem_dtype, y) + _store_scalar(copy_atom_s, elem_dtype, elem_dtype, out_div, idx, y_e) @flyc.jit def launch_rmsnorm( @@ -290,97 +308,143 @@ def launch_rmsnorm( return launch_rmsnorm -def _quant_dtype_to_elem_type(dtype_str: str): - if dtype_str in ("i8", "int8"): - return T.i8 - raise ValueError(f"unsupported quant dtype: {dtype_str!r} (expected 'i8' or 'int8')") +def _build_rmsnorm_large_m_small_n_module(M: int, N: int, dtype_str: str): + BLOCK_N = 1 << (N - 1).bit_length() + BLOCK_M = max(min(16384 // BLOCK_N, 32), 8) + THREADS_PER_ROW = min(WARP_SIZE, 1024 // BLOCK_M) + BLOCK_THREADS_SPECIAL = BLOCK_M * THREADS_PER_ROW + elem_bits = 32 if dtype_str == "f32" else 16 + @flyc.kernel + def rmsnorm_large_m_small_n_kernel( + Input: fx.Tensor, + Gamma: fx.Tensor, + _Unused: fx.Tensor, + Output: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x -def _quant_dtype_max(dtype_str: str) -> float: - if dtype_str in ("i8", "int8"): - return 127.0 - raise ValueError(f"unsupported quant dtype: {dtype_str!r} (expected 'i8' or 'int8')") + lane = tid % THREADS_PER_ROW + row_local = tid // THREADS_PER_ROW + row = bid * fx.Int32(BLOCK_M) + row_local + if row < M: + elem_dtype = dtype_to_elem_type(dtype_str) + fm_fast = arith.FastMathFlags.fast + eps_c = EPS + n_float = float(N) -def _build_rmsnorm_quant_module( - M: int, - N: int, - dtype_str: str, - *, - is_smooth: bool, - quant_dtype_str: str = "i8", -): + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + + row_in = fx.slice(Input_buf, (row, None)) + row_out = fx.slice(Output_buf, (row, None)) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + + row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + + def group_reduce_add(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(THREADS_PER_ROW))): + off = THREADS_PER_ROW // (2 << _sh_exp) + peer = w.shuffle_xor(off, fx.Int32(THREADS_PER_ROW)) + w = w.addf(peer, fastmath=fm_fast) + return w + + c_zero_f = fx.Float32(0.0) + thread_sumsq = c_zero_f + + for base_idx_int in range_constexpr(0, BLOCK_N, THREADS_PER_ROW): + idx = lane + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + x2 = x * x + thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) + + sum_sq = group_reduce_add(thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) + + for base_idx_int in range_constexpr(0, BLOCK_N, THREADS_PER_ROW): + idx = lane + base_idx_int + if idx < N: + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div, idx) + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div, idx) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + y = (x * rrms) * g + y_e = _to_elem_scalar(dtype_str, elem_dtype, y) + _store_scalar(copy_atom_s, elem_dtype, elem_dtype, out_div, idx, y_e) + + @flyc.jit + def launch_rmsnorm_large_m_small_n( + Input: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = rmsnorm_large_m_small_n_kernel(Input, Gamma, Gamma, Output) + launcher.launch( + grid=((M + BLOCK_M - 1) // BLOCK_M, 1, 1), + block=(BLOCK_THREADS_SPECIAL, 1, 1), + stream=stream, + ) + + return launch_rmsnorm_large_m_small_n + + +def build_fused_add_rmsnorm_module(M: int, N: int, dtype_str: str): arch = get_hip_arch() + USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") tile_cols = BLOCK_THREADS * VEC_WIDTH RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) elem_bits = 32 if dtype_str == "f32" else 16 - quant_dtype_max = _quant_dtype_max(quant_dtype_str) - allocator = SmemAllocator(None, arch=arch) - f32_bytes = 4 - red_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red_offset + RED_SLOTS * f32_bytes - red2_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red2_offset + RED_SLOTS * f32_bytes + allocator, red_offset, red2_offset = _make_reduction_allocator(arch, RED_SLOTS) @flyc.kernel - def rmsnorm_quant_kernel( + def fused_add_rmsnorm_kernel( Input: fx.Tensor, + ResidualIn: fx.Tensor, Gamma: fx.Tensor, - XScale: fx.Tensor, - YScale: fx.Tensor, Output: fx.Tensor, + ResidualOut: fx.Tensor, ): bid = fx.block_idx.x tid = fx.thread_idx.x elem_dtype = dtype_to_elem_type(dtype_str) - quant_dtype = Numeric.from_ir_type(_quant_dtype_to_elem_type(quant_dtype_str)) - compute_type = T.f32 - fm_fast = arith.FastMathFlags.fast - eps_c = arith.constant(EPS, type=compute_type) - n_float = arith.constant(float(N), type=compute_type) - c_zero_f = arith.constant(0.0, type=compute_type) - c_one_f = arith.constant(1.0, type=compute_type) - c_neg_inf = arith.constant(float("-inf"), type=compute_type) - c_dtype_max = arith.constant(quant_dtype_max, type=compute_type) + eps_c = EPS + n_float = float(N) base_ptr = allocator.get_base() - s_red = SmemPtr(base_ptr, red_offset, T.f32, shape=(RED_SLOTS,)) - s_red2 = SmemPtr(base_ptr, red2_offset, T.f32, shape=(RED_SLOTS,)) + s_red = SmemPtr(base_ptr, red_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_red2 = SmemPtr(base_ptr, red2_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) s_red.get() s_red2.get() - YScale_buf = fx.rocdl.make_buffer_tensor(YScale) - yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1)) - scale_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) - - def _store_yscale(index, val): - r = fx.make_rmem_tensor(1, Float32) - ts = full(1, Float32(val), Float32) - fx.memref_store_vec(ts, r) - fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) - def wave_reduce_add(x): - width_i32 = fx.Int32(WARP_SIZE) w = x for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): - off = fx.Int32(WARP_SIZE // (2 << _sh_exp)) - peer = w.shuffle_xor(off, width_i32) + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) w = w.addf(peer, fastmath=fm_fast) return w - def wave_reduce_max(x): - width_i32 = fx.Int32(WARP_SIZE) - w = x - for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): - off = fx.Int32(WARP_SIZE // (2 << _sh_exp)) - peer = w.shuffle_xor(off, width_i32) - w = w.maximumf(peer) - return w - def block_reduce_add(val): dummy = fx.Float32(0.0) r0, _ = block_reduce_add2(val, dummy) @@ -396,260 +460,908 @@ def block_reduce_add2(val0, val1): w0 = wave_reduce_add(val0) w1 = wave_reduce_add(val1) - if lane == fx.Int32(0): - wave_idx = ArithValue(wave).index_cast(T.index) - SmemPtr.store(s_red, w0, [wave_idx]) - SmemPtr.store(s_red2, w1, [wave_idx]) + if lane == 0: + SmemPtr.store(s_red, w0, [wave]) + SmemPtr.store(s_red2, w1, [wave]) gpu.barrier() - if wave == fx.Int32(0): + if wave == 0: in_range = lane < RED_SLOTS - lane_safe = in_range.select(lane, fx.Int32(0)) - lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) - v0 = SmemPtr.load(s_red, [lane_safe_idx]) - v1 = SmemPtr.load(s_red2, [lane_safe_idx]) - ww0 = in_range.select(v0, c_zero_f) - ww1 = in_range.select(v1, c_zero_f) + lane_safe = in_range.select(lane, 0) + v0 = SmemPtr.load(s_red, [lane_safe]) + v1 = SmemPtr.load(s_red2, [lane_safe]) + ww0 = in_range.select(v0, 0.0) + ww1 = in_range.select(v1, 0.0) ww0 = wave_reduce_add(ww0) ww1 = wave_reduce_add(ww1) - if lane == fx.Int32(0): - c0_idx = fx.Index(0) - SmemPtr.store(s_red, ww0, [c0_idx]) - SmemPtr.store(s_red2, ww1, [c0_idx]) - gpu.barrier() - - c0_idx = fx.Index(0) - return SmemPtr.load(s_red, [c0_idx]), SmemPtr.load(s_red2, [c0_idx]) - - def block_reduce_max(val): - if const_expr(RED_SLOTS == 1): - return wave_reduce_max(val) - - lane = tid % WARP_SIZE - wave = tid // WARP_SIZE - - w = wave_reduce_max(val) - if lane == fx.Int32(0): - wave_idx = ArithValue(wave).index_cast(T.index) - SmemPtr.store(s_red, w, [wave_idx]) - gpu.barrier() - - if wave == fx.Int32(0): - in_range = lane < RED_SLOTS - lane_safe = in_range.select(lane, fx.Int32(0)) - lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) - v = SmemPtr.load(s_red, [lane_safe_idx]) - ww = in_range.select(v, c_neg_inf) - ww = wave_reduce_max(ww) - if lane == fx.Int32(0): - c0_idx = fx.Index(0) - SmemPtr.store(s_red, ww, [c0_idx]) + if lane == 0: + SmemPtr.store(s_red, ww0, [0]) + SmemPtr.store(s_red2, ww1, [0]) gpu.barrier() - c0_idx = fx.Index(0) - return SmemPtr.load(s_red, [c0_idx]) + return SmemPtr.load(s_red, [0]), SmemPtr.load(s_red2, [0]) # ================================================================== # Fast path: N is a multiple of tile_cols # ================================================================== if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16): num_tiles = N // tile_cols - quant_half_width = VEC_WIDTH // 2 - abs_mask = full(VEC_WIDTH, Uint32(0x7FFFFFFF), Uint32) - + # ── Layout API: buffer-backed tensors + tiled access ───── Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) Output_buf = fx.rocdl.make_buffer_tensor(Output) - if const_expr(is_smooth): - XScale_buf = fx.rocdl.make_buffer_tensor(XScale) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) - out_div_q = fx.logical_divide(row_out, fx.make_layout(quant_half_width, 1)) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(VEC_WIDTH, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(VEC_WIDTH, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(VEC_WIDTH, 1)) gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) - if const_expr(is_smooth): - xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(VEC_WIDTH, 1)) copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) - copy_atom_q = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 8) - - def _load_vec(div_tensor, idx): - r = fx.make_rmem_tensor(VEC_WIDTH, elem_dtype) - fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return fx.memref_load_vec(r) - - def _store_q_vec(val, div_tensor, idx): - r = fx.make_rmem_tensor(quant_half_width, quant_dtype) - fx.memref_store_vec(val, r) - fx.copy_atom_call(copy_atom_q, r, fx.slice(div_tensor, (None, idx))) + c_zero_f = fx.Float32(0.0) thread_sumsq = c_zero_f thread_dummy = c_zero_f - in_local = [] + add_local = [] + # Pass 1: add + cache + sumsq (also write residual_out) for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - vec = _load_vec(in_div, idx) - in_local.append(vec) - x = vec.to(Float32) - x2 = x * x - red2 = x2.reduce(ReductionOp.ADD, fastmath=fm_fast) - thread_sumsq = ArithValue(thread_sumsq) + red2 + x = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx).to(fx.Float32) + residual = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, residual_in_div, idx).to(fx.Float32) + added_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, x + residual) + add_local.append(added_e) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + + added2 = added * added + red2 = added2.reduce(ReductionOp.ADD, fastmath=fm_fast) + thread_sumsq = thread_sumsq + red2 + + _store_vec(copy_atom, VEC_WIDTH, elem_dtype, added_e, residual_out_div, idx) _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) - mean_sq = ArithValue(sum_sq) / n_float + mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c rrms = ms_eps.rsqrt(fastmath=fm_fast) - thread_row_max = c_zero_f - y_local = [] - + # Pass 2: normalize + gamma + store (reuse cached added values) for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - - g = _load_vec(gamma_div, idx).to(Float32) - x = in_local[tile_i].to(Float32) - y = (x * rrms) * g - if const_expr(is_smooth): - s = _load_vec(xscale_div, idx).to(Float32) - y = y * s - - y_local.append(y) - y_abs = (y.bitcast(Uint32) & abs_mask).bitcast(Float32) - tile_max = y_abs.reduce(ReductionOp.MAX) - thread_row_max = thread_row_max.maximumf(tile_max) - - row_max = block_reduce_max(thread_row_max) - scale = ArithValue(row_max) / c_dtype_max - final_scale = (scale == c_zero_f).select(c_one_f, scale) - - if tid == fx.Int32(0): - _store_yscale(bid, final_scale) - - inv_scale = ArithValue(c_one_f) / ArithValue(final_scale) - - for tile_i in range_constexpr(num_tiles): - q = y_local[tile_i] * inv_scale - q_i8 = q.to(quant_dtype) - q_lo = q_i8.shuffle(q_i8, [0, 1, 2, 3]) - q_hi = q_i8.shuffle(q_i8, [4, 5, 6, 7]) - out_idx = tid * 2 + tile_i * BLOCK_THREADS * 2 - _store_q_vec(q_lo, out_div_q, out_idx) - _store_q_vec(q_hi, out_div_q, out_idx + 1) + g = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, gamma_div, idx).to(fx.Float32) + added = add_local[tile_i] if dtype_str == "f32" else add_local[tile_i].to(fx.Float32) + y = (added * rrms) * g + y_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, y) + _store_vec(copy_atom, VEC_WIDTH, elem_dtype, y_e, out_div, idx) else: # ============================================================== # Generic path: scalar 2-pass for arbitrary N # ============================================================== Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) Output_buf = fx.rocdl.make_buffer_tensor(Output) - if const_expr(is_smooth): - XScale_buf = fx.rocdl.make_buffer_tensor(XScale) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + + row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) copy_atom_s = fx.make_copy_atom( fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), elem_bits, ) - copy_atom_qs = fx.make_copy_atom(fx.rocdl.BufferCopy(8), 8) - row_in = fx.slice(Input_buf, (bid, None)) - row_out = fx.slice(Output_buf, (bid, None)) row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) - if const_expr(is_smooth): - xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(1, 1)) - - def _load_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.make_rmem_tensor(1, elem_dtype) - fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0].ir_value() - - def _store_quant_scalar(divided_tensor, index, val): - r = fx.make_rmem_tensor(1, quant_dtype) - ts = full(1, quant_dtype(val), quant_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_qs, r, view) - - def _abs_scalar(val): - is_neg = val < c_zero_f - neg_val = c_zero_f - ArithValue(val) - return is_neg.select(neg_val, val) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) + c_zero_f = fx.Float32(0.0) thread_sumsq = c_zero_f - c_N_i32 = Int32(N) - c0_i = Int32(0) for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int - is_valid = idx < c_N_i32 - idx_safe = is_valid.select(idx, c0_i) - x_e = _load_scalar(row_div, idx_safe) - x = x_e if dtype_str == "f32" else x_e.extf(compute_type) - x2 = ArithValue(x) * ArithValue(x) - thread_sumsq = ArithValue(thread_sumsq) + is_valid.select(x2, c_zero_f) + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div, idx_safe) + residual_e = _load_scalar(copy_atom_s, elem_dtype, residual_in_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) + added_e = _to_elem_scalar(dtype_str, elem_dtype, x + residual) + if idx < N: + _store_scalar(copy_atom_s, elem_dtype, elem_dtype, residual_out_div, idx, added_e) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + added2 = added * added + thread_sumsq = thread_sumsq + is_valid.select(added2, c_zero_f) sum_sq = block_reduce_add(thread_sumsq) - mean_sq = ArithValue(sum_sq) / n_float + mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c - rrms = ms_eps.rsqrt(fastmath=fm_fast) + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) - thread_row_max = c_zero_f for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int - is_valid = idx < c_N_i32 - idx_safe = is_valid.select(idx, c0_i) - x_e = _load_scalar(row_div, idx_safe) - g_e = _load_scalar(gamma_div, idx_safe) - x = x_e if dtype_str == "f32" else x_e.extf(compute_type) - g = g_e if dtype_str == "f32" else g_e.extf(compute_type) - y = (ArithValue(x) * ArithValue(rrms)) * ArithValue(g) - if const_expr(is_smooth): - s_e = _load_scalar(xscale_div, idx_safe) - s = s_e if dtype_str == "f32" else s_e.extf(compute_type) - y = ArithValue(y) * ArithValue(s) - y_abs = _abs_scalar(y) - thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) + if idx < N: + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div, idx) + added_e = _load_scalar(copy_atom_s, elem_dtype, residual_out_div, idx) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + y = (added * rrms) * g + y_e = _to_elem_scalar(dtype_str, elem_dtype, y) + _store_scalar(copy_atom_s, elem_dtype, elem_dtype, out_div, idx, y_e) - row_max = block_reduce_max(thread_row_max) - scale = ArithValue(row_max) / c_dtype_max - final_scale = (scale == c_zero_f).select(c_one_f, scale) + @flyc.jit + def launch_fused_add_rmsnorm( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() - if tid == fx.Int32(0): - _store_yscale(bid, final_scale) + launcher = fused_add_rmsnorm_kernel(Input, ResidualIn, Gamma, Output, ResidualOut) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) - inv_scale = ArithValue(c_one_f) / ArithValue(final_scale) + return launch_fused_add_rmsnorm - for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): - idx = tid + base_idx_int - if arith.cmpi(arith.CmpIPredicate.ult, idx, c_N_i32): - x_e = _load_scalar(row_div, idx) - g_e = _load_scalar(gamma_div, idx) - x = x_e if dtype_str == "f32" else x_e.extf(compute_type) - g = g_e if dtype_str == "f32" else g_e.extf(compute_type) - y = (ArithValue(x) * ArithValue(rrms)) * ArithValue(g) - if const_expr(is_smooth): - s_e = _load_scalar(xscale_div, idx) - s = s_e if dtype_str == "f32" else s_e.extf(compute_type) - y = ArithValue(y) * ArithValue(s) - q = ArithValue(y) * ArithValue(inv_scale) - q_i8 = quant_dtype(q) - _store_quant_scalar(out_div, idx, q_i8) - if is_smooth: +def _quant_dtype_to_elem_type(dtype_str: str): + if dtype_str in ("i8", "int8"): + return fx.Int8 + raise ValueError(f"unsupported quant dtype: {dtype_str!r} (expected 'i8' or 'int8')") - @flyc.jit - def launch_rmsnorm_smoothquant( - Input: fx.Tensor, + +def _quant_dtype_max(dtype_str: str) -> float: + if dtype_str in ("i8", "int8"): + return 127.0 + raise ValueError(f"unsupported quant dtype: {dtype_str!r} (expected 'i8' or 'int8')") + + +def _build_rmsnorm_quant_module( + M: int, + N: int, + dtype_str: str, + *, + is_smooth: bool, + quant_dtype_str: str = "i8", +): + arch = get_hip_arch() + + tile_cols = BLOCK_THREADS * VEC_WIDTH + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) + elem_bits = 32 if dtype_str == "f32" else 16 + quant_dtype_max = _quant_dtype_max(quant_dtype_str) + + allocator, red_offset, red2_offset = _make_reduction_allocator(arch, RED_SLOTS) + + @flyc.kernel + def rmsnorm_quant_kernel( + Input: fx.Tensor, + Gamma: fx.Tensor, + XScale: fx.Tensor, + YScale: fx.Tensor, + Output: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_dtype = dtype_to_elem_type(dtype_str) + quant_dtype = _quant_dtype_to_elem_type(quant_dtype_str) + + fm_fast = arith.FastMathFlags.fast + eps_c = EPS + n_float = float(N) + c_zero_f = fx.Float32(0.0) + c_one_f = fx.Float32(1.0) + c_neg_inf = fx.Float32(float("-inf")) + c_dtype_max = fx.Float32(quant_dtype_max) + + base_ptr = allocator.get_base() + s_red = SmemPtr(base_ptr, red_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_red2 = SmemPtr(base_ptr, red2_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_red.get() + s_red2.get() + + YScale_buf = fx.rocdl.make_buffer_tensor(YScale) + yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1)) + scale_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + + def wave_reduce_add(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) + w = w.addf(peer, fastmath=fm_fast) + return w + + def wave_reduce_max(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) + w = w.maximumf(peer) + return w + + def block_reduce_add(val): + dummy = fx.Float32(0.0) + r0, _ = block_reduce_add2(val, dummy) + return r0 + + def block_reduce_add2(val0, val1): + if const_expr(RED_SLOTS == 1): + return wave_reduce_add(val0), wave_reduce_add(val1) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + w0 = wave_reduce_add(val0) + w1 = wave_reduce_add(val1) + + if lane == 0: + SmemPtr.store(s_red, w0, [wave]) + SmemPtr.store(s_red2, w1, [wave]) + gpu.barrier() + + if wave == 0: + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, 0) + v0 = SmemPtr.load(s_red, [lane_safe]) + v1 = SmemPtr.load(s_red2, [lane_safe]) + ww0 = in_range.select(v0, c_zero_f) + ww1 = in_range.select(v1, c_zero_f) + ww0 = wave_reduce_add(ww0) + ww1 = wave_reduce_add(ww1) + + if lane == 0: + SmemPtr.store(s_red, ww0, [0]) + SmemPtr.store(s_red2, ww1, [0]) + gpu.barrier() + + return SmemPtr.load(s_red, [0]), SmemPtr.load(s_red2, [0]) + + def block_reduce_max(val): + if const_expr(RED_SLOTS == 1): + return wave_reduce_max(val) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + w = wave_reduce_max(val) + if lane == 0: + SmemPtr.store(s_red, w, [wave]) + gpu.barrier() + + if wave == 0: + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, 0) + v = SmemPtr.load(s_red, [lane_safe]) + ww = in_range.select(v, c_neg_inf) + ww = wave_reduce_max(ww) + if lane == 0: + SmemPtr.store(s_red, ww, [0]) + gpu.barrier() + + return SmemPtr.load(s_red, [0]) + + # ================================================================== + # Fast path: N is a multiple of tile_cols + # ================================================================== + if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16): + num_tiles = N // tile_cols + quant_half_width = VEC_WIDTH // 2 + abs_mask = full(VEC_WIDTH, fx.Uint32(0x7FFFFFFF), fx.Uint32) + xscale_vec_width = 4 + # ── Layout API: buffer-backed tensors + tiled access ───── + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + if const_expr(is_smooth): + XScale_buf = fx.rocdl.make_buffer_tensor(XScale) + + row_in = fx.slice(Input_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + + in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) + out_div_q = fx.logical_divide(row_out, fx.make_layout(quant_half_width, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) + if const_expr(is_smooth): + xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(xscale_vec_width, 1)) + + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + if const_expr(is_smooth): + copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), 32) + copy_atom_q = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 8) + + thread_sumsq = c_zero_f + thread_dummy = c_zero_f + in_local = [] + + # Pass 1: load + cache + sumsq + for tile_i in range_constexpr(num_tiles): + idx = tid + tile_i * BLOCK_THREADS + vec = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx) + in_local.append(vec) + x = vec.to(fx.Float32) + x2 = x * x + red2 = x2.reduce(ReductionOp.ADD, fastmath=fm_fast) + thread_sumsq = thread_sumsq + red2 + + _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = ms_eps.rsqrt(fastmath=fm_fast) + + thread_row_max = c_zero_f + y_local = [] + + # Pass 2: normalize + gamma (+ optional smooth scale), cache output, and accumulate row max + for tile_i in range_constexpr(num_tiles): + idx = tid + tile_i * BLOCK_THREADS + + g = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, gamma_div, idx).to(fx.Float32) + x = in_local[tile_i].to(fx.Float32) + y = (x * rrms) * g + if const_expr(is_smooth): + s_lo = _load_vec(copy_atom_xs, xscale_vec_width, fx.Float32, xscale_div, idx * 2) + s_hi = _load_vec(copy_atom_xs, xscale_vec_width, fx.Float32, xscale_div, idx * 2 + 1) + s = Vec(s_lo).shuffle(Vec(s_hi), [0, 1, 2, 3, 4, 5, 6, 7]).ir_value() + y = y * s + + y_local.append(y) + y_abs = (y.bitcast(fx.Uint32) & abs_mask).bitcast(fx.Float32) + tile_max = y_abs.reduce(ReductionOp.MAX) + thread_row_max = thread_row_max.maximumf(tile_max) + + row_max = block_reduce_max(thread_row_max) + scale = row_max / c_dtype_max + final_scale = (scale == c_zero_f).select(c_one_f, scale) + + if tid == 0: + _store_yscale(scale_copy_atom, yscale_div, bid, final_scale) + + inv_scale = c_one_f / final_scale + + # Pass 3: quantize + store using per-row scale + for tile_i in range_constexpr(num_tiles): + q = y_local[tile_i] * inv_scale + q_i8 = q.to(quant_dtype) + q_lo = q_i8.shuffle(q_i8, [0, 1, 2, 3]) + q_hi = q_i8.shuffle(q_i8, [4, 5, 6, 7]) + out_idx = tid * 2 + tile_i * BLOCK_THREADS * 2 + _store_vec(copy_atom_q, quant_half_width, quant_dtype, q_lo, out_div_q, out_idx) + _store_vec(copy_atom_q, quant_half_width, quant_dtype, q_hi, out_div_q, out_idx + 1) + + else: + # ============================================================== + # Generic path: scalar 3-pass for arbitrary N + # ============================================================== + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + if const_expr(is_smooth): + XScale_buf = fx.rocdl.make_buffer_tensor(XScale) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + copy_atom_qs = fx.make_copy_atom(fx.rocdl.BufferCopy(8), 8) + if const_expr(is_smooth): + copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + + row_in = fx.slice(Input_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + if const_expr(is_smooth): + xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(1, 1)) + + def _abs_scalar(val): + is_neg = val < c_zero_f + neg_val = c_zero_f - val + return is_neg.select(neg_val, val) + + thread_sumsq = c_zero_f + + # Pass 1: accumulate sumsq + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + x2 = x * x + thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) + + sum_sq = block_reduce_add(thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) + + thread_row_max = c_zero_f + # Pass 2: normalize, apply gamma (+ optional smooth scale), and accumulate row max + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div, idx_safe) + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + y = (x * rrms) * g + if const_expr(is_smooth): + s = _load_scalar(copy_atom_xs, fx.Float32, xscale_div, idx_safe) + y = y * s + y_abs = _abs_scalar(y) + thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) + + row_max = block_reduce_max(thread_row_max) + scale = row_max / c_dtype_max + final_scale = (scale == c_zero_f).select(c_one_f, scale) + + if tid == 0: + _store_yscale(scale_copy_atom, yscale_div, bid, final_scale) + + inv_scale = c_one_f / final_scale + + # Pass 3: quantize + store using per-row scale + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + if idx < N: + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div, idx) + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div, idx) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + y = (x * rrms) * g + if const_expr(is_smooth): + s = _load_scalar(copy_atom_xs, fx.Float32, xscale_div, idx) + y = y * s + q = y * inv_scale + q_i8 = q.to(quant_dtype) + _store_scalar(copy_atom_qs, quant_dtype, quant_dtype, out_div, idx, q_i8) + + if is_smooth: + + @flyc.jit + def launch_rmsnorm_smoothquant( + Input: fx.Tensor, + Gamma: fx.Tensor, + XScale: fx.Tensor, + Output: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = rmsnorm_quant_kernel(Input, Gamma, XScale, YScale, Output) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_rmsnorm_smoothquant + + else: + + @flyc.jit + def launch_rmsnorm_dynamicquant( + Input: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = rmsnorm_quant_kernel(Input, Gamma, Gamma, YScale, Output) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_rmsnorm_dynamicquant + + +def build_rmsnorm_dynamicquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_rmsnorm_quant_module( + M, + N, + dtype_str, + is_smooth=False, + quant_dtype_str=quant_dtype_str, + ) + + +def build_rmsnorm_smoothquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_rmsnorm_quant_module( + M, + N, + dtype_str, + is_smooth=True, + quant_dtype_str=quant_dtype_str, + ) + + +def _build_fused_add_rmsnorm_quant_module( + M: int, + N: int, + dtype_str: str, + *, + is_smooth: bool, + quant_dtype_str: str = "i8", +): + arch = get_hip_arch() + USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") + + tile_cols = BLOCK_THREADS * VEC_WIDTH + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) + elem_bits = 32 if dtype_str == "f32" else 16 + quant_dtype_max = _quant_dtype_max(quant_dtype_str) + + allocator, red_offset, red2_offset = _make_reduction_allocator(arch, RED_SLOTS) + + @flyc.kernel + def fused_add_rmsnorm_quant_kernel( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + XScale: fx.Tensor, + YScale: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_dtype = dtype_to_elem_type(dtype_str) + quant_dtype = _quant_dtype_to_elem_type(quant_dtype_str) + + fm_fast = arith.FastMathFlags.fast + eps_c = EPS + n_float = float(N) + c_zero_f = fx.Float32(0.0) + c_one_f = fx.Float32(1.0) + c_neg_inf = fx.Float32(float("-inf")) + c_dtype_max = fx.Float32(quant_dtype_max) + + base_ptr = allocator.get_base() + s_red = SmemPtr(base_ptr, red_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_red2 = SmemPtr(base_ptr, red2_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_red.get() + s_red2.get() + + YScale_buf = fx.rocdl.make_buffer_tensor(YScale) + yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1)) + scale_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + + def wave_reduce_add(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) + w = w.addf(peer, fastmath=fm_fast) + return w + + def wave_reduce_max(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) + w = w.maximumf(peer) + return w + + def block_reduce_add(val): + dummy = fx.Float32(0.0) + r0, _ = block_reduce_add2(val, dummy) + return r0 + + def block_reduce_add2(val0, val1): + if const_expr(RED_SLOTS == 1): + return wave_reduce_add(val0), wave_reduce_add(val1) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + w0 = wave_reduce_add(val0) + w1 = wave_reduce_add(val1) + + if lane == 0: + SmemPtr.store(s_red, w0, [wave]) + SmemPtr.store(s_red2, w1, [wave]) + gpu.barrier() + + if wave == 0: + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, 0) + v0 = SmemPtr.load(s_red, [lane_safe]) + v1 = SmemPtr.load(s_red2, [lane_safe]) + ww0 = in_range.select(v0, c_zero_f) + ww1 = in_range.select(v1, c_zero_f) + ww0 = wave_reduce_add(ww0) + ww1 = wave_reduce_add(ww1) + + if lane == 0: + SmemPtr.store(s_red, ww0, [0]) + SmemPtr.store(s_red2, ww1, [0]) + gpu.barrier() + + return SmemPtr.load(s_red, [0]), SmemPtr.load(s_red2, [0]) + + def block_reduce_max(val): + if const_expr(RED_SLOTS == 1): + return wave_reduce_max(val) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + w = wave_reduce_max(val) + if lane == 0: + SmemPtr.store(s_red, w, [wave]) + gpu.barrier() + + if wave == 0: + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, 0) + v = SmemPtr.load(s_red, [lane_safe]) + ww = in_range.select(v, c_neg_inf) + ww = wave_reduce_max(ww) + if lane == 0: + SmemPtr.store(s_red, ww, [0]) + gpu.barrier() + + return SmemPtr.load(s_red, [0]) + + # ================================================================== + # Fast path: N is a multiple of tile_cols + # ================================================================== + if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16): + num_tiles = N // tile_cols + quant_half_width = VEC_WIDTH // 2 + abs_mask = full(VEC_WIDTH, fx.Uint32(0x7FFFFFFF), fx.Uint32) + xscale_vec_width = 4 + # ── Layout API: buffer-backed tensors + tiled access ───── + Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + if const_expr(is_smooth): + XScale_buf = fx.rocdl.make_buffer_tensor(XScale) + + row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(VEC_WIDTH, 1)) + out_div_q = fx.logical_divide(row_out, fx.make_layout(quant_half_width, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(VEC_WIDTH, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) + if const_expr(is_smooth): + xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(xscale_vec_width, 1)) + + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + if const_expr(is_smooth): + copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), 32) + copy_atom_q = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 8) + + thread_sumsq = c_zero_f + thread_dummy = c_zero_f + add_local = [] + + # Pass 1: add + cache + sumsq (also write residual_out) + for tile_i in range_constexpr(num_tiles): + idx = tid + tile_i * BLOCK_THREADS + x = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx).to(fx.Float32) + residual = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, residual_in_div, idx).to(fx.Float32) + added_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, x + residual) + add_local.append(added_e) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + added2 = added * added + red2 = added2.reduce(ReductionOp.ADD, fastmath=fm_fast) + thread_sumsq = thread_sumsq + red2 + _store_vec(copy_atom, VEC_WIDTH, elem_dtype, added_e, residual_out_div, idx) + + _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = ms_eps.rsqrt(fastmath=fm_fast) + + thread_row_max = c_zero_f + y_local = [] + + # Pass 2: normalize + gamma (+ optional smooth scale), cache output, and accumulate row max + for tile_i in range_constexpr(num_tiles): + idx = tid + tile_i * BLOCK_THREADS + g = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, gamma_div, idx).to(fx.Float32) + added = add_local[tile_i] if dtype_str == "f32" else add_local[tile_i].to(fx.Float32) + y = (added * rrms) * g + if const_expr(is_smooth): + s_lo = _load_vec(copy_atom_xs, xscale_vec_width, fx.Float32, xscale_div, idx * 2) + s_hi = _load_vec(copy_atom_xs, xscale_vec_width, fx.Float32, xscale_div, idx * 2 + 1) + s = Vec(s_lo).shuffle(Vec(s_hi), [0, 1, 2, 3, 4, 5, 6, 7]).ir_value() + y = y * s + + y_local.append(y) + y_abs = (y.bitcast(fx.Uint32) & abs_mask).bitcast(fx.Float32) + tile_max = y_abs.reduce(ReductionOp.MAX) + thread_row_max = thread_row_max.maximumf(tile_max) + + row_max = block_reduce_max(thread_row_max) + scale = row_max / c_dtype_max + final_scale = (scale == c_zero_f).select(c_one_f, scale) + + if tid == 0: + _store_yscale(scale_copy_atom, yscale_div, bid, final_scale) + + inv_scale = c_one_f / final_scale + + # Pass 3: quantize + store using per-row scale + for tile_i in range_constexpr(num_tiles): + q = y_local[tile_i] * inv_scale + q_i8 = q.to(quant_dtype) + q_lo = q_i8.shuffle(q_i8, [0, 1, 2, 3]) + q_hi = q_i8.shuffle(q_i8, [4, 5, 6, 7]) + out_idx = tid * 2 + tile_i * BLOCK_THREADS * 2 + _store_vec(copy_atom_q, quant_half_width, quant_dtype, q_lo, out_div_q, out_idx) + _store_vec(copy_atom_q, quant_half_width, quant_dtype, q_hi, out_div_q, out_idx + 1) + + else: + # ============================================================== + # Generic path: scalar 3-pass for arbitrary N + # ============================================================== + Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + if const_expr(is_smooth): + XScale_buf = fx.rocdl.make_buffer_tensor(XScale) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + copy_atom_qs = fx.make_copy_atom(fx.rocdl.BufferCopy(8), 8) + if const_expr(is_smooth): + copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + + row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) + if const_expr(is_smooth): + xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(1, 1)) + + def _abs_scalar(val): + is_neg = val < c_zero_f + neg_val = c_zero_f - val + return is_neg.select(neg_val, val) + + thread_sumsq = c_zero_f + + # Pass 1: add, write residual_out, and accumulate sumsq + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div, idx_safe) + residual_e = _load_scalar(copy_atom_s, elem_dtype, residual_in_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) + added_e = _to_elem_scalar(dtype_str, elem_dtype, x + residual) + if idx < N: + _store_scalar(copy_atom_s, elem_dtype, elem_dtype, residual_out_div, idx, added_e) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + added2 = added * added + thread_sumsq = thread_sumsq + is_valid.select(added2, c_zero_f) + + sum_sq = block_reduce_add(thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) + + thread_row_max = c_zero_f + # Pass 2: normalize, apply gamma (+ optional smooth scale), and accumulate row max + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div, idx_safe) + added_e = _load_scalar(copy_atom_s, elem_dtype, residual_out_div, idx_safe) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + y = (added * rrms) * g + if const_expr(is_smooth): + s = _load_scalar(copy_atom_xs, fx.Float32, xscale_div, idx_safe) + y = y * s + y_abs = _abs_scalar(y) + thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) + + row_max = block_reduce_max(thread_row_max) + scale = row_max / c_dtype_max + final_scale = (scale == c_zero_f).select(c_one_f, scale) + + if tid == 0: + _store_yscale(scale_copy_atom, yscale_div, bid, final_scale) + + inv_scale = c_one_f / final_scale + + # Pass 3: quantize + store using per-row scale + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + if idx < N: + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div, idx) + added_e = _load_scalar(copy_atom_s, elem_dtype, residual_out_div, idx) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + y = (added * rrms) * g + if const_expr(is_smooth): + s = _load_scalar(copy_atom_xs, fx.Float32, xscale_div, idx) + y = y * s + q = y * inv_scale + q_i8 = q.to(quant_dtype) + _store_scalar(copy_atom_qs, quant_dtype, quant_dtype, out_div, idx, q_i8) + + if is_smooth: + + @flyc.jit + def launch_fused_add_rmsnorm_smoothquant( + Input: fx.Tensor, + ResidualIn: fx.Tensor, Gamma: fx.Tensor, XScale: fx.Tensor, Output: fx.Tensor, + ResidualOut: fx.Tensor, YScale: fx.Tensor, m_in: fx.Int32, stream: fx.Stream = fx.Stream(None), @@ -659,46 +1371,50 @@ def launch_rmsnorm_smoothquant( with InsertionPoint(ctx.gpu_module_body): allocator.finalize() - launcher = rmsnorm_quant_kernel(Input, Gamma, XScale, YScale, Output) + launcher = fused_add_rmsnorm_quant_kernel(Input, ResidualIn, Gamma, XScale, YScale, Output, ResidualOut) launcher.launch( grid=(m_in, 1, 1), block=(BLOCK_THREADS, 1, 1), stream=stream, ) - return launch_rmsnorm_smoothquant + return launch_fused_add_rmsnorm_smoothquant - @flyc.jit - def launch_rmsnorm_dynamicquant( - Input: fx.Tensor, - Gamma: fx.Tensor, - Output: fx.Tensor, - YScale: fx.Tensor, - m_in: fx.Int32, - stream: fx.Stream = fx.Stream(None), - ): - allocator.finalized = False - ctx = CompilationContext.get_current() - with InsertionPoint(ctx.gpu_module_body): - allocator.finalize() + else: - launcher = rmsnorm_quant_kernel(Input, Gamma, Gamma, YScale, Output) - launcher.launch( - grid=(m_in, 1, 1), - block=(BLOCK_THREADS, 1, 1), - stream=stream, - ) + @flyc.jit + def launch_fused_add_rmsnorm_dynamicquant( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = fused_add_rmsnorm_quant_kernel(Input, ResidualIn, Gamma, Gamma, YScale, Output, ResidualOut) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) - return launch_rmsnorm_dynamicquant + return launch_fused_add_rmsnorm_dynamicquant -def build_rmsnorm_dynamicquant_module( +def build_fused_add_rmsnorm_dynamicquant_module( M: int, N: int, dtype_str: str, quant_dtype_str: str = "i8", ): - return _build_rmsnorm_quant_module( + return _build_fused_add_rmsnorm_quant_module( M, N, dtype_str, @@ -707,13 +1423,13 @@ def build_rmsnorm_dynamicquant_module( ) -def build_rmsnorm_smoothquant_module( +def build_fused_add_rmsnorm_smoothquant_module( M: int, N: int, dtype_str: str, quant_dtype_str: str = "i8", ): - return _build_rmsnorm_quant_module( + return _build_fused_add_rmsnorm_quant_module( M, N, dtype_str, diff --git a/tests/kernels/test_rmsnorm.py b/tests/kernels/test_rmsnorm.py index 804ef007..04eae3c9 100644 --- a/tests/kernels/test_rmsnorm.py +++ b/tests/kernels/test_rmsnorm.py @@ -18,6 +18,9 @@ import pytest from kernels.rmsnorm_kernel import ( + build_fused_add_rmsnorm_dynamicquant_module, + build_fused_add_rmsnorm_module, + build_fused_add_rmsnorm_smoothquant_module, build_rmsnorm_dynamicquant_module, build_rmsnorm_module, build_rmsnorm_smoothquant_module, @@ -42,6 +45,7 @@ DTYPE_FP32 = torch.float32 DTYPE_FP16 = torch.float16 DTYPE_BF16 = torch.bfloat16 +DTYPE_INT8 = torch.int8 EPS: float = 1e-5 @@ -154,7 +158,7 @@ def test_all(): m_s, n_s, dt = [x.strip() for x in p.split(",")] configs.append((int(m_s), int(n_s), dt)) else: - # Prefer N multiples of BLOCK_THREADS*VEC_WIDTH (=2048) to exercise the fast path. + # Prefer N multiples of 2048 to exercise the fast path. configs = [ # (64, 256, "f32"), # Aligned # (128, 1024, "f32"), # Aligned @@ -236,7 +240,7 @@ def _get_rmsnorm_configs(): configs.append((int(m_s), int(n_s), dt)) return configs - # Prefer N multiples of BLOCK_THREADS*VEC_WIDTH (=2048) to exercise the fast path. + # Prefer N multiples of 2048 to exercise the fast path. return [ # (64, 256, "f32"), # Aligned # (128, 1024, "f32"), # Aligned @@ -284,7 +288,7 @@ def _bench_aiter_rmsnorm_quant(M: int, N: int, dtype: str, *, is_smooth: bool): yscale = torch.empty((M, 1), dtype=torch.float32, device="cuda") if is_smooth: - xscale = (torch.rand((N,), device="cuda", dtype=torch_dtype) + 0.5).contiguous() + xscale = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() def run_aiter(): aiter_rmsnorm_quant(y, x, xscale, yscale, w, EPS) @@ -303,7 +307,6 @@ def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): mode = "smoothquant" if is_smooth else "dynamicquant" print(f"\nTesting RMSNorm {mode} (M={M}, N={N}, dtype={dtype})") - torch_dtype = _torch_dtype(dtype) try: if is_smooth: launch_fn = build_rmsnorm_smoothquant_module(M, N, dtype) @@ -312,7 +315,6 @@ def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): except Exception as e: print(f"[FAIL] Compile failed for {mode} (M={M}, N={N}, dtype={dtype}): " f"{type(e).__name__}: {e}") return False, None - torch.manual_seed(42) input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) @@ -329,30 +331,33 @@ def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): else: raise ValueError(f"unsupported dtype: {dtype}") - output_dev = torch.empty((M, N), device="cuda", dtype=torch.int8) - yscale_dev = torch.empty((M,), device="cuda", dtype=torch.float32) - + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale_dev = torch.empty((M,), device="cuda", dtype=DTYPE_FP32) xscale_dev = None if is_smooth: - xscale_dev = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + 0.5 + xscale_dev = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() + dequant_tol = 0.25 if is_smooth else 0.2 + scale_tol = 1e-2 if is_smooth else 5e-3 + + # PyTorch Reference: + # RMS(x) = sqrt(mean(x^2) + eps) ; RMSNorm(x) = x / RMS(x) * gamma + # Quant path additionally computes per-row yscale and int8 output from the fp32 reference. + expected, q_ref, yscale_ref = _reference_rmsnorm_quant( + input_dev, + gamma_dev, + xscale_dev=xscale_dev, + ) print("Launching kernel...") stream = torch.cuda.current_stream() def kernel_launch(): if is_smooth: - launch_fn( - input_dev, - gamma_dev, - xscale_dev, - output_dev, - yscale_dev, - M, - stream=stream, - ) + launch_fn(input_dev, gamma_dev, xscale_dev, output_dev, yscale_dev, M, stream=stream) else: launch_fn(input_dev, gamma_dev, output_dev, yscale_dev, M, stream=stream) + # run_perftest returns (data, avg_us) _, avg_us = run_perftest( lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, @@ -364,44 +369,51 @@ def kernel_launch(): flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) avg_ms = avg_us / 1000.0 + # Bandwidth estimate: read input + read gamma + write output elem_bytes = 4 if dtype == "f32" else 2 total_bytes = M * N * elem_bytes + N * elem_bytes + M * N + M * 4 if is_smooth: - total_bytes += N * elem_bytes + total_bytes += N * 4 bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 - print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest " f"(warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") if flydsl_gpu_us is not None: print(f"[Perf] FlyDSL rmsnorm {mode} gpu: {flydsl_gpu_us:.1f} us") - expected, q_ref, yscale_ref = _reference_rmsnorm_quant( - input_dev, - gamma_dev, - xscale_dev=xscale_dev, - ) q_out = output_dev.to(torch.int16) q_expected = q_ref.to(torch.int16) yscale_out = yscale_dev.cpu() yscale_expected = yscale_ref.cpu() + output_ref = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) - q_diff = (q_out - q_expected).abs().max().item() + error = (output_ref - expected).abs().max().item() scale_diff = (yscale_out - yscale_expected).abs().max().item() - recon = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) - recon_err = (recon - expected).abs().max().item() - - scale_tol = 1e-2 if is_smooth else 5e-3 - recon_tol = 0.25 if is_smooth else 0.2 + quant_diff = (q_out - q_expected).abs().max().item() - print(f"Max quant diff: {q_diff}") + print(f"Max dequant error: {error:.2e} (tol={dequant_tol})") print(f"Max scale diff: {scale_diff:.2e} (tol={scale_tol})") - print(f"Max recon error: {recon_err:.2e} (tol={recon_tol})") + print(f"Max quant diff: {quant_diff}") - ok = q_diff <= 1 and scale_diff < scale_tol and recon_err < recon_tol + ok = error < dequant_tol and scale_diff < scale_tol and quant_diff <= 1 if ok: print("PASSED") + ok = True else: print("FAILED") + print("First row Expected:") + print(expected[0, :5]) + print("First row Actual:") + print(output_ref[0, :5]) + print("First row Quant Expected:") + print(q_expected[0, :8]) + print("First row Quant Actual:") + print(q_out[0, :8]) + print("First few YScale Expected:") + print(yscale_expected[:5]) + print("First few YScale Actual:") + print(yscale_out[:5]) + ok = False return ok, flydsl_gpu_us @@ -487,5 +499,467 @@ def test_rmsnorm_smoothquant(): raise SystemExit(1) +def _reference_fused_add_rmsnorm(input_dev, residual_in_dev, gamma_dev): + added = input_dev + residual_in_dev + added_fp32 = added.to(DTYPE_FP32) + gamma = gamma_dev.to(DTYPE_FP32) + expected = (added_fp32 / torch.sqrt((added_fp32 * added_fp32).mean(dim=1, keepdim=True) + EPS)) * gamma + return added_fp32, expected + + +def _bench_aiter_fused_add_rmsnorm(M: int, N: int, dtype: str): + torch_dtype = _torch_dtype(dtype) + + try: + from aiter.ops.triton.normalization.rmsnorm import ( + rmsnorm2d_fwd_with_add as aiter_fused_add_rmsnorm, + ) + except Exception as e: + print(f"[Perf] AIter fused_add rmsnorm skipped: {type(e).__name__}: {e!r}") + return None + + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual_in = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + out = torch.empty((M, N), device="cuda", dtype=torch_dtype) + residual_out = torch.empty((M, N), device="cuda", dtype=torch_dtype) + + def run_aiter(): + aiter_fused_add_rmsnorm(out, x, residual_in, residual_out, w, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter fused_add rmsnorm gpu: {aiter_us:.1f} us") + return aiter_us + + +def run_fused_add_test(M: int, N: int, dtype: str): + print(f"\nTesting FusedAdd RMSNorm (M={M}, N={N}, dtype={dtype})") + + try: + launch_fn = build_fused_add_rmsnorm_module(M, N, dtype) + except Exception as e: + print(f"[FAIL] Compile failed for fused_add rmsnorm (M={M}, N={N}, dtype={dtype}): " f"{type(e).__name__}: {e}") + return False, None + + torch.manual_seed(42) + input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + residual_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + + if dtype == "f32": + input_dev = input_t.contiguous() + residual_in_dev = residual_t.contiguous() + gamma_dev = gamma_t.contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + output_atol = 1e-4 + residual_atol = 1e-4 + elif dtype == "f16": + input_dev = input_t.to(DTYPE_FP16).contiguous() + residual_in_dev = residual_t.to(DTYPE_FP16).contiguous() + gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + output_atol = 1e-2 + residual_atol = 1e-2 + elif dtype == "bf16": + input_dev = input_t.to(DTYPE_BF16).contiguous() + residual_in_dev = residual_t.to(DTYPE_BF16).contiguous() + gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + output_atol = 2e-2 + residual_atol = 2e-2 + else: + raise ValueError(f"unsupported dtype: {dtype}") + + print("Launching kernel...") + stream = torch.cuda.current_stream() + + def kernel_launch(): + launch_fn( + input_dev, + residual_in_dev, + gamma_dev, + output_dev, + residual_out_dev, + M, + stream=stream, + ) + + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), + num_iters=BENCH_ITERS, + num_warmup=WARMUP_ITERS, + ) + torch.cuda.synchronize() + flydsl_gpu_us = None + if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": + flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + avg_ms = avg_us / 1000.0 + + elem_bytes = 4 if dtype == "f32" else 2 + total_bytes = (4 * M * N + N) * elem_bytes + bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 + + print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest " f"(warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") + if flydsl_gpu_us is not None: + print(f"[Perf] FlyDSL fused_add rmsnorm gpu: {flydsl_gpu_us:.1f} us") + + residual_expected, output_expected = _reference_fused_add_rmsnorm( + input_dev, + residual_in_dev, + gamma_dev, + ) + residual_out_ref = residual_out_dev.to(DTYPE_FP32) + output_ref = output_dev.to(DTYPE_FP32) + + residual_error = (residual_out_ref - residual_expected).abs().max().item() + output_error = (output_ref - output_expected).abs().max().item() + + print(f"Max residual error: {residual_error:.2e} (atol={residual_atol})") + print(f"Max output error: {output_error:.2e} (atol={output_atol})") + + ok = residual_error < residual_atol and output_error < output_atol + if ok: + print("PASSED") + else: + print("FAILED") + return ok, flydsl_gpu_us + + +def test_rmsnorm_fused_add(): + print("=" * 80) + print("Running FusedAdd RMSNorm Tests") + print("=" * 80) + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + failures = 0 + + for M, N, dtype in _get_rmsnorm_configs(): + ok, flydsl_gpu_us = run_fused_add_test(M, N, dtype) + if not ok: + failures += 1 + + if do_compare: + aiter_us = None + if maybe_enable_aiter(): + aiter_us = _bench_aiter_fused_add_rmsnorm(M, N, dtype) + perf_rows.append( + PerfRow( + op="rmsnorm_add", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + # Ensure a non-zero exit code on failure for shell wrappers. + if failures != 0: + raise SystemExit(1) + + +def _reference_fused_add_rmsnorm_quant( + input_dev, + residual_in_dev, + gamma_dev, + *, + xscale_dev=None, +): + added = input_dev + residual_in_dev + residual_expected = added.to(DTYPE_FP32) + expected, q, yscale = _reference_rmsnorm_quant( + added, + gamma_dev, + xscale_dev=xscale_dev, + ) + return residual_expected, expected, q, yscale + + +def _bench_aiter_fused_add_rmsnorm_quant( + M: int, + N: int, + dtype: str, + *, + is_smooth: bool, +): + mode = "smoothquant" if is_smooth else "dynamicquant" + torch_dtype = _torch_dtype(dtype) + + try: + if is_smooth: + from aiter.ops.triton.normalization.rmsnorm import ( + rmsnorm2d_fwd_with_add_smoothquant as aiter_fused_add_rmsnorm_quant, + ) + else: + from aiter.ops.triton.normalization.rmsnorm import ( + rmsnorm2d_fwd_with_add_dynamicquant as aiter_fused_add_rmsnorm_quant, + ) + except Exception as e: + print(f"[Perf] AIter fused_add rmsnorm {mode} skipped: {type(e).__name__}: {e!r}") + return None + + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual_in = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + y = torch.empty((M, N), dtype=torch.int8, device="cuda") + residual_out = torch.empty((M, N), device="cuda", dtype=torch_dtype) + yscale = torch.empty((M, 1), dtype=torch.float32, device="cuda") + + if is_smooth: + xscale = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() + + def run_aiter(): + aiter_fused_add_rmsnorm_quant(y, x, residual_in, residual_out, xscale, yscale, w, EPS) + + else: + + def run_aiter(): + aiter_fused_add_rmsnorm_quant(y, x, residual_in, residual_out, yscale, w, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter fused_add rmsnorm {mode} gpu: {aiter_us:.1f} us") + return aiter_us + + +def run_fused_add_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): + mode = "smoothquant" if is_smooth else "dynamicquant" + print(f"\nTesting FusedAdd RMSNorm {mode} (M={M}, N={N}, dtype={dtype})") + + try: + if is_smooth: + launch_fn = build_fused_add_rmsnorm_smoothquant_module(M, N, dtype) + else: + launch_fn = build_fused_add_rmsnorm_dynamicquant_module(M, N, dtype) + except Exception as e: + print( + f"[FAIL] Compile failed for fused_add rmsnorm {mode} " + f"(M={M}, N={N}, dtype={dtype}): {type(e).__name__}: {e}" + ) + return False, None + + torch.manual_seed(42) + input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + residual_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + + if dtype == "f32": + input_dev = input_t.contiguous() + residual_in_dev = residual_t.contiguous() + gamma_dev = gamma_t.contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + residual_atol = 1e-4 + elif dtype == "f16": + input_dev = input_t.to(DTYPE_FP16).contiguous() + residual_in_dev = residual_t.to(DTYPE_FP16).contiguous() + gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + residual_atol = 1e-2 + elif dtype == "bf16": + input_dev = input_t.to(DTYPE_BF16).contiguous() + residual_in_dev = residual_t.to(DTYPE_BF16).contiguous() + gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + residual_atol = 2e-2 + else: + raise ValueError(f"unsupported dtype: {dtype}") + + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale_dev = torch.empty((M,), device="cuda", dtype=DTYPE_FP32) + xscale_dev = None + if is_smooth: + xscale_dev = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() + dequant_tol = 0.25 if is_smooth else 0.2 + scale_tol = 1e-2 if is_smooth else 5e-3 + + residual_expected, expected, q_ref, yscale_ref = _reference_fused_add_rmsnorm_quant( + input_dev, + residual_in_dev, + gamma_dev, + xscale_dev=xscale_dev, + ) + + print("Launching kernel...") + stream = torch.cuda.current_stream() + + def kernel_launch(): + if is_smooth: + launch_fn( + input_dev, + residual_in_dev, + gamma_dev, + xscale_dev, + output_dev, + residual_out_dev, + yscale_dev, + M, + stream=stream, + ) + else: + launch_fn( + input_dev, + residual_in_dev, + gamma_dev, + output_dev, + residual_out_dev, + yscale_dev, + M, + stream=stream, + ) + + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), + num_iters=BENCH_ITERS, + num_warmup=WARMUP_ITERS, + ) + torch.cuda.synchronize() + flydsl_gpu_us = None + if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": + flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + avg_ms = avg_us / 1000.0 + + elem_bytes = 4 if dtype == "f32" else 2 + total_bytes = 3 * M * N * elem_bytes + N * elem_bytes + M * N + M * 4 + if is_smooth: + total_bytes += N * 4 + bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 + + print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") + if flydsl_gpu_us is not None: + print(f"[Perf] FlyDSL fused_add rmsnorm {mode} gpu: {flydsl_gpu_us:.1f} us") + + residual_out_ref = residual_out_dev.to(DTYPE_FP32) + output_ref = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) + q_out = output_dev.to(torch.int16) + q_expected = q_ref.to(torch.int16) + yscale_out = yscale_dev.cpu() + yscale_expected = yscale_ref.cpu() + + residual_error = (residual_out_ref - residual_expected).abs().max().item() + dequant_error = (output_ref - expected).abs().max().item() + scale_diff = (yscale_out - yscale_expected).abs().max().item() + quant_diff = (q_out - q_expected).abs().max().item() + + print(f"Max residual error: {residual_error:.2e} (tol={residual_atol})") + print(f"Max dequant error: {dequant_error:.2e} (tol={dequant_tol})") + print(f"Max scale diff: {scale_diff:.2e} (tol={scale_tol})") + print(f"Max quant diff: {quant_diff}") + + ok = residual_error < residual_atol and dequant_error < dequant_tol and scale_diff < scale_tol and quant_diff <= 1 + if ok: + print("PASSED") + else: + print("FAILED") + print("First row Residual Expected:") + print(residual_expected[0, :5]) + print("First row Residual Actual:") + print(residual_out_ref[0, :5]) + print("First row Expected:") + print(expected[0, :5]) + print("First row Actual:") + print(output_ref[0, :5]) + print("First row Quant Expected:") + print(q_expected[0, :8]) + print("First row Quant Actual:") + print(q_out[0, :8]) + print("First few YScale Expected:") + print(yscale_expected[:5]) + print("First few YScale Actual:") + print(yscale_out[:5]) + return ok, flydsl_gpu_us + + +def test_rmsnorm_fused_add_dynamicquant(): + print("=" * 80) + print("Running FusedAdd RMSNorm DynamicQuant Tests") + print("=" * 80) + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + failures = 0 + + for M, N, dtype in _get_rmsnorm_configs(): + ok, flydsl_gpu_us = run_fused_add_quant_test(M, N, dtype, is_smooth=False) + if not ok: + failures += 1 + + if do_compare: + aiter_us = None + if maybe_enable_aiter(): + aiter_us = _bench_aiter_fused_add_rmsnorm_quant(M, N, dtype, is_smooth=False) + perf_rows.append( + PerfRow( + op="rmsnorm_add_dq", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + if failures != 0: + raise SystemExit(1) + + +def test_rmsnorm_fused_add_smoothquant(): + print("=" * 80) + print("Running FusedAdd RMSNorm SmoothQuant Tests") + print("=" * 80) + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + failures = 0 + + for M, N, dtype in _get_rmsnorm_configs(): + ok, flydsl_gpu_us = run_fused_add_quant_test(M, N, dtype, is_smooth=True) + if not ok: + failures += 1 + + if do_compare: + aiter_us = None + if maybe_enable_aiter(): + aiter_us = _bench_aiter_fused_add_rmsnorm_quant(M, N, dtype, is_smooth=True) + perf_rows.append( + PerfRow( + op="rmsnorm_add_sq", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + if failures != 0: + raise SystemExit(1) + + if __name__ == "__main__": test_all()