From 60366466ca802b4d3c992cd7cae448fc0a6a5833 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Wed, 6 May 2026 20:57:58 +0800 Subject: [PATCH 01/11] add fused_add_rmsnorm kernel --- kernels/rmsnorm_kernel.py | 269 ++++++++++++++++++++++++++++++++++ tests/kernels/test_rmsnorm.py | 174 ++++++++++++++++++++++ 2 files changed, 443 insertions(+) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index 99920cbf..1ab081c7 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -297,6 +297,275 @@ def launch_rmsnorm( return launch_rmsnorm +def build_fused_add_rmsnorm_module(M: int, N: int, dtype_str: str): + arch = get_hip_arch() + USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") + + tile_cols = BLOCK_THREADS * VEC_WIDTH + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) + elem_bits = 32 if dtype_str == "f32" else 16 + + allocator = SmemAllocator(None, arch=arch) + f32_bytes = 4 + red_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = red_offset + RED_SLOTS * f32_bytes + red2_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = red2_offset + RED_SLOTS * f32_bytes + + @flyc.kernel + def fused_add_rmsnorm_kernel( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_dtype = dtype_to_elem_type(dtype_str) + elem_type = elem_dtype.ir_type + fm_fast = arith.FastMathFlags.fast + eps_c = EPS + n_float = float(N) + + base_ptr = allocator.get_base() + s_red = SmemPtr(base_ptr, red_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_red2 = SmemPtr(base_ptr, red2_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_red.get() + s_red2.get() + + def wave_reduce_add(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) + w = w.addf(peer, fastmath=fm_fast) + return w + + def block_reduce_add(val): + dummy = fx.Float32(0.0) + r0, _ = block_reduce_add2(val, dummy) + return r0 + + def block_reduce_add2(val0, val1): + if const_expr(RED_SLOTS == 1): + return wave_reduce_add(val0), wave_reduce_add(val1) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + w0 = wave_reduce_add(val0) + w1 = wave_reduce_add(val1) + + if lane == 0: + SmemPtr.store(s_red, w0, [wave]) + SmemPtr.store(s_red2, w1, [wave]) + gpu.barrier() + + if wave == 0: + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, 0) + v0 = SmemPtr.load(s_red, [lane_safe]) + v1 = SmemPtr.load(s_red2, [lane_safe]) + ww0 = in_range.select(v0, 0.0) + ww1 = in_range.select(v1, 0.0) + ww0 = wave_reduce_add(ww0) + ww1 = wave_reduce_add(ww1) + + if lane == 0: + SmemPtr.store(s_red, ww0, [0]) + SmemPtr.store(s_red2, ww1, [0]) + gpu.barrier() + + return SmemPtr.load(s_red, [0]), SmemPtr.load(s_red2, [0]) + + if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16): + num_tiles = N // tile_cols + + Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + + row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(VEC_WIDTH, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(VEC_WIDTH, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(VEC_WIDTH, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) + + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + vec_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register + ) + vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) + + def _load_vec(div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + def _store_vec(val, div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.memref_store_vec(val, r) + fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) + + def _to_elem_vec(y): + if const_expr(dtype_str == "bf16"): + if const_expr(USE_HW_CVT_PK_BF16_F32): + return y.to(elem_dtype) + u = y.bitcast(fx.Uint32) + upper = u >> 16 + lsb = upper & 1 + bias = lsb + 0x7FFF + u_round = y.bitcast(fx.Uint32) + bias + bf16_bits = u_round >> 16 + even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) + odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) + odd_sh = odd << 16 + packed = even | odd_sh + return packed.bitcast(elem_dtype) + if const_expr(dtype_str == "f32"): + return y + return y.to(elem_dtype) + + c_zero_f = fx.Float32(0.0) + thread_sumsq = c_zero_f + thread_dummy = c_zero_f + add_local = [] + + for tile_i in range_constexpr(num_tiles): + idx = tid + tile_i * BLOCK_THREADS + x = _load_vec(in_div, idx).to(fx.Float32) + residual = _load_vec(residual_in_div, idx).to(fx.Float32) + added = x + residual + add_local.append(added) + + added2 = added * added + red2 = added2.reduce(ReductionOp.ADD, fastmath=fm_fast) + thread_sumsq = thread_sumsq + red2 + + _store_vec(_to_elem_vec(added), residual_out_div, idx) + + _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = ms_eps.rsqrt(fastmath=fm_fast) + + for tile_i in range_constexpr(num_tiles): + idx = tid + tile_i * BLOCK_THREADS + g = _load_vec(gamma_div, idx).to(fx.Float32) + y = (add_local[tile_i] * rrms) * g + _store_vec(_to_elem_vec(y), out_div, idx) + + else: + Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + + row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + scalar_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + scalar_reg_lay = fx.make_layout(1, 1) + + row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) + + def _load_scalar(divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + fx.copy_atom_call(copy_atom_s, view, r) + return fx.memref_load_vec(r)[0] + + def _store_scalar(divided_tensor, index, val): + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_s, r, view) + + def _to_elem_scalar(y): + if const_expr(dtype_str == "f32"): + return y + return y.to(elem_dtype) + + c_zero_f = fx.Float32(0.0) + thread_sumsq = c_zero_f + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(row_div, idx_safe) + residual_e = _load_scalar(residual_in_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) + added = x + residual + added2 = added * added + thread_sumsq = thread_sumsq + is_valid.select(added2, c_zero_f) + + sum_sq = block_reduce_add(thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + if idx < N: + x_e = _load_scalar(row_div, idx) + residual_e = _load_scalar(residual_in_div, idx) + g_e = _load_scalar(gamma_div, idx) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + added = x + residual + y = (added * rrms) * g + _store_scalar(residual_out_div, idx, _to_elem_scalar(added)) + _store_scalar(out_div, idx, _to_elem_scalar(y)) + + @flyc.jit + def launch_fused_add_rmsnorm( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = fused_add_rmsnorm_kernel(Input, ResidualIn, Gamma, Output, ResidualOut) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_fused_add_rmsnorm + + def _quant_dtype_to_elem_type(dtype_str: str): if dtype_str in ("i8", "int8"): return T.i8 diff --git a/tests/kernels/test_rmsnorm.py b/tests/kernels/test_rmsnorm.py index 0f1f8447..6209cc6c 100644 --- a/tests/kernels/test_rmsnorm.py +++ b/tests/kernels/test_rmsnorm.py @@ -40,6 +40,7 @@ EPS: float = 1e-5 from kernels.rmsnorm_kernel import ( build_rmsnorm_module, + build_fused_add_rmsnorm_module, build_rmsnorm_dynamicquant_module, build_rmsnorm_smoothquant_module, KERNEL_NAME as RMSNORM_KERNEL_NAME, @@ -480,5 +481,178 @@ def test_rmsnorm_smoothquant(): raise SystemExit(1) +def _reference_fused_add_rmsnorm(input_dev, residual_in_dev, gamma_dev): + added = input_dev.to(DTYPE_FP32) + residual_in_dev.to(DTYPE_FP32) + gamma = gamma_dev.to(DTYPE_FP32) + expected = (added / torch.sqrt((added * added).mean(dim=1, keepdim=True) + EPS)) * gamma + return added, expected + + +def _bench_aiter_fused_add_rmsnorm(M: int, N: int, dtype: str): + torch_dtype = _torch_dtype(dtype) + + try: + from aiter.ops.triton.normalization.rmsnorm import ( + rmsnorm2d_fwd_with_add as aiter_fused_add_rmsnorm, + ) + except Exception as e: + print(f"[Perf] AIter fused_add rmsnorm skipped: {type(e).__name__}: {e!r}") + return None + + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual_in = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + out = torch.empty((M, N), device="cuda", dtype=torch_dtype) + residual_out = torch.empty((M, N), device="cuda", dtype=torch_dtype) + + def run_aiter(): + aiter_fused_add_rmsnorm(out, x, residual_in, residual_out, w, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter fused_add rmsnorm gpu: {aiter_us:.1f} us") + return aiter_us + + +def run_fused_add_test(M: int, N: int, dtype: str): + print(f"\nTesting FusedAdd RMSNorm (M={M}, N={N}, dtype={dtype})") + + try: + launch_fn = build_fused_add_rmsnorm_module(M, N, dtype) + except Exception as e: + print( + f"[FAIL] Compile failed for fused_add rmsnorm (M={M}, N={N}, dtype={dtype}): " + f"{type(e).__name__}: {e}" + ) + return False, None + + torch.manual_seed(42) + input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + residual_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + + if dtype == "f32": + input_dev = input_t.contiguous() + residual_in_dev = residual_t.contiguous() + gamma_dev = gamma_t.contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + atol = 1e-4 + elif dtype == "f16": + input_dev = input_t.to(DTYPE_FP16).contiguous() + residual_in_dev = residual_t.to(DTYPE_FP16).contiguous() + gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + atol = 1e-2 + elif dtype == "bf16": + input_dev = input_t.to(DTYPE_BF16).contiguous() + residual_in_dev = residual_t.to(DTYPE_BF16).contiguous() + gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + atol = 2e-2 + else: + raise ValueError(f"unsupported dtype: {dtype}") + + print("Launching kernel...") + stream = torch.cuda.current_stream() + + def kernel_launch(): + launch_fn( + input_dev, + residual_in_dev, + gamma_dev, + output_dev, + residual_out_dev, + M, + stream=stream, + ) + + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), + num_iters=BENCH_ITERS, + num_warmup=WARMUP_ITERS, + ) + torch.cuda.synchronize() + flydsl_gpu_us = None + if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": + flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + avg_ms = avg_us / 1000.0 + + elem_bytes = 4 if dtype == "f32" else 2 + total_bytes = (4 * M * N + N) * elem_bytes + bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 + + print( + f"Kernel avg time: {avg_ms:.4f} ms via run_perftest " + f"(warmup={WARMUP_ITERS}, iters={BENCH_ITERS})" + ) + print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") + if flydsl_gpu_us is not None: + print(f"[Perf] FlyDSL fused_add rmsnorm gpu: {flydsl_gpu_us:.1f} us") + + residual_expected, output_expected = _reference_fused_add_rmsnorm( + input_dev, + residual_in_dev, + gamma_dev, + ) + residual_out_ref = residual_out_dev.to(DTYPE_FP32) + output_ref = output_dev.to(DTYPE_FP32) + + residual_error = (residual_out_ref - residual_expected).abs().max().item() + output_error = (output_ref - output_expected).abs().max().item() + + print(f"Max residual error: {residual_error:.2e} (atol={atol})") + print(f"Max output error: {output_error:.2e} (atol={atol})") + + ok = residual_error < atol and output_error < atol + if ok: + print("PASSED") + else: + print("FAILED") + return ok, flydsl_gpu_us + + +def test_rmsnorm_fused_add(): + print("="*80) + print("Running FusedAdd RMSNorm Tests") + print("="*80) + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + failures = 0 + + for M, N, dtype in _get_rmsnorm_configs(): + ok, flydsl_gpu_us = run_fused_add_test(M, N, dtype) + if not ok: + failures += 1 + + if do_compare: + aiter_us = None + if maybe_enable_aiter(): + aiter_us = _bench_aiter_fused_add_rmsnorm(M, N, dtype) + perf_rows.append( + PerfRow( + op="rmsnorm_add", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "="*80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("="*80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + # Ensure a non-zero exit code on failure for shell wrappers. + if failures != 0: + raise SystemExit(1) + + if __name__ == "__main__": test_all() From 03bb5eefad56338278b9ded7b8b6c3bfc1d5c7e8 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Wed, 6 May 2026 21:04:15 +0800 Subject: [PATCH 02/11] refactor: align rmsnorm quant kernels with base kernel style --- kernels/rmsnorm_kernel.py | 166 +++++++++++++++++--------------------- 1 file changed, 76 insertions(+), 90 deletions(-) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index 1ab081c7..c07a3034 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -17,10 +17,7 @@ from flydsl._mlir.ir import InsertionPoint from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith, const_expr, gpu, range_constexpr -from flydsl.expr.arith import ArithValue from flydsl.expr import math as fmath -from flydsl.expr.numeric import Numeric, Float32, Uint32 -from flydsl.expr.typing import T, Int32 from flydsl.expr.vector import ReductionOp, full from flydsl.runtime.device import get_rocm_arch as get_hip_arch from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr @@ -568,7 +565,7 @@ def launch_fused_add_rmsnorm( def _quant_dtype_to_elem_type(dtype_str: str): if dtype_str in ("i8", "int8"): - return T.i8 + return fx.Int8 raise ValueError(f"unsupported quant dtype: {dtype_str!r} (expected 'i8' or 'int8')") @@ -613,51 +610,50 @@ def rmsnorm_quant_kernel( elem_dtype = dtype_to_elem_type(dtype_str) elem_type = elem_dtype.ir_type - quant_elem_type = _quant_dtype_to_elem_type(quant_dtype_str) - quant_dtype = Numeric.from_ir_type(quant_elem_type) - compute_type = T.f32 + quant_dtype = _quant_dtype_to_elem_type(quant_dtype_str) + quant_elem_type = quant_dtype.ir_type fm_fast = arith.FastMathFlags.fast - eps_c = arith.constant(EPS, type=compute_type) - n_float = arith.constant(float(N), type=compute_type) - c_zero_f = arith.constant(0.0, type=compute_type) - c_one_f = arith.constant(1.0, type=compute_type) - c_neg_inf = arith.constant(float("-inf"), type=compute_type) - c_dtype_max = arith.constant(quant_dtype_max, type=compute_type) + eps_c = EPS + n_float = float(N) + c_zero_f = fx.Float32(0.0) + c_one_f = fx.Float32(1.0) + c_neg_inf = fx.Float32(float("-inf")) + c_dtype_max = fx.Float32(quant_dtype_max) base_ptr = allocator.get_base() - s_red = SmemPtr(base_ptr, red_offset, T.f32, shape=(RED_SLOTS,)) - s_red2 = SmemPtr(base_ptr, red2_offset, T.f32, shape=(RED_SLOTS,)) + s_red = SmemPtr(base_ptr, red_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_red2 = SmemPtr(base_ptr, red2_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) s_red.get() s_red2.get() YScale_buf = fx.rocdl.make_buffer_tensor(YScale) yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1)) scale_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) - scale_reg_ty = fx.MemRefType.get(T.f32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + scale_reg_ty = fx.MemRefType.get( + fx.Float32.ir_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) scale_reg_lay = fx.make_layout(1, 1) def _store_yscale(index, val): r = fx.memref_alloca(scale_reg_ty, scale_reg_lay) - ts = full(1, Float32(val), Float32) + ts = full(1, fx.Float32(val), fx.Float32) fx.memref_store_vec(ts, r) fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) def wave_reduce_add(x): - width_i32 = fx.Int32(WARP_SIZE) w = x for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): - off = fx.Int32(WARP_SIZE // (2 << _sh_exp)) - peer = w.shuffle_xor(off, width_i32) + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) w = w.addf(peer, fastmath=fm_fast) return w def wave_reduce_max(x): - width_i32 = fx.Int32(WARP_SIZE) w = x for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): - off = fx.Int32(WARP_SIZE // (2 << _sh_exp)) - peer = w.shuffle_xor(off, width_i32) + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) w = w.maximumf(peer) return w @@ -676,31 +672,27 @@ def block_reduce_add2(val0, val1): w0 = wave_reduce_add(val0) w1 = wave_reduce_add(val1) - if lane == fx.Int32(0): - wave_idx = ArithValue(wave).index_cast(T.index) - SmemPtr.store(s_red, w0, [wave_idx]) - SmemPtr.store(s_red2, w1, [wave_idx]) + if lane == 0: + SmemPtr.store(s_red, w0, [wave]) + SmemPtr.store(s_red2, w1, [wave]) gpu.barrier() - if wave == fx.Int32(0): + if wave == 0: in_range = lane < RED_SLOTS - lane_safe = in_range.select(lane, fx.Int32(0)) - lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) - v0 = SmemPtr.load(s_red, [lane_safe_idx]) - v1 = SmemPtr.load(s_red2, [lane_safe_idx]) + lane_safe = in_range.select(lane, 0) + v0 = SmemPtr.load(s_red, [lane_safe]) + v1 = SmemPtr.load(s_red2, [lane_safe]) ww0 = in_range.select(v0, c_zero_f) ww1 = in_range.select(v1, c_zero_f) ww0 = wave_reduce_add(ww0) ww1 = wave_reduce_add(ww1) - if lane == fx.Int32(0): - c0_idx = fx.Index(0) - SmemPtr.store(s_red, ww0, [c0_idx]) - SmemPtr.store(s_red2, ww1, [c0_idx]) + if lane == 0: + SmemPtr.store(s_red, ww0, [0]) + SmemPtr.store(s_red2, ww1, [0]) gpu.barrier() - c0_idx = fx.Index(0) - return SmemPtr.load(s_red, [c0_idx]), SmemPtr.load(s_red2, [c0_idx]) + return SmemPtr.load(s_red, [0]), SmemPtr.load(s_red2, [0]) def block_reduce_max(val): if const_expr(RED_SLOTS == 1): @@ -710,25 +702,21 @@ def block_reduce_max(val): wave = tid // WARP_SIZE w = wave_reduce_max(val) - if lane == fx.Int32(0): - wave_idx = ArithValue(wave).index_cast(T.index) - SmemPtr.store(s_red, w, [wave_idx]) + if lane == 0: + SmemPtr.store(s_red, w, [wave]) gpu.barrier() - if wave == fx.Int32(0): + if wave == 0: in_range = lane < RED_SLOTS - lane_safe = in_range.select(lane, fx.Int32(0)) - lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) - v = SmemPtr.load(s_red, [lane_safe_idx]) + lane_safe = in_range.select(lane, 0) + v = SmemPtr.load(s_red, [lane_safe]) ww = in_range.select(v, c_neg_inf) ww = wave_reduce_max(ww) - if lane == fx.Int32(0): - c0_idx = fx.Index(0) - SmemPtr.store(s_red, ww, [c0_idx]) + if lane == 0: + SmemPtr.store(s_red, ww, [0]) gpu.barrier() - c0_idx = fx.Index(0) - return SmemPtr.load(s_red, [c0_idx]) + return SmemPtr.load(s_red, [0]) # ================================================================== # Fast path: N is a multiple of tile_cols @@ -736,7 +724,7 @@ def block_reduce_max(val): if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16): num_tiles = N // tile_cols quant_half_width = VEC_WIDTH // 2 - abs_mask = full(VEC_WIDTH, Uint32(0x7FFFFFFF), Uint32) + abs_mask = full(VEC_WIDTH, fx.Uint32(0x7FFFFFFF), fx.Uint32) Input_buf = fx.rocdl.make_buffer_tensor(Input) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) @@ -782,13 +770,13 @@ def _store_q_vec(val, div_tensor, idx): idx = tid + tile_i * BLOCK_THREADS vec = _load_vec(in_div, idx) in_local.append(vec) - x = vec.to(Float32) + x = vec.to(fx.Float32) x2 = x * x red2 = x2.reduce(ReductionOp.ADD, fastmath=fm_fast) - thread_sumsq = ArithValue(thread_sumsq) + red2 + thread_sumsq = thread_sumsq + red2 _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) - mean_sq = ArithValue(sum_sq) / n_float + mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c rrms = ms_eps.rsqrt(fastmath=fm_fast) @@ -798,26 +786,26 @@ def _store_q_vec(val, div_tensor, idx): for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - g = _load_vec(gamma_div, idx).to(Float32) - x = in_local[tile_i].to(Float32) + g = _load_vec(gamma_div, idx).to(fx.Float32) + x = in_local[tile_i].to(fx.Float32) y = (x * rrms) * g if const_expr(is_smooth): - s = _load_vec(xscale_div, idx).to(Float32) + s = _load_vec(xscale_div, idx).to(fx.Float32) y = y * s y_local.append(y) - y_abs = (y.bitcast(Uint32) & abs_mask).bitcast(Float32) + y_abs = (y.bitcast(fx.Uint32) & abs_mask).bitcast(fx.Float32) tile_max = y_abs.reduce(ReductionOp.MAX) thread_row_max = thread_row_max.maximumf(tile_max) row_max = block_reduce_max(thread_row_max) - scale = ArithValue(row_max) / c_dtype_max + scale = row_max / c_dtype_max final_scale = (scale == c_zero_f).select(c_one_f, scale) - if tid == fx.Int32(0): + if tid == 0: _store_yscale(bid, final_scale) - inv_scale = ArithValue(c_one_f) / ArithValue(final_scale) + inv_scale = c_one_f / final_scale for tile_i in range_constexpr(num_tiles): q = y_local[tile_i] * inv_scale @@ -862,7 +850,7 @@ def _load_scalar(divided_tensor, index): view = fx.slice(divided_tensor, (None, index)) r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0].ir_value() + return fx.memref_load_vec(r)[0] def _store_quant_scalar(divided_tensor, index, val): r = fx.memref_alloca(scalar_reg_ty_q, scalar_reg_lay_q) @@ -873,67 +861,65 @@ def _store_quant_scalar(divided_tensor, index, val): def _abs_scalar(val): is_neg = val < c_zero_f - neg_val = c_zero_f - ArithValue(val) + neg_val = c_zero_f - val return is_neg.select(neg_val, val) thread_sumsq = c_zero_f - c_N_i32 = Int32(N) - c0_i = Int32(0) for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int - is_valid = idx < c_N_i32 - idx_safe = is_valid.select(idx, c0_i) + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) x_e = _load_scalar(row_div, idx_safe) - x = x_e if dtype_str == "f32" else x_e.extf(compute_type) - x2 = ArithValue(x) * ArithValue(x) - thread_sumsq = ArithValue(thread_sumsq) + is_valid.select(x2, c_zero_f) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + x2 = x * x + thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) sum_sq = block_reduce_add(thread_sumsq) - mean_sq = ArithValue(sum_sq) / n_float + mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c - rrms = ms_eps.rsqrt(fastmath=fm_fast) + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) thread_row_max = c_zero_f for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int - is_valid = idx < c_N_i32 - idx_safe = is_valid.select(idx, c0_i) + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) x_e = _load_scalar(row_div, idx_safe) g_e = _load_scalar(gamma_div, idx_safe) - x = x_e if dtype_str == "f32" else x_e.extf(compute_type) - g = g_e if dtype_str == "f32" else g_e.extf(compute_type) - y = (ArithValue(x) * ArithValue(rrms)) * ArithValue(g) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + y = (x * rrms) * g if const_expr(is_smooth): s_e = _load_scalar(xscale_div, idx_safe) - s = s_e if dtype_str == "f32" else s_e.extf(compute_type) - y = ArithValue(y) * ArithValue(s) + s = s_e if dtype_str == "f32" else s_e.to(fx.Float32) + y = y * s y_abs = _abs_scalar(y) thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) row_max = block_reduce_max(thread_row_max) - scale = ArithValue(row_max) / c_dtype_max + scale = row_max / c_dtype_max final_scale = (scale == c_zero_f).select(c_one_f, scale) - if tid == fx.Int32(0): + if tid == 0: _store_yscale(bid, final_scale) - inv_scale = ArithValue(c_one_f) / ArithValue(final_scale) + inv_scale = c_one_f / final_scale for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int - if arith.cmpi(arith.CmpIPredicate.ult, idx, c_N_i32): + if idx < N: x_e = _load_scalar(row_div, idx) g_e = _load_scalar(gamma_div, idx) - x = x_e if dtype_str == "f32" else x_e.extf(compute_type) - g = g_e if dtype_str == "f32" else g_e.extf(compute_type) - y = (ArithValue(x) * ArithValue(rrms)) * ArithValue(g) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + y = (x * rrms) * g if const_expr(is_smooth): s_e = _load_scalar(xscale_div, idx) - s = s_e if dtype_str == "f32" else s_e.extf(compute_type) - y = ArithValue(y) * ArithValue(s) - q = ArithValue(y) * ArithValue(inv_scale) - q_i8 = quant_dtype(q) + s = s_e if dtype_str == "f32" else s_e.to(fx.Float32) + y = y * s + q = y * inv_scale + q_i8 = q.to(quant_dtype) _store_quant_scalar(out_div, idx, q_i8) if is_smooth: From 18924e6014e4051495ac11062d44f0bc474a6698 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Wed, 6 May 2026 23:16:13 +0800 Subject: [PATCH 03/11] test: separate rmsnorm variant configs and refine fused_add checks --- tests/kernels/test_rmsnorm.py | 63 ++++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_rmsnorm.py b/tests/kernels/test_rmsnorm.py index 6209cc6c..10490833 100644 --- a/tests/kernels/test_rmsnorm.py +++ b/tests/kernels/test_rmsnorm.py @@ -238,6 +238,48 @@ def _get_rmsnorm_configs(): ] +def _get_rmsnorm_quant_configs(): + shapes_env = os.environ.get("ROCDSL_RMSNORM_QUANT_SHAPES", "").strip() + if not shapes_env: + shapes_env = os.environ.get("ROCDSL_RMSNORM_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + m_s, n_s, dt = [x.strip() for x in p.split(",")] + configs.append((int(m_s), int(n_s), dt)) + return configs + + return [ + (128, 4096, "f16"), # Aligned + (173, 409, "f16"), # Unaligned (tail handling) + (256, 4096, "bf16"), # BF16 + ] + + +def _get_rmsnorm_fused_add_configs(): + shapes_env = os.environ.get("ROCDSL_RMSNORM_FUSED_ADD_SHAPES", "").strip() + if not shapes_env: + shapes_env = os.environ.get("ROCDSL_RMSNORM_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + m_s, n_s, dt = [x.strip() for x in p.split(",")] + configs.append((int(m_s), int(n_s), dt)) + return configs + + return [ + (128, 4096, "f16"), # Aligned + (173, 409, "f16"), # Unaligned (tail handling) + (1024, 8192, "bf16"), # BF16 + ] + + def _reference_rmsnorm_quant(input_dev, gamma_dev, *, xscale_dev=None): x = input_dev.to(DTYPE_FP32) gamma = gamma_dev.to(DTYPE_FP32) @@ -408,7 +450,7 @@ def test_rmsnorm_dynamicquant(): perf_rows = [] failures = 0 - for M, N, dtype in _get_rmsnorm_configs(): + for M, N, dtype in _get_rmsnorm_quant_configs(): ok, flydsl_gpu_us = run_quant_test(M, N, dtype, is_smooth=False) if not ok: failures += 1 @@ -449,7 +491,7 @@ def test_rmsnorm_smoothquant(): perf_rows = [] failures = 0 - for M, N, dtype in _get_rmsnorm_configs(): + for M, N, dtype in _get_rmsnorm_quant_configs(): ok, flydsl_gpu_us = run_quant_test(M, N, dtype, is_smooth=True) if not ok: failures += 1 @@ -536,21 +578,24 @@ def run_fused_add_test(M: int, N: int, dtype: str): gamma_dev = gamma_t.contiguous() output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) - atol = 1e-4 + output_atol = 1e-4 + residual_atol = 1e-4 elif dtype == "f16": input_dev = input_t.to(DTYPE_FP16).contiguous() residual_in_dev = residual_t.to(DTYPE_FP16).contiguous() gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) - atol = 1e-2 + output_atol = 1e-2 + residual_atol = 1e-2 elif dtype == "bf16": input_dev = input_t.to(DTYPE_BF16).contiguous() residual_in_dev = residual_t.to(DTYPE_BF16).contiguous() gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) - atol = 2e-2 + output_atol = 2e-2 + residual_atol = 4e-2 else: raise ValueError(f"unsupported dtype: {dtype}") @@ -602,10 +647,10 @@ def kernel_launch(): residual_error = (residual_out_ref - residual_expected).abs().max().item() output_error = (output_ref - output_expected).abs().max().item() - print(f"Max residual error: {residual_error:.2e} (atol={atol})") - print(f"Max output error: {output_error:.2e} (atol={atol})") + print(f"Max residual error: {residual_error:.2e} (atol={residual_atol})") + print(f"Max output error: {output_error:.2e} (atol={output_atol})") - ok = residual_error < atol and output_error < atol + ok = residual_error < residual_atol and output_error < output_atol if ok: print("PASSED") else: @@ -622,7 +667,7 @@ def test_rmsnorm_fused_add(): perf_rows = [] failures = 0 - for M, N, dtype in _get_rmsnorm_configs(): + for M, N, dtype in _get_rmsnorm_fused_add_configs(): ok, flydsl_gpu_us = run_fused_add_test(M, N, dtype) if not ok: failures += 1 From 4c1a1dadcb1635774d8c4b128d7d3a52862746ae Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Fri, 8 May 2026 15:05:18 +0800 Subject: [PATCH 04/11] align fused_add rmsnorm semantics and refine rmsnorm variant tests --- kernels/rmsnorm_kernel.py | 18 +++-- tests/kernels/test_rmsnorm.py | 126 +++++++++++++--------------------- 2 files changed, 57 insertions(+), 87 deletions(-) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index c07a3034..bb019a2a 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -441,14 +441,15 @@ def _to_elem_vec(y): idx = tid + tile_i * BLOCK_THREADS x = _load_vec(in_div, idx).to(fx.Float32) residual = _load_vec(residual_in_div, idx).to(fx.Float32) - added = x + residual - add_local.append(added) + added_e = _to_elem_vec(x + residual) + add_local.append(added_e) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) added2 = added * added red2 = added2.reduce(ReductionOp.ADD, fastmath=fm_fast) thread_sumsq = thread_sumsq + red2 - _store_vec(_to_elem_vec(added), residual_out_div, idx) + _store_vec(added_e, residual_out_div, idx) _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = sum_sq / n_float @@ -458,7 +459,8 @@ def _to_elem_vec(y): for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS g = _load_vec(gamma_div, idx).to(fx.Float32) - y = (add_local[tile_i] * rrms) * g + added = add_local[tile_i] if dtype_str == "f32" else add_local[tile_i].to(fx.Float32) + y = (added * rrms) * g _store_vec(_to_elem_vec(y), out_div, idx) else: @@ -515,7 +517,8 @@ def _to_elem_scalar(y): residual_e = _load_scalar(residual_in_div, idx_safe) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) - added = x + residual + added_e = _to_elem_scalar(x + residual) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) added2 = added * added thread_sumsq = thread_sumsq + is_valid.select(added2, c_zero_f) @@ -533,9 +536,10 @@ def _to_elem_scalar(y): x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) - added = x + residual + added_e = _to_elem_scalar(x + residual) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) y = (added * rrms) * g - _store_scalar(residual_out_div, idx, _to_elem_scalar(added)) + _store_scalar(residual_out_div, idx, added_e) _store_scalar(out_div, idx, _to_elem_scalar(y)) @flyc.jit diff --git a/tests/kernels/test_rmsnorm.py b/tests/kernels/test_rmsnorm.py index 10490833..0256cb2b 100644 --- a/tests/kernels/test_rmsnorm.py +++ b/tests/kernels/test_rmsnorm.py @@ -36,6 +36,7 @@ DTYPE_FP32 = torch.float32 DTYPE_FP16 = torch.float16 DTYPE_BF16 = torch.bfloat16 +DTYPE_INT8 = torch.int8 EPS: float = 1e-5 from kernels.rmsnorm_kernel import ( @@ -238,48 +239,6 @@ def _get_rmsnorm_configs(): ] -def _get_rmsnorm_quant_configs(): - shapes_env = os.environ.get("ROCDSL_RMSNORM_QUANT_SHAPES", "").strip() - if not shapes_env: - shapes_env = os.environ.get("ROCDSL_RMSNORM_SHAPES", "").strip() - if shapes_env: - configs = [] - for part in shapes_env.split(";"): - p = part.strip() - if not p: - continue - m_s, n_s, dt = [x.strip() for x in p.split(",")] - configs.append((int(m_s), int(n_s), dt)) - return configs - - return [ - (128, 4096, "f16"), # Aligned - (173, 409, "f16"), # Unaligned (tail handling) - (256, 4096, "bf16"), # BF16 - ] - - -def _get_rmsnorm_fused_add_configs(): - shapes_env = os.environ.get("ROCDSL_RMSNORM_FUSED_ADD_SHAPES", "").strip() - if not shapes_env: - shapes_env = os.environ.get("ROCDSL_RMSNORM_SHAPES", "").strip() - if shapes_env: - configs = [] - for part in shapes_env.split(";"): - p = part.strip() - if not p: - continue - m_s, n_s, dt = [x.strip() for x in p.split(",")] - configs.append((int(m_s), int(n_s), dt)) - return configs - - return [ - (128, 4096, "f16"), # Aligned - (173, 409, "f16"), # Unaligned (tail handling) - (1024, 8192, "bf16"), # BF16 - ] - - def _reference_rmsnorm_quant(input_dev, gamma_dev, *, xscale_dev=None): x = input_dev.to(DTYPE_FP32) gamma = gamma_dev.to(DTYPE_FP32) @@ -345,7 +304,6 @@ def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): f"{type(e).__name__}: {e}" ) return False, None - torch.manual_seed(42) input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) @@ -362,30 +320,33 @@ def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): else: raise ValueError(f"unsupported dtype: {dtype}") - output_dev = torch.empty((M, N), device="cuda", dtype=torch.int8) - yscale_dev = torch.empty((M,), device="cuda", dtype=torch.float32) - + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale_dev = torch.empty((M,), device="cuda", dtype=DTYPE_FP32) xscale_dev = None if is_smooth: xscale_dev = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + 0.5 + dequant_tol = 0.25 if is_smooth else 0.2 + scale_tol = 1e-2 if is_smooth else 5e-3 + + # PyTorch Reference: + # RMS(x) = sqrt(mean(x^2) + eps) ; RMSNorm(x) = x / RMS(x) * gamma + # Quant path additionally computes per-row yscale and int8 output from the fp32 reference. + expected, q_ref, yscale_ref = _reference_rmsnorm_quant( + input_dev, + gamma_dev, + xscale_dev=xscale_dev, + ) print("Launching kernel...") stream = torch.cuda.current_stream() def kernel_launch(): if is_smooth: - launch_fn( - input_dev, - gamma_dev, - xscale_dev, - output_dev, - yscale_dev, - M, - stream=stream, - ) + launch_fn(input_dev, gamma_dev, xscale_dev, output_dev, yscale_dev, M, stream=stream) else: launch_fn(input_dev, gamma_dev, output_dev, yscale_dev, M, stream=stream) + # run_perftest returns (data, avg_us) _, avg_us = run_perftest( lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, @@ -397,47 +358,51 @@ def kernel_launch(): flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) avg_ms = avg_us / 1000.0 + # Bandwidth estimate: read input + read gamma + write output elem_bytes = 4 if dtype == "f32" else 2 total_bytes = M * N * elem_bytes + N * elem_bytes + M * N + M * 4 if is_smooth: total_bytes += N * elem_bytes bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 - print( - f"Kernel avg time: {avg_ms:.4f} ms via run_perftest " - f"(warmup={WARMUP_ITERS}, iters={BENCH_ITERS})" - ) + print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") if flydsl_gpu_us is not None: print(f"[Perf] FlyDSL rmsnorm {mode} gpu: {flydsl_gpu_us:.1f} us") - expected, q_ref, yscale_ref = _reference_rmsnorm_quant( - input_dev, - gamma_dev, - xscale_dev=xscale_dev, - ) q_out = output_dev.to(torch.int16) q_expected = q_ref.to(torch.int16) yscale_out = yscale_dev.cpu() yscale_expected = yscale_ref.cpu() + output_ref = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) - q_diff = (q_out - q_expected).abs().max().item() + error = (output_ref - expected).abs().max().item() scale_diff = (yscale_out - yscale_expected).abs().max().item() - recon = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) - recon_err = (recon - expected).abs().max().item() + quant_diff = (q_out - q_expected).abs().max().item() - scale_tol = 1e-2 if is_smooth else 5e-3 - recon_tol = 0.25 if is_smooth else 0.2 - - print(f"Max quant diff: {q_diff}") + print(f"Max dequant error: {error:.2e} (tol={dequant_tol})") print(f"Max scale diff: {scale_diff:.2e} (tol={scale_tol})") - print(f"Max recon error: {recon_err:.2e} (tol={recon_tol})") + print(f"Max quant diff: {quant_diff}") - ok = q_diff <= 1 and scale_diff < scale_tol and recon_err < recon_tol + ok = error < dequant_tol and scale_diff < scale_tol and quant_diff <= 1 if ok: print("PASSED") + ok = True else: print("FAILED") + print("First row Expected:") + print(expected[0, :5]) + print("First row Actual:") + print(output_ref[0, :5]) + print("First row Quant Expected:") + print(q_expected[0, :8]) + print("First row Quant Actual:") + print(q_out[0, :8]) + print("First few YScale Expected:") + print(yscale_expected[:5]) + print("First few YScale Actual:") + print(yscale_out[:5]) + ok = False return ok, flydsl_gpu_us @@ -450,7 +415,7 @@ def test_rmsnorm_dynamicquant(): perf_rows = [] failures = 0 - for M, N, dtype in _get_rmsnorm_quant_configs(): + for M, N, dtype in _get_rmsnorm_configs(): ok, flydsl_gpu_us = run_quant_test(M, N, dtype, is_smooth=False) if not ok: failures += 1 @@ -491,7 +456,7 @@ def test_rmsnorm_smoothquant(): perf_rows = [] failures = 0 - for M, N, dtype in _get_rmsnorm_quant_configs(): + for M, N, dtype in _get_rmsnorm_configs(): ok, flydsl_gpu_us = run_quant_test(M, N, dtype, is_smooth=True) if not ok: failures += 1 @@ -524,10 +489,11 @@ def test_rmsnorm_smoothquant(): def _reference_fused_add_rmsnorm(input_dev, residual_in_dev, gamma_dev): - added = input_dev.to(DTYPE_FP32) + residual_in_dev.to(DTYPE_FP32) + added = input_dev + residual_in_dev + added_fp32 = added.to(DTYPE_FP32) gamma = gamma_dev.to(DTYPE_FP32) - expected = (added / torch.sqrt((added * added).mean(dim=1, keepdim=True) + EPS)) * gamma - return added, expected + expected = (added_fp32 / torch.sqrt((added_fp32 * added_fp32).mean(dim=1, keepdim=True) + EPS)) * gamma + return added_fp32, expected def _bench_aiter_fused_add_rmsnorm(M: int, N: int, dtype: str): @@ -595,7 +561,7 @@ def run_fused_add_test(M: int, N: int, dtype: str): output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) output_atol = 2e-2 - residual_atol = 4e-2 + residual_atol = 2e-2 else: raise ValueError(f"unsupported dtype: {dtype}") @@ -667,7 +633,7 @@ def test_rmsnorm_fused_add(): perf_rows = [] failures = 0 - for M, N, dtype in _get_rmsnorm_fused_add_configs(): + for M, N, dtype in _get_rmsnorm_configs(): ok, flydsl_gpu_us = run_fused_add_test(M, N, dtype) if not ok: failures += 1 From e5ce156f2b7f5a667689fbde5ff080670b688617 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Fri, 8 May 2026 15:39:15 +0800 Subject: [PATCH 05/11] align smoothquant xscale to fp32 path in kernel and tests --- kernels/rmsnorm_kernel.py | 41 ++++++++++++++++++++++++++++++----- tests/kernels/test_rmsnorm.py | 6 ++--- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index bb019a2a..f2a5aa0b 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -18,6 +18,7 @@ from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith, const_expr, gpu, range_constexpr from flydsl.expr import math as fmath +from flydsl.expr.typing import Vector as Vec from flydsl.expr.vector import ReductionOp, full from flydsl.runtime.device import get_rocm_arch as get_hip_arch from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr @@ -729,6 +730,7 @@ def block_reduce_max(val): num_tiles = N // tile_cols quant_half_width = VEC_WIDTH // 2 abs_mask = full(VEC_WIDTH, fx.Uint32(0x7FFFFFFF), fx.Uint32) + xscale_vec_width = 4 Input_buf = fx.rocdl.make_buffer_tensor(Input) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) @@ -743,13 +745,21 @@ def block_reduce_max(val): out_div_q = fx.logical_divide(row_out, fx.make_layout(quant_half_width, 1)) gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) if const_expr(is_smooth): - xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(VEC_WIDTH, 1)) + xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(xscale_vec_width, 1)) copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) vec_reg_ty = fx.MemRefType.get( elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register ) vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) + if const_expr(is_smooth): + copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), 32) + xscale_reg_ty = fx.MemRefType.get( + fx.Float32.ir_type, + fx.LayoutType.get(xscale_vec_width, 1), + fx.AddressSpace.Register, + ) + xscale_reg_lay = fx.make_layout(xscale_vec_width, 1) copy_atom_q = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 8) vec_reg_ty_q = fx.MemRefType.get( quant_elem_type, fx.LayoutType.get(quant_half_width, 1), fx.AddressSpace.Register @@ -766,6 +776,12 @@ def _store_q_vec(val, div_tensor, idx): fx.memref_store_vec(val, r) fx.copy_atom_call(copy_atom_q, r, fx.slice(div_tensor, (None, idx))) + if const_expr(is_smooth): + def _load_xscale_vec(div_tensor, idx): + r = fx.memref_alloca(xscale_reg_ty, xscale_reg_lay) + fx.copy_atom_call(copy_atom_xs, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + thread_sumsq = c_zero_f thread_dummy = c_zero_f in_local = [] @@ -794,7 +810,9 @@ def _store_q_vec(val, div_tensor, idx): x = in_local[tile_i].to(fx.Float32) y = (x * rrms) * g if const_expr(is_smooth): - s = _load_vec(xscale_div, idx).to(fx.Float32) + s_lo = _load_xscale_vec(xscale_div, idx * 2) + s_hi = _load_xscale_vec(xscale_div, idx * 2 + 1) + s = Vec(s_lo).shuffle(Vec(s_hi), [0, 1, 2, 3, 4, 5, 6, 7]).ir_value() y = y * s y_local.append(y) @@ -837,6 +855,12 @@ def _store_q_vec(val, div_tensor, idx): copy_atom_qs = fx.make_copy_atom(fx.rocdl.BufferCopy(8), 8) scalar_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) scalar_reg_lay = fx.make_layout(1, 1) + if const_expr(is_smooth): + copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + xscale_scalar_reg_ty = fx.MemRefType.get( + fx.Float32.ir_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) + xscale_scalar_reg_lay = fx.make_layout(1, 1) scalar_reg_ty_q = fx.MemRefType.get( quant_elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register ) @@ -863,6 +887,13 @@ def _store_quant_scalar(divided_tensor, index, val): view = fx.slice(divided_tensor, (None, index)) fx.copy_atom_call(copy_atom_qs, r, view) + if const_expr(is_smooth): + def _load_xscale_scalar(divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.memref_alloca(xscale_scalar_reg_ty, xscale_scalar_reg_lay) + fx.copy_atom_call(copy_atom_xs, view, r) + return fx.memref_load_vec(r)[0] + def _abs_scalar(val): is_neg = val < c_zero_f neg_val = c_zero_f - val @@ -895,8 +926,7 @@ def _abs_scalar(val): g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) y = (x * rrms) * g if const_expr(is_smooth): - s_e = _load_scalar(xscale_div, idx_safe) - s = s_e if dtype_str == "f32" else s_e.to(fx.Float32) + s = _load_xscale_scalar(xscale_div, idx_safe) y = y * s y_abs = _abs_scalar(y) thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) @@ -919,8 +949,7 @@ def _abs_scalar(val): g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) y = (x * rrms) * g if const_expr(is_smooth): - s_e = _load_scalar(xscale_div, idx) - s = s_e if dtype_str == "f32" else s_e.to(fx.Float32) + s = _load_xscale_scalar(xscale_div, idx) y = y * s q = y * inv_scale q_i8 = q.to(quant_dtype) diff --git a/tests/kernels/test_rmsnorm.py b/tests/kernels/test_rmsnorm.py index 0256cb2b..1c16d5c7 100644 --- a/tests/kernels/test_rmsnorm.py +++ b/tests/kernels/test_rmsnorm.py @@ -275,7 +275,7 @@ def _bench_aiter_rmsnorm_quant(M: int, N: int, dtype: str, *, is_smooth: bool): yscale = torch.empty((M, 1), dtype=torch.float32, device="cuda") if is_smooth: - xscale = (torch.rand((N,), device="cuda", dtype=torch_dtype) + 0.5).contiguous() + xscale = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() def run_aiter(): aiter_rmsnorm_quant(y, x, xscale, yscale, w, EPS) @@ -324,7 +324,7 @@ def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): yscale_dev = torch.empty((M,), device="cuda", dtype=DTYPE_FP32) xscale_dev = None if is_smooth: - xscale_dev = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + 0.5 + xscale_dev = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() dequant_tol = 0.25 if is_smooth else 0.2 scale_tol = 1e-2 if is_smooth else 5e-3 @@ -362,7 +362,7 @@ def kernel_launch(): elem_bytes = 4 if dtype == "f32" else 2 total_bytes = M * N * elem_bytes + N * elem_bytes + M * N + M * 4 if is_smooth: - total_bytes += N * elem_bytes + total_bytes += N * 4 bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") From 1d09d318aef7d0636ec8c7b1132eda227570074c Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Fri, 8 May 2026 16:10:16 +0800 Subject: [PATCH 06/11] add fused_add rmsnorm quant kernels and tests --- kernels/rmsnorm_kernel.py | 531 ++++++++++++++++++++++++++++++++++ tests/kernels/test_rmsnorm.py | 300 +++++++++++++++++++ 2 files changed, 831 insertions(+) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index f2a5aa0b..4dbc9e25 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -1032,3 +1032,534 @@ def build_rmsnorm_smoothquant_module( is_smooth=True, quant_dtype_str=quant_dtype_str, ) + + +def _build_fused_add_rmsnorm_quant_module( + M: int, + N: int, + dtype_str: str, + *, + is_smooth: bool, + quant_dtype_str: str = "i8", +): + arch = get_hip_arch() + USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") + + tile_cols = BLOCK_THREADS * VEC_WIDTH + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) + elem_bits = 32 if dtype_str == "f32" else 16 + quant_dtype_max = _quant_dtype_max(quant_dtype_str) + + allocator = SmemAllocator(None, arch=arch) + f32_bytes = 4 + red_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = red_offset + RED_SLOTS * f32_bytes + red2_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = red2_offset + RED_SLOTS * f32_bytes + + @flyc.kernel + def fused_add_rmsnorm_quant_kernel( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + XScale: fx.Tensor, + YScale: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_dtype = dtype_to_elem_type(dtype_str) + elem_type = elem_dtype.ir_type + quant_dtype = _quant_dtype_to_elem_type(quant_dtype_str) + quant_elem_type = quant_dtype.ir_type + + fm_fast = arith.FastMathFlags.fast + eps_c = EPS + n_float = float(N) + c_zero_f = fx.Float32(0.0) + c_one_f = fx.Float32(1.0) + c_neg_inf = fx.Float32(float("-inf")) + c_dtype_max = fx.Float32(quant_dtype_max) + + base_ptr = allocator.get_base() + s_red = SmemPtr(base_ptr, red_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_red2 = SmemPtr(base_ptr, red2_offset, fx.Float32.ir_type, shape=(RED_SLOTS,)) + s_red.get() + s_red2.get() + + YScale_buf = fx.rocdl.make_buffer_tensor(YScale) + yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1)) + scale_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + scale_reg_ty = fx.MemRefType.get( + fx.Float32.ir_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) + scale_reg_lay = fx.make_layout(1, 1) + + def _store_yscale(index, val): + r = fx.memref_alloca(scale_reg_ty, scale_reg_lay) + ts = full(1, fx.Float32(val), fx.Float32) + fx.memref_store_vec(ts, r) + fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) + + def wave_reduce_add(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) + w = w.addf(peer, fastmath=fm_fast) + return w + + def wave_reduce_max(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = WARP_SIZE // (2 << _sh_exp) + peer = w.shuffle_xor(off, WARP_SIZE) + w = w.maximumf(peer) + return w + + def block_reduce_add(val): + dummy = fx.Float32(0.0) + r0, _ = block_reduce_add2(val, dummy) + return r0 + + def block_reduce_add2(val0, val1): + if const_expr(RED_SLOTS == 1): + return wave_reduce_add(val0), wave_reduce_add(val1) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + w0 = wave_reduce_add(val0) + w1 = wave_reduce_add(val1) + + if lane == 0: + SmemPtr.store(s_red, w0, [wave]) + SmemPtr.store(s_red2, w1, [wave]) + gpu.barrier() + + if wave == 0: + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, 0) + v0 = SmemPtr.load(s_red, [lane_safe]) + v1 = SmemPtr.load(s_red2, [lane_safe]) + ww0 = in_range.select(v0, c_zero_f) + ww1 = in_range.select(v1, c_zero_f) + ww0 = wave_reduce_add(ww0) + ww1 = wave_reduce_add(ww1) + + if lane == 0: + SmemPtr.store(s_red, ww0, [0]) + SmemPtr.store(s_red2, ww1, [0]) + gpu.barrier() + + return SmemPtr.load(s_red, [0]), SmemPtr.load(s_red2, [0]) + + def block_reduce_max(val): + if const_expr(RED_SLOTS == 1): + return wave_reduce_max(val) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + w = wave_reduce_max(val) + if lane == 0: + SmemPtr.store(s_red, w, [wave]) + gpu.barrier() + + if wave == 0: + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, 0) + v = SmemPtr.load(s_red, [lane_safe]) + ww = in_range.select(v, c_neg_inf) + ww = wave_reduce_max(ww) + if lane == 0: + SmemPtr.store(s_red, ww, [0]) + gpu.barrier() + + return SmemPtr.load(s_red, [0]) + + if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16): + num_tiles = N // tile_cols + quant_half_width = VEC_WIDTH // 2 + abs_mask = full(VEC_WIDTH, fx.Uint32(0x7FFFFFFF), fx.Uint32) + xscale_vec_width = 4 + + Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + if const_expr(is_smooth): + XScale_buf = fx.rocdl.make_buffer_tensor(XScale) + + row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) + residual_in_div = fx.logical_divide( + row_residual_in, fx.make_layout(VEC_WIDTH, 1) + ) + out_div_q = fx.logical_divide(row_out, fx.make_layout(quant_half_width, 1)) + residual_out_div = fx.logical_divide( + row_residual_out, fx.make_layout(VEC_WIDTH, 1) + ) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) + if const_expr(is_smooth): + xscale_div = fx.logical_divide( + XScale_buf, fx.make_layout(xscale_vec_width, 1) + ) + + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + vec_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register + ) + vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) + if const_expr(is_smooth): + copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), 32) + xscale_reg_ty = fx.MemRefType.get( + fx.Float32.ir_type, + fx.LayoutType.get(xscale_vec_width, 1), + fx.AddressSpace.Register, + ) + xscale_reg_lay = fx.make_layout(xscale_vec_width, 1) + copy_atom_q = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 8) + vec_reg_ty_q = fx.MemRefType.get( + quant_elem_type, fx.LayoutType.get(quant_half_width, 1), fx.AddressSpace.Register + ) + vec_reg_lay_q = fx.make_layout(quant_half_width, 1) + + def _load_vec(div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + def _store_vec(val, div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.memref_store_vec(val, r) + fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) + + def _store_q_vec(val, div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty_q, vec_reg_lay_q) + fx.memref_store_vec(val, r) + fx.copy_atom_call(copy_atom_q, r, fx.slice(div_tensor, (None, idx))) + + if const_expr(is_smooth): + def _load_xscale_vec(div_tensor, idx): + r = fx.memref_alloca(xscale_reg_ty, xscale_reg_lay) + fx.copy_atom_call(copy_atom_xs, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + def _to_elem_vec(y): + if const_expr(dtype_str == "bf16"): + if const_expr(USE_HW_CVT_PK_BF16_F32): + return y.to(elem_dtype) + u = y.bitcast(fx.Uint32) + upper = u >> 16 + lsb = upper & 1 + bias = lsb + 0x7FFF + u_round = y.bitcast(fx.Uint32) + bias + bf16_bits = u_round >> 16 + even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) + odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) + odd_sh = odd << 16 + packed = even | odd_sh + return packed.bitcast(elem_dtype) + if const_expr(dtype_str == "f32"): + return y + return y.to(elem_dtype) + + thread_sumsq = c_zero_f + thread_dummy = c_zero_f + add_local = [] + + for tile_i in range_constexpr(num_tiles): + idx = tid + tile_i * BLOCK_THREADS + x = _load_vec(in_div, idx).to(fx.Float32) + residual = _load_vec(residual_in_div, idx).to(fx.Float32) + added_e = _to_elem_vec(x + residual) + add_local.append(added_e) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + added2 = added * added + red2 = added2.reduce(ReductionOp.ADD, fastmath=fm_fast) + thread_sumsq = thread_sumsq + red2 + _store_vec(added_e, residual_out_div, idx) + + _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = ms_eps.rsqrt(fastmath=fm_fast) + + thread_row_max = c_zero_f + y_local = [] + + for tile_i in range_constexpr(num_tiles): + idx = tid + tile_i * BLOCK_THREADS + g = _load_vec(gamma_div, idx).to(fx.Float32) + added = add_local[tile_i] if dtype_str == "f32" else add_local[tile_i].to(fx.Float32) + y = (added * rrms) * g + if const_expr(is_smooth): + s_lo = _load_xscale_vec(xscale_div, idx * 2) + s_hi = _load_xscale_vec(xscale_div, idx * 2 + 1) + s = Vec(s_lo).shuffle(Vec(s_hi), [0, 1, 2, 3, 4, 5, 6, 7]).ir_value() + y = y * s + + y_local.append(y) + y_abs = (y.bitcast(fx.Uint32) & abs_mask).bitcast(fx.Float32) + tile_max = y_abs.reduce(ReductionOp.MAX) + thread_row_max = thread_row_max.maximumf(tile_max) + + row_max = block_reduce_max(thread_row_max) + scale = row_max / c_dtype_max + final_scale = (scale == c_zero_f).select(c_one_f, scale) + + if tid == 0: + _store_yscale(bid, final_scale) + + inv_scale = c_one_f / final_scale + + for tile_i in range_constexpr(num_tiles): + q = y_local[tile_i] * inv_scale + q_i8 = q.to(quant_dtype) + q_lo = q_i8.shuffle(q_i8, [0, 1, 2, 3]) + q_hi = q_i8.shuffle(q_i8, [4, 5, 6, 7]) + out_idx = tid * 2 + tile_i * BLOCK_THREADS * 2 + _store_q_vec(q_lo, out_div_q, out_idx) + _store_q_vec(q_hi, out_div_q, out_idx + 1) + + else: + Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + if const_expr(is_smooth): + XScale_buf = fx.rocdl.make_buffer_tensor(XScale) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + copy_atom_qs = fx.make_copy_atom(fx.rocdl.BufferCopy(8), 8) + scalar_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) + scalar_reg_lay = fx.make_layout(1, 1) + if const_expr(is_smooth): + copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + xscale_scalar_reg_ty = fx.MemRefType.get( + fx.Float32.ir_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) + xscale_scalar_reg_lay = fx.make_layout(1, 1) + scalar_reg_ty_q = fx.MemRefType.get( + quant_elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) + scalar_reg_lay_q = fx.make_layout(1, 1) + + row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) + if const_expr(is_smooth): + xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(1, 1)) + + def _load_scalar(divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + fx.copy_atom_call(copy_atom_s, view, r) + return fx.memref_load_vec(r)[0] + + def _store_scalar(divided_tensor, index, val): + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_s, r, view) + + def _store_quant_scalar(divided_tensor, index, val): + r = fx.memref_alloca(scalar_reg_ty_q, scalar_reg_lay_q) + ts = full(1, quant_dtype(val), quant_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_qs, r, view) + + if const_expr(is_smooth): + def _load_xscale_scalar(divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.memref_alloca(xscale_scalar_reg_ty, xscale_scalar_reg_lay) + fx.copy_atom_call(copy_atom_xs, view, r) + return fx.memref_load_vec(r)[0] + + def _to_elem_scalar(y): + if const_expr(dtype_str == "f32"): + return y + return y.to(elem_dtype) + + def _abs_scalar(val): + is_neg = val < c_zero_f + neg_val = c_zero_f - val + return is_neg.select(neg_val, val) + + thread_sumsq = c_zero_f + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(row_div, idx_safe) + residual_e = _load_scalar(residual_in_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) + added_e = _to_elem_scalar(x + residual) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + added2 = added * added + thread_sumsq = thread_sumsq + is_valid.select(added2, c_zero_f) + + sum_sq = block_reduce_add(thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) + + thread_row_max = c_zero_f + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(row_div, idx_safe) + residual_e = _load_scalar(residual_in_div, idx_safe) + g_e = _load_scalar(gamma_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + added_e = _to_elem_scalar(x + residual) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + y = (added * rrms) * g + if const_expr(is_smooth): + s = _load_xscale_scalar(xscale_div, idx_safe) + y = y * s + y_abs = _abs_scalar(y) + thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) + + row_max = block_reduce_max(thread_row_max) + scale = row_max / c_dtype_max + final_scale = (scale == c_zero_f).select(c_one_f, scale) + + if tid == 0: + _store_yscale(bid, final_scale) + + inv_scale = c_one_f / final_scale + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + if idx < N: + x_e = _load_scalar(row_div, idx) + residual_e = _load_scalar(residual_in_div, idx) + g_e = _load_scalar(gamma_div, idx) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + added_e = _to_elem_scalar(x + residual) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + y = (added * rrms) * g + if const_expr(is_smooth): + s = _load_xscale_scalar(xscale_div, idx) + y = y * s + q = y * inv_scale + q_i8 = q.to(quant_dtype) + _store_scalar(residual_out_div, idx, added_e) + _store_quant_scalar(out_div, idx, q_i8) + + if is_smooth: + @flyc.jit + def launch_fused_add_rmsnorm_smoothquant( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + XScale: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = fused_add_rmsnorm_quant_kernel( + Input, ResidualIn, Gamma, XScale, YScale, Output, ResidualOut + ) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_fused_add_rmsnorm_smoothquant + + @flyc.jit + def launch_fused_add_rmsnorm_dynamicquant( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = fused_add_rmsnorm_quant_kernel( + Input, ResidualIn, Gamma, Gamma, YScale, Output, ResidualOut + ) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_fused_add_rmsnorm_dynamicquant + + +def build_fused_add_rmsnorm_dynamicquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_fused_add_rmsnorm_quant_module( + M, + N, + dtype_str, + is_smooth=False, + quant_dtype_str=quant_dtype_str, + ) + + +def build_fused_add_rmsnorm_smoothquant_module( + M: int, + N: int, + dtype_str: str, + quant_dtype_str: str = "i8", +): + return _build_fused_add_rmsnorm_quant_module( + M, + N, + dtype_str, + is_smooth=True, + quant_dtype_str=quant_dtype_str, + ) diff --git a/tests/kernels/test_rmsnorm.py b/tests/kernels/test_rmsnorm.py index 1c16d5c7..25cee901 100644 --- a/tests/kernels/test_rmsnorm.py +++ b/tests/kernels/test_rmsnorm.py @@ -44,6 +44,8 @@ build_fused_add_rmsnorm_module, build_rmsnorm_dynamicquant_module, build_rmsnorm_smoothquant_module, + build_fused_add_rmsnorm_dynamicquant_module, + build_fused_add_rmsnorm_smoothquant_module, KERNEL_NAME as RMSNORM_KERNEL_NAME, BLOCK_THREADS, ) @@ -665,5 +667,303 @@ def test_rmsnorm_fused_add(): raise SystemExit(1) +def _reference_fused_add_rmsnorm_quant( + input_dev, + residual_in_dev, + gamma_dev, + *, + xscale_dev=None, +): + added = input_dev + residual_in_dev + residual_expected = added.to(DTYPE_FP32) + expected, q, yscale = _reference_rmsnorm_quant( + added, + gamma_dev, + xscale_dev=xscale_dev, + ) + return residual_expected, expected, q, yscale + + +def _bench_aiter_fused_add_rmsnorm_quant( + M: int, + N: int, + dtype: str, + *, + is_smooth: bool, +): + mode = "smoothquant" if is_smooth else "dynamicquant" + torch_dtype = _torch_dtype(dtype) + + try: + if is_smooth: + from aiter.ops.triton.normalization.rmsnorm import ( + rmsnorm2d_fwd_with_add_smoothquant as aiter_fused_add_rmsnorm_quant, + ) + else: + from aiter.ops.triton.normalization.rmsnorm import ( + rmsnorm2d_fwd_with_add_dynamicquant as aiter_fused_add_rmsnorm_quant, + ) + except Exception as e: + print(f"[Perf] AIter fused_add rmsnorm {mode} skipped: {type(e).__name__}: {e!r}") + return None + + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual_in = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + y = torch.empty((M, N), dtype=torch.int8, device="cuda") + residual_out = torch.empty((M, N), device="cuda", dtype=torch_dtype) + yscale = torch.empty((M, 1), dtype=torch.float32, device="cuda") + + if is_smooth: + xscale = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() + + def run_aiter(): + aiter_fused_add_rmsnorm_quant( + y, x, residual_in, residual_out, xscale, yscale, w, EPS + ) + else: + def run_aiter(): + aiter_fused_add_rmsnorm_quant( + y, x, residual_in, residual_out, yscale, w, EPS + ) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter fused_add rmsnorm {mode} gpu: {aiter_us:.1f} us") + return aiter_us + + +def run_fused_add_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): + mode = "smoothquant" if is_smooth else "dynamicquant" + print(f"\nTesting FusedAdd RMSNorm {mode} (M={M}, N={N}, dtype={dtype})") + + try: + if is_smooth: + launch_fn = build_fused_add_rmsnorm_smoothquant_module(M, N, dtype) + else: + launch_fn = build_fused_add_rmsnorm_dynamicquant_module(M, N, dtype) + except Exception as e: + print( + f"[FAIL] Compile failed for fused_add rmsnorm {mode} " + f"(M={M}, N={N}, dtype={dtype}): {type(e).__name__}: {e}" + ) + return False, None + + torch.manual_seed(42) + input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + residual_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + + if dtype == "f32": + input_dev = input_t.contiguous() + residual_in_dev = residual_t.contiguous() + gamma_dev = gamma_t.contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + residual_atol = 1e-4 + elif dtype == "f16": + input_dev = input_t.to(DTYPE_FP16).contiguous() + residual_in_dev = residual_t.to(DTYPE_FP16).contiguous() + gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + residual_atol = 1e-2 + elif dtype == "bf16": + input_dev = input_t.to(DTYPE_BF16).contiguous() + residual_in_dev = residual_t.to(DTYPE_BF16).contiguous() + gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + residual_atol = 2e-2 + else: + raise ValueError(f"unsupported dtype: {dtype}") + + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_INT8) + yscale_dev = torch.empty((M,), device="cuda", dtype=DTYPE_FP32) + xscale_dev = None + if is_smooth: + xscale_dev = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() + dequant_tol = 0.25 if is_smooth else 0.2 + scale_tol = 1e-2 if is_smooth else 5e-3 + + residual_expected, expected, q_ref, yscale_ref = _reference_fused_add_rmsnorm_quant( + input_dev, + residual_in_dev, + gamma_dev, + xscale_dev=xscale_dev, + ) + + print("Launching kernel...") + stream = torch.cuda.current_stream() + + def kernel_launch(): + if is_smooth: + launch_fn( + input_dev, + residual_in_dev, + gamma_dev, + xscale_dev, + output_dev, + residual_out_dev, + yscale_dev, + M, + stream=stream, + ) + else: + launch_fn( + input_dev, + residual_in_dev, + gamma_dev, + output_dev, + residual_out_dev, + yscale_dev, + M, + stream=stream, + ) + + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), + num_iters=BENCH_ITERS, + num_warmup=WARMUP_ITERS, + ) + torch.cuda.synchronize() + flydsl_gpu_us = None + if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": + flydsl_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + avg_ms = avg_us / 1000.0 + + elem_bytes = 4 if dtype == "f32" else 2 + total_bytes = 3 * M * N * elem_bytes + N * elem_bytes + M * N + M * 4 + if is_smooth: + total_bytes += N * 4 + bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 + + print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") + if flydsl_gpu_us is not None: + print(f"[Perf] FlyDSL fused_add rmsnorm {mode} gpu: {flydsl_gpu_us:.1f} us") + + residual_out_ref = residual_out_dev.to(DTYPE_FP32) + output_ref = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) + q_out = output_dev.to(torch.int16) + q_expected = q_ref.to(torch.int16) + yscale_out = yscale_dev.cpu() + yscale_expected = yscale_ref.cpu() + + residual_error = (residual_out_ref - residual_expected).abs().max().item() + dequant_error = (output_ref - expected).abs().max().item() + scale_diff = (yscale_out - yscale_expected).abs().max().item() + quant_diff = (q_out - q_expected).abs().max().item() + + print(f"Max residual error: {residual_error:.2e} (tol={residual_atol})") + print(f"Max dequant error: {dequant_error:.2e} (tol={dequant_tol})") + print(f"Max scale diff: {scale_diff:.2e} (tol={scale_tol})") + print(f"Max quant diff: {quant_diff}") + + ok = ( + residual_error < residual_atol + and dequant_error < dequant_tol + and scale_diff < scale_tol + and quant_diff <= 1 + ) + if ok: + print("PASSED") + else: + print("FAILED") + print("First row Residual Expected:") + print(residual_expected[0, :5]) + print("First row Residual Actual:") + print(residual_out_ref[0, :5]) + print("First row Expected:") + print(expected[0, :5]) + print("First row Actual:") + print(output_ref[0, :5]) + print("First row Quant Expected:") + print(q_expected[0, :8]) + print("First row Quant Actual:") + print(q_out[0, :8]) + print("First few YScale Expected:") + print(yscale_expected[:5]) + print("First few YScale Actual:") + print(yscale_out[:5]) + return ok, flydsl_gpu_us + + +def test_rmsnorm_fused_add_dynamicquant(): + print("="*80) + print("Running FusedAdd RMSNorm DynamicQuant Tests") + print("="*80) + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + failures = 0 + + for M, N, dtype in _get_rmsnorm_configs(): + ok, flydsl_gpu_us = run_fused_add_quant_test(M, N, dtype, is_smooth=False) + if not ok: + failures += 1 + + if do_compare: + aiter_us = None + if maybe_enable_aiter(): + aiter_us = _bench_aiter_fused_add_rmsnorm_quant(M, N, dtype, is_smooth=False) + perf_rows.append( + PerfRow( + op="rmsnorm_add_dq", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "="*80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("="*80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + if failures != 0: + raise SystemExit(1) + + +def test_rmsnorm_fused_add_smoothquant(): + print("="*80) + print("Running FusedAdd RMSNorm SmoothQuant Tests") + print("="*80) + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + failures = 0 + + for M, N, dtype in _get_rmsnorm_configs(): + ok, flydsl_gpu_us = run_fused_add_quant_test(M, N, dtype, is_smooth=True) + if not ok: + failures += 1 + + if do_compare: + aiter_us = None + if maybe_enable_aiter(): + aiter_us = _bench_aiter_fused_add_rmsnorm_quant(M, N, dtype, is_smooth=True) + perf_rows.append( + PerfRow( + op="rmsnorm_add_sq", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "="*80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("="*80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + if failures != 0: + raise SystemExit(1) + + if __name__ == "__main__": test_all() From 5e676acaa8de58702c0f369bedfdfe5fdaf46baa Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Tue, 12 May 2026 16:30:28 +0800 Subject: [PATCH 07/11] add large_m_small_n rmsnorm path and refine generic fused_add flow --- kernels/rmsnorm_kernel.py | 149 +++++++++++++++++++++++++++++++++----- 1 file changed, 131 insertions(+), 18 deletions(-) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index 4dbc9e25..e07dd77f 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -5,7 +5,8 @@ RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gamma -Two paths: +Three paths: + - Large-M/small-N specialized path. - Fast path (N % tile_cols == 0): buffer_load/store vectorised access. - Generic path (arbitrary N): scalar copy_atom_call. """ @@ -33,7 +34,129 @@ VEC_WIDTH = 8 +def _should_use_large_m_small_n(M: int, N: int) -> bool: + return M > 8192 and N <= 2048 + + +def _build_rmsnorm_large_m_small_n_module(M: int, N: int, dtype_str: str): + BLOCK_N = 1 << (N - 1).bit_length() + BLOCK_M = max(min(16384 // BLOCK_N, 32), 8) + THREADS_PER_ROW = min(WARP_SIZE, 1024 // BLOCK_M) + BLOCK_THREADS_SPECIAL = BLOCK_M * THREADS_PER_ROW + elem_bits = 32 if dtype_str == "f32" else 16 + + @flyc.kernel + def rmsnorm_large_m_small_n_kernel( + Input: fx.Tensor, + Gamma: fx.Tensor, + _Unused: fx.Tensor, + Output: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + lane = tid % THREADS_PER_ROW + row_local = tid // THREADS_PER_ROW + row = bid * fx.Int32(BLOCK_M) + row_local + + if row < M: + elem_dtype = dtype_to_elem_type(dtype_str) + elem_type = elem_dtype.ir_type + fm_fast = arith.FastMathFlags.fast + eps_c = EPS + n_float = float(N) + + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + + row_in = fx.slice(Input_buf, (row, None)) + row_out = fx.slice(Output_buf, (row, None)) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + scalar_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) + scalar_reg_lay = fx.make_layout(1, 1) + + row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + + def _load_scalar(divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + fx.copy_atom_call(copy_atom_s, view, r) + return fx.memref_load_vec(r)[0] + + def _store_scalar(divided_tensor, index, val): + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_s, r, view) + + def group_reduce_add(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(THREADS_PER_ROW))): + off = THREADS_PER_ROW // (2 << _sh_exp) + peer = w.shuffle_xor(off, fx.Int32(THREADS_PER_ROW)) + w = w.addf(peer, fastmath=fm_fast) + return w + + c_zero_f = fx.Float32(0.0) + thread_sumsq = c_zero_f + + for base_idx_int in range_constexpr(0, BLOCK_N, THREADS_PER_ROW): + idx = lane + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(row_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + x2 = x * x + thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) + + sum_sq = group_reduce_add(thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) + + for base_idx_int in range_constexpr(0, BLOCK_N, THREADS_PER_ROW): + idx = lane + base_idx_int + if idx < N: + x_e = _load_scalar(row_div, idx) + g_e = _load_scalar(gamma_div, idx) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + y = (x * rrms) * g + y_e = y if dtype_str == "f32" else y.to(elem_dtype) + _store_scalar(out_div, idx, y_e) + + @flyc.jit + def launch_rmsnorm_large_m_small_n( + Input: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = rmsnorm_large_m_small_n_kernel(Input, Gamma, Gamma, Output) + launcher.launch( + grid=((M + BLOCK_M - 1) // BLOCK_M, 1, 1), + block=(BLOCK_THREADS_SPECIAL, 1, 1), + stream=stream, + ) + + return launch_rmsnorm_large_m_small_n + + def build_rmsnorm_module(M: int, N: int, dtype_str: str): + if _should_use_large_m_small_n(M, N): + return _build_rmsnorm_large_m_small_n_module(M, N, dtype_str) + arch = get_hip_arch() USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") @@ -519,6 +642,8 @@ def _to_elem_scalar(y): x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) added_e = _to_elem_scalar(x + residual) + if idx < N: + _store_scalar(residual_out_div, idx, added_e) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) added2 = added * added thread_sumsq = thread_sumsq + is_valid.select(added2, c_zero_f) @@ -531,16 +656,11 @@ def _to_elem_scalar(y): for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int if idx < N: - x_e = _load_scalar(row_div, idx) - residual_e = _load_scalar(residual_in_div, idx) g_e = _load_scalar(gamma_div, idx) - x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) - residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) + added_e = _load_scalar(residual_out_div, idx) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) - added_e = _to_elem_scalar(x + residual) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) y = (added * rrms) * g - _store_scalar(residual_out_div, idx, added_e) _store_scalar(out_div, idx, _to_elem_scalar(y)) @flyc.jit @@ -1420,6 +1540,8 @@ def _abs_scalar(val): x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) added_e = _to_elem_scalar(x + residual) + if idx < N: + _store_scalar(residual_out_div, idx, added_e) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) added2 = added * added thread_sumsq = thread_sumsq + is_valid.select(added2, c_zero_f) @@ -1434,13 +1556,9 @@ def _abs_scalar(val): idx = tid + base_idx_int is_valid = idx < N idx_safe = is_valid.select(idx, 0) - x_e = _load_scalar(row_div, idx_safe) - residual_e = _load_scalar(residual_in_div, idx_safe) g_e = _load_scalar(gamma_div, idx_safe) - x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) - residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) + added_e = _load_scalar(residual_out_div, idx_safe) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) - added_e = _to_elem_scalar(x + residual) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) y = (added * rrms) * g if const_expr(is_smooth): @@ -1461,13 +1579,9 @@ def _abs_scalar(val): for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int if idx < N: - x_e = _load_scalar(row_div, idx) - residual_e = _load_scalar(residual_in_div, idx) g_e = _load_scalar(gamma_div, idx) - x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) - residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) + added_e = _load_scalar(residual_out_div, idx) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) - added_e = _to_elem_scalar(x + residual) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) y = (added * rrms) * g if const_expr(is_smooth): @@ -1475,7 +1589,6 @@ def _abs_scalar(val): y = y * s q = y * inv_scale q_i8 = q.to(quant_dtype) - _store_scalar(residual_out_div, idx, added_e) _store_quant_scalar(out_div, idx, q_i8) if is_smooth: From 4e409cdb6fd48d9758ef7d655303ffe4252dcf06 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Tue, 12 May 2026 22:21:12 +0800 Subject: [PATCH 08/11] add some annotations for the variants --- kernels/rmsnorm_kernel.py | 124 +++++++++++++++++++++++--------------- 1 file changed, 76 insertions(+), 48 deletions(-) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index e07dd77f..1c6f0ced 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -501,9 +501,12 @@ def block_reduce_add2(val0, val1): return SmemPtr.load(s_red, [0]), SmemPtr.load(s_red2, [0]) + # ================================================================== + # Fast path: N is a multiple of tile_cols + # ================================================================== if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16): num_tiles = N // tile_cols - + # ── Layout API: buffer-backed tensors + tiled access ───── Input_buf = fx.rocdl.make_buffer_tensor(Input) ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) @@ -561,6 +564,7 @@ def _to_elem_vec(y): thread_dummy = c_zero_f add_local = [] + # Pass 1: add + cache + sumsq (also write residual_out) for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS x = _load_vec(in_div, idx).to(fx.Float32) @@ -580,6 +584,7 @@ def _to_elem_vec(y): ms_eps = mean_sq + eps_c rrms = ms_eps.rsqrt(fastmath=fm_fast) + # Pass 2: normalize + gamma + store (reuse cached added values) for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS g = _load_vec(gamma_div, idx).to(fx.Float32) @@ -588,6 +593,9 @@ def _to_elem_vec(y): _store_vec(_to_elem_vec(y), out_div, idx) else: + # ============================================================== + # Generic path: scalar 2-pass for arbitrary N + # ============================================================== Input_buf = fx.rocdl.make_buffer_tensor(Input) ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) @@ -851,7 +859,7 @@ def block_reduce_max(val): quant_half_width = VEC_WIDTH // 2 abs_mask = full(VEC_WIDTH, fx.Uint32(0x7FFFFFFF), fx.Uint32) xscale_vec_width = 4 - + # ── Layout API: buffer-backed tensors + tiled access ───── Input_buf = fx.rocdl.make_buffer_tensor(Input) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) Output_buf = fx.rocdl.make_buffer_tensor(Output) @@ -906,6 +914,7 @@ def _load_xscale_vec(div_tensor, idx): thread_dummy = c_zero_f in_local = [] + # Pass 1: load + cache + sumsq for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS vec = _load_vec(in_div, idx) @@ -923,6 +932,7 @@ def _load_xscale_vec(div_tensor, idx): thread_row_max = c_zero_f y_local = [] + # Pass 2: normalize + gamma (+ optional smooth scale), cache output, and accumulate row max for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS @@ -949,6 +959,7 @@ def _load_xscale_vec(div_tensor, idx): inv_scale = c_one_f / final_scale + # Pass 3: quantize + store using per-row scale for tile_i in range_constexpr(num_tiles): q = y_local[tile_i] * inv_scale q_i8 = q.to(quant_dtype) @@ -960,7 +971,7 @@ def _load_xscale_vec(div_tensor, idx): else: # ============================================================== - # Generic path: scalar 2-pass for arbitrary N + # Generic path: scalar 3-pass for arbitrary N # ============================================================== Input_buf = fx.rocdl.make_buffer_tensor(Input) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) @@ -1021,6 +1032,7 @@ def _abs_scalar(val): thread_sumsq = c_zero_f + # Pass 1: accumulate sumsq for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int is_valid = idx < N @@ -1036,6 +1048,7 @@ def _abs_scalar(val): rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) thread_row_max = c_zero_f + # Pass 2: normalize, apply gamma (+ optional smooth scale), and accumulate row max for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int is_valid = idx < N @@ -1060,6 +1073,7 @@ def _abs_scalar(val): inv_scale = c_one_f / final_scale + # Pass 3: quantize + store using per-row scale for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int if idx < N: @@ -1100,28 +1114,29 @@ def launch_rmsnorm_smoothquant( return launch_rmsnorm_smoothquant - @flyc.jit - def launch_rmsnorm_dynamicquant( - Input: fx.Tensor, - Gamma: fx.Tensor, - Output: fx.Tensor, - YScale: fx.Tensor, - m_in: fx.Int32, - stream: fx.Stream = fx.Stream(None), - ): - allocator.finalized = False - ctx = CompilationContext.get_current() - with InsertionPoint(ctx.gpu_module_body): - allocator.finalize() + else: + @flyc.jit + def launch_rmsnorm_dynamicquant( + Input: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() - launcher = rmsnorm_quant_kernel(Input, Gamma, Gamma, YScale, Output) - launcher.launch( - grid=(m_in, 1, 1), - block=(BLOCK_THREADS, 1, 1), - stream=stream, - ) + launcher = rmsnorm_quant_kernel(Input, Gamma, Gamma, YScale, Output) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) - return launch_rmsnorm_dynamicquant + return launch_rmsnorm_dynamicquant def build_rmsnorm_dynamicquant_module( @@ -1300,12 +1315,15 @@ def block_reduce_max(val): return SmemPtr.load(s_red, [0]) + # ================================================================== + # Fast path: N is a multiple of tile_cols + # ================================================================== if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16): num_tiles = N // tile_cols quant_half_width = VEC_WIDTH // 2 abs_mask = full(VEC_WIDTH, fx.Uint32(0x7FFFFFFF), fx.Uint32) xscale_vec_width = 4 - + # ── Layout API: buffer-backed tensors + tiled access ───── Input_buf = fx.rocdl.make_buffer_tensor(Input) ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) @@ -1396,6 +1414,7 @@ def _to_elem_vec(y): thread_dummy = c_zero_f add_local = [] + # Pass 1: add + cache + sumsq (also write residual_out) for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS x = _load_vec(in_div, idx).to(fx.Float32) @@ -1416,6 +1435,7 @@ def _to_elem_vec(y): thread_row_max = c_zero_f y_local = [] + # Pass 2: normalize + gamma (+ optional smooth scale), cache output, and accumulate row max for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS g = _load_vec(gamma_div, idx).to(fx.Float32) @@ -1441,6 +1461,7 @@ def _to_elem_vec(y): inv_scale = c_one_f / final_scale + # Pass 3: quantize + store using per-row scale for tile_i in range_constexpr(num_tiles): q = y_local[tile_i] * inv_scale q_i8 = q.to(quant_dtype) @@ -1451,6 +1472,9 @@ def _to_elem_vec(y): _store_q_vec(q_hi, out_div_q, out_idx + 1) else: + # ============================================================== + # Generic path: scalar 3-pass for arbitrary N + # ============================================================== Input_buf = fx.rocdl.make_buffer_tensor(Input) ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) @@ -1531,6 +1555,7 @@ def _abs_scalar(val): thread_sumsq = c_zero_f + # Pass 1: add, write residual_out, and accumulate sumsq for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int is_valid = idx < N @@ -1552,6 +1577,7 @@ def _abs_scalar(val): rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) thread_row_max = c_zero_f + # Pass 2: normalize, apply gamma (+ optional smooth scale), and accumulate row max for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int is_valid = idx < N @@ -1576,6 +1602,7 @@ def _abs_scalar(val): inv_scale = c_one_f / final_scale + # Pass 3: quantize + store using per-row scale for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int if idx < N: @@ -1620,32 +1647,33 @@ def launch_fused_add_rmsnorm_smoothquant( return launch_fused_add_rmsnorm_smoothquant - @flyc.jit - def launch_fused_add_rmsnorm_dynamicquant( - Input: fx.Tensor, - ResidualIn: fx.Tensor, - Gamma: fx.Tensor, - Output: fx.Tensor, - ResidualOut: fx.Tensor, - YScale: fx.Tensor, - m_in: fx.Int32, - stream: fx.Stream = fx.Stream(None), - ): - allocator.finalized = False - ctx = CompilationContext.get_current() - with InsertionPoint(ctx.gpu_module_body): - allocator.finalize() + else: + @flyc.jit + def launch_fused_add_rmsnorm_dynamicquant( + Input: fx.Tensor, + ResidualIn: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + ResidualOut: fx.Tensor, + YScale: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with InsertionPoint(ctx.gpu_module_body): + allocator.finalize() - launcher = fused_add_rmsnorm_quant_kernel( - Input, ResidualIn, Gamma, Gamma, YScale, Output, ResidualOut - ) - launcher.launch( - grid=(m_in, 1, 1), - block=(BLOCK_THREADS, 1, 1), - stream=stream, - ) + launcher = fused_add_rmsnorm_quant_kernel( + Input, ResidualIn, Gamma, Gamma, YScale, Output, ResidualOut + ) + launcher.launch( + grid=(m_in, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) - return launch_fused_add_rmsnorm_dynamicquant + return launch_fused_add_rmsnorm_dynamicquant def build_fused_add_rmsnorm_dynamicquant_module( From 057b70e5c862171b61d94432dac289264ec99ae3 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Wed, 13 May 2026 00:07:17 +0800 Subject: [PATCH 09/11] refactor: add common functions in rmsnorm varients to reduce duplicated codes --- kernels/rmsnorm_kernel.py | 637 ++++++++++++++------------------------ 1 file changed, 230 insertions(+), 407 deletions(-) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index 1c6f0ced..ebdb55b9 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -5,8 +5,7 @@ RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gamma -Three paths: - - Large-M/small-N specialized path. +Two paths: - Fast path (N % tile_cols == 0): buffer_load/store vectorised access. - Generic path (arbitrary N): scalar copy_atom_call. """ @@ -34,127 +33,78 @@ VEC_WIDTH = 8 -def _should_use_large_m_small_n(M: int, N: int) -> bool: - return M > 8192 and N <= 2048 +def _make_reduction_allocator(arch: str, red_slots: int): + allocator = SmemAllocator(None, arch=arch) + f32_bytes = 4 + red_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = red_offset + red_slots * f32_bytes + red2_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = red2_offset + red_slots * f32_bytes + return allocator, red_offset, red2_offset -def _build_rmsnorm_large_m_small_n_module(M: int, N: int, dtype_str: str): - BLOCK_N = 1 << (N - 1).bit_length() - BLOCK_M = max(min(16384 // BLOCK_N, 32), 8) - THREADS_PER_ROW = min(WARP_SIZE, 1024 // BLOCK_M) - BLOCK_THREADS_SPECIAL = BLOCK_M * THREADS_PER_ROW - elem_bits = 32 if dtype_str == "f32" else 16 +def _load_scalar(copy_atom, scalar_reg_ty, scalar_reg_lay, divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + fx.copy_atom_call(copy_atom, view, r) + return fx.memref_load_vec(r)[0] - @flyc.kernel - def rmsnorm_large_m_small_n_kernel( - Input: fx.Tensor, - Gamma: fx.Tensor, - _Unused: fx.Tensor, - Output: fx.Tensor, - ): - bid = fx.block_idx.x - tid = fx.thread_idx.x - lane = tid % THREADS_PER_ROW - row_local = tid // THREADS_PER_ROW - row = bid * fx.Int32(BLOCK_M) + row_local +def _store_scalar(copy_atom, scalar_reg_ty, scalar_reg_lay, store_dtype, divided_tensor, index, val): + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + ts = full(1, store_dtype(val), store_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom, r, view) - if row < M: - elem_dtype = dtype_to_elem_type(dtype_str) - elem_type = elem_dtype.ir_type - fm_fast = arith.FastMathFlags.fast - eps_c = EPS - n_float = float(N) - Input_buf = fx.rocdl.make_buffer_tensor(Input) - Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) - Output_buf = fx.rocdl.make_buffer_tensor(Output) +def _load_vec(copy_atom, vec_reg_ty, vec_reg_lay, div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) - row_in = fx.slice(Input_buf, (row, None)) - row_out = fx.slice(Output_buf, (row, None)) - copy_atom_s = fx.make_copy_atom( - fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), - elem_bits, - ) - scalar_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register - ) - scalar_reg_lay = fx.make_layout(1, 1) +def _store_vec(copy_atom, vec_reg_ty, vec_reg_lay, val, div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.memref_store_vec(val, r) + fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) - row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) - gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) - out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) - def _load_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0] +def _to_elem_scalar(dtype_str: str, elem_dtype, y): + if const_expr(dtype_str == "f32"): + return y + return y.to(elem_dtype) - def _store_scalar(divided_tensor, index, val): - r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - ts = full(1, elem_dtype(val), elem_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_s, r, view) - def group_reduce_add(x): - w = x - for _sh_exp in range_constexpr(int(math.log2(THREADS_PER_ROW))): - off = THREADS_PER_ROW // (2 << _sh_exp) - peer = w.shuffle_xor(off, fx.Int32(THREADS_PER_ROW)) - w = w.addf(peer, fastmath=fm_fast) - return w +def _to_elem_vec(dtype_str: str, elem_dtype, use_hw_cvt_bf16: bool, y): + if const_expr(dtype_str == "bf16"): + if const_expr(use_hw_cvt_bf16): + return y.to(elem_dtype) + u = y.bitcast(fx.Uint32) + upper = u >> 16 + lsb = upper & 1 + bias = lsb + 0x7FFF + u_round = y.bitcast(fx.Uint32) + bias + bf16_bits = u_round >> 16 + even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) + odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) + odd_sh = odd << 16 + packed = even | odd_sh + return packed.bitcast(elem_dtype) + if const_expr(dtype_str == "f32"): + return y + return y.to(elem_dtype) - c_zero_f = fx.Float32(0.0) - thread_sumsq = c_zero_f - for base_idx_int in range_constexpr(0, BLOCK_N, THREADS_PER_ROW): - idx = lane + base_idx_int - is_valid = idx < N - idx_safe = is_valid.select(idx, 0) - x_e = _load_scalar(row_div, idx_safe) - x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) - x2 = x * x - thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) - - sum_sq = group_reduce_add(thread_sumsq) - mean_sq = sum_sq / n_float - ms_eps = mean_sq + eps_c - rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) - - for base_idx_int in range_constexpr(0, BLOCK_N, THREADS_PER_ROW): - idx = lane + base_idx_int - if idx < N: - x_e = _load_scalar(row_div, idx) - g_e = _load_scalar(gamma_div, idx) - x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) - g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) - y = (x * rrms) * g - y_e = y if dtype_str == "f32" else y.to(elem_dtype) - _store_scalar(out_div, idx, y_e) - - @flyc.jit - def launch_rmsnorm_large_m_small_n( - Input: fx.Tensor, - Gamma: fx.Tensor, - Output: fx.Tensor, - m_in: fx.Int32, - stream: fx.Stream = fx.Stream(None), - ): - launcher = rmsnorm_large_m_small_n_kernel(Input, Gamma, Gamma, Output) - launcher.launch( - grid=((M + BLOCK_M - 1) // BLOCK_M, 1, 1), - block=(BLOCK_THREADS_SPECIAL, 1, 1), - stream=stream, - ) - - return launch_rmsnorm_large_m_small_n +def _store_yscale(scale_copy_atom, scale_reg_ty, scale_reg_lay, yscale_div, index, val): + r = fx.memref_alloca(scale_reg_ty, scale_reg_lay) + ts = full(1, fx.Float32(val), fx.Float32) + fx.memref_store_vec(ts, r) + fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) def build_rmsnorm_module(M: int, N: int, dtype_str: str): - if _should_use_large_m_small_n(M, N): + if M > 8192 and N <= 2048: return _build_rmsnorm_large_m_small_n_module(M, N, dtype_str) arch = get_hip_arch() @@ -164,12 +114,7 @@ def build_rmsnorm_module(M: int, N: int, dtype_str: str): RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) elem_bits = 32 if dtype_str == "f32" else 16 - allocator = SmemAllocator(None, arch=arch) - f32_bytes = 4 - red_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red_offset + RED_SLOTS * f32_bytes - red2_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red2_offset + RED_SLOTS * f32_bytes + allocator, red_offset, red2_offset = _make_reduction_allocator(arch, RED_SLOTS) @flyc.kernel def rmsnorm_kernel( @@ -261,16 +206,6 @@ def block_reduce_add2(val0, val1): ) vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) - def _load_vec(div_tensor, idx): - r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return fx.memref_load_vec(r) - - def _store_vec(val, div_tensor, idx): - r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.memref_store_vec(val, r) - fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) - c_zero_f = fx.Float32(0.0) thread_sumsq = c_zero_f thread_dummy = c_zero_f @@ -279,7 +214,7 @@ def _store_vec(val, div_tensor, idx): # Pass 1: load + cache + sumsq for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - vec = _load_vec(in_div, idx) + vec = _load_vec(copy_atom, vec_reg_ty, vec_reg_lay, in_div, idx) in_local.append(vec) x = vec.to(fx.Float32) @@ -296,34 +231,14 @@ def _store_vec(val, div_tensor, idx): for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - g = _load_vec(gamma_div, idx).to(fx.Float32) + g = _load_vec(copy_atom, vec_reg_ty, vec_reg_lay, gamma_div, idx).to(fx.Float32) x = in_local[tile_i].to(fx.Float32) y = (x * rrms) * g - - out_e = y.to(elem_dtype) - if const_expr(dtype_str == "bf16"): - if const_expr(USE_HW_CVT_PK_BF16_F32): - out_e = y.to(elem_dtype) - else: - u = y.bitcast(fx.Uint32) - upper = u >> 16 - lsb = upper & 1 - bias = lsb + 0x7FFF - u_round = y.bitcast(fx.Uint32) + bias - bf16_bits = u_round >> 16 - even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) - odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) - odd_sh = odd << 16 - packed = even | odd_sh - out_e = packed.bitcast(elem_dtype) - elif const_expr(dtype_str == "f32"): - out_e = y - else: - out_e = y.to(elem_dtype) + out_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, y) out_idx = tid + tile_i * BLOCK_THREADS - _store_vec(out_e, out_div, out_idx) + _store_vec(copy_atom, vec_reg_ty, vec_reg_lay, out_e, out_div, out_idx) else: # ============================================================== @@ -340,26 +255,15 @@ def _store_vec(val, div_tensor, idx): fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), elem_bits, ) - scalar_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + scalar_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) scalar_reg_lay = fx.make_layout(1, 1) row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) - def _load_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0] - - def _store_scalar(divided_tensor, index, val): - r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - ts = full(1, elem_dtype(val), elem_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_s, r, view) - c_zero_f = fx.Float32(0.0) thread_sumsq = c_zero_f @@ -367,7 +271,7 @@ def _store_scalar(divided_tensor, index, val): idx = tid + base_idx_int is_valid = idx < N idx_safe = is_valid.select(idx, 0) - x_e = _load_scalar(row_div, idx_safe) + x_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, row_div, idx_safe) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) x2 = x * x x2_safe = is_valid.select(x2, c_zero_f) @@ -381,19 +285,14 @@ def _store_scalar(divided_tensor, index, val): for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int if idx < N: - x_e = _load_scalar(row_div, idx) - g_e = _load_scalar(gamma_div, idx) + x_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, row_div, idx) + g_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, gamma_div, idx) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) norm = x * rrms y = norm * g - if const_expr(dtype_str == "f32"): - y_e = y - elif const_expr(dtype_str == "bf16"): - y_e = y.to(elem_dtype) - else: - y_e = y.to(elem_dtype) - _store_scalar(out_div, idx, y_e) + y_e = _to_elem_scalar(dtype_str, elem_dtype, y) + _store_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, elem_dtype, out_div, idx, y_e) @flyc.jit def launch_rmsnorm( @@ -418,6 +317,108 @@ def launch_rmsnorm( return launch_rmsnorm +def _build_rmsnorm_large_m_small_n_module(M: int, N: int, dtype_str: str): + BLOCK_N = 1 << (N - 1).bit_length() + BLOCK_M = max(min(16384 // BLOCK_N, 32), 8) + THREADS_PER_ROW = min(WARP_SIZE, 1024 // BLOCK_M) + BLOCK_THREADS_SPECIAL = BLOCK_M * THREADS_PER_ROW + elem_bits = 32 if dtype_str == "f32" else 16 + + @flyc.kernel + def rmsnorm_large_m_small_n_kernel( + Input: fx.Tensor, + Gamma: fx.Tensor, + _Unused: fx.Tensor, + Output: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + lane = tid % THREADS_PER_ROW + row_local = tid // THREADS_PER_ROW + row = bid * fx.Int32(BLOCK_M) + row_local + + if row < M: + elem_dtype = dtype_to_elem_type(dtype_str) + elem_type = elem_dtype.ir_type + fm_fast = arith.FastMathFlags.fast + eps_c = EPS + n_float = float(N) + + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + + row_in = fx.slice(Input_buf, (row, None)) + row_out = fx.slice(Output_buf, (row, None)) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + scalar_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) + scalar_reg_lay = fx.make_layout(1, 1) + + row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + + def group_reduce_add(x): + w = x + for _sh_exp in range_constexpr(int(math.log2(THREADS_PER_ROW))): + off = THREADS_PER_ROW // (2 << _sh_exp) + peer = w.shuffle_xor(off, fx.Int32(THREADS_PER_ROW)) + w = w.addf(peer, fastmath=fm_fast) + return w + + c_zero_f = fx.Float32(0.0) + thread_sumsq = c_zero_f + + for base_idx_int in range_constexpr(0, BLOCK_N, THREADS_PER_ROW): + idx = lane + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, row_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + x2 = x * x + thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) + + sum_sq = group_reduce_add(thread_sumsq) + mean_sq = sum_sq / n_float + ms_eps = mean_sq + eps_c + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) + + for base_idx_int in range_constexpr(0, BLOCK_N, THREADS_PER_ROW): + idx = lane + base_idx_int + if idx < N: + x_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, row_div, idx) + g_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, gamma_div, idx) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + y = (x * rrms) * g + y_e = _to_elem_scalar(dtype_str, elem_dtype, y) + _store_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, elem_dtype, out_div, idx, y_e) + + @flyc.jit + def launch_rmsnorm_large_m_small_n( + Input: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + launcher = rmsnorm_large_m_small_n_kernel(Input, Gamma, Gamma, Output) + launcher.launch( + grid=((M + BLOCK_M - 1) // BLOCK_M, 1, 1), + block=(BLOCK_THREADS_SPECIAL, 1, 1), + stream=stream, + ) + + return launch_rmsnorm_large_m_small_n + + def build_fused_add_rmsnorm_module(M: int, N: int, dtype_str: str): arch = get_hip_arch() USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") @@ -426,12 +427,7 @@ def build_fused_add_rmsnorm_module(M: int, N: int, dtype_str: str): RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) elem_bits = 32 if dtype_str == "f32" else 16 - allocator = SmemAllocator(None, arch=arch) - f32_bytes = 4 - red_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red_offset + RED_SLOTS * f32_bytes - red2_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red2_offset + RED_SLOTS * f32_bytes + allocator, red_offset, red2_offset = _make_reduction_allocator(arch, RED_SLOTS) @flyc.kernel def fused_add_rmsnorm_kernel( @@ -530,35 +526,6 @@ def block_reduce_add2(val0, val1): ) vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) - def _load_vec(div_tensor, idx): - r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return fx.memref_load_vec(r) - - def _store_vec(val, div_tensor, idx): - r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.memref_store_vec(val, r) - fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) - - def _to_elem_vec(y): - if const_expr(dtype_str == "bf16"): - if const_expr(USE_HW_CVT_PK_BF16_F32): - return y.to(elem_dtype) - u = y.bitcast(fx.Uint32) - upper = u >> 16 - lsb = upper & 1 - bias = lsb + 0x7FFF - u_round = y.bitcast(fx.Uint32) + bias - bf16_bits = u_round >> 16 - even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) - odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) - odd_sh = odd << 16 - packed = even | odd_sh - return packed.bitcast(elem_dtype) - if const_expr(dtype_str == "f32"): - return y - return y.to(elem_dtype) - c_zero_f = fx.Float32(0.0) thread_sumsq = c_zero_f thread_dummy = c_zero_f @@ -567,9 +534,9 @@ def _to_elem_vec(y): # Pass 1: add + cache + sumsq (also write residual_out) for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - x = _load_vec(in_div, idx).to(fx.Float32) - residual = _load_vec(residual_in_div, idx).to(fx.Float32) - added_e = _to_elem_vec(x + residual) + x = _load_vec(copy_atom, vec_reg_ty, vec_reg_lay, in_div, idx).to(fx.Float32) + residual = _load_vec(copy_atom, vec_reg_ty, vec_reg_lay, residual_in_div, idx).to(fx.Float32) + added_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, x + residual) add_local.append(added_e) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) @@ -577,7 +544,7 @@ def _to_elem_vec(y): red2 = added2.reduce(ReductionOp.ADD, fastmath=fm_fast) thread_sumsq = thread_sumsq + red2 - _store_vec(added_e, residual_out_div, idx) + _store_vec(copy_atom, vec_reg_ty, vec_reg_lay, added_e, residual_out_div, idx) _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = sum_sq / n_float @@ -587,10 +554,11 @@ def _to_elem_vec(y): # Pass 2: normalize + gamma + store (reuse cached added values) for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - g = _load_vec(gamma_div, idx).to(fx.Float32) + g = _load_vec(copy_atom, vec_reg_ty, vec_reg_lay, gamma_div, idx).to(fx.Float32) added = add_local[tile_i] if dtype_str == "f32" else add_local[tile_i].to(fx.Float32) y = (added * rrms) * g - _store_vec(_to_elem_vec(y), out_div, idx) + y_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, y) + _store_vec(copy_atom, vec_reg_ty, vec_reg_lay, y_e, out_div, idx) else: # ============================================================== @@ -611,7 +579,9 @@ def _to_elem_vec(y): fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), elem_bits, ) - scalar_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + scalar_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) scalar_reg_lay = fx.make_layout(1, 1) row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) @@ -620,24 +590,6 @@ def _to_elem_vec(y): out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) - def _load_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0] - - def _store_scalar(divided_tensor, index, val): - r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - ts = full(1, elem_dtype(val), elem_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_s, r, view) - - def _to_elem_scalar(y): - if const_expr(dtype_str == "f32"): - return y - return y.to(elem_dtype) - c_zero_f = fx.Float32(0.0) thread_sumsq = c_zero_f @@ -645,13 +597,13 @@ def _to_elem_scalar(y): idx = tid + base_idx_int is_valid = idx < N idx_safe = is_valid.select(idx, 0) - x_e = _load_scalar(row_div, idx_safe) - residual_e = _load_scalar(residual_in_div, idx_safe) + x_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, row_div, idx_safe) + residual_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, residual_in_div, idx_safe) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) - added_e = _to_elem_scalar(x + residual) + added_e = _to_elem_scalar(dtype_str, elem_dtype, x + residual) if idx < N: - _store_scalar(residual_out_div, idx, added_e) + _store_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, elem_dtype, residual_out_div, idx, added_e) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) added2 = added * added thread_sumsq = thread_sumsq + is_valid.select(added2, c_zero_f) @@ -664,12 +616,13 @@ def _to_elem_scalar(y): for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int if idx < N: - g_e = _load_scalar(gamma_div, idx) - added_e = _load_scalar(residual_out_div, idx) + g_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, gamma_div, idx) + added_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, residual_out_div, idx) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) y = (added * rrms) * g - _store_scalar(out_div, idx, _to_elem_scalar(y)) + y_e = _to_elem_scalar(dtype_str, elem_dtype, y) + _store_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, elem_dtype, out_div, idx, y_e) @flyc.jit def launch_fused_add_rmsnorm( @@ -723,12 +676,7 @@ def _build_rmsnorm_quant_module( elem_bits = 32 if dtype_str == "f32" else 16 quant_dtype_max = _quant_dtype_max(quant_dtype_str) - allocator = SmemAllocator(None, arch=arch) - f32_bytes = 4 - red_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red_offset + RED_SLOTS * f32_bytes - red2_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red2_offset + RED_SLOTS * f32_bytes + allocator, red_offset, red2_offset = _make_reduction_allocator(arch, RED_SLOTS) @flyc.kernel def rmsnorm_quant_kernel( @@ -768,12 +716,6 @@ def rmsnorm_quant_kernel( ) scale_reg_lay = fx.make_layout(1, 1) - def _store_yscale(index, val): - r = fx.memref_alloca(scale_reg_ty, scale_reg_lay) - ts = full(1, fx.Float32(val), fx.Float32) - fx.memref_store_vec(ts, r) - fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) - def wave_reduce_add(x): w = x for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): @@ -894,22 +836,6 @@ def block_reduce_max(val): ) vec_reg_lay_q = fx.make_layout(quant_half_width, 1) - def _load_vec(div_tensor, idx): - r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return fx.memref_load_vec(r) - - def _store_q_vec(val, div_tensor, idx): - r = fx.memref_alloca(vec_reg_ty_q, vec_reg_lay_q) - fx.memref_store_vec(val, r) - fx.copy_atom_call(copy_atom_q, r, fx.slice(div_tensor, (None, idx))) - - if const_expr(is_smooth): - def _load_xscale_vec(div_tensor, idx): - r = fx.memref_alloca(xscale_reg_ty, xscale_reg_lay) - fx.copy_atom_call(copy_atom_xs, fx.slice(div_tensor, (None, idx)), r) - return fx.memref_load_vec(r) - thread_sumsq = c_zero_f thread_dummy = c_zero_f in_local = [] @@ -917,7 +843,7 @@ def _load_xscale_vec(div_tensor, idx): # Pass 1: load + cache + sumsq for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - vec = _load_vec(in_div, idx) + vec = _load_vec(copy_atom, vec_reg_ty, vec_reg_lay, in_div, idx) in_local.append(vec) x = vec.to(fx.Float32) x2 = x * x @@ -936,12 +862,12 @@ def _load_xscale_vec(div_tensor, idx): for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - g = _load_vec(gamma_div, idx).to(fx.Float32) + g = _load_vec(copy_atom, vec_reg_ty, vec_reg_lay, gamma_div, idx).to(fx.Float32) x = in_local[tile_i].to(fx.Float32) y = (x * rrms) * g if const_expr(is_smooth): - s_lo = _load_xscale_vec(xscale_div, idx * 2) - s_hi = _load_xscale_vec(xscale_div, idx * 2 + 1) + s_lo = _load_vec(copy_atom_xs, xscale_reg_ty, xscale_reg_lay, xscale_div, idx * 2) + s_hi = _load_vec(copy_atom_xs, xscale_reg_ty, xscale_reg_lay, xscale_div, idx * 2 + 1) s = Vec(s_lo).shuffle(Vec(s_hi), [0, 1, 2, 3, 4, 5, 6, 7]).ir_value() y = y * s @@ -955,7 +881,7 @@ def _load_xscale_vec(div_tensor, idx): final_scale = (scale == c_zero_f).select(c_one_f, scale) if tid == 0: - _store_yscale(bid, final_scale) + _store_yscale(scale_copy_atom, scale_reg_ty, scale_reg_lay, yscale_div, bid, final_scale) inv_scale = c_one_f / final_scale @@ -966,8 +892,8 @@ def _load_xscale_vec(div_tensor, idx): q_lo = q_i8.shuffle(q_i8, [0, 1, 2, 3]) q_hi = q_i8.shuffle(q_i8, [4, 5, 6, 7]) out_idx = tid * 2 + tile_i * BLOCK_THREADS * 2 - _store_q_vec(q_lo, out_div_q, out_idx) - _store_q_vec(q_hi, out_div_q, out_idx + 1) + _store_vec(copy_atom_q, vec_reg_ty_q, vec_reg_lay_q, q_lo, out_div_q, out_idx) + _store_vec(copy_atom_q, vec_reg_ty_q, vec_reg_lay_q, q_hi, out_div_q, out_idx + 1) else: # ============================================================== @@ -1005,26 +931,6 @@ def _load_xscale_vec(div_tensor, idx): if const_expr(is_smooth): xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(1, 1)) - def _load_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0] - - def _store_quant_scalar(divided_tensor, index, val): - r = fx.memref_alloca(scalar_reg_ty_q, scalar_reg_lay_q) - ts = full(1, quant_dtype(val), quant_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_qs, r, view) - - if const_expr(is_smooth): - def _load_xscale_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.memref_alloca(xscale_scalar_reg_ty, xscale_scalar_reg_lay) - fx.copy_atom_call(copy_atom_xs, view, r) - return fx.memref_load_vec(r)[0] - def _abs_scalar(val): is_neg = val < c_zero_f neg_val = c_zero_f - val @@ -1037,7 +943,7 @@ def _abs_scalar(val): idx = tid + base_idx_int is_valid = idx < N idx_safe = is_valid.select(idx, 0) - x_e = _load_scalar(row_div, idx_safe) + x_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, row_div, idx_safe) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) x2 = x * x thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) @@ -1053,13 +959,13 @@ def _abs_scalar(val): idx = tid + base_idx_int is_valid = idx < N idx_safe = is_valid.select(idx, 0) - x_e = _load_scalar(row_div, idx_safe) - g_e = _load_scalar(gamma_div, idx_safe) + x_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, row_div, idx_safe) + g_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, gamma_div, idx_safe) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) y = (x * rrms) * g if const_expr(is_smooth): - s = _load_xscale_scalar(xscale_div, idx_safe) + s = _load_scalar(copy_atom_xs, xscale_scalar_reg_ty, xscale_scalar_reg_lay, xscale_div, idx_safe) y = y * s y_abs = _abs_scalar(y) thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) @@ -1069,7 +975,7 @@ def _abs_scalar(val): final_scale = (scale == c_zero_f).select(c_one_f, scale) if tid == 0: - _store_yscale(bid, final_scale) + _store_yscale(scale_copy_atom, scale_reg_ty, scale_reg_lay, yscale_div, bid, final_scale) inv_scale = c_one_f / final_scale @@ -1077,17 +983,17 @@ def _abs_scalar(val): for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int if idx < N: - x_e = _load_scalar(row_div, idx) - g_e = _load_scalar(gamma_div, idx) + x_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, row_div, idx) + g_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, gamma_div, idx) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) y = (x * rrms) * g if const_expr(is_smooth): - s = _load_xscale_scalar(xscale_div, idx) + s = _load_scalar(copy_atom_xs, xscale_scalar_reg_ty, xscale_scalar_reg_lay, xscale_div, idx) y = y * s q = y * inv_scale q_i8 = q.to(quant_dtype) - _store_quant_scalar(out_div, idx, q_i8) + _store_scalar(copy_atom_qs, scalar_reg_ty_q, scalar_reg_lay_q, quant_dtype, out_div, idx, q_i8) if is_smooth: @flyc.jit @@ -1185,12 +1091,7 @@ def _build_fused_add_rmsnorm_quant_module( elem_bits = 32 if dtype_str == "f32" else 16 quant_dtype_max = _quant_dtype_max(quant_dtype_str) - allocator = SmemAllocator(None, arch=arch) - f32_bytes = 4 - red_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red_offset + RED_SLOTS * f32_bytes - red2_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = red2_offset + RED_SLOTS * f32_bytes + allocator, red_offset, red2_offset = _make_reduction_allocator(arch, RED_SLOTS) @flyc.kernel def fused_add_rmsnorm_quant_kernel( @@ -1232,12 +1133,6 @@ def fused_add_rmsnorm_quant_kernel( ) scale_reg_lay = fx.make_layout(1, 1) - def _store_yscale(index, val): - r = fx.memref_alloca(scale_reg_ty, scale_reg_lay) - ts = full(1, fx.Float32(val), fx.Float32) - fx.memref_store_vec(ts, r) - fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) - def wave_reduce_add(x): w = x for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): @@ -1370,46 +1265,6 @@ def block_reduce_max(val): ) vec_reg_lay_q = fx.make_layout(quant_half_width, 1) - def _load_vec(div_tensor, idx): - r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return fx.memref_load_vec(r) - - def _store_vec(val, div_tensor, idx): - r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.memref_store_vec(val, r) - fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) - - def _store_q_vec(val, div_tensor, idx): - r = fx.memref_alloca(vec_reg_ty_q, vec_reg_lay_q) - fx.memref_store_vec(val, r) - fx.copy_atom_call(copy_atom_q, r, fx.slice(div_tensor, (None, idx))) - - if const_expr(is_smooth): - def _load_xscale_vec(div_tensor, idx): - r = fx.memref_alloca(xscale_reg_ty, xscale_reg_lay) - fx.copy_atom_call(copy_atom_xs, fx.slice(div_tensor, (None, idx)), r) - return fx.memref_load_vec(r) - - def _to_elem_vec(y): - if const_expr(dtype_str == "bf16"): - if const_expr(USE_HW_CVT_PK_BF16_F32): - return y.to(elem_dtype) - u = y.bitcast(fx.Uint32) - upper = u >> 16 - lsb = upper & 1 - bias = lsb + 0x7FFF - u_round = y.bitcast(fx.Uint32) + bias - bf16_bits = u_round >> 16 - even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) - odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) - odd_sh = odd << 16 - packed = even | odd_sh - return packed.bitcast(elem_dtype) - if const_expr(dtype_str == "f32"): - return y - return y.to(elem_dtype) - thread_sumsq = c_zero_f thread_dummy = c_zero_f add_local = [] @@ -1417,15 +1272,15 @@ def _to_elem_vec(y): # Pass 1: add + cache + sumsq (also write residual_out) for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - x = _load_vec(in_div, idx).to(fx.Float32) - residual = _load_vec(residual_in_div, idx).to(fx.Float32) - added_e = _to_elem_vec(x + residual) + x = _load_vec(copy_atom, vec_reg_ty, vec_reg_lay, in_div, idx).to(fx.Float32) + residual = _load_vec(copy_atom, vec_reg_ty, vec_reg_lay, residual_in_div, idx).to(fx.Float32) + added_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, x + residual) add_local.append(added_e) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) added2 = added * added red2 = added2.reduce(ReductionOp.ADD, fastmath=fm_fast) thread_sumsq = thread_sumsq + red2 - _store_vec(added_e, residual_out_div, idx) + _store_vec(copy_atom, vec_reg_ty, vec_reg_lay, added_e, residual_out_div, idx) _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = sum_sq / n_float @@ -1438,12 +1293,12 @@ def _to_elem_vec(y): # Pass 2: normalize + gamma (+ optional smooth scale), cache output, and accumulate row max for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - g = _load_vec(gamma_div, idx).to(fx.Float32) + g = _load_vec(copy_atom, vec_reg_ty, vec_reg_lay, gamma_div, idx).to(fx.Float32) added = add_local[tile_i] if dtype_str == "f32" else add_local[tile_i].to(fx.Float32) y = (added * rrms) * g if const_expr(is_smooth): - s_lo = _load_xscale_vec(xscale_div, idx * 2) - s_hi = _load_xscale_vec(xscale_div, idx * 2 + 1) + s_lo = _load_vec(copy_atom_xs, xscale_reg_ty, xscale_reg_lay, xscale_div, idx * 2) + s_hi = _load_vec(copy_atom_xs, xscale_reg_ty, xscale_reg_lay, xscale_div, idx * 2 + 1) s = Vec(s_lo).shuffle(Vec(s_hi), [0, 1, 2, 3, 4, 5, 6, 7]).ir_value() y = y * s @@ -1457,7 +1312,7 @@ def _to_elem_vec(y): final_scale = (scale == c_zero_f).select(c_one_f, scale) if tid == 0: - _store_yscale(bid, final_scale) + _store_yscale(scale_copy_atom, scale_reg_ty, scale_reg_lay, yscale_div, bid, final_scale) inv_scale = c_one_f / final_scale @@ -1468,8 +1323,8 @@ def _to_elem_vec(y): q_lo = q_i8.shuffle(q_i8, [0, 1, 2, 3]) q_hi = q_i8.shuffle(q_i8, [4, 5, 6, 7]) out_idx = tid * 2 + tile_i * BLOCK_THREADS * 2 - _store_q_vec(q_lo, out_div_q, out_idx) - _store_q_vec(q_hi, out_div_q, out_idx + 1) + _store_vec(copy_atom_q, vec_reg_ty_q, vec_reg_lay_q, q_lo, out_div_q, out_idx) + _store_vec(copy_atom_q, vec_reg_ty_q, vec_reg_lay_q, q_hi, out_div_q, out_idx + 1) else: # ============================================================== @@ -1516,38 +1371,6 @@ def _to_elem_vec(y): if const_expr(is_smooth): xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(1, 1)) - def _load_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0] - - def _store_scalar(divided_tensor, index, val): - r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - ts = full(1, elem_dtype(val), elem_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_s, r, view) - - def _store_quant_scalar(divided_tensor, index, val): - r = fx.memref_alloca(scalar_reg_ty_q, scalar_reg_lay_q) - ts = full(1, quant_dtype(val), quant_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_qs, r, view) - - if const_expr(is_smooth): - def _load_xscale_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.memref_alloca(xscale_scalar_reg_ty, xscale_scalar_reg_lay) - fx.copy_atom_call(copy_atom_xs, view, r) - return fx.memref_load_vec(r)[0] - - def _to_elem_scalar(y): - if const_expr(dtype_str == "f32"): - return y - return y.to(elem_dtype) - def _abs_scalar(val): is_neg = val < c_zero_f neg_val = c_zero_f - val @@ -1560,13 +1383,13 @@ def _abs_scalar(val): idx = tid + base_idx_int is_valid = idx < N idx_safe = is_valid.select(idx, 0) - x_e = _load_scalar(row_div, idx_safe) - residual_e = _load_scalar(residual_in_div, idx_safe) + x_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, row_div, idx_safe) + residual_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, residual_in_div, idx_safe) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) - added_e = _to_elem_scalar(x + residual) + added_e = _to_elem_scalar(dtype_str, elem_dtype, x + residual) if idx < N: - _store_scalar(residual_out_div, idx, added_e) + _store_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, elem_dtype, residual_out_div, idx, added_e) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) added2 = added * added thread_sumsq = thread_sumsq + is_valid.select(added2, c_zero_f) @@ -1582,13 +1405,13 @@ def _abs_scalar(val): idx = tid + base_idx_int is_valid = idx < N idx_safe = is_valid.select(idx, 0) - g_e = _load_scalar(gamma_div, idx_safe) - added_e = _load_scalar(residual_out_div, idx_safe) + g_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, gamma_div, idx_safe) + added_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, residual_out_div, idx_safe) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) y = (added * rrms) * g if const_expr(is_smooth): - s = _load_xscale_scalar(xscale_div, idx_safe) + s = _load_scalar(copy_atom_xs, xscale_scalar_reg_ty, xscale_scalar_reg_lay, xscale_div, idx_safe) y = y * s y_abs = _abs_scalar(y) thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) @@ -1598,7 +1421,7 @@ def _abs_scalar(val): final_scale = (scale == c_zero_f).select(c_one_f, scale) if tid == 0: - _store_yscale(bid, final_scale) + _store_yscale(scale_copy_atom, scale_reg_ty, scale_reg_lay, yscale_div, bid, final_scale) inv_scale = c_one_f / final_scale @@ -1606,17 +1429,17 @@ def _abs_scalar(val): for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int if idx < N: - g_e = _load_scalar(gamma_div, idx) - added_e = _load_scalar(residual_out_div, idx) + g_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, gamma_div, idx) + added_e = _load_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, residual_out_div, idx) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) y = (added * rrms) * g if const_expr(is_smooth): - s = _load_xscale_scalar(xscale_div, idx) + s = _load_scalar(copy_atom_xs, xscale_scalar_reg_ty, xscale_scalar_reg_lay, xscale_div, idx) y = y * s q = y * inv_scale q_i8 = q.to(quant_dtype) - _store_quant_scalar(out_div, idx, q_i8) + _store_scalar(copy_atom_qs, scalar_reg_ty_q, scalar_reg_lay_q, quant_dtype, out_div, idx, q_i8) if is_smooth: @flyc.jit From e6fa2f9262dead8ad50d33e69abc7667a500cb2b Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Wed, 13 May 2026 01:11:14 +0800 Subject: [PATCH 10/11] fix python style check --- kernels/rmsnorm_kernel.py | 80 +++++++++++++-------------------------- 1 file changed, 27 insertions(+), 53 deletions(-) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index ebdb55b9..ad357c50 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -201,9 +201,7 @@ def block_reduce_add2(val0, val1): gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) - vec_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register - ) + vec_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register) vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) c_zero_f = fx.Float32(0.0) @@ -255,9 +253,7 @@ def block_reduce_add2(val0, val1): fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), elem_bits, ) - scalar_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register - ) + scalar_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) scalar_reg_lay = fx.make_layout(1, 1) row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) @@ -356,9 +352,7 @@ def rmsnorm_large_m_small_n_kernel( fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), elem_bits, ) - scalar_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register - ) + scalar_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) scalar_reg_lay = fx.make_layout(1, 1) row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) @@ -521,9 +515,7 @@ def block_reduce_add2(val0, val1): gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) - vec_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register - ) + vec_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register) vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) c_zero_f = fx.Float32(0.0) @@ -579,9 +571,7 @@ def block_reduce_add2(val0, val1): fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), elem_bits, ) - scalar_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register - ) + scalar_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) scalar_reg_lay = fx.make_layout(1, 1) row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) @@ -603,7 +593,9 @@ def block_reduce_add2(val0, val1): residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) added_e = _to_elem_scalar(dtype_str, elem_dtype, x + residual) if idx < N: - _store_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, elem_dtype, residual_out_div, idx, added_e) + _store_scalar( + copy_atom_s, scalar_reg_ty, scalar_reg_lay, elem_dtype, residual_out_div, idx, added_e + ) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) added2 = added * added thread_sumsq = thread_sumsq + is_valid.select(added2, c_zero_f) @@ -711,9 +703,7 @@ def rmsnorm_quant_kernel( YScale_buf = fx.rocdl.make_buffer_tensor(YScale) yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1)) scale_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) - scale_reg_ty = fx.MemRefType.get( - fx.Float32.ir_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register - ) + scale_reg_ty = fx.MemRefType.get(fx.Float32.ir_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) scale_reg_lay = fx.make_layout(1, 1) def wave_reduce_add(x): @@ -818,9 +808,7 @@ def block_reduce_max(val): xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(xscale_vec_width, 1)) copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) - vec_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register - ) + vec_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register) vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) if const_expr(is_smooth): copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), 32) @@ -918,9 +906,7 @@ def block_reduce_max(val): fx.Float32.ir_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register ) xscale_scalar_reg_lay = fx.make_layout(1, 1) - scalar_reg_ty_q = fx.MemRefType.get( - quant_elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register - ) + scalar_reg_ty_q = fx.MemRefType.get(quant_elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) scalar_reg_lay_q = fx.make_layout(1, 1) row_in = fx.slice(Input_buf, (bid, None)) @@ -996,6 +982,7 @@ def _abs_scalar(val): _store_scalar(copy_atom_qs, scalar_reg_ty_q, scalar_reg_lay_q, quant_dtype, out_div, idx, q_i8) if is_smooth: + @flyc.jit def launch_rmsnorm_smoothquant( Input: fx.Tensor, @@ -1021,6 +1008,7 @@ def launch_rmsnorm_smoothquant( return launch_rmsnorm_smoothquant else: + @flyc.jit def launch_rmsnorm_dynamicquant( Input: fx.Tensor, @@ -1128,9 +1116,7 @@ def fused_add_rmsnorm_quant_kernel( YScale_buf = fx.rocdl.make_buffer_tensor(YScale) yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1)) scale_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) - scale_reg_ty = fx.MemRefType.get( - fx.Float32.ir_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register - ) + scale_reg_ty = fx.MemRefType.get(fx.Float32.ir_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) scale_reg_lay = fx.make_layout(1, 1) def wave_reduce_add(x): @@ -1233,23 +1219,15 @@ def block_reduce_max(val): row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) - residual_in_div = fx.logical_divide( - row_residual_in, fx.make_layout(VEC_WIDTH, 1) - ) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(VEC_WIDTH, 1)) out_div_q = fx.logical_divide(row_out, fx.make_layout(quant_half_width, 1)) - residual_out_div = fx.logical_divide( - row_residual_out, fx.make_layout(VEC_WIDTH, 1) - ) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(VEC_WIDTH, 1)) gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) if const_expr(is_smooth): - xscale_div = fx.logical_divide( - XScale_buf, fx.make_layout(xscale_vec_width, 1) - ) + xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(xscale_vec_width, 1)) copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) - vec_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register - ) + vec_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register) vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) if const_expr(is_smooth): copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), 32) @@ -1343,9 +1321,7 @@ def block_reduce_max(val): elem_bits, ) copy_atom_qs = fx.make_copy_atom(fx.rocdl.BufferCopy(8), 8) - scalar_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register - ) + scalar_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) scalar_reg_lay = fx.make_layout(1, 1) if const_expr(is_smooth): copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) @@ -1353,9 +1329,7 @@ def block_reduce_max(val): fx.Float32.ir_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register ) xscale_scalar_reg_lay = fx.make_layout(1, 1) - scalar_reg_ty_q = fx.MemRefType.get( - quant_elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register - ) + scalar_reg_ty_q = fx.MemRefType.get(quant_elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) scalar_reg_lay_q = fx.make_layout(1, 1) row_in = fx.slice(Input_buf, (bid, None)) @@ -1389,7 +1363,9 @@ def _abs_scalar(val): residual = residual_e if dtype_str == "f32" else residual_e.to(fx.Float32) added_e = _to_elem_scalar(dtype_str, elem_dtype, x + residual) if idx < N: - _store_scalar(copy_atom_s, scalar_reg_ty, scalar_reg_lay, elem_dtype, residual_out_div, idx, added_e) + _store_scalar( + copy_atom_s, scalar_reg_ty, scalar_reg_lay, elem_dtype, residual_out_div, idx, added_e + ) added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) added2 = added * added thread_sumsq = thread_sumsq + is_valid.select(added2, c_zero_f) @@ -1442,6 +1418,7 @@ def _abs_scalar(val): _store_scalar(copy_atom_qs, scalar_reg_ty_q, scalar_reg_lay_q, quant_dtype, out_div, idx, q_i8) if is_smooth: + @flyc.jit def launch_fused_add_rmsnorm_smoothquant( Input: fx.Tensor, @@ -1459,9 +1436,7 @@ def launch_fused_add_rmsnorm_smoothquant( with InsertionPoint(ctx.gpu_module_body): allocator.finalize() - launcher = fused_add_rmsnorm_quant_kernel( - Input, ResidualIn, Gamma, XScale, YScale, Output, ResidualOut - ) + launcher = fused_add_rmsnorm_quant_kernel(Input, ResidualIn, Gamma, XScale, YScale, Output, ResidualOut) launcher.launch( grid=(m_in, 1, 1), block=(BLOCK_THREADS, 1, 1), @@ -1471,6 +1446,7 @@ def launch_fused_add_rmsnorm_smoothquant( return launch_fused_add_rmsnorm_smoothquant else: + @flyc.jit def launch_fused_add_rmsnorm_dynamicquant( Input: fx.Tensor, @@ -1487,9 +1463,7 @@ def launch_fused_add_rmsnorm_dynamicquant( with InsertionPoint(ctx.gpu_module_body): allocator.finalize() - launcher = fused_add_rmsnorm_quant_kernel( - Input, ResidualIn, Gamma, Gamma, YScale, Output, ResidualOut - ) + launcher = fused_add_rmsnorm_quant_kernel(Input, ResidualIn, Gamma, Gamma, YScale, Output, ResidualOut) launcher.launch( grid=(m_in, 1, 1), block=(BLOCK_THREADS, 1, 1), From a8e7cde9cdb0fde2280c5a15453eb36ba2134c58 Mon Sep 17 00:00:00 2001 From: Junlin Chen Date: Wed, 13 May 2026 02:20:07 +0800 Subject: [PATCH 11/11] fix python style check of the tests --- tests/kernels/test_rmsnorm.py | 128 ++++++++++++++++------------------ 1 file changed, 62 insertions(+), 66 deletions(-) diff --git a/tests/kernels/test_rmsnorm.py b/tests/kernels/test_rmsnorm.py index 25cee901..04eae3c9 100644 --- a/tests/kernels/test_rmsnorm.py +++ b/tests/kernels/test_rmsnorm.py @@ -15,14 +15,23 @@ import os -from tests.test_common import run_perftest +import pytest + +from kernels.rmsnorm_kernel import ( + build_fused_add_rmsnorm_dynamicquant_module, + build_fused_add_rmsnorm_module, + build_fused_add_rmsnorm_smoothquant_module, + build_rmsnorm_dynamicquant_module, + build_rmsnorm_module, + build_rmsnorm_smoothquant_module, +) from tests.kernels.benchmark_common import ( PerfRow, bench_gpu_us_torch, maybe_enable_aiter, print_perf_table, ) -import pytest +from tests.test_common import run_perftest pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] @@ -39,20 +48,11 @@ DTYPE_INT8 = torch.int8 EPS: float = 1e-5 -from kernels.rmsnorm_kernel import ( - build_rmsnorm_module, - build_fused_add_rmsnorm_module, - build_rmsnorm_dynamicquant_module, - build_rmsnorm_smoothquant_module, - build_fused_add_rmsnorm_dynamicquant_module, - build_fused_add_rmsnorm_smoothquant_module, - KERNEL_NAME as RMSNORM_KERNEL_NAME, - BLOCK_THREADS, -) WARMUP_ITERS = 10 BENCH_ITERS = 100 + def run_test(M: int, N: int, dtype: str = "f32"): print(f"\nTesting RMSNorm (M={M}, N={N}, dtype={dtype})") @@ -105,7 +105,9 @@ def kernel_launch(): launch_fn(input_dev, gamma_dev, output_dev, M, stream=stream) # run_perftest returns (data, avg_us) - _, avg_us = run_perftest(lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS) + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), num_iters=BENCH_ITERS, num_warmup=WARMUP_ITERS + ) torch.cuda.synchronize() flydsl_gpu_us = None if os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1": @@ -140,10 +142,11 @@ def kernel_launch(): ok = False return ok, flydsl_gpu_us + def test_all(): - print("="*80) + print("=" * 80) print("Running RMSNorm Tests") - print("="*80) + print("=" * 80) shapes_env = os.environ.get("ROCDSL_RMSNORM_SHAPES", "").strip() if shapes_env: @@ -155,7 +158,7 @@ def test_all(): m_s, n_s, dt = [x.strip() for x in p.split(",")] configs.append((int(m_s), int(n_s), dt)) else: - # Prefer N multiples of BLOCK_THREADS*VEC_WIDTH (=2048) to exercise the fast path. + # Prefer N multiples of 2048 to exercise the fast path. configs = [ # (64, 256, "f32"), # Aligned # (128, 1024, "f32"), # Aligned @@ -177,11 +180,17 @@ def test_all(): if do_compare: import torch + aiter_us = None if maybe_enable_aiter(): try: from aiter.ops.triton.rmsnorm import rms_norm as aiter_rms_norm - x = torch.randn((M, N), device="cuda", dtype=DTYPE_BF16 if dtype == "bf16" else (DTYPE_FP16 if dtype == "f16" else DTYPE_FP32)) + + x = torch.randn( + (M, N), + device="cuda", + dtype=DTYPE_BF16 if dtype == "bf16" else (DTYPE_FP16 if dtype == "f16" else DTYPE_FP32), + ) w = torch.rand((N,), device="cuda", dtype=x.dtype) def run_aiter(): @@ -192,14 +201,16 @@ def run_aiter(): except Exception as e: print(f"[Perf] AIter rmsnorm skipped: {type(e).__name__}: {e!r}") - perf_rows.append(PerfRow(op="rmsnorm", shape=f"{M}x{N}", dtype=dtype, flydsl_gpu_us=flydsl_gpu_us, aiter_gpu_us=aiter_us)) + perf_rows.append( + PerfRow(op="rmsnorm", shape=f"{M}x{N}", dtype=dtype, flydsl_gpu_us=flydsl_gpu_us, aiter_gpu_us=aiter_us) + ) - print("\n" + "="*80) + print("\n" + "=" * 80) if failures == 0: print("ALL TESTS PASSED") else: print(f"{failures} TESTS FAILED") - print("="*80) + print("=" * 80) if do_compare and perf_rows: print_perf_table(perf_rows) # Ensure a non-zero exit code on failure for shell wrappers. @@ -229,7 +240,7 @@ def _get_rmsnorm_configs(): configs.append((int(m_s), int(n_s), dt)) return configs - # Prefer N multiples of BLOCK_THREADS*VEC_WIDTH (=2048) to exercise the fast path. + # Prefer N multiples of 2048 to exercise the fast path. return [ # (64, 256, "f32"), # Aligned # (128, 1024, "f32"), # Aligned @@ -281,7 +292,9 @@ def _bench_aiter_rmsnorm_quant(M: int, N: int, dtype: str, *, is_smooth: bool): def run_aiter(): aiter_rmsnorm_quant(y, x, xscale, yscale, w, EPS) + else: + def run_aiter(): aiter_rmsnorm_quant(y, x, yscale, w, EPS) @@ -294,17 +307,13 @@ def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): mode = "smoothquant" if is_smooth else "dynamicquant" print(f"\nTesting RMSNorm {mode} (M={M}, N={N}, dtype={dtype})") - torch_dtype = _torch_dtype(dtype) try: if is_smooth: launch_fn = build_rmsnorm_smoothquant_module(M, N, dtype) else: launch_fn = build_rmsnorm_dynamicquant_module(M, N, dtype) except Exception as e: - print( - f"[FAIL] Compile failed for {mode} (M={M}, N={N}, dtype={dtype}): " - f"{type(e).__name__}: {e}" - ) + print(f"[FAIL] Compile failed for {mode} (M={M}, N={N}, dtype={dtype}): " f"{type(e).__name__}: {e}") return False, None torch.manual_seed(42) input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) @@ -409,9 +418,9 @@ def kernel_launch(): def test_rmsnorm_dynamicquant(): - print("="*80) + print("=" * 80) print("Running RMSNorm DynamicQuant Tests") - print("="*80) + print("=" * 80) do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" perf_rows = [] @@ -436,12 +445,12 @@ def test_rmsnorm_dynamicquant(): ) ) - print("\n" + "="*80) + print("\n" + "=" * 80) if failures == 0: print("ALL TESTS PASSED") else: print(f"{failures} TESTS FAILED") - print("="*80) + print("=" * 80) if do_compare and perf_rows: print_perf_table(perf_rows) # Ensure a non-zero exit code on failure for shell wrappers. @@ -450,9 +459,9 @@ def test_rmsnorm_dynamicquant(): def test_rmsnorm_smoothquant(): - print("="*80) + print("=" * 80) print("Running RMSNorm SmoothQuant Tests") - print("="*80) + print("=" * 80) do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" perf_rows = [] @@ -477,12 +486,12 @@ def test_rmsnorm_smoothquant(): ) ) - print("\n" + "="*80) + print("\n" + "=" * 80) if failures == 0: print("ALL TESTS PASSED") else: print(f"{failures} TESTS FAILED") - print("="*80) + print("=" * 80) if do_compare and perf_rows: print_perf_table(perf_rows) # Ensure a non-zero exit code on failure for shell wrappers. @@ -529,10 +538,7 @@ def run_fused_add_test(M: int, N: int, dtype: str): try: launch_fn = build_fused_add_rmsnorm_module(M, N, dtype) except Exception as e: - print( - f"[FAIL] Compile failed for fused_add rmsnorm (M={M}, N={N}, dtype={dtype}): " - f"{type(e).__name__}: {e}" - ) + print(f"[FAIL] Compile failed for fused_add rmsnorm (M={M}, N={N}, dtype={dtype}): " f"{type(e).__name__}: {e}") return False, None torch.manual_seed(42) @@ -596,10 +602,7 @@ def kernel_launch(): total_bytes = (4 * M * N + N) * elem_bytes bandwidth_gbs = total_bytes / (avg_us / 1e6) / 1e9 - print( - f"Kernel avg time: {avg_ms:.4f} ms via run_perftest " - f"(warmup={WARMUP_ITERS}, iters={BENCH_ITERS})" - ) + print(f"Kernel avg time: {avg_ms:.4f} ms via run_perftest " f"(warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") if flydsl_gpu_us is not None: print(f"[Perf] FlyDSL fused_add rmsnorm gpu: {flydsl_gpu_us:.1f} us") @@ -627,9 +630,9 @@ def kernel_launch(): def test_rmsnorm_fused_add(): - print("="*80) + print("=" * 80) print("Running FusedAdd RMSNorm Tests") - print("="*80) + print("=" * 80) do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" perf_rows = [] @@ -654,12 +657,12 @@ def test_rmsnorm_fused_add(): ) ) - print("\n" + "="*80) + print("\n" + "=" * 80) if failures == 0: print("ALL TESTS PASSED") else: print(f"{failures} TESTS FAILED") - print("="*80) + print("=" * 80) if do_compare and perf_rows: print_perf_table(perf_rows) # Ensure a non-zero exit code on failure for shell wrappers. @@ -718,14 +721,12 @@ def _bench_aiter_fused_add_rmsnorm_quant( xscale = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() def run_aiter(): - aiter_fused_add_rmsnorm_quant( - y, x, residual_in, residual_out, xscale, yscale, w, EPS - ) + aiter_fused_add_rmsnorm_quant(y, x, residual_in, residual_out, xscale, yscale, w, EPS) + else: + def run_aiter(): - aiter_fused_add_rmsnorm_quant( - y, x, residual_in, residual_out, yscale, w, EPS - ) + aiter_fused_add_rmsnorm_quant(y, x, residual_in, residual_out, yscale, w, EPS) aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) print(f"[Perf] AIter fused_add rmsnorm {mode} gpu: {aiter_us:.1f} us") @@ -856,12 +857,7 @@ def kernel_launch(): print(f"Max scale diff: {scale_diff:.2e} (tol={scale_tol})") print(f"Max quant diff: {quant_diff}") - ok = ( - residual_error < residual_atol - and dequant_error < dequant_tol - and scale_diff < scale_tol - and quant_diff <= 1 - ) + ok = residual_error < residual_atol and dequant_error < dequant_tol and scale_diff < scale_tol and quant_diff <= 1 if ok: print("PASSED") else: @@ -886,9 +882,9 @@ def kernel_launch(): def test_rmsnorm_fused_add_dynamicquant(): - print("="*80) + print("=" * 80) print("Running FusedAdd RMSNorm DynamicQuant Tests") - print("="*80) + print("=" * 80) do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" perf_rows = [] @@ -913,12 +909,12 @@ def test_rmsnorm_fused_add_dynamicquant(): ) ) - print("\n" + "="*80) + print("\n" + "=" * 80) if failures == 0: print("ALL TESTS PASSED") else: print(f"{failures} TESTS FAILED") - print("="*80) + print("=" * 80) if do_compare and perf_rows: print_perf_table(perf_rows) if failures != 0: @@ -926,9 +922,9 @@ def test_rmsnorm_fused_add_dynamicquant(): def test_rmsnorm_fused_add_smoothquant(): - print("="*80) + print("=" * 80) print("Running FusedAdd RMSNorm SmoothQuant Tests") - print("="*80) + print("=" * 80) do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" perf_rows = [] @@ -953,12 +949,12 @@ def test_rmsnorm_fused_add_smoothquant(): ) ) - print("\n" + "="*80) + print("\n" + "=" * 80) if failures == 0: print("ALL TESTS PASSED") else: print(f"{failures} TESTS FAILED") - print("="*80) + print("=" * 80) if do_compare and perf_rows: print_perf_table(perf_rows) if failures != 0: