diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index 57a7c32644..2aa9bb9ca0 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -134,7 +134,6 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( ).to(tl.float32) x = _apply_activation_from_str(a, ACTIVATION) * b - out_tensor, bs_e8m0 = _mxfp4_quant_op( x, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE ) @@ -183,7 +182,6 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( if EVEN_M_N: tl.store(bs_ptr + bs_offs, bs_e8m0) else: - tl.store( bs_ptr + bs_offs, bs_e8m0, @@ -199,7 +197,7 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( @triton.jit def _act_mul_and_dynamic_fp8_group_quant_kernel( x_ptr, - x_fp8_ptr, + x_out_ptr, x_bs_ptr, stride_x_m_in, stride_x_n_in, @@ -215,6 +213,7 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( DTYPE_MAX: tl.constexpr, DTYPE_MIN: tl.constexpr, EVEN_N: tl.constexpr, + DO_QUANT: tl.constexpr, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) @@ -223,8 +222,6 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( stride_x_n = tl.cast(stride_x_n_in, tl.int64) stride_x_fp8_m = tl.cast(stride_x_fp8_m_in, tl.int64) stride_x_fp8_n = tl.cast(stride_x_fp8_n_in, tl.int64) - stride_bs_m = tl.cast(stride_bs_m_in, tl.int64) - stride_bs_n = tl.cast(stride_bs_n_in, tl.int64) NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE x_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -245,31 +242,43 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel( x = _apply_activation_from_str(a, ACTIVATION) * b - x_fp8, x_bs = _fp8_quant_op( - x, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN - ) - x_fp8 = tl.ravel(x_fp8) - x_bs = tl.ravel(x_bs) - out_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) out_offs = pid_m * stride_x_fp8_m + out_offs_n * stride_x_fp8_n - if EVEN_N: - tl.store(x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty)) - else: - out_mask = out_offs_n < N - tl.store( - x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty), mask=out_mask + if DO_QUANT: + stride_bs_m = tl.cast(stride_bs_m_in, tl.int64) + stride_bs_n = tl.cast(stride_bs_n_in, tl.int64) + x_fp8, x_bs = _fp8_quant_op( + x, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN ) + x_fp8 = tl.ravel(x_fp8) + x_bs = tl.ravel(x_bs) - bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS) - bs_offs = pid_m * stride_bs_m + bs_offs_n * stride_bs_n - if EVEN_N: - tl.store(x_bs_ptr + bs_offs, x_bs.to(x_bs_ptr.dtype.element_ty)) + if EVEN_N: + tl.store(x_out_ptr + out_offs, x_fp8.to(x_out_ptr.dtype.element_ty)) + else: + out_mask = out_offs_n < N + tl.store( + x_out_ptr + out_offs, + x_fp8.to(x_out_ptr.dtype.element_ty), + mask=out_mask, + ) + + bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS) + bs_offs = pid_m * stride_bs_m + bs_offs_n * stride_bs_n + if EVEN_N: + tl.store(x_bs_ptr + bs_offs, x_bs.to(x_bs_ptr.dtype.element_ty)) + else: + bs_mask = bs_offs_n < scaleN + tl.store( + x_bs_ptr + bs_offs, + x_bs.to(x_bs_ptr.dtype.element_ty), + mask=bs_mask, + ) else: - bs_mask = bs_offs_n < scaleN - tl.store( - x_bs_ptr + bs_offs, - x_bs.to(x_bs_ptr.dtype.element_ty), - mask=bs_mask, - ) + x_out = x.to(x_out_ptr.dtype.element_ty) + if EVEN_N: + tl.store(x_out_ptr + out_offs, x_out) + else: + out_mask = out_offs_n < N + tl.store(x_out_ptr + out_offs, x_out, mask=out_mask) diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index b52cf465eb..b7cc17fdcd 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -1,16 +1,15 @@ -from typing import Literal +from typing import Literal, Optional import triton -import triton.language as tl import torch import aiter - -fp8_dtype = aiter.dtypes.fp8 from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.activation import ( _act_mul_and_dynamic_mxfp4_quant_kernel, _act_mul_and_dynamic_fp8_group_quant_kernel, ) +fp8_dtype = aiter.dtypes.fp8 + _LOGGER = AiterTritonLogger() @@ -128,6 +127,71 @@ def act_mul_and_mxfp4_quant( return x_fp4, blockscale_e8m0 +def act_mul( + x: torch.Tensor, + activation: Literal["silu", "gelu", "gelu_tanh"], + out: Optional[torch.Tensor] = None, + group_size: Optional[int] = None, +) -> torch.Tensor: + """ + Gated activation along the last dimension only (no quantization): ``act(x0) * x1`` + where ``x`` is ``[..., 2 * d]`` split into two ``[..., d]`` halves. + + Accepts any input rank; leading dimensions are flattened before the kernel and + restored in the returned tensor, which has shape ``[..., d]``. + + Uses the same Triton path as ``act_mul_and_fp8_group_quant`` with quantization disabled. + """ + _LOGGER.info(f"ACT_MUL: x={tuple(x.shape)} activation={activation}") + assert x.is_cuda and x.is_contiguous() + assert x.shape[-1] % 2 == 0 + + # Flatten all leading dimensions so the kernel always sees a 2-D tensor. + orig_shape = x.shape + x_2d = x.view(-1, orig_shape[-1]) + M, N = x_2d.shape + N_half = N // 2 + out_shape = orig_shape[:-1] + (N_half,) + + if out is None: + out_2d = torch.empty((M, N_half), dtype=x.dtype, device=x.device) + else: + assert out.shape == out_shape + assert out.dtype == x.dtype and out.is_contiguous() + out_2d = out.view(M, N_half) + + if group_size is None: + group_size = min(256, triton.next_power_of_2(N_half)) + group_size = max(32, group_size) + + scaleN = triton.cdiv(N_half, group_size) + DTYPE_MAX = 1.0 + BLOCK_SIZE_N = group_size + + grid = ( + M, + triton.cdiv(N_half, BLOCK_SIZE_N), + ) + _act_mul_and_dynamic_fp8_group_quant_kernel[grid]( + x_2d, + out_2d, + None, + *x_2d.stride(), + *out_2d.stride(), + 0, # stride_bs_m_in + 0, # stride_bs_n_in + N=N_half, + ACTIVATION=activation, + scaleN=scaleN, + BLOCK_SIZE_N=BLOCK_SIZE_N, + QUANT_BLOCK_SIZE=group_size, + DTYPE_MAX=DTYPE_MAX, + DTYPE_MIN=-DTYPE_MAX, + DO_QUANT=False, + ) + return out_2d.view(out_shape) + + def act_mul_and_fp8_group_quant( x: torch.Tensor, activation: Literal["silu", "gelu", "gelu_tanh"], @@ -195,6 +259,7 @@ def act_mul_and_fp8_group_quant( # num_warps=NUM_WARPS, # waves_per_eu=0, # num_stages=1, + DO_QUANT=True, ) return x_fp8, out_bs diff --git a/op_tests/op_benchmarks/triton/bench_moe.py b/op_tests/op_benchmarks/triton/bench_moe.py index 3594278162..0fdb94f308 100644 --- a/op_tests/op_benchmarks/triton/bench_moe.py +++ b/op_tests/op_benchmarks/triton/bench_moe.py @@ -5,6 +5,7 @@ from aiter.ops.triton.utils.types import torch_to_triton_dtype, str_to_torch_dtype from aiter.ops.triton.moe.moe_op import fused_moe as triton_moe from aiter.ops.triton.moe.moe_op_silu_fused import fused_moe_silu as triton_moe_silu +from aiter.ops.triton.activation import act_mul from op_tests.triton_tests.moe.test_moe import input_helper, input_helper_int4_w4a16 from op_tests.op_benchmarks.triton.utils.benchmark_utils import ( get_model_configs, @@ -262,6 +263,68 @@ def bench_moe_gemm(M, N, K, E, top_k, metric, model=None): bench_moe_gemm.run(save_path="." if args.o else None, print_data=True) +def run_act_mul_benchmark(args): + """Benchmark ``act_mul`` (SiLU/GELU * gate) on tensors shaped like MoE activations.""" + print_time = args.print_time + dtype = str_to_torch_dtype[args.dtype] + activation = args.act_mul_activation + + if print_time: + line_names = ["Time_(ms)"] + line_vals = ["time"] + else: + line_names = ["Time_(ms)", "GFLOPS", "Bandwidth_(GB/s)"] + line_vals = ["time", "gflops", "bandwidth"] + + x_vals_list = model_benchmark_configs(args) + x_names = ["model", "M", "N", "K", "E", "top_k"] + + benchmark = triton.testing.Benchmark( + x_names=x_names, + x_vals=x_vals_list, + line_arg="metric", + line_vals=line_vals, + line_names=line_names, + styles=[("red", "-"), ("blue", "-"), ("yellow", "-")], + ylabel="ms / GFLOPS / GB/s", + plot_name=get_caller_name_no_ext() + "_act_mul", + args={}, + ) + + @triton.testing.perf_report([benchmark]) + def bench_act_mul(M, N, K, E, top_k, metric, model=None): + n_even = N if N % 2 == 0 else N - 1 + if n_even < 2: + return 0.0 + n_rows = M * top_k + d = n_even // 2 + x = torch.randn(n_rows, n_even, device="cuda", dtype=dtype) + out = torch.empty(n_rows, d, device="cuda", dtype=dtype) + + elem = torch.tensor([], dtype=dtype).element_size() + mem_read = n_rows * n_even * elem + mem_write = n_rows * d * elem + flops = float(n_rows * d * 8) + + def fn(): + return act_mul(x, activation, out=out) + + ms = triton.testing.do_bench(fn, warmup=25, rep=100) + bandwidth = (mem_read + mem_write) / (ms * 1e-3) * 1e-9 + gflops = flops / ms * 1e-6 + + if metric == "time": + return ms + elif metric == "gflops": + return gflops + elif metric == "bandwidth": + return bandwidth + else: + raise ValueError("Unknown metric: " + metric) + + bench_act_mul.run(save_path="." if args.o else None, print_data=True) + + def parse_args(): parser = argparse.ArgumentParser( prog="Benchmark MoE GEMM", @@ -300,6 +363,19 @@ def parse_args(): parser.add_argument("-dtype", default="fp16") parser.add_argument("-fp8_type", default="e5m2fnuz") parser.add_argument("-silu_fused", action="store_true", default=False) + parser.add_argument( + "-bench_act_mul", + action="store_true", + default=False, + help="Benchmark act_mul (no quant) using model M, N, top_k (same layout as silu-fused MoE).", + ) + parser.add_argument( + "-act_mul_activation", + type=str, + default="silu", + choices=["silu", "gelu", "gelu_tanh"], + help="Activation for -bench_act_mul.", + ) parser.add_argument( "-o", action="store_true", help="Write performance results to CSV file" ) @@ -310,6 +386,17 @@ def parse_args(): def main(): args = parse_args() + if args.bench_act_mul: + if args.print_vgpr: + + def fun(): + return run_act_mul_benchmark(args) + + print_vgpr(fun, get_caller_name_no_ext() + "_act_mul") + return 0 + run_act_mul_benchmark(args) + return 0 + if args.print_vgpr: print("Retrieving VGPR usage for Triton kernels...") diff --git a/op_tests/op_benchmarks/triton/utils/model_configs.json b/op_tests/op_benchmarks/triton/utils/model_configs.json index f41efcd412..7f384521cd 100644 --- a/op_tests/op_benchmarks/triton/utils/model_configs.json +++ b/op_tests/op_benchmarks/triton/utils/model_configs.json @@ -61,8 +61,55 @@ "num_expert": 256, "top_k": 4 } + }, + "glm47fp8": { + "TP4": { + "hidden_size": 1280, + "intermediate_size": 768, + "num_attention_heads": 96, + "num_key_value_heads": 8, + "vocab_size": 151552, + "num_expert": 160, + "top_k": 8, + "parallel": "TP4", + "notes": "GLM-4.7 MoE (zai-org/GLM-4.7): per-GPU tensor-parallel shard; K=5120/4, N=2*moe_intermediate/4=3072/4" + }, + "EP4": { + "hidden_size": 5120, + "intermediate_size": 3072, + "num_attention_heads": 96, + "num_key_value_heads": 8, + "vocab_size": 151552, + "num_expert": 40, + "top_k": 8, + "parallel": "EP4", + "notes": "GLM-4.7 MoE: per-GPU expert-parallel shard; E=160/4, N=2*1536" + } + }, + "kimik25": { + "TP4": { + "hidden_size": 1792, + "intermediate_size": 1024, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "vocab_size": 163840, + "num_expert": 384, + "top_k": 8, + "parallel": "TP4", + "notes": "Kimi-K2.5 text MoE (moonshotai/Kimi-K2.5): TP4 shard; K=7168/4, N=2*moe_intermediate/4=4096/4" + }, + "EP4": { + "hidden_size": 7168, + "intermediate_size": 4096, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "vocab_size": 163840, + "num_expert": 96, + "top_k": 8, + "parallel": "EP4", + "notes": "Kimi-K2.5 text MoE: EP4 shard; E=384/4, N=2*2048" + } } } - - + \ No newline at end of file diff --git a/op_tests/triton_tests/test_activation.py b/op_tests/triton_tests/test_activation.py index 38dbd4779d..946f3e525e 100644 --- a/op_tests/triton_tests/test_activation.py +++ b/op_tests/triton_tests/test_activation.py @@ -6,7 +6,7 @@ shuffle_scales, un_shuffle_scales, ) -from aiter.ops.triton.activation import act_mul_and_mxfp4_quant +from aiter.ops.triton.activation import act_mul, act_mul_and_mxfp4_quant import aiter.ops.triton.utils._triton.arch_info as arch_info DEBUG_MODE = False @@ -140,3 +140,58 @@ def test_act_mul_and_mxfp4_quant( torch.testing.assert_close(triton_out, torch_out) torch.testing.assert_close(triton_scale, torch_scale) + + +def torch_act_mul_ref(x: torch.Tensor, activation: str) -> torch.Tensor: + d = x.shape[-1] // 2 + a, b = x[..., :d], x[..., d:] + if activation == "silu": + y = F.silu(a) * b + elif activation == "gelu": + y = F.gelu(a) * b + else: + y = F.gelu(a, approximate="tanh") * b + return y.to(x.dtype) + + +@pytest.mark.parametrize( + "M, N_half", + [ + (4, 64), + (31, 128), + (128, 256), + (1, 7), + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"]) +@pytest.mark.parametrize("use_out", [False, True]) +def test_act_mul_no_quant(M, N_half, dtype, activation, use_out): + N = 2 * N_half + x = torch.randn((M, N), dtype=dtype, device="cuda") + ref = torch_act_mul_ref(x, activation) + if use_out: + out = torch.empty_like(ref) + act_mul(x, activation, out=out) + else: + out = act_mul(x, activation) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize( + "shape", + [ + (2, 4, 64), + (3, 5, 128), + (2, 3, 4, 32), + ], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("activation", ["silu", "gelu"]) +def test_act_mul_no_quant_higher_rank(shape, dtype, activation): + """act_mul should handle any input rank by flattening/unflattening internally.""" + x = torch.randn(shape, dtype=dtype, device="cuda") + ref = torch_act_mul_ref(x, activation) + out = act_mul(x, activation) + assert out.shape == ref.shape + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2)