Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions aiter/ops/triton/_triton_kernels/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,48 @@ def _silu(x):
return _silu_exp2(x)


@triton.jit
def fused_silu_mul_kernel(
inp_ptr,
out_ptr,
n_rows,
n_cols,
row_stride_in,
col_stride_in,
row_stride_out,
col_stride_out,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""
SiLU on the first half of the last dimension, multiply by the second half.
Each row has 2 * n_cols input elements; writes n_cols outputs.
2D grid: axis 0 tiles rows (BLOCK_M), axis 1 tiles columns (BLOCK_N).
"""
m_pid = tl.program_id(0)
n_pid = tl.program_id(1)
m_offs = tl.arange(0, BLOCK_M)
n_offs = tl.arange(0, BLOCK_N)
row_idx = m_pid * BLOCK_M + m_offs
col_idx = n_pid * BLOCK_N + n_offs

row_in = row_idx * row_stride_in
row_out = row_idx * row_stride_out

first_half_ptrs = inp_ptr + row_in[:, None] + col_idx[None, :] * col_stride_in
second_half_ptrs = (
inp_ptr + row_in[:, None] + (n_cols + col_idx)[None, :] * col_stride_in
)
out_ptrs = out_ptr + row_out[:, None] + col_idx[None, :] * col_stride_out

mask = (row_idx < n_rows)[:, None] & (col_idx < n_cols)[None, :]
a = tl.load(first_half_ptrs, mask=mask, other=0.0).to(tl.float32)
silu_a = _silu_exp2(a).to(inp_ptr.dtype.element_ty)
b = tl.load(second_half_ptrs, mask=mask, other=0.0)
o = (silu_a * b).to(out_ptr.dtype.element_ty)
tl.store(out_ptrs, o, mask=mask)


@triton.jit
def _tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
Expand Down
117 changes: 113 additions & 4 deletions aiter/ops/triton/activation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Literal
from typing import Literal, Optional
import triton
import triton.language as tl
import torch
import aiter

fp8_dtype = aiter.dtypes.fp8
from aiter.ops.triton.utils.logger import AiterTritonLogger
from aiter.ops.triton._triton_kernels.activation import (
_act_mul_and_dynamic_mxfp4_quant_kernel,
_act_mul_and_dynamic_fp8_group_quant_kernel,
fused_silu_mul_kernel,
)

fp8_dtype = aiter.dtypes.fp8

_LOGGER = AiterTritonLogger()


Expand Down Expand Up @@ -198,3 +198,112 @@ def act_mul_and_fp8_group_quant(
)

return x_fp8, out_bs


def fused_silu_mul(
x: torch.Tensor,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Fused SiLU-and-mul along the last dimension (same pattern as MoE silu-fused GEMM).

``x`` must be contiguous with even ``size(-1)``. For last size ``2 * d``, the first
``d`` lanes are passed through SiLU (``_silu_exp2``); the second ``d`` lanes are the
multipliers. Output shape matches ``x`` except ``out.size(-1) == d``.

Returns:
``out`` if provided, else a newly allocated tensor.
"""

def _pick_block_n(d: int, n_rows: int) -> int:
"""Tile size along the reduced last dim (cap 1024); at least 32 for vectorization.

Tuned on ROCm for MoE TP4 locals (GLM-4.7 ``d=384``, Kimi-K2.5 ``d=512``) and wide
MoE activations: ``n_rows`` selects decode vs prefill N-tiling (see sweep in repo
history / ``bench_moe.py -bench_silu_mul``).
"""
n = max(d, 1)
# Kimi-K2.5 TP4 (d=512): prefill favors one 512-wide N tile; decode keeps 256×2.
if n == 512:
return 512 if n_rows > 4096 else 256
# GLM-4.7 TP4 (d=384): wider decode rows use 256×2; larger batches favor 128×3 N tiles.
if n == 384:
return 256 if n_rows <= 128 else 128
upper = min(n, 1024)
p = 1
while p * 2 <= upper:
p *= 2
return max(32, p)

def _pick_block_m(n_rows: int, block_n: int, d: int) -> int:
"""Row tile size: latency shapes use wide M tiles; prefill uses tuned (d, n_rows) pairs."""
if n_rows <= 64:
return min(32, max(4, triton.next_power_of_2(n_rows)))
if d == 384 and n_rows > 128:
return 32 if n_rows > 8192 else 8
if d == 512 and n_rows > 4096:
return 8
if d == 512 and 128 < n_rows <= 4096:
return 8
if block_n >= 1024:
return 8
if block_n >= 512:
return 8
return 16

def _pick_num_warps(n_rows: int, block_m: int, block_n: int) -> int:
"""ROCm: 8 warps for tiny full-wavefront decode tiles; 2 warps for larger tiles."""
if n_rows <= 128 and block_m >= 16 and block_n >= 128:
return 8
return 2

assert x.is_cuda, "fused_silu_mul requires a CUDA tensor"
assert x.is_contiguous(), "x must be contiguous"
last = x.size(-1)
assert last % 2 == 0, "last dimension must be even (2 * d)"
d = last // 2
leading = x.shape[:-1]
n_rows = x.numel() // (2 * d)
if n_rows == 0:
return (
torch.empty(*leading, d, dtype=x.dtype, device=x.device)
if out is None
else out
)

_LOGGER.info(f"fused_silu_mul: x={tuple(x.shape)} last_half={d} rows={n_rows}")

if out is None:
out = torch.empty(*leading, d, dtype=x.dtype, device=x.device)
else:
assert out.is_contiguous(), "out must be contiguous"
assert out.shape == (*leading, d), "out shape must match x with last dim halved"
assert out.dtype == x.dtype and out.device == x.device

row_stride_in = 2 * d
col_stride_in = 1
row_stride_out = d
col_stride_out = 1

block_n = _pick_block_n(d, n_rows)
block_m = _pick_block_m(n_rows, block_n, d)
grid_m = triton.cdiv(n_rows, block_m)
grid_n = triton.cdiv(d, block_n)
num_warps = _pick_num_warps(n_rows, block_m, block_n)

grid = (grid_m, grid_n)
fused_silu_mul_kernel[grid](
x,
out,
n_rows,
d,
row_stride_in,
col_stride_in,
row_stride_out,
col_stride_out,
BLOCK_M=block_m,
BLOCK_N=block_n,
num_warps=num_warps,
waves_per_eu=0,
)
return out
126 changes: 126 additions & 0 deletions op_tests/op_benchmarks/triton/bench_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aiter.ops.triton.utils.types import torch_to_triton_dtype, str_to_torch_dtype
from aiter.ops.triton.moe.moe_op import fused_moe as triton_moe
from aiter.ops.triton.moe.moe_op_silu_fused import fused_moe_silu as triton_moe_silu
from aiter.ops.triton.activation import fused_silu_mul
from op_tests.triton_tests.moe.test_moe import input_helper, input_helper_int4_w4a16
from op_tests.op_benchmarks.triton.utils.benchmark_utils import (
get_model_configs,
Expand Down Expand Up @@ -41,6 +42,44 @@ def model_benchmark_configs(args):
return moe_configs


def silu_mul_benchmark_configs(args):
"""Like ``model_benchmark_configs`` but supports multiple M via ``-silu_mul_M_list``."""
configs = get_model_configs(
config_path=args.model_configs,
models="mixtral" if args.model is None else args.model,
)
if not configs:
return []
no_bench_stage2 = args.no_bench_stage2
if args.silu_mul_M_list:
default_ms = [
int(x.strip()) for x in args.silu_mul_M_list.split(",") if x.strip()
]
elif args.M:
default_ms = [args.M]
else:
default_ms = [4096]

moe_configs = []
for model_name, config in configs.items():
ms_model = config.get("silu_mul_benchmark_M", default_ms)
if not isinstance(ms_model, list):
ms_model = [int(ms_model)]
else:
ms_model = [int(m) for m in ms_model]
for M in ms_model:
N1 = config["intermediate_size"]
K1 = config["hidden_size"]
E = config["num_expert"]
top_k = config["top_k"]
moe_configs.append((model_name, M, N1, K1, E, top_k))
if no_bench_stage2:
N2 = config["hidden_size"]
K2 = config["intermediate_size"] // 2
moe_configs.append((model_name, M, N2, K2, E, top_k))
return moe_configs


def fused_moe(
M,
N,
Expand Down Expand Up @@ -262,6 +301,69 @@ def bench_moe_gemm(M, N, K, E, top_k, metric, model=None):
bench_moe_gemm.run(save_path="." if args.o else None, print_data=True)


def run_silu_mul_benchmark(args):
"""Benchmark last-dim fused SiLU-and-mul (same activation as silu-fused MoE)."""
print_time = args.print_time
dtype = str_to_torch_dtype[args.dtype]

if print_time:
line_names = ["Time_(ms)"]
line_vals = ["time"]
else:
line_names = ["Time_(ms)", "GFLOPS", "Bandwidth_(GB/s)"]
line_vals = ["time", "gflops", "bandwidth"]

x_vals_list = silu_mul_benchmark_configs(args)
x_names = ["model", "M", "N", "K", "E", "top_k"]

benchmark = triton.testing.Benchmark(
x_names=x_names,
x_vals=x_vals_list,
line_arg="metric",
line_vals=line_vals,
line_names=line_names,
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
ylabel="ms / GFLOPS / GB/s",
plot_name=get_caller_name_no_ext() + "_silu_mul",
args={},
)

@triton.testing.perf_report([benchmark])
def bench_silu_mul(M, N, K, E, top_k, metric, model=None):
# Match MoE post-GEMM layout: (M * top_k, N); N must be even for gate/up pairs.
n_even = N if N % 2 == 0 else N - 1
if n_even < 2:
return 0.0
n_rows = M * top_k
d = n_even // 2
x = torch.randn(n_rows, n_even, device="cuda", dtype=dtype)
out = torch.empty(n_rows, d, device="cuda", dtype=dtype)

elem = torch.tensor([], dtype=dtype).element_size()
mem_read = n_rows * n_even * elem
mem_write = n_rows * d * elem
# Rough op count: SiLU + mul per output element
flops = float(n_rows * d * 8)

def fn():
return fused_silu_mul(x, out)

ms = triton.testing.do_bench(fn, warmup=25, rep=100)
bandwidth = (mem_read + mem_write) / (ms * 1e-3) * 1e-9
gflops = flops / ms * 1e-6

if metric == "time":
return ms
elif metric == "gflops":
return gflops
elif metric == "bandwidth":
return bandwidth
else:
raise ValueError("Unknown metric: " + metric)

bench_silu_mul.run(save_path="." if args.o else None, print_data=True)


def parse_args():
parser = argparse.ArgumentParser(
prog="Benchmark MoE GEMM",
Expand Down Expand Up @@ -300,6 +402,19 @@ def parse_args():
parser.add_argument("-dtype", default="fp16")
parser.add_argument("-fp8_type", default="e5m2fnuz")
parser.add_argument("-silu_fused", action="store_true", default=False)
parser.add_argument(
"-bench_silu_mul",
action="store_true",
default=False,
help="Benchmark fused last-dim SiLU-and-mul only (uses model M, N, top_k).",
)
parser.add_argument(
"-silu_mul_M_list",
type=str,
default=None,
help="Comma-separated token counts M for silu_mul bench (e.g. 4,8193,7238). "
"Row count is M * top_k. Implies multiple table rows when set.",
)
parser.add_argument(
"-o", action="store_true", help="Write performance results to CSV file"
)
Expand All @@ -310,6 +425,17 @@ def parse_args():
def main():
args = parse_args()

if args.bench_silu_mul:
if args.print_vgpr:

def fun():
return run_silu_mul_benchmark(args)

print_vgpr(fun, get_caller_name_no_ext() + "_silu_mul")
return 0
run_silu_mul_benchmark(args)
return 0

if args.print_vgpr:
print("Retrieving VGPR usage for Triton kernels...")

Expand Down
44 changes: 44 additions & 0 deletions op_tests/op_benchmarks/triton/utils/model_configs.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,50 @@
"num_expert": 256,
"top_k": 4
}
},
"glm47fp8": {
"tp4": {
"hidden_size": 5120,
"intermediate_size": 768,
"num_attention_heads": 96,
"num_key_value_heads": 8,
"vocab_size": 151552,
"num_expert": 160,
"top_k": 8,
"silu_mul_benchmark_M": [4, 8193, 7238]
},
"ep4": {
"hidden_size": 5120,
"intermediate_size": 3072,
"num_attention_heads": 96,
"num_key_value_heads": 8,
"vocab_size": 151552,
"num_expert": 40,
"top_k": 2,
"silu_mul_benchmark_M": [4, 8193, 7238]
}
},
"kimik25": {
"tp4": {
"hidden_size": 7168,
"intermediate_size": 1024,
"num_attention_heads": 64,
"num_key_value_heads": 64,
"vocab_size": 163840,
"num_expert": 384,
"top_k": 8,
"silu_mul_benchmark_M": [4]
},
"ep4": {
"hidden_size": 7168,
"intermediate_size": 4096,
"num_attention_heads": 64,
"num_key_value_heads": 64,
"vocab_size": 163840,
"num_expert": 96,
"top_k": 2,
"silu_mul_benchmark_M": [4]
}
}
}

Expand Down
Loading
Loading