Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
8aeaf2a
Add act_mul without quant (DO_QUANT), model configs, benchmarks
Chi-Chu319 Apr 2, 2026
9c22cd2
Update op_tests/op_benchmarks/triton/bench_moe.py
Chi-Chu319 Apr 2, 2026
4bbd3d6
Update aiter/ops/triton/activation.py
juuso-oskari Apr 7, 2026
3f06beb
Merge branch 'main' into feature/act-mul-no-quant
juuso-oskari Apr 7, 2026
048736c
no need to move the b load after a activation
juuso-oskari Apr 7, 2026
8c333b9
Merge branch 'feature/act-mul-no-quant' of https://github.com/ROCm/ai…
juuso-oskari Apr 7, 2026
1357f0e
Update aiter/ops/triton/activation.py
juuso-oskari Apr 7, 2026
775a25a
fix: act_mul handles higher-rank inputs by flattening/unflattening in…
Copilot Apr 7, 2026
85c792d
do mul in b dtype and use more universal x_out_ptr name
juuso-oskari Apr 7, 2026
8e90c8c
black
juuso-oskari Apr 7, 2026
0f07849
revert back the b dtype mul
juuso-oskari Apr 7, 2026
53171cd
Merge branch 'main' into feature/act-mul-no-quant
Chi-Chu319 May 13, 2026
daa2a75
Merge branch 'main' into feature/act-mul-no-quant
Chi-Chu319 May 15, 2026
321bcf5
Remove DO_QUANT parameter from activation functions and update relate…
Chi-Chu319 May 15, 2026
85941be
Merge branch 'feature/act-mul-no-quant' of https://github.com/ROCm/ai…
Chi-Chu319 May 15, 2026
9c90e54
renamed x_out_ptr back to x_fp4_ptr
Chi-Chu319 May 15, 2026
3a0c78f
move fp8_dtype = aiter.dtypes.fp8 after imports
Chi-Chu319 May 15, 2026
b4a4117
Refactor act_mul function to remove unused dummy_bs tensor and update…
Chi-Chu319 May 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 36 additions & 27 deletions aiter/ops/triton/_triton_kernels/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel(
).to(tl.float32)

x = _apply_activation_from_str(a, ACTIVATION) * b

out_tensor, bs_e8m0 = _mxfp4_quant_op(
x, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE
)
Expand Down Expand Up @@ -183,7 +182,6 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel(
if EVEN_M_N:
tl.store(bs_ptr + bs_offs, bs_e8m0)
else:

tl.store(
bs_ptr + bs_offs,
bs_e8m0,
Expand All @@ -199,7 +197,7 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel(
@triton.jit
def _act_mul_and_dynamic_fp8_group_quant_kernel(
x_ptr,
x_fp8_ptr,
x_out_ptr,
x_bs_ptr,
stride_x_m_in,
stride_x_n_in,
Expand All @@ -215,6 +213,7 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel(
DTYPE_MAX: tl.constexpr,
DTYPE_MIN: tl.constexpr,
EVEN_N: tl.constexpr,
DO_QUANT: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
Expand All @@ -223,8 +222,6 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel(
stride_x_n = tl.cast(stride_x_n_in, tl.int64)
stride_x_fp8_m = tl.cast(stride_x_fp8_m_in, tl.int64)
stride_x_fp8_n = tl.cast(stride_x_fp8_n_in, tl.int64)
stride_bs_m = tl.cast(stride_bs_m_in, tl.int64)
stride_bs_n = tl.cast(stride_bs_n_in, tl.int64)
NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE

x_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
Expand All @@ -245,31 +242,43 @@ def _act_mul_and_dynamic_fp8_group_quant_kernel(

x = _apply_activation_from_str(a, ACTIVATION) * b

x_fp8, x_bs = _fp8_quant_op(
x, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN
)
x_fp8 = tl.ravel(x_fp8)
x_bs = tl.ravel(x_bs)

out_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_offs = pid_m * stride_x_fp8_m + out_offs_n * stride_x_fp8_n

if EVEN_N:
tl.store(x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty))
else:
out_mask = out_offs_n < N
tl.store(
x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty), mask=out_mask
if DO_QUANT:
Comment thread
Chi-Chu319 marked this conversation as resolved.
stride_bs_m = tl.cast(stride_bs_m_in, tl.int64)
stride_bs_n = tl.cast(stride_bs_n_in, tl.int64)
x_fp8, x_bs = _fp8_quant_op(
x, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN
)
x_fp8 = tl.ravel(x_fp8)
x_bs = tl.ravel(x_bs)

bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS)
bs_offs = pid_m * stride_bs_m + bs_offs_n * stride_bs_n
if EVEN_N:
tl.store(x_bs_ptr + bs_offs, x_bs.to(x_bs_ptr.dtype.element_ty))
if EVEN_N:
tl.store(x_out_ptr + out_offs, x_fp8.to(x_out_ptr.dtype.element_ty))
else:
out_mask = out_offs_n < N
tl.store(
x_out_ptr + out_offs,
x_fp8.to(x_out_ptr.dtype.element_ty),
mask=out_mask,
)

bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS)
bs_offs = pid_m * stride_bs_m + bs_offs_n * stride_bs_n
if EVEN_N:
tl.store(x_bs_ptr + bs_offs, x_bs.to(x_bs_ptr.dtype.element_ty))
else:
bs_mask = bs_offs_n < scaleN
tl.store(
x_bs_ptr + bs_offs,
x_bs.to(x_bs_ptr.dtype.element_ty),
mask=bs_mask,
)
else:
bs_mask = bs_offs_n < scaleN
tl.store(
x_bs_ptr + bs_offs,
x_bs.to(x_bs_ptr.dtype.element_ty),
mask=bs_mask,
)
x_out = x.to(x_out_ptr.dtype.element_ty)
if EVEN_N:
tl.store(x_out_ptr + out_offs, x_out)
else:
out_mask = out_offs_n < N
tl.store(x_out_ptr + out_offs, x_out, mask=out_mask)
73 changes: 69 additions & 4 deletions aiter/ops/triton/activation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Literal
from typing import Literal, Optional
import triton
import triton.language as tl
import torch
import aiter

fp8_dtype = aiter.dtypes.fp8
from aiter.ops.triton.utils.logger import AiterTritonLogger
from aiter.ops.triton._triton_kernels.activation import (
_act_mul_and_dynamic_mxfp4_quant_kernel,
_act_mul_and_dynamic_fp8_group_quant_kernel,
)

fp8_dtype = aiter.dtypes.fp8

_LOGGER = AiterTritonLogger()


Expand Down Expand Up @@ -128,6 +127,71 @@ def act_mul_and_mxfp4_quant(
return x_fp4, blockscale_e8m0


def act_mul(
x: torch.Tensor,
activation: Literal["silu", "gelu", "gelu_tanh"],
out: Optional[torch.Tensor] = None,
group_size: Optional[int] = None,
) -> torch.Tensor:
"""
Gated activation along the last dimension only (no quantization): ``act(x0) * x1``
where ``x`` is ``[..., 2 * d]`` split into two ``[..., d]`` halves.

Accepts any input rank; leading dimensions are flattened before the kernel and
restored in the returned tensor, which has shape ``[..., d]``.

Uses the same Triton path as ``act_mul_and_fp8_group_quant`` with quantization disabled.
"""
_LOGGER.info(f"ACT_MUL: x={tuple(x.shape)} activation={activation}")
assert x.is_cuda and x.is_contiguous()
assert x.shape[-1] % 2 == 0

# Flatten all leading dimensions so the kernel always sees a 2-D tensor.
orig_shape = x.shape
x_2d = x.view(-1, orig_shape[-1])
M, N = x_2d.shape
N_half = N // 2
out_shape = orig_shape[:-1] + (N_half,)

if out is None:
out_2d = torch.empty((M, N_half), dtype=x.dtype, device=x.device)
else:
assert out.shape == out_shape
assert out.dtype == x.dtype and out.is_contiguous()
out_2d = out.view(M, N_half)

if group_size is None:
group_size = min(256, triton.next_power_of_2(N_half))
group_size = max(32, group_size)

scaleN = triton.cdiv(N_half, group_size)
DTYPE_MAX = 1.0
BLOCK_SIZE_N = group_size

grid = (
M,
triton.cdiv(N_half, BLOCK_SIZE_N),
)
_act_mul_and_dynamic_fp8_group_quant_kernel[grid](
x_2d,
out_2d,
None,
*x_2d.stride(),
*out_2d.stride(),
0, # stride_bs_m_in
0, # stride_bs_n_in
N=N_half,
ACTIVATION=activation,
scaleN=scaleN,
BLOCK_SIZE_N=BLOCK_SIZE_N,
QUANT_BLOCK_SIZE=group_size,
DTYPE_MAX=DTYPE_MAX,
DTYPE_MIN=-DTYPE_MAX,
DO_QUANT=False,
)
return out_2d.view(out_shape)


def act_mul_and_fp8_group_quant(
x: torch.Tensor,
activation: Literal["silu", "gelu", "gelu_tanh"],
Expand Down Expand Up @@ -195,6 +259,7 @@ def act_mul_and_fp8_group_quant(
# num_warps=NUM_WARPS,
# waves_per_eu=0,
# num_stages=1,
DO_QUANT=True,
)

return x_fp8, out_bs
87 changes: 87 additions & 0 deletions op_tests/op_benchmarks/triton/bench_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aiter.ops.triton.utils.types import torch_to_triton_dtype, str_to_torch_dtype
from aiter.ops.triton.moe.moe_op import fused_moe as triton_moe
from aiter.ops.triton.moe.moe_op_silu_fused import fused_moe_silu as triton_moe_silu
from aiter.ops.triton.activation import act_mul
from op_tests.triton_tests.moe.test_moe import input_helper, input_helper_int4_w4a16
from op_tests.op_benchmarks.triton.utils.benchmark_utils import (
get_model_configs,
Expand Down Expand Up @@ -262,6 +263,68 @@ def bench_moe_gemm(M, N, K, E, top_k, metric, model=None):
bench_moe_gemm.run(save_path="." if args.o else None, print_data=True)


def run_act_mul_benchmark(args):
"""Benchmark ``act_mul`` (SiLU/GELU * gate) on tensors shaped like MoE activations."""
print_time = args.print_time
dtype = str_to_torch_dtype[args.dtype]
activation = args.act_mul_activation

if print_time:
line_names = ["Time_(ms)"]
line_vals = ["time"]
else:
line_names = ["Time_(ms)", "GFLOPS", "Bandwidth_(GB/s)"]
line_vals = ["time", "gflops", "bandwidth"]

x_vals_list = model_benchmark_configs(args)
x_names = ["model", "M", "N", "K", "E", "top_k"]

benchmark = triton.testing.Benchmark(
x_names=x_names,
x_vals=x_vals_list,
line_arg="metric",
line_vals=line_vals,
line_names=line_names,
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
ylabel="ms / GFLOPS / GB/s",
plot_name=get_caller_name_no_ext() + "_act_mul",
args={},
)

@triton.testing.perf_report([benchmark])
def bench_act_mul(M, N, K, E, top_k, metric, model=None):
n_even = N if N % 2 == 0 else N - 1
if n_even < 2:
return 0.0
n_rows = M * top_k
d = n_even // 2
x = torch.randn(n_rows, n_even, device="cuda", dtype=dtype)
out = torch.empty(n_rows, d, device="cuda", dtype=dtype)

elem = torch.tensor([], dtype=dtype).element_size()
mem_read = n_rows * n_even * elem
mem_write = n_rows * d * elem
flops = float(n_rows * d * 8)

def fn():
return act_mul(x, activation, out=out)

ms = triton.testing.do_bench(fn, warmup=25, rep=100)
bandwidth = (mem_read + mem_write) / (ms * 1e-3) * 1e-9
gflops = flops / ms * 1e-6

if metric == "time":
return ms
elif metric == "gflops":
return gflops
elif metric == "bandwidth":
return bandwidth
else:
raise ValueError("Unknown metric: " + metric)

bench_act_mul.run(save_path="." if args.o else None, print_data=True)


def parse_args():
parser = argparse.ArgumentParser(
prog="Benchmark MoE GEMM",
Expand Down Expand Up @@ -300,6 +363,19 @@ def parse_args():
parser.add_argument("-dtype", default="fp16")
parser.add_argument("-fp8_type", default="e5m2fnuz")
parser.add_argument("-silu_fused", action="store_true", default=False)
parser.add_argument(
"-bench_act_mul",
action="store_true",
default=False,
help="Benchmark act_mul (no quant) using model M, N, top_k (same layout as silu-fused MoE).",
)
parser.add_argument(
"-act_mul_activation",
type=str,
default="silu",
choices=["silu", "gelu", "gelu_tanh"],
help="Activation for -bench_act_mul.",
)
parser.add_argument(
"-o", action="store_true", help="Write performance results to CSV file"
)
Expand All @@ -310,6 +386,17 @@ def parse_args():
def main():
args = parse_args()

if args.bench_act_mul:
if args.print_vgpr:

def fun():
return run_act_mul_benchmark(args)

print_vgpr(fun, get_caller_name_no_ext() + "_act_mul")
return 0
run_act_mul_benchmark(args)
return 0

if args.print_vgpr:
print("Retrieving VGPR usage for Triton kernels...")

Expand Down
51 changes: 49 additions & 2 deletions op_tests/op_benchmarks/triton/utils/model_configs.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,55 @@
"num_expert": 256,
"top_k": 4
}
},
"glm47fp8": {
"TP4": {
"hidden_size": 1280,
"intermediate_size": 768,
"num_attention_heads": 96,
"num_key_value_heads": 8,
"vocab_size": 151552,
"num_expert": 160,
"top_k": 8,
"parallel": "TP4",
"notes": "GLM-4.7 MoE (zai-org/GLM-4.7): per-GPU tensor-parallel shard; K=5120/4, N=2*moe_intermediate/4=3072/4"
},
"EP4": {
"hidden_size": 5120,
"intermediate_size": 3072,
"num_attention_heads": 96,
"num_key_value_heads": 8,
"vocab_size": 151552,
"num_expert": 40,
"top_k": 8,
"parallel": "EP4",
"notes": "GLM-4.7 MoE: per-GPU expert-parallel shard; E=160/4, N=2*1536"
}
},
"kimik25": {
"TP4": {
"hidden_size": 1792,
"intermediate_size": 1024,
"num_attention_heads": 64,
"num_key_value_heads": 64,
"vocab_size": 163840,
"num_expert": 384,
"top_k": 8,
"parallel": "TP4",
"notes": "Kimi-K2.5 text MoE (moonshotai/Kimi-K2.5): TP4 shard; K=7168/4, N=2*moe_intermediate/4=4096/4"
},
"EP4": {
"hidden_size": 7168,
"intermediate_size": 4096,
"num_attention_heads": 64,
"num_key_value_heads": 64,
"vocab_size": 163840,
"num_expert": 96,
"top_k": 8,
"parallel": "EP4",
"notes": "Kimi-K2.5 text MoE: EP4 shard; E=384/4, N=2*2048"
}
}
}




Loading
Loading