Skip to content

[TRITON] gfx1201: gemm_a8w8 tuning configs (Mistral-3 / Qwen3 shapes)#3168

Open
carlushuang wants to merge 1 commit into
mainfrom
carhuang/gfx1201_silu_and_mul_and_a8w8_configs
Open

[TRITON] gfx1201: gemm_a8w8 tuning configs (Mistral-3 / Qwen3 shapes)#3168
carlushuang wants to merge 1 commit into
mainfrom
carhuang/gfx1201_silu_and_mul_and_a8w8_configs

Conversation

@carlushuang
Copy link
Copy Markdown
Collaborator

Summary

Two small additions that let aiter run on RDNA4 (gfx1201, RX 9070 XT family) without the calling project having to maintain its own kernel/config replicas.

  1. aiter.ops.triton.activation.silu_and_mul — a triton implementation of the existing HIP silu_and_mul, with the same (out, x) signature so callers can dispatch by arch without changing call sites. The HIP kernel does not compile on RDNA4: its inner activation_kernels.cu uses v_pk_mul_f32, an instruction that exists only on CDNA (gfx9*) and gfx1250.

  2. 5 gfx1201-GEMM-A8W8*.json tuning configs for the per-tensor FP8 gemm_a8w8 triton kernel. Without these, gemm_config_utils falls through to the cross-arch default (GROUP_SIZE_M=4), which leaves 75% of M-dim launch slots idle on RDNA4 at decode bs=1..32. Each config is hand-tuned on RX 9070 XT for one of the four projection shapes used by Mistral-3-8B / Qwen3-8B-FP8 (qkv, o, gate_up, down).

Headline numbers (gfx1201, RX 9070 XT, ROCm 7.x)

silu_and_mul triton vs torch fallback (the only other option on gfx1201):

Shape (M, 2H) Triton Torch Δ
(8, 28672) 9.1us 10.0us -9%
(32, 28672) 9.2us 13.2us -30%
(1024, 28672) 153us 205us -25%

gemm_a8w8 per-shape configs vs the cross-arch default at decode bs=1:

Shape Default Tuned
qkv 163us 33us
o 45us 28us
gate_up 229us 211us
down 107us 36us

Test plan

  • silu_and_mul: bf16 + fp16, 2H non-power-of-2 included; relative err <1% vs F.silu(a)*b (triton accumulates in fp32, so it is in fact more accurate than the bf16 reference)
  • gemm_a8w8: get_gemm_config(\"GEMM-A8W8\", M, N, K) returns the new specialized blocks for all 4 (N, K); kernel output 0 abs-err vs BF16 reference (dequant FP8 then matmul) at bs in {1, 8, 32}
  • No regression risk for other archs — both additions are arch-specific files / new symbols; nothing existing is renamed or removed

Context

Used by ROCm/ATOM in ROCm/ATOM#749 (gfx1201 / Mistral-3-8B + Qwen3-8B-FP8 enablement). Once this PR lands and the aiter pin in ATOM is bumped, ATOM can delete its _silu_mul_triton and _gfx1201_gemm_a8w8_config replicas.

@carlushuang carlushuang requested a review from a team May 13, 2026 14:24
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3168 --add-label <label>

Comment thread aiter/ops/triton/activation.py
carlushuang added a commit to ROCm/ATOM that referenced this pull request May 13, 2026
… aiter PR #3168

Two atom-side replicas removed; the kernels and tuning configs now live
upstream in aiter (ROCm/aiter#3168).

atom/model_ops/activation.py
- Delete _silu_mul_kernel (the @triton.jit) and the _silu_mul_triton
  wrapper. Replace the gfx1201 SiluAndMul.forward dispatch with a call
  to aiter.ops.triton.activation.silu_and_mul (added in aiter PR).
- Drop now-unused triton imports.

atom/model_ops/linear.py
- Delete _gfx1201_gemm_a8w8_config and the two helpers it pulled in
  (_gfx1201_parse_gemm_config_value, _gfx1201_apply_gemm_config_spec,
  plus the ATOM_GFX1201_GEMM_A8W8_CONFIG[_<SHAPE>] env-var override
  hooks). aiter's get_gemm_config now auto-loads our 4 specialized
  gfx1201-GEMM-A8W8-N=X-K=Y.json files plus the gfx1201 default,
  so atom no longer needs per-shape dispatch logic.
- Drop the config=cfg kwarg at the gemm_a8w8 call site; aiter resolves
  the config from arch + M, N, K on its own.

Net: -182 LOC. Behavior is bit-identical: the aiter PR ports the same
kernel and the same JSON config values, verified end-to-end against
the prior baseline:

  Mistral-3-8B gsm8k 5-shot, n=200:  0.765 / 0.765  within 2 sigma of 0.83 baseline
  Mistral-3-8B TPOT BS=1/8/16:       18.4 / 19.8 / 21.8 ms  matches baseline within 0.1 ms
  Qwen3-8B-FP8 gsm8k 5-shot, n=200:  0.91  / 0.90   within 1 sigma of 0.925 baseline
  Qwen3-8B-FP8 TPOT BS=1/8/16:       18.5 / 19.9 / 21.5 ms  matches the quiet-host baseline

Note: this commit assumes aiter PR #3168 is merged or that the docker
base image has it staged in /app/aiter-test/aiter/. Until then atom
will ImportError on aiter.ops.triton.activation.silu_and_mul on
gfx1201; non-gfx1201 paths are unaffected.
brunomazzottiamd

This comment was marked as resolved.

@Chi-Chu319
Copy link
Copy Markdown
Contributor

Hi, can this be a duplicate of #2592?

…hapes

Drops 5 JSON configs into aiter/ops/triton/configs/gemm/:

- gfx1201-GEMM-A8W8.json                     (default)
- gfx1201-GEMM-A8W8-N=6144-K=4096.json       (Mistral-3 / Qwen3 qkv_proj)
- gfx1201-GEMM-A8W8-N=4096-K=4096.json       (o_proj)
- gfx1201-GEMM-A8W8-N=28672-K=4096.json      (gate_up_proj for Mistral-3)
- gfx1201-GEMM-A8W8-N=4096-K=14336.json      (down_proj for Mistral-3)

Without a per-arch config file, aiter/ops/triton/utils/gemm_config_utils
falls through to the cross-arch default, which on gfx1201 selects
GROUP_SIZE_M=4. That is a reasonable choice on CDNA where 4 M-tiles of
work fit naturally per workgroup, but it leaves 75% of the M-dim launch
slots idle on RDNA4 at decode bs=1..32 (only 1 real M-tile per call).

Each shape is hand-tuned on RX 9070 XT (cold-cache, 30-iter bench).
Headline kernel-time deltas vs the cross-arch default at decode bs=1:

  qkv      163us -> 33us
  o         45us -> 28us
  gate_up  229us -> 211us
  down     107us -> 36us

The "any" key plus matching M_LEQ behavior in get_gemm_config means a
single tuned entry per (N, K) covers our full BS=1..32 sweep. Verified
correct against the BF16 reference (dequant FP8 then matmul) at 0.0
abs error for all 4 shapes at bs in {1, 8, 32}.
@carlushuang carlushuang force-pushed the carhuang/gfx1201_silu_and_mul_and_a8w8_configs branch from dfdc7cd to 9fe8b5a Compare May 16, 2026 08:22
@carlushuang carlushuang changed the title [TRITON] gfx1201: silu_and_mul kernel + gemm_a8w8 tuning configs [TRITON] gfx1201: gemm_a8w8 tuning configs (Mistral-3 / Qwen3 shapes) May 16, 2026
@brunomazzottiamd
Copy link
Copy Markdown
Contributor

Hi @carlushuang. Can you please edit the PR description, removing the content related to SiLU + Mul Triton kernel? Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants