From 8aeaf2aadc6070e6f0e75159ce444d84a1c405d0 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 2 Apr 2026 11:50:51 +0000 Subject: [PATCH 01/13] Add act_mul without quant (DO_QUANT), model configs, benchmarks - Gate DO_QUANT in mxfp4 and fp8-group activation Triton kernels - Add act_mul wrapper and tests in test_activation.py - Wire bench_moe -bench_act_mul and GLM-4.7 / Kimi-K2.5 TP4/EP4 in model_configs.json Made-with: Cursor --- .../ops/triton/_triton_kernels/activation.py | 184 ++++++++++-------- aiter/ops/triton/activation.py | 64 +++++- op_tests/op_benchmarks/triton/bench_moe.py | 86 ++++++++ .../triton/utils/model_configs.json | 51 ++++- op_tests/triton_tests/test_activation.py | 38 +++- 5 files changed, 343 insertions(+), 80 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index 84b4f4fe76..d1c7deba2d 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -100,6 +100,7 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( scaleM_pad: tl.constexpr, scaleN_pad: tl.constexpr, SHUFFLE: tl.constexpr, + DO_QUANT: tl.constexpr, ): pid_m = tl.program_id(0) start_n = tl.program_id(1) * NUM_ITER @@ -120,75 +121,95 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( if EVEN_M_N: a = tl.load(x_ptr + x_offs, cache_modifier=".cg").to(tl.float32) - b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg").to( - tl.float32 - ) else: x_mask = (x_offs_m < M)[:, None] & (x_offs_n < N)[None, :] a = tl.load(x_ptr + x_offs, mask=x_mask, cache_modifier=".cg").to( tl.float32 ) # a and b can share the same mask - b = tl.load( - x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg" - ).to(tl.float32) - x = _apply_activation_from_str(a, ACTIVATION) * b - - out_tensor, bs_e8m0 = _mxfp4_quant_op( - x, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE - ) - - out_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - out_offs_n = pid_n * BLOCK_SIZE_N // 2 + tl.arange(0, BLOCK_SIZE_N // 2) - out_offs = ( - out_offs_m[:, None] * stride_x_fp4_m + out_offs_n[None, :] * stride_x_fp4_n - ) + activated_a = _apply_activation_from_str(a, ACTIVATION) if EVEN_M_N: - tl.store(x_fp4_ptr + out_offs, out_tensor) + b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg").to( + tl.float32 + ) else: - out_mask = (out_offs_m < M)[:, None] & (out_offs_n < (N // 2))[None, :] - tl.store(x_fp4_ptr + out_offs, out_tensor, mask=out_mask) + x_mask = (x_offs_m < M)[:, None] & (x_offs_n < N)[None, :] + # a and b can share the same mask + b = tl.load( + x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg" + ).to(tl.float32) - bs_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS) - if SHUFFLE: - bs_offs_0 = bs_offs_m[:, None] // 32 - bs_offs_1 = bs_offs_m[:, None] % 32 - bs_offs_2 = bs_offs_1 % 16 - bs_offs_1 = bs_offs_1 // 16 - bs_offs_3 = bs_offs_n[None, :] // 8 - bs_offs_4 = bs_offs_n[None, :] % 8 - bs_offs_5 = bs_offs_4 % 4 - bs_offs_4 = bs_offs_4 // 4 - bs_offs = ( - bs_offs_1 - + bs_offs_4 * 2 - + bs_offs_2 * 2 * 2 - + bs_offs_5 * 2 * 2 * 16 - + bs_offs_3 * 2 * 2 * 16 * 4 - + bs_offs_0 * 2 * 16 * scaleN + x = activated_a * b + if DO_QUANT: + out_tensor, bs_e8m0 = _mxfp4_quant_op( + x, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE ) - bs_mask1 = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :] - bs_mask = (bs_offs_m < scaleM_pad)[:, None] & (bs_offs_n < scaleN_pad)[ - None, : - ] - bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 127) - else: - bs_offs = ( - bs_offs_m[:, None] * stride_bs_m + bs_offs_n[None, :] * stride_bs_n + + out_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + out_offs_n = pid_n * BLOCK_SIZE_N // 2 + tl.arange(0, BLOCK_SIZE_N // 2) + out_offs = ( + out_offs_m[:, None] * stride_x_fp4_m + + out_offs_n[None, :] * stride_x_fp4_n ) - bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :] - if EVEN_M_N: - tl.store(bs_ptr + bs_offs, bs_e8m0) - else: - tl.store( - bs_ptr + bs_offs, - bs_e8m0, - mask=bs_mask, + if EVEN_M_N: + tl.store(x_fp4_ptr + out_offs, out_tensor) + else: + out_mask = (out_offs_m < M)[:, None] & (out_offs_n < (N // 2))[None, :] + tl.store(x_fp4_ptr + out_offs, out_tensor, mask=out_mask) + + bs_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS) + if SHUFFLE: + bs_offs_0 = bs_offs_m[:, None] // 32 + bs_offs_1 = bs_offs_m[:, None] % 32 + bs_offs_2 = bs_offs_1 % 16 + bs_offs_1 = bs_offs_1 // 16 + bs_offs_3 = bs_offs_n[None, :] // 8 + bs_offs_4 = bs_offs_n[None, :] % 8 + bs_offs_5 = bs_offs_4 % 4 + bs_offs_4 = bs_offs_4 // 4 + bs_offs = ( + bs_offs_1 + + bs_offs_4 * 2 + + bs_offs_2 * 2 * 2 + + bs_offs_5 * 2 * 2 * 16 + + bs_offs_3 * 2 * 2 * 16 * 4 + + bs_offs_0 * 2 * 16 * scaleN + ) + bs_mask1 = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :] + bs_mask = (bs_offs_m < scaleM_pad)[:, None] & (bs_offs_n < scaleN_pad)[ + None, : + ] + bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 127) + else: + bs_offs = ( + bs_offs_m[:, None] * stride_bs_m + bs_offs_n[None, :] * stride_bs_n + ) + bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :] + if EVEN_M_N: + tl.store(bs_ptr + bs_offs, bs_e8m0) + else: + tl.store( + bs_ptr + bs_offs, + bs_e8m0, + mask=bs_mask, + ) + else: + out_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + out_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_offs = ( + out_offs_m[:, None] * stride_x_fp4_m + + out_offs_n[None, :] * stride_x_fp4_n ) + x_out = x.to(x_fp4_ptr.dtype.element_ty) + if EVEN_M_N: + tl.store(x_fp4_ptr + out_offs, x_out) + else: + out_mask = (out_offs_m < M)[:, None] & (out_offs_n < N)[None, :] + tl.store(x_fp4_ptr + out_offs, x_out, mask=out_mask) @triton.heuristics( @@ -215,6 +236,7 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( DTYPE_MAX: tl.constexpr, DTYPE_MIN: tl.constexpr, EVEN_N: tl.constexpr, + DO_QUANT: tl.constexpr, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) @@ -245,31 +267,41 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( x = _apply_activation_from_str(a, ACTIVATION) * b - x_fp8, x_bs = _fp8_quant_op( - x, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN - ) - x_fp8 = tl.ravel(x_fp8) - x_bs = tl.ravel(x_bs) - out_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) out_offs = pid_m * stride_x_fp8_m + out_offs_n * stride_x_fp8_n - if EVEN_N: - tl.store(x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty)) - else: - out_mask = out_offs_n < N - tl.store( - x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty), mask=out_mask + if DO_QUANT: + x_fp8, x_bs = _fp8_quant_op( + x, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN ) + x_fp8 = tl.ravel(x_fp8) + x_bs = tl.ravel(x_bs) - bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS) - bs_offs = pid_m * stride_bs_m + bs_offs_n * stride_bs_n - if EVEN_N: - tl.store(x_bs_ptr + bs_offs, x_bs.to(x_bs_ptr.dtype.element_ty)) + if EVEN_N: + tl.store(x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty)) + else: + out_mask = out_offs_n < N + tl.store( + x_fp8_ptr + out_offs, + x_fp8.to(x_fp8_ptr.dtype.element_ty), + mask=out_mask, + ) + + bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS) + bs_offs = pid_m * stride_bs_m + bs_offs_n * stride_bs_n + if EVEN_N: + tl.store(x_bs_ptr + bs_offs, x_bs.to(x_bs_ptr.dtype.element_ty)) + else: + bs_mask = bs_offs_n < scaleN + tl.store( + x_bs_ptr + bs_offs, + x_bs.to(x_bs_ptr.dtype.element_ty), + mask=bs_mask, + ) else: - bs_mask = bs_offs_n < scaleN - tl.store( - x_bs_ptr + bs_offs, - x_bs.to(x_bs_ptr.dtype.element_ty), - mask=bs_mask, - ) + x_out = x.to(x_fp8_ptr.dtype.element_ty) + if EVEN_N: + tl.store(x_fp8_ptr + out_offs, x_out) + else: + out_mask = out_offs_n < N + tl.store(x_fp8_ptr + out_offs, x_out, mask=out_mask) diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index b52cf465eb..3545c3429b 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Optional import triton import triton.language as tl import torch @@ -123,11 +123,72 @@ def act_mul_and_mxfp4_quant( num_warps=NUM_WARPS, waves_per_eu=0, num_stages=1, + DO_QUANT=True, ) return x_fp4, blockscale_e8m0 +def act_mul( + x: torch.Tensor, + activation: Literal["silu", "gelu", "gelu_tanh"], + out: Optional[torch.Tensor] = None, + group_size: Optional[int] = None, +) -> torch.Tensor: + """ + Gated activation along the last dimension only (no quantization): ``act(x0) * x1`` + where ``x`` is ``[..., 2 * d]`` split into two ``[..., d]`` halves. + + Uses the same Triton path as ``act_mul_and_fp8_group_quant`` with quantization disabled. + """ + _LOGGER.info(f"ACT_MUL: x={tuple(x.shape)} activation={activation}") + assert x.is_cuda and x.is_contiguous() + M, N = x.shape + assert N % 2 == 0 + N_half = N // 2 + + if out is None: + out = torch.empty((M, N_half), dtype=x.dtype, device=x.device) + else: + assert out.shape == (M, N_half) + assert out.dtype == x.dtype and out.is_contiguous() + + if group_size is None: + group_size = min(256, triton.next_power_of_2(N_half)) + group_size = max(32, group_size) + + scaleN = triton.cdiv(N, group_size) + dummy_bs = torch.empty( + (M, triton.cdiv(N_half, group_size)), + dtype=torch.float32, + device=x.device, + ) + DTYPE_MAX = 1.0 + BLOCK_SIZE_N = group_size + + grid = ( + M, + triton.cdiv(N_half, BLOCK_SIZE_N), + ) + _act_mul_and_dynamic_fp8_group_quant_kernel[grid]( + x, + out, + dummy_bs, + *x.stride(), + *out.stride(), + *dummy_bs.stride(), + N=N_half, + ACTIVATION=activation, + scaleN=scaleN, + BLOCK_SIZE_N=BLOCK_SIZE_N, + QUANT_BLOCK_SIZE=group_size, + DTYPE_MAX=DTYPE_MAX, + DTYPE_MIN=-DTYPE_MAX, + DO_QUANT=False, + ) + return out + + def act_mul_and_fp8_group_quant( x: torch.Tensor, activation: Literal["silu", "gelu", "gelu_tanh"], @@ -195,6 +256,7 @@ def act_mul_and_fp8_group_quant( # num_warps=NUM_WARPS, # waves_per_eu=0, # num_stages=1, + DO_QUANT=True, ) return x_fp8, out_bs diff --git a/op_tests/op_benchmarks/triton/bench_moe.py b/op_tests/op_benchmarks/triton/bench_moe.py index 3594278162..83b843c7f5 100644 --- a/op_tests/op_benchmarks/triton/bench_moe.py +++ b/op_tests/op_benchmarks/triton/bench_moe.py @@ -5,6 +5,7 @@ from aiter.ops.triton.utils.types import torch_to_triton_dtype, str_to_torch_dtype from aiter.ops.triton.moe.moe_op import fused_moe as triton_moe from aiter.ops.triton.moe.moe_op_silu_fused import fused_moe_silu as triton_moe_silu +from aiter.ops.triton.activation import act_mul from op_tests.triton_tests.moe.test_moe import input_helper, input_helper_int4_w4a16 from op_tests.op_benchmarks.triton.utils.benchmark_utils import ( get_model_configs, @@ -262,6 +263,67 @@ def bench_moe_gemm(M, N, K, E, top_k, metric, model=None): bench_moe_gemm.run(save_path="." if args.o else None, print_data=True) +def run_act_mul_benchmark(args): + """Benchmark ``act_mul`` (SiLU/GELU * gate) on tensors shaped like MoE activations.""" + print_time = args.print_time + dtype = str_to_torch_dtype[args.dtype] + activation = args.act_mul_activation + + if print_time: + line_names = ["Time_(ms)"] + line_vals = ["time"] + else: + line_names = ["Time_(ms)", "GFLOPS", "Bandwidth_(GB/s)"] + line_vals = ["time", "gflops", "bandwidth"] + + x_vals_list = model_benchmark_configs(args) + x_names = ["model", "M", "N", "K", "E", "top_k"] + + benchmark = triton.testing.Benchmark( + x_names=x_names, + x_vals=x_vals_list, + line_arg="metric", + line_vals=line_vals, + line_names=line_names, + styles=[("red", "-"), ("blue", "-"), ("yellow", "-")], + ylabel="ms / GFLOPS / GB/s", + plot_name=get_caller_name_no_ext() + "_act_mul", + args={}, + ) + + @triton.testing.perf_report([benchmark]) + def bench_act_mul(M, N, K, E, top_k, metric, model=None): + n_even = N if N % 2 == 0 else N - 1 + if n_even < 2: + return 0.0 + n_rows = M * top_k + d = n_even // 2 + x = torch.randn(n_rows, n_even, device="cuda", dtype=dtype) + out = torch.empty(n_rows, d, device="cuda", dtype=dtype) + + elem = torch.tensor([], dtype=dtype).element_size() + mem_read = n_rows * n_even * elem + mem_write = n_rows * d * elem + flops = float(n_rows * d * 8) + + fn = lambda: act_mul(x, activation, out=out) + + ms = triton.testing.do_bench(fn, warmup=25, rep=100) + bandwidth = (mem_read + mem_write) / (ms * 1e-3) * 1e-9 + gflops = flops / ms * 1e-6 + + if metric == "time": + return ms + elif metric == "gflops": + return gflops + elif metric == "bandwidth": + return bandwidth + else: + raise ValueError("Unknown metric: " + metric) + + bench_act_mul.run(save_path="." if args.o else None, print_data=True) + + def parse_args(): parser = argparse.ArgumentParser( prog="Benchmark MoE GEMM", @@ -300,6 +362,19 @@ def parse_args(): parser.add_argument("-dtype", default="fp16") parser.add_argument("-fp8_type", default="e5m2fnuz") parser.add_argument("-silu_fused", action="store_true", default=False) + parser.add_argument( + "-bench_act_mul", + action="store_true", + default=False, + help="Benchmark act_mul (no quant) using model M, N, top_k (same layout as silu-fused MoE).", + ) + parser.add_argument( + "-act_mul_activation", + type=str, + default="silu", + choices=["silu", "gelu", "gelu_tanh"], + help="Activation for -bench_act_mul.", + ) parser.add_argument( "-o", action="store_true", help="Write performance results to CSV file" ) @@ -310,6 +385,17 @@ def parse_args(): def main(): args = parse_args() + if args.bench_act_mul: + if args.print_vgpr: + + def fun(): + return run_act_mul_benchmark(args) + + print_vgpr(fun, get_caller_name_no_ext() + "_act_mul") + return 0 + run_act_mul_benchmark(args) + return 0 + if args.print_vgpr: print("Retrieving VGPR usage for Triton kernels...") diff --git a/op_tests/op_benchmarks/triton/utils/model_configs.json b/op_tests/op_benchmarks/triton/utils/model_configs.json index f41efcd412..7f384521cd 100644 --- a/op_tests/op_benchmarks/triton/utils/model_configs.json +++ b/op_tests/op_benchmarks/triton/utils/model_configs.json @@ -61,8 +61,55 @@ "num_expert": 256, "top_k": 4 } + }, + "glm47fp8": { + "TP4": { + "hidden_size": 1280, + "intermediate_size": 768, + "num_attention_heads": 96, + "num_key_value_heads": 8, + "vocab_size": 151552, + "num_expert": 160, + "top_k": 8, + "parallel": "TP4", + "notes": "GLM-4.7 MoE (zai-org/GLM-4.7): per-GPU tensor-parallel shard; K=5120/4, N=2*moe_intermediate/4=3072/4" + }, + "EP4": { + "hidden_size": 5120, + "intermediate_size": 3072, + "num_attention_heads": 96, + "num_key_value_heads": 8, + "vocab_size": 151552, + "num_expert": 40, + "top_k": 8, + "parallel": "EP4", + "notes": "GLM-4.7 MoE: per-GPU expert-parallel shard; E=160/4, N=2*1536" + } + }, + "kimik25": { + "TP4": { + "hidden_size": 1792, + "intermediate_size": 1024, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "vocab_size": 163840, + "num_expert": 384, + "top_k": 8, + "parallel": "TP4", + "notes": "Kimi-K2.5 text MoE (moonshotai/Kimi-K2.5): TP4 shard; K=7168/4, N=2*moe_intermediate/4=4096/4" + }, + "EP4": { + "hidden_size": 7168, + "intermediate_size": 4096, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "vocab_size": 163840, + "num_expert": 96, + "top_k": 8, + "parallel": "EP4", + "notes": "Kimi-K2.5 text MoE: EP4 shard; E=384/4, N=2*2048" + } } } - - + \ No newline at end of file diff --git a/op_tests/triton_tests/test_activation.py b/op_tests/triton_tests/test_activation.py index 38dbd4779d..afd9c505fd 100644 --- a/op_tests/triton_tests/test_activation.py +++ b/op_tests/triton_tests/test_activation.py @@ -6,7 +6,7 @@ shuffle_scales, un_shuffle_scales, ) -from aiter.ops.triton.activation import act_mul_and_mxfp4_quant +from aiter.ops.triton.activation import act_mul, act_mul_and_mxfp4_quant import aiter.ops.triton.utils._triton.arch_info as arch_info DEBUG_MODE = False @@ -140,3 +140,39 @@ def test_act_mul_and_mxfp4_quant( torch.testing.assert_close(triton_out, torch_out) torch.testing.assert_close(triton_scale, torch_scale) + + +def torch_act_mul_ref(x: torch.Tensor, activation: str) -> torch.Tensor: + d = x.shape[-1] // 2 + a, b = x[:, :d], x[:, d:] + if activation == "silu": + y = F.silu(a) * b + elif activation == "gelu": + y = F.gelu(a) * b + else: + y = F.gelu(a, approximate="tanh") * b + return y.to(x.dtype) + + +@pytest.mark.parametrize( + "M, N_half", + [ + (4, 64), + (31, 128), + (128, 256), + (1, 7), + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"]) +@pytest.mark.parametrize("use_out", [False, True]) +def test_act_mul_no_quant(M, N_half, dtype, activation, use_out): + N = 2 * N_half + x = torch.randn((M, N), dtype=dtype, device="cuda") + ref = torch_act_mul_ref(x, activation) + if use_out: + out = torch.empty_like(ref) + act_mul(x, activation, out=out) + else: + out = act_mul(x, activation) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) From 9c22cd2949cb7800062584ea477de4ef6da8f5f6 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 2 Apr 2026 14:56:31 +0300 Subject: [PATCH 02/13] Update op_tests/op_benchmarks/triton/bench_moe.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- op_tests/op_benchmarks/triton/bench_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/op_tests/op_benchmarks/triton/bench_moe.py b/op_tests/op_benchmarks/triton/bench_moe.py index 83b843c7f5..0fdb94f308 100644 --- a/op_tests/op_benchmarks/triton/bench_moe.py +++ b/op_tests/op_benchmarks/triton/bench_moe.py @@ -306,7 +306,8 @@ def bench_act_mul(M, N, K, E, top_k, metric, model=None): mem_write = n_rows * d * elem flops = float(n_rows * d * 8) - fn = lambda: act_mul(x, activation, out=out) + def fn(): + return act_mul(x, activation, out=out) ms = triton.testing.do_bench(fn, warmup=25, rep=100) bandwidth = (mem_read + mem_write) / (ms * 1e-3) * 1e-9 From 4bbd3d604a497f8cb9cdac635bd9df471b46868b Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Tue, 7 Apr 2026 12:42:32 +0300 Subject: [PATCH 03/13] Update aiter/ops/triton/activation.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- aiter/ops/triton/activation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index 3545c3429b..edf0efb3eb 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -1,6 +1,5 @@ from typing import Literal, Optional import triton -import triton.language as tl import torch import aiter From 048736c4027b5088826e56b53e00ba7760c24dea Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:55:20 +0000 Subject: [PATCH 04/13] no need to move the b load after a activation --- aiter/ops/triton/_triton_kernels/activation.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index d1c7deba2d..e12ba603b8 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -121,27 +121,20 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( if EVEN_M_N: a = tl.load(x_ptr + x_offs, cache_modifier=".cg").to(tl.float32) - else: - x_mask = (x_offs_m < M)[:, None] & (x_offs_n < N)[None, :] - a = tl.load(x_ptr + x_offs, mask=x_mask, cache_modifier=".cg").to( - tl.float32 - ) - # a and b can share the same mask - - activated_a = _apply_activation_from_str(a, ACTIVATION) - - if EVEN_M_N: b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg").to( tl.float32 ) else: x_mask = (x_offs_m < M)[:, None] & (x_offs_n < N)[None, :] + a = tl.load(x_ptr + x_offs, mask=x_mask, cache_modifier=".cg").to( + tl.float32 + ) # a and b can share the same mask b = tl.load( x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg" ).to(tl.float32) - x = activated_a * b + x = _apply_activation_from_str(a, ACTIVATION) * b if DO_QUANT: out_tensor, bs_e8m0 = _mxfp4_quant_op( x, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE From 1357f0edcf9e50662daedc757bbd02862920eeff Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Tue, 7 Apr 2026 13:00:33 +0300 Subject: [PATCH 05/13] Update aiter/ops/triton/activation.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- aiter/ops/triton/activation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index edf0efb3eb..d62b337fed 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -156,9 +156,9 @@ def act_mul( group_size = min(256, triton.next_power_of_2(N_half)) group_size = max(32, group_size) - scaleN = triton.cdiv(N, group_size) + scaleN = triton.cdiv(N_half, group_size) dummy_bs = torch.empty( - (M, triton.cdiv(N_half, group_size)), + (1, 1), dtype=torch.float32, device=x.device, ) From 775a25a1ab03d797b0ef5c852ae777ed98444aab Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 7 Apr 2026 10:09:24 +0000 Subject: [PATCH 06/13] fix: act_mul handles higher-rank inputs by flattening/unflattening internally Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/6bb3b38c-af05-43d0-9fb0-86d98eb98852 Co-authored-by: juuso-oskari <40278371+juuso-oskari@users.noreply.github.com> --- aiter/ops/triton/activation.py | 27 ++++++++++++++++-------- op_tests/triton_tests/test_activation.py | 21 +++++++++++++++++- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index d62b337fed..0e21940682 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -138,19 +138,28 @@ def act_mul( Gated activation along the last dimension only (no quantization): ``act(x0) * x1`` where ``x`` is ``[..., 2 * d]`` split into two ``[..., d]`` halves. + Accepts any input rank; leading dimensions are flattened before the kernel and + restored in the returned tensor, which has shape ``[..., d]``. + Uses the same Triton path as ``act_mul_and_fp8_group_quant`` with quantization disabled. """ _LOGGER.info(f"ACT_MUL: x={tuple(x.shape)} activation={activation}") assert x.is_cuda and x.is_contiguous() - M, N = x.shape - assert N % 2 == 0 + assert x.shape[-1] % 2 == 0 + + # Flatten all leading dimensions so the kernel always sees a 2-D tensor. + orig_shape = x.shape + x_2d = x.view(-1, orig_shape[-1]) + M, N = x_2d.shape N_half = N // 2 + out_shape = orig_shape[:-1] + (N_half,) if out is None: - out = torch.empty((M, N_half), dtype=x.dtype, device=x.device) + out_2d = torch.empty((M, N_half), dtype=x.dtype, device=x.device) else: - assert out.shape == (M, N_half) + assert out.shape == out_shape assert out.dtype == x.dtype and out.is_contiguous() + out_2d = out.view(M, N_half) if group_size is None: group_size = min(256, triton.next_power_of_2(N_half)) @@ -170,11 +179,11 @@ def act_mul( triton.cdiv(N_half, BLOCK_SIZE_N), ) _act_mul_and_dynamic_fp8_group_quant_kernel[grid]( - x, - out, + x_2d, + out_2d, dummy_bs, - *x.stride(), - *out.stride(), + *x_2d.stride(), + *out_2d.stride(), *dummy_bs.stride(), N=N_half, ACTIVATION=activation, @@ -185,7 +194,7 @@ def act_mul( DTYPE_MIN=-DTYPE_MAX, DO_QUANT=False, ) - return out + return out_2d.view(out_shape) def act_mul_and_fp8_group_quant( diff --git a/op_tests/triton_tests/test_activation.py b/op_tests/triton_tests/test_activation.py index afd9c505fd..be2e365d47 100644 --- a/op_tests/triton_tests/test_activation.py +++ b/op_tests/triton_tests/test_activation.py @@ -144,7 +144,7 @@ def test_act_mul_and_mxfp4_quant( def torch_act_mul_ref(x: torch.Tensor, activation: str) -> torch.Tensor: d = x.shape[-1] // 2 - a, b = x[:, :d], x[:, d:] + a, b = x[..., :d], x[..., d:] if activation == "silu": y = F.silu(a) * b elif activation == "gelu": @@ -176,3 +176,22 @@ def test_act_mul_no_quant(M, N_half, dtype, activation, use_out): else: out = act_mul(x, activation) torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize( + "shape", + [ + (2, 4, 64), + (3, 5, 128), + (2, 3, 4, 32), + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("activation", ["silu", "gelu"]) +def test_act_mul_no_quant_higher_rank(shape, dtype, activation): + """act_mul should handle any input rank by flattening/unflattening internally.""" + x = torch.randn(shape, dtype=dtype, device="cuda") + ref = torch_act_mul_ref(x, activation) + out = act_mul(x, activation) + assert out.shape == ref.shape + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) From 85c792de0ee3cfc27f5e19af70f58ef25559e546 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Tue, 7 Apr 2026 11:08:09 +0000 Subject: [PATCH 07/13] do mul in b dtype and use more universal x_out_ptr name --- .../ops/triton/_triton_kernels/activation.py | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index 6bc5859291..42597ad1ff 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -78,7 +78,7 @@ def _apply_activation_from_str(x, activation: tl.constexpr): @triton.jit def _act_mul_and_dynamic_mxfp4_quant_kernel( x_ptr, - x_fp4_ptr, + x_out_ptr, bs_ptr, stride_x_m_in, stride_x_n_in, @@ -121,9 +121,7 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( if EVEN_M_N: a = tl.load(x_ptr + x_offs, cache_modifier=".cg").to(tl.float32) - b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg").to( - tl.float32 - ) + b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg") else: x_mask = (x_offs_m < M)[:, None] & (x_offs_n < N)[None, :] a = tl.load(x_ptr + x_offs, mask=x_mask, cache_modifier=".cg").to( @@ -132,9 +130,8 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( # a and b can share the same mask b = tl.load( x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg" - ).to(tl.float32) - - x = _apply_activation_from_str(a, ACTIVATION) * b + ) + x = _apply_activation_from_str(a, ACTIVATION).to(b.dtype) * b if DO_QUANT: out_tensor, bs_e8m0 = _mxfp4_quant_op( x, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE @@ -148,10 +145,10 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( ) if EVEN_M_N: - tl.store(x_fp4_ptr + out_offs, out_tensor) + tl.store(x_out_ptr + out_offs, out_tensor) else: out_mask = (out_offs_m < M)[:, None] & (out_offs_n < (N // 2))[None, :] - tl.store(x_fp4_ptr + out_offs, out_tensor, mask=out_mask) + tl.store(x_out_ptr + out_offs, out_tensor, mask=out_mask) bs_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS) @@ -197,12 +194,12 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( out_offs_m[:, None] * stride_x_fp4_m + out_offs_n[None, :] * stride_x_fp4_n ) - x_out = x.to(x_fp4_ptr.dtype.element_ty) + x_out = x.to(x_out_ptr.dtype.element_ty) if EVEN_M_N: - tl.store(x_fp4_ptr + out_offs, x_out) + tl.store(x_out_ptr + out_offs, x_out) else: out_mask = (out_offs_m < M)[:, None] & (out_offs_n < N)[None, :] - tl.store(x_fp4_ptr + out_offs, x_out, mask=out_mask) + tl.store(x_out_ptr + out_offs, x_out, mask=out_mask) @triton.heuristics( @@ -213,7 +210,7 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( @triton.jit def _act_mul_and_dynamic_fp8_group_quant_kernel( x_ptr, - x_fp8_ptr, + x_out_ptr, x_bs_ptr, stride_x_m_in, stride_x_n_in, @@ -247,18 +244,16 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( if EVEN_N: a = tl.load(x_ptr + x_offs, cache_modifier=".cg").to(tl.float32) - b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg").to( - tl.float32 - ) + b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg") else: x_mask = x_offs_n < N a = tl.load(x_ptr + x_offs, mask=x_mask, cache_modifier=".cg").to(tl.float32) # a and b can share the same mask b = tl.load( x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg" - ).to(tl.float32) + ) - x = _apply_activation_from_str(a, ACTIVATION) * b + x = _apply_activation_from_str(a, ACTIVATION).to(b.dtype) * b out_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) out_offs = pid_m * stride_x_fp8_m + out_offs_n * stride_x_fp8_n @@ -271,12 +266,12 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( x_bs = tl.ravel(x_bs) if EVEN_N: - tl.store(x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty)) + tl.store(x_out_ptr + out_offs, x_fp8.to(x_out_ptr.dtype.element_ty)) else: out_mask = out_offs_n < N tl.store( - x_fp8_ptr + out_offs, - x_fp8.to(x_fp8_ptr.dtype.element_ty), + x_out_ptr + out_offs, + x_fp8.to(x_out_ptr.dtype.element_ty), mask=out_mask, ) @@ -292,9 +287,9 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( mask=bs_mask, ) else: - x_out = x.to(x_fp8_ptr.dtype.element_ty) + x_out = x.to(x_out_ptr.dtype.element_ty) if EVEN_N: - tl.store(x_fp8_ptr + out_offs, x_out) + tl.store(x_out_ptr + out_offs, x_out) else: out_mask = out_offs_n < N - tl.store(x_fp8_ptr + out_offs, x_out, mask=out_mask) + tl.store(x_out_ptr + out_offs, x_out, mask=out_mask) From 8e90c8c11a043cf137266fa2f1d90f0dad4bb622 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Tue, 7 Apr 2026 11:10:31 +0000 Subject: [PATCH 08/13] black --- aiter/ops/triton/_triton_kernels/activation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index 42597ad1ff..63d36a9ceb 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -249,9 +249,7 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( x_mask = x_offs_n < N a = tl.load(x_ptr + x_offs, mask=x_mask, cache_modifier=".cg").to(tl.float32) # a and b can share the same mask - b = tl.load( - x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg" - ) + b = tl.load(x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg") x = _apply_activation_from_str(a, ACTIVATION).to(b.dtype) * b From 0f0784943b82668ecee8ce4737ee12ce1976845a Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Tue, 7 Apr 2026 11:24:49 +0000 Subject: [PATCH 09/13] revert back the b dtype mul --- .../ops/triton/_triton_kernels/activation.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index 63d36a9ceb..4f5396368f 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -121,7 +121,9 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( if EVEN_M_N: a = tl.load(x_ptr + x_offs, cache_modifier=".cg").to(tl.float32) - b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg") + b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg").to( + tl.float32 + ) else: x_mask = (x_offs_m < M)[:, None] & (x_offs_n < N)[None, :] a = tl.load(x_ptr + x_offs, mask=x_mask, cache_modifier=".cg").to( @@ -130,8 +132,9 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( # a and b can share the same mask b = tl.load( x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg" - ) - x = _apply_activation_from_str(a, ACTIVATION).to(b.dtype) * b + ).to(tl.float32) + + x = _apply_activation_from_str(a, ACTIVATION) * b if DO_QUANT: out_tensor, bs_e8m0 = _mxfp4_quant_op( x, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE @@ -244,14 +247,18 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( if EVEN_N: a = tl.load(x_ptr + x_offs, cache_modifier=".cg").to(tl.float32) - b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg") + b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg").to( + tl.float32 + ) else: x_mask = x_offs_n < N a = tl.load(x_ptr + x_offs, mask=x_mask, cache_modifier=".cg").to(tl.float32) # a and b can share the same mask - b = tl.load(x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg") + b = tl.load( + x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg" + ).to(tl.float32) - x = _apply_activation_from_str(a, ACTIVATION).to(b.dtype) * b + x = _apply_activation_from_str(a, ACTIVATION) * b out_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) out_offs = pid_m * stride_x_fp8_m + out_offs_n * stride_x_fp8_n From 321bcf586bc96132553c33d60d69ce0f7815f7f3 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 15 May 2026 05:04:33 +0000 Subject: [PATCH 10/13] Remove DO_QUANT parameter from activation functions and update related kernel logic to ensure quantization is always applied. Adjust test cases to reflect the change in data type parameterization. --- .../ops/triton/_triton_kernels/activation.py | 115 ++++++++---------- aiter/ops/triton/activation.py | 1 - op_tests/triton_tests/test_activation.py | 4 +- 3 files changed, 52 insertions(+), 68 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index 4f5396368f..5749257cf1 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -100,7 +100,6 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( scaleM_pad: tl.constexpr, scaleN_pad: tl.constexpr, SHUFFLE: tl.constexpr, - DO_QUANT: tl.constexpr, ): pid_m = tl.program_id(0) start_n = tl.program_id(1) * NUM_ITER @@ -135,74 +134,60 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( ).to(tl.float32) x = _apply_activation_from_str(a, ACTIVATION) * b - if DO_QUANT: - out_tensor, bs_e8m0 = _mxfp4_quant_op( - x, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE - ) + out_tensor, bs_e8m0 = _mxfp4_quant_op( + x, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE + ) - out_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - out_offs_n = pid_n * BLOCK_SIZE_N // 2 + tl.arange(0, BLOCK_SIZE_N // 2) - out_offs = ( - out_offs_m[:, None] * stride_x_fp4_m - + out_offs_n[None, :] * stride_x_fp4_n - ) + out_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + out_offs_n = pid_n * BLOCK_SIZE_N // 2 + tl.arange(0, BLOCK_SIZE_N // 2) + out_offs = ( + out_offs_m[:, None] * stride_x_fp4_m + + out_offs_n[None, :] * stride_x_fp4_n + ) - if EVEN_M_N: - tl.store(x_out_ptr + out_offs, out_tensor) - else: - out_mask = (out_offs_m < M)[:, None] & (out_offs_n < (N // 2))[None, :] - tl.store(x_out_ptr + out_offs, out_tensor, mask=out_mask) - - bs_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS) - if SHUFFLE: - bs_offs_0 = bs_offs_m[:, None] // 32 - bs_offs_1 = bs_offs_m[:, None] % 32 - bs_offs_2 = bs_offs_1 % 16 - bs_offs_1 = bs_offs_1 // 16 - bs_offs_3 = bs_offs_n[None, :] // 8 - bs_offs_4 = bs_offs_n[None, :] % 8 - bs_offs_5 = bs_offs_4 % 4 - bs_offs_4 = bs_offs_4 // 4 - bs_offs = ( - bs_offs_1 - + bs_offs_4 * 2 - + bs_offs_2 * 2 * 2 - + bs_offs_5 * 2 * 2 * 16 - + bs_offs_3 * 2 * 2 * 16 * 4 - + bs_offs_0 * 2 * 16 * scaleN - ) - bs_mask1 = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :] - bs_mask = (bs_offs_m < scaleM_pad)[:, None] & (bs_offs_n < scaleN_pad)[ - None, : - ] - bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 127) - else: - bs_offs = ( - bs_offs_m[:, None] * stride_bs_m + bs_offs_n[None, :] * stride_bs_n - ) - bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :] - if EVEN_M_N: - tl.store(bs_ptr + bs_offs, bs_e8m0) - else: - tl.store( - bs_ptr + bs_offs, - bs_e8m0, - mask=bs_mask, - ) + if EVEN_M_N: + tl.store(x_out_ptr + out_offs, out_tensor) else: - out_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - out_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - out_offs = ( - out_offs_m[:, None] * stride_x_fp4_m - + out_offs_n[None, :] * stride_x_fp4_n + out_mask = (out_offs_m < M)[:, None] & (out_offs_n < (N // 2))[None, :] + tl.store(x_out_ptr + out_offs, out_tensor, mask=out_mask) + + bs_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS) + if SHUFFLE: + bs_offs_0 = bs_offs_m[:, None] // 32 + bs_offs_1 = bs_offs_m[:, None] % 32 + bs_offs_2 = bs_offs_1 % 16 + bs_offs_1 = bs_offs_1 // 16 + bs_offs_3 = bs_offs_n[None, :] // 8 + bs_offs_4 = bs_offs_n[None, :] % 8 + bs_offs_5 = bs_offs_4 % 4 + bs_offs_4 = bs_offs_4 // 4 + bs_offs = ( + bs_offs_1 + + bs_offs_4 * 2 + + bs_offs_2 * 2 * 2 + + bs_offs_5 * 2 * 2 * 16 + + bs_offs_3 * 2 * 2 * 16 * 4 + + bs_offs_0 * 2 * 16 * scaleN + ) + bs_mask1 = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :] + bs_mask = (bs_offs_m < scaleM_pad)[:, None] & (bs_offs_n < scaleN_pad)[ + None, : + ] + bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 127) + else: + bs_offs = ( + bs_offs_m[:, None] * stride_bs_m + bs_offs_n[None, :] * stride_bs_n + ) + bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :] + if EVEN_M_N: + tl.store(bs_ptr + bs_offs, bs_e8m0) + else: + tl.store( + bs_ptr + bs_offs, + bs_e8m0, + mask=bs_mask, ) - x_out = x.to(x_out_ptr.dtype.element_ty) - if EVEN_M_N: - tl.store(x_out_ptr + out_offs, x_out) - else: - out_mask = (out_offs_m < M)[:, None] & (out_offs_n < N)[None, :] - tl.store(x_out_ptr + out_offs, x_out, mask=out_mask) @triton.heuristics( diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index 0e21940682..f56404eeba 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -122,7 +122,6 @@ def act_mul_and_mxfp4_quant( num_warps=NUM_WARPS, waves_per_eu=0, num_stages=1, - DO_QUANT=True, ) return x_fp4, blockscale_e8m0 diff --git a/op_tests/triton_tests/test_activation.py b/op_tests/triton_tests/test_activation.py index be2e365d47..946f3e525e 100644 --- a/op_tests/triton_tests/test_activation.py +++ b/op_tests/triton_tests/test_activation.py @@ -163,7 +163,7 @@ def torch_act_mul_ref(x: torch.Tensor, activation: str) -> torch.Tensor: (1, 7), ], ) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"]) @pytest.mark.parametrize("use_out", [False, True]) def test_act_mul_no_quant(M, N_half, dtype, activation, use_out): @@ -186,7 +186,7 @@ def test_act_mul_no_quant(M, N_half, dtype, activation, use_out): (2, 3, 4, 32), ], ) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("activation", ["silu", "gelu"]) def test_act_mul_no_quant_higher_rank(shape, dtype, activation): """act_mul should handle any input rank by flattening/unflattening internally.""" From 9c90e5406437e84a7d7266af3c8d61ded48ca407 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 15 May 2026 05:11:31 +0000 Subject: [PATCH 11/13] renamed x_out_ptr back to x_fp4_ptr --- aiter/ops/triton/_triton_kernels/activation.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index 5749257cf1..2d5e2c36dd 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -78,7 +78,7 @@ def _apply_activation_from_str(x, activation: tl.constexpr): @triton.jit def _act_mul_and_dynamic_mxfp4_quant_kernel( x_ptr, - x_out_ptr, + x_fp4_ptr, bs_ptr, stride_x_m_in, stride_x_n_in, @@ -141,15 +141,14 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( out_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) out_offs_n = pid_n * BLOCK_SIZE_N // 2 + tl.arange(0, BLOCK_SIZE_N // 2) out_offs = ( - out_offs_m[:, None] * stride_x_fp4_m - + out_offs_n[None, :] * stride_x_fp4_n + out_offs_m[:, None] * stride_x_fp4_m + out_offs_n[None, :] * stride_x_fp4_n ) if EVEN_M_N: - tl.store(x_out_ptr + out_offs, out_tensor) + tl.store(x_fp4_ptr + out_offs, out_tensor) else: out_mask = (out_offs_m < M)[:, None] & (out_offs_n < (N // 2))[None, :] - tl.store(x_out_ptr + out_offs, out_tensor, mask=out_mask) + tl.store(x_fp4_ptr + out_offs, out_tensor, mask=out_mask) bs_offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS) From 3a0c78f93c6585f6c43984e32cbc931a26a7fc41 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 15 May 2026 05:15:05 +0000 Subject: [PATCH 12/13] move fp8_dtype = aiter.dtypes.fp8 after imports --- aiter/ops/triton/activation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index f56404eeba..10378a7a0e 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -2,14 +2,14 @@ import triton import torch import aiter - -fp8_dtype = aiter.dtypes.fp8 from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.activation import ( _act_mul_and_dynamic_mxfp4_quant_kernel, _act_mul_and_dynamic_fp8_group_quant_kernel, ) +fp8_dtype = aiter.dtypes.fp8 + _LOGGER = AiterTritonLogger() From b4a4117d7dee9dffb8ba54b5580d992b0a684691 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 15 May 2026 05:28:36 +0000 Subject: [PATCH 13/13] Refactor act_mul function to remove unused dummy_bs tensor and update kernel call to use None for batch strides. Adjust stride handling in the _act_mul_and_dynamic_fp8_group_quant_kernel to ensure proper quantization logic. --- aiter/ops/triton/_triton_kernels/activation.py | 4 ++-- aiter/ops/triton/activation.py | 10 +++------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index 2d5e2c36dd..2aa9bb9ca0 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -222,8 +222,6 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( stride_x_n = tl.cast(stride_x_n_in, tl.int64) stride_x_fp8_m = tl.cast(stride_x_fp8_m_in, tl.int64) stride_x_fp8_n = tl.cast(stride_x_fp8_n_in, tl.int64) - stride_bs_m = tl.cast(stride_bs_m_in, tl.int64) - stride_bs_n = tl.cast(stride_bs_n_in, tl.int64) NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE x_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -248,6 +246,8 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( out_offs = pid_m * stride_x_fp8_m + out_offs_n * stride_x_fp8_n if DO_QUANT: + stride_bs_m = tl.cast(stride_bs_m_in, tl.int64) + stride_bs_n = tl.cast(stride_bs_n_in, tl.int64) x_fp8, x_bs = _fp8_quant_op( x, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN ) diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index 10378a7a0e..b7cc17fdcd 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -165,11 +165,6 @@ def act_mul( group_size = max(32, group_size) scaleN = triton.cdiv(N_half, group_size) - dummy_bs = torch.empty( - (1, 1), - dtype=torch.float32, - device=x.device, - ) DTYPE_MAX = 1.0 BLOCK_SIZE_N = group_size @@ -180,10 +175,11 @@ def act_mul( _act_mul_and_dynamic_fp8_group_quant_kernel[grid]( x_2d, out_2d, - dummy_bs, + None, *x_2d.stride(), *out_2d.stride(), - *dummy_bs.stride(), + 0, # stride_bs_m_in + 0, # stride_bs_n_in N=N_half, ACTIVATION=activation, scaleN=scaleN,