From bb3e0a770752c3ea04853f0717f9235dcf7902e4 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Wed, 1 Apr 2026 11:43:26 +0000 Subject: [PATCH 1/5] silu_mul_fused kernel --- .../_triton_kernels/fusions/fused_silu_mul.py | 47 +++++++ aiter/ops/triton/fusions/fused_silu_mul.py | 123 +++++++++++++++++ op_tests/op_benchmarks/triton/bench_moe.py | 127 ++++++++++++++++++ .../triton/utils/model_configs.json | 24 ++++ .../fusions/test_fused_silu_mul.py | 120 +++++++++++++++++ 5 files changed, 441 insertions(+) create mode 100644 aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py create mode 100644 aiter/ops/triton/fusions/fused_silu_mul.py create mode 100644 op_tests/triton_tests/fusions/test_fused_silu_mul.py diff --git a/aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py b/aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py new file mode 100644 index 0000000000..69795e84b6 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + +from aiter.ops.triton._triton_kernels.activation import _silu_exp2 + + +@triton.jit +def fused_silu_mul_kernel( + inp_ptr, + out_ptr, + n_rows, + n_cols, + row_stride_in, + col_stride_in, + row_stride_out, + col_stride_out, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + SiLU on the first half of the last dimension, multiply by the second half. + Each row has 2 * n_cols input elements; writes n_cols outputs. + 2D grid: axis 0 tiles rows (BLOCK_M), axis 1 tiles columns (BLOCK_N). + """ + m_pid = tl.program_id(0) + n_pid = tl.program_id(1) + m_offs = tl.arange(0, BLOCK_M) + n_offs = tl.arange(0, BLOCK_N) + row_idx = m_pid * BLOCK_M + m_offs + col_idx = n_pid * BLOCK_N + n_offs + + row_in = row_idx * row_stride_in + row_out = row_idx * row_stride_out + + first_half_ptrs = inp_ptr + row_in[:, None] + col_idx[None, :] * col_stride_in + second_half_ptrs = inp_ptr + row_in[:, None] + (n_cols + col_idx)[None, :] * col_stride_in + out_ptrs = out_ptr + row_out[:, None] + col_idx[None, :] * col_stride_out + + mask = (row_idx < n_rows)[:, None] & (col_idx < n_cols)[None, :] + a = tl.load(first_half_ptrs, mask=mask, other=0.0).to(tl.float32) + b = tl.load(second_half_ptrs, mask=mask, other=0.0).to(tl.float32) + silu_a = _silu_exp2(a) + o = (silu_a * b).to(out_ptr.dtype.element_ty) + tl.store(out_ptrs, o, mask=mask) diff --git a/aiter/ops/triton/fusions/fused_silu_mul.py b/aiter/ops/triton/fusions/fused_silu_mul.py new file mode 100644 index 0000000000..14f25bab36 --- /dev/null +++ b/aiter/ops/triton/fusions/fused_silu_mul.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional + +import torch +import triton + +from aiter.ops.triton._triton_kernels.fusions.fused_silu_mul import ( + fused_silu_mul_kernel, +) +from aiter.ops.triton.utils.logger import AiterTritonLogger + +_LOGGER = AiterTritonLogger() + + +def _pick_block_n(d: int, n_rows: int) -> int: + """Tile size along the reduced last dim (cap 1024); at least 32 for vectorization. + + Tuned on ROCm for MoE TP4 locals (GLM-4.7 ``d=384``, Kimi-K2.5 ``d=512``) and wide + MoE activations: ``n_rows`` selects decode vs prefill N-tiling (see sweep in repo + history / ``bench_moe.py -bench_silu_mul``). + """ + n = max(d, 1) + # Kimi-K2.5 TP4 (d=512): prefill favors one 512-wide N tile; decode keeps 256×2. + if n == 512: + return 512 if n_rows > 4096 else 256 + # GLM-4.7 TP4 (d=384): wider decode rows use 256×2; larger batches favor 128×3 N tiles. + if n == 384: + return 256 if n_rows <= 128 else 128 + upper = min(n, 1024) + p = 1 + while p * 2 <= upper: + p *= 2 + return max(32, p) + + +def _pick_block_m(n_rows: int, block_n: int, d: int) -> int: + """Row tile size: latency shapes use wide M tiles; prefill uses tuned (d, n_rows) pairs.""" + if n_rows <= 64: + return min(32, max(4, triton.next_power_of_2(n_rows))) + if d == 384 and n_rows > 128: + return 32 if n_rows > 8192 else 8 + if d == 512 and n_rows > 4096: + return 8 + if d == 512 and 128 < n_rows <= 4096: + return 8 + if block_n >= 1024: + return 8 + if block_n >= 512: + return 8 + return 16 + + +def _pick_num_warps(n_rows: int, block_m: int, block_n: int) -> int: + """ROCm: 8 warps for tiny full-wavefront decode tiles; 2 warps for larger tiles.""" + if n_rows <= 128 and block_m >= 16 and block_n >= 128: + return 8 + return 2 + + +def fused_silu_mul_last_dim( + x: torch.Tensor, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Fused SiLU-and-mul along the last dimension (same pattern as MoE silu-fused GEMM). + + ``x`` must be contiguous with even ``size(-1)``. For last size ``2 * d``, the first + ``d`` lanes are passed through SiLU (``_silu_exp2``); the second ``d`` lanes are the + multipliers. Output shape matches ``x`` except ``out.size(-1) == d``. + + Returns: + ``out`` if provided, else a newly allocated tensor. + """ + assert x.is_cuda, "fused_silu_mul_last_dim requires a CUDA tensor" + assert x.is_contiguous(), "x must be contiguous" + last = x.size(-1) + assert last % 2 == 0, "last dimension must be even (2 * d)" + d = last // 2 + leading = x.shape[:-1] + n_rows = x.numel() // (2 * d) + if n_rows == 0: + return torch.empty(*leading, d, dtype=x.dtype, device=x.device) if out is None else out + + _LOGGER.info( + f"FUSED_SILU_MUL_LAST_DIM: x={tuple(x.shape)} last_half={d} rows={n_rows}" + ) + + if out is None: + out = torch.empty(*leading, d, dtype=x.dtype, device=x.device) + else: + assert out.is_contiguous(), "out must be contiguous" + assert out.shape == (*leading, d), "out shape must match x with last dim halved" + assert out.dtype == x.dtype and out.device == x.device + + row_stride_in = 2 * d + col_stride_in = 1 + row_stride_out = d + col_stride_out = 1 + + block_n = _pick_block_n(d, n_rows) + block_m = _pick_block_m(n_rows, block_n, d) + grid_m = triton.cdiv(n_rows, block_m) + grid_n = triton.cdiv(d, block_n) + num_warps = _pick_num_warps(n_rows, block_m, block_n) + + grid = (grid_m, grid_n) + fused_silu_mul_kernel[grid]( + x, + out, + n_rows, + d, + row_stride_in, + col_stride_in, + row_stride_out, + col_stride_out, + BLOCK_M=block_m, + BLOCK_N=block_n, + num_warps=num_warps, + waves_per_eu=0, + ) + return out diff --git a/op_tests/op_benchmarks/triton/bench_moe.py b/op_tests/op_benchmarks/triton/bench_moe.py index 3594278162..e8463040c2 100644 --- a/op_tests/op_benchmarks/triton/bench_moe.py +++ b/op_tests/op_benchmarks/triton/bench_moe.py @@ -5,6 +5,7 @@ from aiter.ops.triton.utils.types import torch_to_triton_dtype, str_to_torch_dtype from aiter.ops.triton.moe.moe_op import fused_moe as triton_moe from aiter.ops.triton.moe.moe_op_silu_fused import fused_moe_silu as triton_moe_silu +from aiter.ops.triton.fusions.fused_silu_mul import fused_silu_mul_last_dim from op_tests.triton_tests.moe.test_moe import input_helper, input_helper_int4_w4a16 from op_tests.op_benchmarks.triton.utils.benchmark_utils import ( get_model_configs, @@ -41,6 +42,46 @@ def model_benchmark_configs(args): return moe_configs +def silu_mul_benchmark_configs(args): + """Like ``model_benchmark_configs`` but supports multiple M via ``-silu_mul_M_list``.""" + configs = get_model_configs( + config_path=args.model_configs, + models="mixtral" if args.model is None else args.model, + ) + if not configs: + return [] + no_bench_stage2 = args.no_bench_stage2 + if args.silu_mul_M_list: + default_ms = [ + int(x.strip()) + for x in args.silu_mul_M_list.split(",") + if x.strip() + ] + elif args.M: + default_ms = [args.M] + else: + default_ms = [4096] + + moe_configs = [] + for model_name, config in configs.items(): + ms_model = config.get("silu_mul_benchmark_M", default_ms) + if not isinstance(ms_model, list): + ms_model = [int(ms_model)] + else: + ms_model = [int(m) for m in ms_model] + for M in ms_model: + N1 = config["intermediate_size"] + K1 = config["hidden_size"] + E = config["num_expert"] + top_k = config["top_k"] + moe_configs.append((model_name, M, N1, K1, E, top_k)) + if no_bench_stage2: + N2 = config["hidden_size"] + K2 = config["intermediate_size"] // 2 + moe_configs.append((model_name, M, N2, K2, E, top_k)) + return moe_configs + + def fused_moe( M, N, @@ -262,6 +303,68 @@ def bench_moe_gemm(M, N, K, E, top_k, metric, model=None): bench_moe_gemm.run(save_path="." if args.o else None, print_data=True) +def run_silu_mul_benchmark(args): + """Benchmark last-dim fused SiLU-and-mul (same activation as silu-fused MoE).""" + print_time = args.print_time + dtype = str_to_torch_dtype[args.dtype] + + if print_time: + line_names = ["Time_(ms)"] + line_vals = ["time"] + else: + line_names = ["Time_(ms)", "GFLOPS", "Bandwidth_(GB/s)"] + line_vals = ["time", "gflops", "bandwidth"] + + x_vals_list = silu_mul_benchmark_configs(args) + x_names = ["model", "M", "N", "K", "E", "top_k"] + + benchmark = triton.testing.Benchmark( + x_names=x_names, + x_vals=x_vals_list, + line_arg="metric", + line_vals=line_vals, + line_names=line_names, + styles=[("red", "-"), ("blue", "-"), ("yellow", "-")], + ylabel="ms / GFLOPS / GB/s", + plot_name=get_caller_name_no_ext() + "_silu_mul", + args={}, + ) + + @triton.testing.perf_report([benchmark]) + def bench_silu_mul(M, N, K, E, top_k, metric, model=None): + # Match MoE post-GEMM layout: (M * top_k, N); N must be even for gate/up pairs. + n_even = N if N % 2 == 0 else N - 1 + if n_even < 2: + return 0.0 + n_rows = M * top_k + d = n_even // 2 + x = torch.randn(n_rows, n_even, device="cuda", dtype=dtype) + out = torch.empty(n_rows, d, device="cuda", dtype=dtype) + + elem = torch.tensor([], dtype=dtype).element_size() + mem_read = n_rows * n_even * elem + mem_write = n_rows * d * elem + # Rough op count: SiLU + mul per output element + flops = float(n_rows * d * 8) + + fn = lambda: fused_silu_mul_last_dim(x, out) + + ms = triton.testing.do_bench(fn, warmup=25, rep=100) + bandwidth = (mem_read + mem_write) / (ms * 1e-3) * 1e-9 + gflops = flops / ms * 1e-6 + + if metric == "time": + return ms + elif metric == "gflops": + return gflops + elif metric == "bandwidth": + return bandwidth + else: + raise ValueError("Unknown metric: " + metric) + + bench_silu_mul.run(save_path="." if args.o else None, print_data=True) + + def parse_args(): parser = argparse.ArgumentParser( prog="Benchmark MoE GEMM", @@ -300,6 +403,19 @@ def parse_args(): parser.add_argument("-dtype", default="fp16") parser.add_argument("-fp8_type", default="e5m2fnuz") parser.add_argument("-silu_fused", action="store_true", default=False) + parser.add_argument( + "-bench_silu_mul", + action="store_true", + default=False, + help="Benchmark fused last-dim SiLU-and-mul only (uses model M, N, top_k).", + ) + parser.add_argument( + "-silu_mul_M_list", + type=str, + default=None, + help="Comma-separated token counts M for silu_mul bench (e.g. 4,8193,7238). " + "Row count is M * top_k. Implies multiple table rows when set.", + ) parser.add_argument( "-o", action="store_true", help="Write performance results to CSV file" ) @@ -310,6 +426,17 @@ def parse_args(): def main(): args = parse_args() + if args.bench_silu_mul: + if args.print_vgpr: + + def fun(): + return run_silu_mul_benchmark(args) + + print_vgpr(fun, get_caller_name_no_ext() + "_silu_mul") + return 0 + run_silu_mul_benchmark(args) + return 0 + if args.print_vgpr: print("Retrieving VGPR usage for Triton kernels...") diff --git a/op_tests/op_benchmarks/triton/utils/model_configs.json b/op_tests/op_benchmarks/triton/utils/model_configs.json index f41efcd412..b3eff5edae 100644 --- a/op_tests/op_benchmarks/triton/utils/model_configs.json +++ b/op_tests/op_benchmarks/triton/utils/model_configs.json @@ -61,6 +61,30 @@ "num_expert": 256, "top_k": 4 } + }, + "glm47fp8": { + "tp4": { + "hidden_size": 5120, + "intermediate_size": 768, + "num_attention_heads": 96, + "num_key_value_heads": 8, + "vocab_size": 151552, + "num_expert": 160, + "top_k": 8, + "silu_mul_benchmark_M": [4, 8193, 7238] + } + }, + "kimik25": { + "tp4": { + "hidden_size": 7168, + "intermediate_size": 1024, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "vocab_size": 163840, + "num_expert": 384, + "top_k": 8, + "silu_mul_benchmark_M": [4, 8193, 7238] + } } } diff --git a/op_tests/triton_tests/fusions/test_fused_silu_mul.py b/op_tests/triton_tests/fusions/test_fused_silu_mul.py new file mode 100644 index 0000000000..f35908fc38 --- /dev/null +++ b/op_tests/triton_tests/fusions/test_fused_silu_mul.py @@ -0,0 +1,120 @@ +import torch +import pytest + +from aiter.ops.triton.fusions.fused_silu_mul import fused_silu_mul_last_dim + +LOG2_E = 1.44269504089 + +# GLM-4.7-FP8 MoE (e.g. zai-org/GLM-4.7-FP8): moe_intermediate_size=1536, top_k=8. +# Column-parallel TP4: local d = 1536 // 4 = 384, fused silu-mul input last dim = 768. +_GLM47_TP4_LAST = 768 +_GLM47_TOP_K = 8 + +# Kimi-K2.5 MoE (moonshotai/Kimi-K2.5 text_config): moe_intermediate_size=2048, top_k=8. +# TP4: local d = 2048 // 4 = 512, last dim = 1024. +_KIMI_K25_TP4_LAST = 1024 +_KIMI_K25_TOP_K = 8 + + +def silu_exp2_ref(t: torch.Tensor) -> torch.Tensor: + """Match ``_silu_exp2`` in Triton (same as MoE silu-fused path).""" + x = t.float() + return x / (1.0 + torch.exp2(-(x * LOG2_E))) + + +def torch_silu_mul_last_dim_ref(x: torch.Tensor) -> torch.Tensor: + d = x.size(-1) // 2 + a, b = x[..., :d], x[..., d:] + return (silu_exp2_ref(a) * b).to(x.dtype) + + +@pytest.mark.parametrize( + "shape", + [ + (4, 64), + (128, 256), + (31, 500), + (2, 16, 128), + (1, 3, 7, 32), + ], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("use_explicit_out", [False, True]) +def test_fused_silu_mul_last_dim(shape, dtype, use_explicit_out): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + x = torch.randn(shape, dtype=dtype, device="cuda") + ref = torch_silu_mul_last_dim_ref(x) + if use_explicit_out: + out = torch.empty_like(ref) + fused_silu_mul_last_dim(x, out) + else: + out = fused_silu_mul_last_dim(x) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) + + +def test_fused_silu_mul_requires_even_last_dim(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + x = torch.randn(2, 3, device="cuda") + with pytest.raises(AssertionError, match="even"): + fused_silu_mul_last_dim(x) + + +@pytest.mark.parametrize( + "n_rows,last_dim", + [ + # Decode M=4 → rows M * top_k + pytest.param(4 * _GLM47_TOP_K, _GLM47_TP4_LAST, id="glm47_tp4_decode4"), + pytest.param(4 * _KIMI_K25_TOP_K, _KIMI_K25_TP4_LAST, id="kimi_k25_tp4_decode4"), + # Medium prefill / batched decode + pytest.param(256 * _GLM47_TOP_K, _GLM47_TP4_LAST, id="glm47_tp4_rows256x8"), + pytest.param(256 * _KIMI_K25_TOP_K, _KIMI_K25_TP4_LAST, id="kimi_k25_tp4_rows256x8"), + ], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_silu_mul_tp4_moe_shapes(n_rows, last_dim, dtype): + """MoE fused silu×mul tensor as (tokens * top_k, 2 * local_d) under TP4.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + shape = (n_rows, last_dim) + x = torch.randn(shape, dtype=dtype, device="cuda") + ref = torch_silu_mul_last_dim_ref(x) + out = fused_silu_mul_last_dim(x) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize( + "n_rows,last_dim", + [ + pytest.param( + (8190 + 3) * _GLM47_TOP_K, + _GLM47_TP4_LAST, + id="glm47_tp4_pref8190_dec3", + ), + pytest.param( + (7235 + 3) * _GLM47_TOP_K, + _GLM47_TP4_LAST, + id="glm47_tp4_pref7235_dec3", + ), + pytest.param( + (8190 + 3) * _KIMI_K25_TOP_K, + _KIMI_K25_TP4_LAST, + id="kimi_k25_tp4_pref8190_dec3", + ), + pytest.param( + (7235 + 3) * _KIMI_K25_TOP_K, + _KIMI_K25_TP4_LAST, + id="kimi_k25_tp4_pref7235_dec3", + ), + ], +) +def test_fused_silu_mul_tp4_prefill_bf16(n_rows, last_dim): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + dtype = torch.bfloat16 + shape = (n_rows, last_dim) + x = torch.randn(shape, dtype=dtype, device="cuda") + ref = torch_silu_mul_last_dim_ref(x) + out = fused_silu_mul_last_dim(x) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) From 5e04c56b31d1eee3531a6742fe7206ea124b39b1 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Wed, 1 Apr 2026 18:01:48 +0300 Subject: [PATCH 2/5] Update op_tests/op_benchmarks/triton/bench_moe.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- op_tests/op_benchmarks/triton/bench_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/op_tests/op_benchmarks/triton/bench_moe.py b/op_tests/op_benchmarks/triton/bench_moe.py index e8463040c2..6f04c28ae7 100644 --- a/op_tests/op_benchmarks/triton/bench_moe.py +++ b/op_tests/op_benchmarks/triton/bench_moe.py @@ -347,7 +347,8 @@ def bench_silu_mul(M, N, K, E, top_k, metric, model=None): # Rough op count: SiLU + mul per output element flops = float(n_rows * d * 8) - fn = lambda: fused_silu_mul_last_dim(x, out) + def fn(): + return fused_silu_mul_last_dim(x, out) ms = triton.testing.do_bench(fn, warmup=25, rep=100) bandwidth = (mem_read + mem_write) / (ms * 1e-3) * 1e-9 From 4d8bcde4913bcd578d12e9953bd0c5f4b0be2982 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Wed, 1 Apr 2026 15:11:24 +0000 Subject: [PATCH 3/5] code refine and linter --- .../_triton_kernels/fusions/fused_silu_mul.py | 8 ++++--- aiter/ops/triton/fusions/fused_silu_mul.py | 6 ++++- op_tests/op_benchmarks/triton/bench_moe.py | 4 +--- .../triton/utils/model_configs.json | 22 ++++++++++++++++++- .../fusions/test_fused_silu_mul.py | 8 +++++-- 5 files changed, 38 insertions(+), 10 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py b/aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py index 69795e84b6..6d67f5baba 100644 --- a/aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py +++ b/aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py @@ -36,12 +36,14 @@ def fused_silu_mul_kernel( row_out = row_idx * row_stride_out first_half_ptrs = inp_ptr + row_in[:, None] + col_idx[None, :] * col_stride_in - second_half_ptrs = inp_ptr + row_in[:, None] + (n_cols + col_idx)[None, :] * col_stride_in + second_half_ptrs = ( + inp_ptr + row_in[:, None] + (n_cols + col_idx)[None, :] * col_stride_in + ) out_ptrs = out_ptr + row_out[:, None] + col_idx[None, :] * col_stride_out mask = (row_idx < n_rows)[:, None] & (col_idx < n_cols)[None, :] a = tl.load(first_half_ptrs, mask=mask, other=0.0).to(tl.float32) - b = tl.load(second_half_ptrs, mask=mask, other=0.0).to(tl.float32) - silu_a = _silu_exp2(a) + silu_a = _silu_exp2(a).to(inp_ptr.dtype.element_ty) + b = tl.load(second_half_ptrs, mask=mask, other=0.0) o = (silu_a * b).to(out_ptr.dtype.element_ty) tl.store(out_ptrs, o, mask=mask) diff --git a/aiter/ops/triton/fusions/fused_silu_mul.py b/aiter/ops/triton/fusions/fused_silu_mul.py index 14f25bab36..9b86aebe09 100644 --- a/aiter/ops/triton/fusions/fused_silu_mul.py +++ b/aiter/ops/triton/fusions/fused_silu_mul.py @@ -81,7 +81,11 @@ def fused_silu_mul_last_dim( leading = x.shape[:-1] n_rows = x.numel() // (2 * d) if n_rows == 0: - return torch.empty(*leading, d, dtype=x.dtype, device=x.device) if out is None else out + return ( + torch.empty(*leading, d, dtype=x.dtype, device=x.device) + if out is None + else out + ) _LOGGER.info( f"FUSED_SILU_MUL_LAST_DIM: x={tuple(x.shape)} last_half={d} rows={n_rows}" diff --git a/op_tests/op_benchmarks/triton/bench_moe.py b/op_tests/op_benchmarks/triton/bench_moe.py index e8463040c2..d5baa7303a 100644 --- a/op_tests/op_benchmarks/triton/bench_moe.py +++ b/op_tests/op_benchmarks/triton/bench_moe.py @@ -53,9 +53,7 @@ def silu_mul_benchmark_configs(args): no_bench_stage2 = args.no_bench_stage2 if args.silu_mul_M_list: default_ms = [ - int(x.strip()) - for x in args.silu_mul_M_list.split(",") - if x.strip() + int(x.strip()) for x in args.silu_mul_M_list.split(",") if x.strip() ] elif args.M: default_ms = [args.M] diff --git a/op_tests/op_benchmarks/triton/utils/model_configs.json b/op_tests/op_benchmarks/triton/utils/model_configs.json index b3eff5edae..d8371fc6c5 100644 --- a/op_tests/op_benchmarks/triton/utils/model_configs.json +++ b/op_tests/op_benchmarks/triton/utils/model_configs.json @@ -72,6 +72,16 @@ "num_expert": 160, "top_k": 8, "silu_mul_benchmark_M": [4, 8193, 7238] + }, + "ep4": { + "hidden_size": 5120, + "intermediate_size": 3072, + "num_attention_heads": 96, + "num_key_value_heads": 8, + "vocab_size": 151552, + "num_expert": 40, + "top_k": 2, + "silu_mul_benchmark_M": [4, 8193, 7238] } }, "kimik25": { @@ -83,7 +93,17 @@ "vocab_size": 163840, "num_expert": 384, "top_k": 8, - "silu_mul_benchmark_M": [4, 8193, 7238] + "silu_mul_benchmark_M": [4] + }, + "ep4": { + "hidden_size": 7168, + "intermediate_size": 4096, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "vocab_size": 163840, + "num_expert": 96, + "top_k": 2, + "silu_mul_benchmark_M": [4] } } } diff --git a/op_tests/triton_tests/fusions/test_fused_silu_mul.py b/op_tests/triton_tests/fusions/test_fused_silu_mul.py index f35908fc38..7e95e27c7d 100644 --- a/op_tests/triton_tests/fusions/test_fused_silu_mul.py +++ b/op_tests/triton_tests/fusions/test_fused_silu_mul.py @@ -66,10 +66,14 @@ def test_fused_silu_mul_requires_even_last_dim(): [ # Decode M=4 → rows M * top_k pytest.param(4 * _GLM47_TOP_K, _GLM47_TP4_LAST, id="glm47_tp4_decode4"), - pytest.param(4 * _KIMI_K25_TOP_K, _KIMI_K25_TP4_LAST, id="kimi_k25_tp4_decode4"), + pytest.param( + 4 * _KIMI_K25_TOP_K, _KIMI_K25_TP4_LAST, id="kimi_k25_tp4_decode4" + ), # Medium prefill / batched decode pytest.param(256 * _GLM47_TOP_K, _GLM47_TP4_LAST, id="glm47_tp4_rows256x8"), - pytest.param(256 * _KIMI_K25_TOP_K, _KIMI_K25_TP4_LAST, id="kimi_k25_tp4_rows256x8"), + pytest.param( + 256 * _KIMI_K25_TOP_K, _KIMI_K25_TP4_LAST, id="kimi_k25_tp4_rows256x8" + ), ], ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) From 7fb268aa1499a10b7aff8402733c66a9a5dc1731 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 15 May 2026 05:45:26 +0000 Subject: [PATCH 4/5] Implement fused SiLU and multiplication kernel in Triton, replacing previous implementation. Update activation module to include new fused_silu_mul function. Adjust benchmarks and tests to utilize the new function. Remove deprecated fused_silu_mul files. --- .../ops/triton/_triton_kernels/activation.py | 42 ++++++ .../_triton_kernels/fusions/fused_silu_mul.py | 49 ------- aiter/ops/triton/activation.py | 116 +++++++++++++++- aiter/ops/triton/fusions/fused_silu_mul.py | 127 ------------------ op_tests/op_benchmarks/triton/bench_moe.py | 4 +- .../fusions/test_fused_silu_mul.py | 16 +-- 6 files changed, 165 insertions(+), 189 deletions(-) delete mode 100644 aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py delete mode 100644 aiter/ops/triton/fusions/fused_silu_mul.py diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index 57a7c32644..0dd7241560 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -14,6 +14,48 @@ def _silu(x): return _silu_exp2(x) +@triton.jit +def fused_silu_mul_kernel( + inp_ptr, + out_ptr, + n_rows, + n_cols, + row_stride_in, + col_stride_in, + row_stride_out, + col_stride_out, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + SiLU on the first half of the last dimension, multiply by the second half. + Each row has 2 * n_cols input elements; writes n_cols outputs. + 2D grid: axis 0 tiles rows (BLOCK_M), axis 1 tiles columns (BLOCK_N). + """ + m_pid = tl.program_id(0) + n_pid = tl.program_id(1) + m_offs = tl.arange(0, BLOCK_M) + n_offs = tl.arange(0, BLOCK_N) + row_idx = m_pid * BLOCK_M + m_offs + col_idx = n_pid * BLOCK_N + n_offs + + row_in = row_idx * row_stride_in + row_out = row_idx * row_stride_out + + first_half_ptrs = inp_ptr + row_in[:, None] + col_idx[None, :] * col_stride_in + second_half_ptrs = ( + inp_ptr + row_in[:, None] + (n_cols + col_idx)[None, :] * col_stride_in + ) + out_ptrs = out_ptr + row_out[:, None] + col_idx[None, :] * col_stride_out + + mask = (row_idx < n_rows)[:, None] & (col_idx < n_cols)[None, :] + a = tl.load(first_half_ptrs, mask=mask, other=0.0).to(tl.float32) + silu_a = _silu_exp2(a).to(inp_ptr.dtype.element_ty) + b = tl.load(second_half_ptrs, mask=mask, other=0.0) + o = (silu_a * b).to(out_ptr.dtype.element_ty) + tl.store(out_ptrs, o, mask=mask) + + @triton.jit def _tanh(x): return 2 * tl.sigmoid(2 * x) - 1 diff --git a/aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py b/aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py deleted file mode 100644 index 6d67f5baba..0000000000 --- a/aiter/ops/triton/_triton_kernels/fusions/fused_silu_mul.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. - -import triton -import triton.language as tl - -from aiter.ops.triton._triton_kernels.activation import _silu_exp2 - - -@triton.jit -def fused_silu_mul_kernel( - inp_ptr, - out_ptr, - n_rows, - n_cols, - row_stride_in, - col_stride_in, - row_stride_out, - col_stride_out, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - """ - SiLU on the first half of the last dimension, multiply by the second half. - Each row has 2 * n_cols input elements; writes n_cols outputs. - 2D grid: axis 0 tiles rows (BLOCK_M), axis 1 tiles columns (BLOCK_N). - """ - m_pid = tl.program_id(0) - n_pid = tl.program_id(1) - m_offs = tl.arange(0, BLOCK_M) - n_offs = tl.arange(0, BLOCK_N) - row_idx = m_pid * BLOCK_M + m_offs - col_idx = n_pid * BLOCK_N + n_offs - - row_in = row_idx * row_stride_in - row_out = row_idx * row_stride_out - - first_half_ptrs = inp_ptr + row_in[:, None] + col_idx[None, :] * col_stride_in - second_half_ptrs = ( - inp_ptr + row_in[:, None] + (n_cols + col_idx)[None, :] * col_stride_in - ) - out_ptrs = out_ptr + row_out[:, None] + col_idx[None, :] * col_stride_out - - mask = (row_idx < n_rows)[:, None] & (col_idx < n_cols)[None, :] - a = tl.load(first_half_ptrs, mask=mask, other=0.0).to(tl.float32) - silu_a = _silu_exp2(a).to(inp_ptr.dtype.element_ty) - b = tl.load(second_half_ptrs, mask=mask, other=0.0) - o = (silu_a * b).to(out_ptr.dtype.element_ty) - tl.store(out_ptrs, o, mask=mask) diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index b52cf465eb..d433136fda 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -1,16 +1,17 @@ -from typing import Literal +from typing import Literal, Optional import triton import triton.language as tl import torch import aiter - -fp8_dtype = aiter.dtypes.fp8 from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.activation import ( _act_mul_and_dynamic_mxfp4_quant_kernel, _act_mul_and_dynamic_fp8_group_quant_kernel, + fused_silu_mul_kernel, ) +fp8_dtype = aiter.dtypes.fp8 + _LOGGER = AiterTritonLogger() @@ -198,3 +199,112 @@ def act_mul_and_fp8_group_quant( ) return x_fp8, out_bs + + +def fused_silu_mul( + x: torch.Tensor, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Fused SiLU-and-mul along the last dimension (same pattern as MoE silu-fused GEMM). + + ``x`` must be contiguous with even ``size(-1)``. For last size ``2 * d``, the first + ``d`` lanes are passed through SiLU (``_silu_exp2``); the second ``d`` lanes are the + multipliers. Output shape matches ``x`` except ``out.size(-1) == d``. + + Returns: + ``out`` if provided, else a newly allocated tensor. + """ + + def _pick_block_n(d: int, n_rows: int) -> int: + """Tile size along the reduced last dim (cap 1024); at least 32 for vectorization. + + Tuned on ROCm for MoE TP4 locals (GLM-4.7 ``d=384``, Kimi-K2.5 ``d=512``) and wide + MoE activations: ``n_rows`` selects decode vs prefill N-tiling (see sweep in repo + history / ``bench_moe.py -bench_silu_mul``). + """ + n = max(d, 1) + # Kimi-K2.5 TP4 (d=512): prefill favors one 512-wide N tile; decode keeps 256×2. + if n == 512: + return 512 if n_rows > 4096 else 256 + # GLM-4.7 TP4 (d=384): wider decode rows use 256×2; larger batches favor 128×3 N tiles. + if n == 384: + return 256 if n_rows <= 128 else 128 + upper = min(n, 1024) + p = 1 + while p * 2 <= upper: + p *= 2 + return max(32, p) + + def _pick_block_m(n_rows: int, block_n: int, d: int) -> int: + """Row tile size: latency shapes use wide M tiles; prefill uses tuned (d, n_rows) pairs.""" + if n_rows <= 64: + return min(32, max(4, triton.next_power_of_2(n_rows))) + if d == 384 and n_rows > 128: + return 32 if n_rows > 8192 else 8 + if d == 512 and n_rows > 4096: + return 8 + if d == 512 and 128 < n_rows <= 4096: + return 8 + if block_n >= 1024: + return 8 + if block_n >= 512: + return 8 + return 16 + + def _pick_num_warps(n_rows: int, block_m: int, block_n: int) -> int: + """ROCm: 8 warps for tiny full-wavefront decode tiles; 2 warps for larger tiles.""" + if n_rows <= 128 and block_m >= 16 and block_n >= 128: + return 8 + return 2 + + assert x.is_cuda, "fused_silu_mul requires a CUDA tensor" + assert x.is_contiguous(), "x must be contiguous" + last = x.size(-1) + assert last % 2 == 0, "last dimension must be even (2 * d)" + d = last // 2 + leading = x.shape[:-1] + n_rows = x.numel() // (2 * d) + if n_rows == 0: + return ( + torch.empty(*leading, d, dtype=x.dtype, device=x.device) + if out is None + else out + ) + + _LOGGER.info(f"fused_silu_mul: x={tuple(x.shape)} last_half={d} rows={n_rows}") + + if out is None: + out = torch.empty(*leading, d, dtype=x.dtype, device=x.device) + else: + assert out.is_contiguous(), "out must be contiguous" + assert out.shape == (*leading, d), "out shape must match x with last dim halved" + assert out.dtype == x.dtype and out.device == x.device + + row_stride_in = 2 * d + col_stride_in = 1 + row_stride_out = d + col_stride_out = 1 + + block_n = _pick_block_n(d, n_rows) + block_m = _pick_block_m(n_rows, block_n, d) + grid_m = triton.cdiv(n_rows, block_m) + grid_n = triton.cdiv(d, block_n) + num_warps = _pick_num_warps(n_rows, block_m, block_n) + + grid = (grid_m, grid_n) + fused_silu_mul_kernel[grid]( + x, + out, + n_rows, + d, + row_stride_in, + col_stride_in, + row_stride_out, + col_stride_out, + BLOCK_M=block_m, + BLOCK_N=block_n, + num_warps=num_warps, + waves_per_eu=0, + ) + return out diff --git a/aiter/ops/triton/fusions/fused_silu_mul.py b/aiter/ops/triton/fusions/fused_silu_mul.py deleted file mode 100644 index 9b86aebe09..0000000000 --- a/aiter/ops/triton/fusions/fused_silu_mul.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. - -from typing import Optional - -import torch -import triton - -from aiter.ops.triton._triton_kernels.fusions.fused_silu_mul import ( - fused_silu_mul_kernel, -) -from aiter.ops.triton.utils.logger import AiterTritonLogger - -_LOGGER = AiterTritonLogger() - - -def _pick_block_n(d: int, n_rows: int) -> int: - """Tile size along the reduced last dim (cap 1024); at least 32 for vectorization. - - Tuned on ROCm for MoE TP4 locals (GLM-4.7 ``d=384``, Kimi-K2.5 ``d=512``) and wide - MoE activations: ``n_rows`` selects decode vs prefill N-tiling (see sweep in repo - history / ``bench_moe.py -bench_silu_mul``). - """ - n = max(d, 1) - # Kimi-K2.5 TP4 (d=512): prefill favors one 512-wide N tile; decode keeps 256×2. - if n == 512: - return 512 if n_rows > 4096 else 256 - # GLM-4.7 TP4 (d=384): wider decode rows use 256×2; larger batches favor 128×3 N tiles. - if n == 384: - return 256 if n_rows <= 128 else 128 - upper = min(n, 1024) - p = 1 - while p * 2 <= upper: - p *= 2 - return max(32, p) - - -def _pick_block_m(n_rows: int, block_n: int, d: int) -> int: - """Row tile size: latency shapes use wide M tiles; prefill uses tuned (d, n_rows) pairs.""" - if n_rows <= 64: - return min(32, max(4, triton.next_power_of_2(n_rows))) - if d == 384 and n_rows > 128: - return 32 if n_rows > 8192 else 8 - if d == 512 and n_rows > 4096: - return 8 - if d == 512 and 128 < n_rows <= 4096: - return 8 - if block_n >= 1024: - return 8 - if block_n >= 512: - return 8 - return 16 - - -def _pick_num_warps(n_rows: int, block_m: int, block_n: int) -> int: - """ROCm: 8 warps for tiny full-wavefront decode tiles; 2 warps for larger tiles.""" - if n_rows <= 128 and block_m >= 16 and block_n >= 128: - return 8 - return 2 - - -def fused_silu_mul_last_dim( - x: torch.Tensor, - out: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - Fused SiLU-and-mul along the last dimension (same pattern as MoE silu-fused GEMM). - - ``x`` must be contiguous with even ``size(-1)``. For last size ``2 * d``, the first - ``d`` lanes are passed through SiLU (``_silu_exp2``); the second ``d`` lanes are the - multipliers. Output shape matches ``x`` except ``out.size(-1) == d``. - - Returns: - ``out`` if provided, else a newly allocated tensor. - """ - assert x.is_cuda, "fused_silu_mul_last_dim requires a CUDA tensor" - assert x.is_contiguous(), "x must be contiguous" - last = x.size(-1) - assert last % 2 == 0, "last dimension must be even (2 * d)" - d = last // 2 - leading = x.shape[:-1] - n_rows = x.numel() // (2 * d) - if n_rows == 0: - return ( - torch.empty(*leading, d, dtype=x.dtype, device=x.device) - if out is None - else out - ) - - _LOGGER.info( - f"FUSED_SILU_MUL_LAST_DIM: x={tuple(x.shape)} last_half={d} rows={n_rows}" - ) - - if out is None: - out = torch.empty(*leading, d, dtype=x.dtype, device=x.device) - else: - assert out.is_contiguous(), "out must be contiguous" - assert out.shape == (*leading, d), "out shape must match x with last dim halved" - assert out.dtype == x.dtype and out.device == x.device - - row_stride_in = 2 * d - col_stride_in = 1 - row_stride_out = d - col_stride_out = 1 - - block_n = _pick_block_n(d, n_rows) - block_m = _pick_block_m(n_rows, block_n, d) - grid_m = triton.cdiv(n_rows, block_m) - grid_n = triton.cdiv(d, block_n) - num_warps = _pick_num_warps(n_rows, block_m, block_n) - - grid = (grid_m, grid_n) - fused_silu_mul_kernel[grid]( - x, - out, - n_rows, - d, - row_stride_in, - col_stride_in, - row_stride_out, - col_stride_out, - BLOCK_M=block_m, - BLOCK_N=block_n, - num_warps=num_warps, - waves_per_eu=0, - ) - return out diff --git a/op_tests/op_benchmarks/triton/bench_moe.py b/op_tests/op_benchmarks/triton/bench_moe.py index 8d47b2f62d..52cdb0d751 100644 --- a/op_tests/op_benchmarks/triton/bench_moe.py +++ b/op_tests/op_benchmarks/triton/bench_moe.py @@ -5,7 +5,7 @@ from aiter.ops.triton.utils.types import torch_to_triton_dtype, str_to_torch_dtype from aiter.ops.triton.moe.moe_op import fused_moe as triton_moe from aiter.ops.triton.moe.moe_op_silu_fused import fused_moe_silu as triton_moe_silu -from aiter.ops.triton.fusions.fused_silu_mul import fused_silu_mul_last_dim +from aiter.ops.triton.activation import fused_silu_mul from op_tests.triton_tests.moe.test_moe import input_helper, input_helper_int4_w4a16 from op_tests.op_benchmarks.triton.utils.benchmark_utils import ( get_model_configs, @@ -346,7 +346,7 @@ def bench_silu_mul(M, N, K, E, top_k, metric, model=None): flops = float(n_rows * d * 8) def fn(): - return fused_silu_mul_last_dim(x, out) + return fused_silu_mul(x, out) ms = triton.testing.do_bench(fn, warmup=25, rep=100) bandwidth = (mem_read + mem_write) / (ms * 1e-3) * 1e-9 diff --git a/op_tests/triton_tests/fusions/test_fused_silu_mul.py b/op_tests/triton_tests/fusions/test_fused_silu_mul.py index 7e95e27c7d..cc5e9b7b03 100644 --- a/op_tests/triton_tests/fusions/test_fused_silu_mul.py +++ b/op_tests/triton_tests/fusions/test_fused_silu_mul.py @@ -1,7 +1,7 @@ import torch import pytest -from aiter.ops.triton.fusions.fused_silu_mul import fused_silu_mul_last_dim +from aiter.ops.triton.activation import fused_silu_mul LOG2_E = 1.44269504089 @@ -38,18 +38,18 @@ def torch_silu_mul_last_dim_ref(x: torch.Tensor) -> torch.Tensor: (1, 3, 7, 32), ], ) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("use_explicit_out", [False, True]) -def test_fused_silu_mul_last_dim(shape, dtype, use_explicit_out): +def test_fused_silu_mul(shape, dtype, use_explicit_out): if not torch.cuda.is_available(): pytest.skip("CUDA required") x = torch.randn(shape, dtype=dtype, device="cuda") ref = torch_silu_mul_last_dim_ref(x) if use_explicit_out: out = torch.empty_like(ref) - fused_silu_mul_last_dim(x, out) + fused_silu_mul(x, out) else: - out = fused_silu_mul_last_dim(x) + out = fused_silu_mul(x) torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) @@ -58,7 +58,7 @@ def test_fused_silu_mul_requires_even_last_dim(): pytest.skip("CUDA required") x = torch.randn(2, 3, device="cuda") with pytest.raises(AssertionError, match="even"): - fused_silu_mul_last_dim(x) + fused_silu_mul(x) @pytest.mark.parametrize( @@ -84,7 +84,7 @@ def test_fused_silu_mul_tp4_moe_shapes(n_rows, last_dim, dtype): shape = (n_rows, last_dim) x = torch.randn(shape, dtype=dtype, device="cuda") ref = torch_silu_mul_last_dim_ref(x) - out = fused_silu_mul_last_dim(x) + out = fused_silu_mul(x) torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) @@ -120,5 +120,5 @@ def test_fused_silu_mul_tp4_prefill_bf16(n_rows, last_dim): shape = (n_rows, last_dim) x = torch.randn(shape, dtype=dtype, device="cuda") ref = torch_silu_mul_last_dim_ref(x) - out = fused_silu_mul_last_dim(x) + out = fused_silu_mul(x) torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) From 97ade7e0ca2588b4112b4a9a77d7a62087328313 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 15 May 2026 08:47:35 +0300 Subject: [PATCH 5/5] Update aiter/ops/triton/activation.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- aiter/ops/triton/activation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index d433136fda..a757a37550 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -1,6 +1,5 @@ from typing import Literal, Optional import triton -import triton.language as tl import torch import aiter from aiter.ops.triton.utils.logger import AiterTritonLogger