Skip to content

[gfx1201] Mistral-3 + Qwen3-8B-FP8 on RDNA4 via native triton attention#811

Open
carlushuang wants to merge 3 commits into
mainfrom
carhuang/support_gfx1201_mistral3_rebased
Open

[gfx1201] Mistral-3 + Qwen3-8B-FP8 on RDNA4 via native triton attention#811
carlushuang wants to merge 3 commits into
mainfrom
carhuang/support_gfx1201_mistral3_rebased

Conversation

@carlushuang
Copy link
Copy Markdown
Contributor

@carlushuang carlushuang commented May 16, 2026

gfx1201 (RX 9070 XT, RDNA4) support for ATOM. All-triton inference path — no prebuilt HIP kernels required.

Models verified:

  • Qwen3-8B-FP8 (single GPU, FP8 KV) — TTFT 167ms, TPOT 28ms @ bs=1
  • Ministral-3-3B-Reasoning (single GPU, BF16) — TTFT 267ms, TPOT 18ms @ bs=1
  • Ministral-3-8B-Instruct (single GPU, BF16 KV) — see recipe for full perf table

What this PR adds:

  • NativeTritonAttentionBackend — pure-triton prefill/decode attention, KV cache write, auto-selected on gfx1201
  • Triton RMSNorm, SiLU+Mul, dynamic FP8 quant kernels (in-tree, no aiter HIP deps)
  • Block table buffer fix: max_num_blocks_per_seq sized from max(max_model_len, max_num_batched_tokens)
  • CUDAGraph support at all decode batch sizes (NaN-from-padding fix)
  • Optional lm_head FP8 quantization (ATOM_GFX1201_LM_HEAD_FP8=1)
  • Deployment recipes: recipes/Ministral-3-8B.md, recipes/Qwen3-8B-FP8.md

Dependencies:

  • aiter PR #3236 (opus.hpp gfx1201 support)
  • aiter PR #3234 (gfx1201 GEMM tuning configs)

Tested on: RX 9070 XT (gfx1201), single GPU, rocm/atom-dev:latest

@carlushuang carlushuang changed the title Add Mistral-3-8B + Qwen3-8B-FP8 + native triton attention backend for gfx1201 (RDNA4 / RX 9070 XT) [gfx1201] Mistral-3 + Qwen3-8B-FP8 on RDNA4 via native triton attention May 18, 2026
…iton attention backend

Bring up two natively-FP8 8B models on AMD RX 9070 XT (gfx1201) with an
all-Triton stack and cudagraph on. Auto-routed on gfx1201; opt-in elsewhere
via ATOM_NATIVE_TRITON_ATTN=1.

Models
- New Mistral3 text-only model (strips Pixtral vision tower)
- Qwen3-8B-FP8 (block-128) via gemm_a16w8_blockscale (PREQUANT=False); the
  fp8 weight is cast to bf16 inside the kernel so tl.dot(bf16, bf16) is
  used (Triton on this gfx1201 build does not implement tl.dot(fp8, fp8))

Attention
- New NativeTritonBackend: JIT triton kernels in place of AITER prebuilt
  HIP .so (gfx1201 has no prebuilt code objects)
- CUDAGraph-correct decode at all bs incl. NaN-from-padding fix in
  prepare_decode (was silently producing wrong logits for
  scheduled_bs < captured_bs)

GEMM / activation / FP8
- Per-shape gemm_a8w8 configs (GROUP_SIZE_M=1) + per-token FP8 quant
  (1 kernel, no atomic) via aiter PR #3168
- SiluAndMul routed through aiter fused_silu_mul (PR #2578, merged)
- Removes torch fallbacks across linear/embed_head/sampler/paged_attention
  (every fallback contained .item()/.cpu() syncs that silently break
  cudagraph capture on ROCm)
- Critical: aiter dtypes.d_dtypes[fp8] == torch.uint8. FP8 weights MUST
  be .view(torch.float8_e4m3fn) before .to(bf16) or byte values 0-255
  decode as integers and outputs look numerically reasonable but are
  garbage

Config / setup
- scripts/gfx1201/setup_aiter_configs.sh aliases gfx1250 GEMM tuned
  configs to gfx1201 names (aiter ships zero gfx1201 configs in
  rocm/atom-dev:latest); without this the autotuner falls back to a
  default that is 30-50% slower at 8B-class shapes
- scripts/gfx1201/gemm_a8w8_sweep.py for shape-by-shape tuning
- recipes/Ministral-3-8B.md + recipes/Qwen3-8B-FP8.md with serve cmd,
  env vars, perf+accuracy table, debug journey notes

Results (single GPU, BF16 KV, cudagraph, conc=1, OSL=256)
- Ministral-3-8B: TPOT 22.0 ms, 43 tok/s, gsm8k 5-shot n=200 = 0.785
- Qwen3-8B-FP8:   TPOT 21.6 ms, 45 tok/s, gsm8k 5-shot n=50  = 0.86

Depends on aiter PRs #3168 (gfx1201 gemm_a8w8 tuning configs) and #2578
(silu_mul_fused, already merged upstream).
@carlushuang carlushuang force-pushed the carhuang/support_gfx1201_mistral3_rebased branch 2 times, most recently from df15f7c to a988bf8 Compare May 18, 2026 00:50
@carlushuang carlushuang force-pushed the carhuang/support_gfx1201_mistral3_rebased branch from a988bf8 to f2b1551 Compare May 18, 2026 00:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant