diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index 57a7c32644..0dd7241560 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -14,6 +14,48 @@ def _silu(x): return _silu_exp2(x) +@triton.jit +def fused_silu_mul_kernel( + inp_ptr, + out_ptr, + n_rows, + n_cols, + row_stride_in, + col_stride_in, + row_stride_out, + col_stride_out, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + SiLU on the first half of the last dimension, multiply by the second half. + Each row has 2 * n_cols input elements; writes n_cols outputs. + 2D grid: axis 0 tiles rows (BLOCK_M), axis 1 tiles columns (BLOCK_N). + """ + m_pid = tl.program_id(0) + n_pid = tl.program_id(1) + m_offs = tl.arange(0, BLOCK_M) + n_offs = tl.arange(0, BLOCK_N) + row_idx = m_pid * BLOCK_M + m_offs + col_idx = n_pid * BLOCK_N + n_offs + + row_in = row_idx * row_stride_in + row_out = row_idx * row_stride_out + + first_half_ptrs = inp_ptr + row_in[:, None] + col_idx[None, :] * col_stride_in + second_half_ptrs = ( + inp_ptr + row_in[:, None] + (n_cols + col_idx)[None, :] * col_stride_in + ) + out_ptrs = out_ptr + row_out[:, None] + col_idx[None, :] * col_stride_out + + mask = (row_idx < n_rows)[:, None] & (col_idx < n_cols)[None, :] + a = tl.load(first_half_ptrs, mask=mask, other=0.0).to(tl.float32) + silu_a = _silu_exp2(a).to(inp_ptr.dtype.element_ty) + b = tl.load(second_half_ptrs, mask=mask, other=0.0) + o = (silu_a * b).to(out_ptr.dtype.element_ty) + tl.store(out_ptrs, o, mask=mask) + + @triton.jit def _tanh(x): return 2 * tl.sigmoid(2 * x) - 1 diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index b52cf465eb..a757a37550 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -1,16 +1,16 @@ -from typing import Literal +from typing import Literal, Optional import triton -import triton.language as tl import torch import aiter - -fp8_dtype = aiter.dtypes.fp8 from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.activation import ( _act_mul_and_dynamic_mxfp4_quant_kernel, _act_mul_and_dynamic_fp8_group_quant_kernel, + fused_silu_mul_kernel, ) +fp8_dtype = aiter.dtypes.fp8 + _LOGGER = AiterTritonLogger() @@ -198,3 +198,112 @@ def act_mul_and_fp8_group_quant( ) return x_fp8, out_bs + + +def fused_silu_mul( + x: torch.Tensor, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Fused SiLU-and-mul along the last dimension (same pattern as MoE silu-fused GEMM). + + ``x`` must be contiguous with even ``size(-1)``. For last size ``2 * d``, the first + ``d`` lanes are passed through SiLU (``_silu_exp2``); the second ``d`` lanes are the + multipliers. Output shape matches ``x`` except ``out.size(-1) == d``. + + Returns: + ``out`` if provided, else a newly allocated tensor. + """ + + def _pick_block_n(d: int, n_rows: int) -> int: + """Tile size along the reduced last dim (cap 1024); at least 32 for vectorization. + + Tuned on ROCm for MoE TP4 locals (GLM-4.7 ``d=384``, Kimi-K2.5 ``d=512``) and wide + MoE activations: ``n_rows`` selects decode vs prefill N-tiling (see sweep in repo + history / ``bench_moe.py -bench_silu_mul``). + """ + n = max(d, 1) + # Kimi-K2.5 TP4 (d=512): prefill favors one 512-wide N tile; decode keeps 256×2. + if n == 512: + return 512 if n_rows > 4096 else 256 + # GLM-4.7 TP4 (d=384): wider decode rows use 256×2; larger batches favor 128×3 N tiles. + if n == 384: + return 256 if n_rows <= 128 else 128 + upper = min(n, 1024) + p = 1 + while p * 2 <= upper: + p *= 2 + return max(32, p) + + def _pick_block_m(n_rows: int, block_n: int, d: int) -> int: + """Row tile size: latency shapes use wide M tiles; prefill uses tuned (d, n_rows) pairs.""" + if n_rows <= 64: + return min(32, max(4, triton.next_power_of_2(n_rows))) + if d == 384 and n_rows > 128: + return 32 if n_rows > 8192 else 8 + if d == 512 and n_rows > 4096: + return 8 + if d == 512 and 128 < n_rows <= 4096: + return 8 + if block_n >= 1024: + return 8 + if block_n >= 512: + return 8 + return 16 + + def _pick_num_warps(n_rows: int, block_m: int, block_n: int) -> int: + """ROCm: 8 warps for tiny full-wavefront decode tiles; 2 warps for larger tiles.""" + if n_rows <= 128 and block_m >= 16 and block_n >= 128: + return 8 + return 2 + + assert x.is_cuda, "fused_silu_mul requires a CUDA tensor" + assert x.is_contiguous(), "x must be contiguous" + last = x.size(-1) + assert last % 2 == 0, "last dimension must be even (2 * d)" + d = last // 2 + leading = x.shape[:-1] + n_rows = x.numel() // (2 * d) + if n_rows == 0: + return ( + torch.empty(*leading, d, dtype=x.dtype, device=x.device) + if out is None + else out + ) + + _LOGGER.info(f"fused_silu_mul: x={tuple(x.shape)} last_half={d} rows={n_rows}") + + if out is None: + out = torch.empty(*leading, d, dtype=x.dtype, device=x.device) + else: + assert out.is_contiguous(), "out must be contiguous" + assert out.shape == (*leading, d), "out shape must match x with last dim halved" + assert out.dtype == x.dtype and out.device == x.device + + row_stride_in = 2 * d + col_stride_in = 1 + row_stride_out = d + col_stride_out = 1 + + block_n = _pick_block_n(d, n_rows) + block_m = _pick_block_m(n_rows, block_n, d) + grid_m = triton.cdiv(n_rows, block_m) + grid_n = triton.cdiv(d, block_n) + num_warps = _pick_num_warps(n_rows, block_m, block_n) + + grid = (grid_m, grid_n) + fused_silu_mul_kernel[grid]( + x, + out, + n_rows, + d, + row_stride_in, + col_stride_in, + row_stride_out, + col_stride_out, + BLOCK_M=block_m, + BLOCK_N=block_n, + num_warps=num_warps, + waves_per_eu=0, + ) + return out diff --git a/op_tests/op_benchmarks/triton/bench_moe.py b/op_tests/op_benchmarks/triton/bench_moe.py index 3594278162..52cdb0d751 100644 --- a/op_tests/op_benchmarks/triton/bench_moe.py +++ b/op_tests/op_benchmarks/triton/bench_moe.py @@ -5,6 +5,7 @@ from aiter.ops.triton.utils.types import torch_to_triton_dtype, str_to_torch_dtype from aiter.ops.triton.moe.moe_op import fused_moe as triton_moe from aiter.ops.triton.moe.moe_op_silu_fused import fused_moe_silu as triton_moe_silu +from aiter.ops.triton.activation import fused_silu_mul from op_tests.triton_tests.moe.test_moe import input_helper, input_helper_int4_w4a16 from op_tests.op_benchmarks.triton.utils.benchmark_utils import ( get_model_configs, @@ -41,6 +42,44 @@ def model_benchmark_configs(args): return moe_configs +def silu_mul_benchmark_configs(args): + """Like ``model_benchmark_configs`` but supports multiple M via ``-silu_mul_M_list``.""" + configs = get_model_configs( + config_path=args.model_configs, + models="mixtral" if args.model is None else args.model, + ) + if not configs: + return [] + no_bench_stage2 = args.no_bench_stage2 + if args.silu_mul_M_list: + default_ms = [ + int(x.strip()) for x in args.silu_mul_M_list.split(",") if x.strip() + ] + elif args.M: + default_ms = [args.M] + else: + default_ms = [4096] + + moe_configs = [] + for model_name, config in configs.items(): + ms_model = config.get("silu_mul_benchmark_M", default_ms) + if not isinstance(ms_model, list): + ms_model = [int(ms_model)] + else: + ms_model = [int(m) for m in ms_model] + for M in ms_model: + N1 = config["intermediate_size"] + K1 = config["hidden_size"] + E = config["num_expert"] + top_k = config["top_k"] + moe_configs.append((model_name, M, N1, K1, E, top_k)) + if no_bench_stage2: + N2 = config["hidden_size"] + K2 = config["intermediate_size"] // 2 + moe_configs.append((model_name, M, N2, K2, E, top_k)) + return moe_configs + + def fused_moe( M, N, @@ -262,6 +301,69 @@ def bench_moe_gemm(M, N, K, E, top_k, metric, model=None): bench_moe_gemm.run(save_path="." if args.o else None, print_data=True) +def run_silu_mul_benchmark(args): + """Benchmark last-dim fused SiLU-and-mul (same activation as silu-fused MoE).""" + print_time = args.print_time + dtype = str_to_torch_dtype[args.dtype] + + if print_time: + line_names = ["Time_(ms)"] + line_vals = ["time"] + else: + line_names = ["Time_(ms)", "GFLOPS", "Bandwidth_(GB/s)"] + line_vals = ["time", "gflops", "bandwidth"] + + x_vals_list = silu_mul_benchmark_configs(args) + x_names = ["model", "M", "N", "K", "E", "top_k"] + + benchmark = triton.testing.Benchmark( + x_names=x_names, + x_vals=x_vals_list, + line_arg="metric", + line_vals=line_vals, + line_names=line_names, + styles=[("red", "-"), ("blue", "-"), ("yellow", "-")], + ylabel="ms / GFLOPS / GB/s", + plot_name=get_caller_name_no_ext() + "_silu_mul", + args={}, + ) + + @triton.testing.perf_report([benchmark]) + def bench_silu_mul(M, N, K, E, top_k, metric, model=None): + # Match MoE post-GEMM layout: (M * top_k, N); N must be even for gate/up pairs. + n_even = N if N % 2 == 0 else N - 1 + if n_even < 2: + return 0.0 + n_rows = M * top_k + d = n_even // 2 + x = torch.randn(n_rows, n_even, device="cuda", dtype=dtype) + out = torch.empty(n_rows, d, device="cuda", dtype=dtype) + + elem = torch.tensor([], dtype=dtype).element_size() + mem_read = n_rows * n_even * elem + mem_write = n_rows * d * elem + # Rough op count: SiLU + mul per output element + flops = float(n_rows * d * 8) + + def fn(): + return fused_silu_mul(x, out) + + ms = triton.testing.do_bench(fn, warmup=25, rep=100) + bandwidth = (mem_read + mem_write) / (ms * 1e-3) * 1e-9 + gflops = flops / ms * 1e-6 + + if metric == "time": + return ms + elif metric == "gflops": + return gflops + elif metric == "bandwidth": + return bandwidth + else: + raise ValueError("Unknown metric: " + metric) + + bench_silu_mul.run(save_path="." if args.o else None, print_data=True) + + def parse_args(): parser = argparse.ArgumentParser( prog="Benchmark MoE GEMM", @@ -300,6 +402,19 @@ def parse_args(): parser.add_argument("-dtype", default="fp16") parser.add_argument("-fp8_type", default="e5m2fnuz") parser.add_argument("-silu_fused", action="store_true", default=False) + parser.add_argument( + "-bench_silu_mul", + action="store_true", + default=False, + help="Benchmark fused last-dim SiLU-and-mul only (uses model M, N, top_k).", + ) + parser.add_argument( + "-silu_mul_M_list", + type=str, + default=None, + help="Comma-separated token counts M for silu_mul bench (e.g. 4,8193,7238). " + "Row count is M * top_k. Implies multiple table rows when set.", + ) parser.add_argument( "-o", action="store_true", help="Write performance results to CSV file" ) @@ -310,6 +425,17 @@ def parse_args(): def main(): args = parse_args() + if args.bench_silu_mul: + if args.print_vgpr: + + def fun(): + return run_silu_mul_benchmark(args) + + print_vgpr(fun, get_caller_name_no_ext() + "_silu_mul") + return 0 + run_silu_mul_benchmark(args) + return 0 + if args.print_vgpr: print("Retrieving VGPR usage for Triton kernels...") diff --git a/op_tests/op_benchmarks/triton/utils/model_configs.json b/op_tests/op_benchmarks/triton/utils/model_configs.json index f41efcd412..d8371fc6c5 100644 --- a/op_tests/op_benchmarks/triton/utils/model_configs.json +++ b/op_tests/op_benchmarks/triton/utils/model_configs.json @@ -61,6 +61,50 @@ "num_expert": 256, "top_k": 4 } + }, + "glm47fp8": { + "tp4": { + "hidden_size": 5120, + "intermediate_size": 768, + "num_attention_heads": 96, + "num_key_value_heads": 8, + "vocab_size": 151552, + "num_expert": 160, + "top_k": 8, + "silu_mul_benchmark_M": [4, 8193, 7238] + }, + "ep4": { + "hidden_size": 5120, + "intermediate_size": 3072, + "num_attention_heads": 96, + "num_key_value_heads": 8, + "vocab_size": 151552, + "num_expert": 40, + "top_k": 2, + "silu_mul_benchmark_M": [4, 8193, 7238] + } + }, + "kimik25": { + "tp4": { + "hidden_size": 7168, + "intermediate_size": 1024, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "vocab_size": 163840, + "num_expert": 384, + "top_k": 8, + "silu_mul_benchmark_M": [4] + }, + "ep4": { + "hidden_size": 7168, + "intermediate_size": 4096, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "vocab_size": 163840, + "num_expert": 96, + "top_k": 2, + "silu_mul_benchmark_M": [4] + } } } diff --git a/op_tests/triton_tests/fusions/test_fused_silu_mul.py b/op_tests/triton_tests/fusions/test_fused_silu_mul.py new file mode 100644 index 0000000000..cc5e9b7b03 --- /dev/null +++ b/op_tests/triton_tests/fusions/test_fused_silu_mul.py @@ -0,0 +1,124 @@ +import torch +import pytest + +from aiter.ops.triton.activation import fused_silu_mul + +LOG2_E = 1.44269504089 + +# GLM-4.7-FP8 MoE (e.g. zai-org/GLM-4.7-FP8): moe_intermediate_size=1536, top_k=8. +# Column-parallel TP4: local d = 1536 // 4 = 384, fused silu-mul input last dim = 768. +_GLM47_TP4_LAST = 768 +_GLM47_TOP_K = 8 + +# Kimi-K2.5 MoE (moonshotai/Kimi-K2.5 text_config): moe_intermediate_size=2048, top_k=8. +# TP4: local d = 2048 // 4 = 512, last dim = 1024. +_KIMI_K25_TP4_LAST = 1024 +_KIMI_K25_TOP_K = 8 + + +def silu_exp2_ref(t: torch.Tensor) -> torch.Tensor: + """Match ``_silu_exp2`` in Triton (same as MoE silu-fused path).""" + x = t.float() + return x / (1.0 + torch.exp2(-(x * LOG2_E))) + + +def torch_silu_mul_last_dim_ref(x: torch.Tensor) -> torch.Tensor: + d = x.size(-1) // 2 + a, b = x[..., :d], x[..., d:] + return (silu_exp2_ref(a) * b).to(x.dtype) + + +@pytest.mark.parametrize( + "shape", + [ + (4, 64), + (128, 256), + (31, 500), + (2, 16, 128), + (1, 3, 7, 32), + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("use_explicit_out", [False, True]) +def test_fused_silu_mul(shape, dtype, use_explicit_out): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + x = torch.randn(shape, dtype=dtype, device="cuda") + ref = torch_silu_mul_last_dim_ref(x) + if use_explicit_out: + out = torch.empty_like(ref) + fused_silu_mul(x, out) + else: + out = fused_silu_mul(x) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) + + +def test_fused_silu_mul_requires_even_last_dim(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + x = torch.randn(2, 3, device="cuda") + with pytest.raises(AssertionError, match="even"): + fused_silu_mul(x) + + +@pytest.mark.parametrize( + "n_rows,last_dim", + [ + # Decode M=4 → rows M * top_k + pytest.param(4 * _GLM47_TOP_K, _GLM47_TP4_LAST, id="glm47_tp4_decode4"), + pytest.param( + 4 * _KIMI_K25_TOP_K, _KIMI_K25_TP4_LAST, id="kimi_k25_tp4_decode4" + ), + # Medium prefill / batched decode + pytest.param(256 * _GLM47_TOP_K, _GLM47_TP4_LAST, id="glm47_tp4_rows256x8"), + pytest.param( + 256 * _KIMI_K25_TOP_K, _KIMI_K25_TP4_LAST, id="kimi_k25_tp4_rows256x8" + ), + ], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_silu_mul_tp4_moe_shapes(n_rows, last_dim, dtype): + """MoE fused silu×mul tensor as (tokens * top_k, 2 * local_d) under TP4.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + shape = (n_rows, last_dim) + x = torch.randn(shape, dtype=dtype, device="cuda") + ref = torch_silu_mul_last_dim_ref(x) + out = fused_silu_mul(x) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize( + "n_rows,last_dim", + [ + pytest.param( + (8190 + 3) * _GLM47_TOP_K, + _GLM47_TP4_LAST, + id="glm47_tp4_pref8190_dec3", + ), + pytest.param( + (7235 + 3) * _GLM47_TOP_K, + _GLM47_TP4_LAST, + id="glm47_tp4_pref7235_dec3", + ), + pytest.param( + (8190 + 3) * _KIMI_K25_TOP_K, + _KIMI_K25_TP4_LAST, + id="kimi_k25_tp4_pref8190_dec3", + ), + pytest.param( + (7235 + 3) * _KIMI_K25_TOP_K, + _KIMI_K25_TP4_LAST, + id="kimi_k25_tp4_pref7235_dec3", + ), + ], +) +def test_fused_silu_mul_tp4_prefill_bf16(n_rows, last_dim): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + dtype = torch.bfloat16 + shape = (n_rows, last_dim) + x = torch.randn(shape, dtype=dtype, device="cuda") + ref = torch_silu_mul_last_dim_ref(x) + out = fused_silu_mul(x) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2)