[gfx1201] Mistral-3 + Qwen3-8B-FP8 on RDNA4 via native triton attention#811
Open
carlushuang wants to merge 3 commits into
Open
[gfx1201] Mistral-3 + Qwen3-8B-FP8 on RDNA4 via native triton attention#811carlushuang wants to merge 3 commits into
carlushuang wants to merge 3 commits into
Conversation
4 tasks
…iton attention backend Bring up two natively-FP8 8B models on AMD RX 9070 XT (gfx1201) with an all-Triton stack and cudagraph on. Auto-routed on gfx1201; opt-in elsewhere via ATOM_NATIVE_TRITON_ATTN=1. Models - New Mistral3 text-only model (strips Pixtral vision tower) - Qwen3-8B-FP8 (block-128) via gemm_a16w8_blockscale (PREQUANT=False); the fp8 weight is cast to bf16 inside the kernel so tl.dot(bf16, bf16) is used (Triton on this gfx1201 build does not implement tl.dot(fp8, fp8)) Attention - New NativeTritonBackend: JIT triton kernels in place of AITER prebuilt HIP .so (gfx1201 has no prebuilt code objects) - CUDAGraph-correct decode at all bs incl. NaN-from-padding fix in prepare_decode (was silently producing wrong logits for scheduled_bs < captured_bs) GEMM / activation / FP8 - Per-shape gemm_a8w8 configs (GROUP_SIZE_M=1) + per-token FP8 quant (1 kernel, no atomic) via aiter PR #3168 - SiluAndMul routed through aiter fused_silu_mul (PR #2578, merged) - Removes torch fallbacks across linear/embed_head/sampler/paged_attention (every fallback contained .item()/.cpu() syncs that silently break cudagraph capture on ROCm) - Critical: aiter dtypes.d_dtypes[fp8] == torch.uint8. FP8 weights MUST be .view(torch.float8_e4m3fn) before .to(bf16) or byte values 0-255 decode as integers and outputs look numerically reasonable but are garbage Config / setup - scripts/gfx1201/setup_aiter_configs.sh aliases gfx1250 GEMM tuned configs to gfx1201 names (aiter ships zero gfx1201 configs in rocm/atom-dev:latest); without this the autotuner falls back to a default that is 30-50% slower at 8B-class shapes - scripts/gfx1201/gemm_a8w8_sweep.py for shape-by-shape tuning - recipes/Ministral-3-8B.md + recipes/Qwen3-8B-FP8.md with serve cmd, env vars, perf+accuracy table, debug journey notes Results (single GPU, BF16 KV, cudagraph, conc=1, OSL=256) - Ministral-3-8B: TPOT 22.0 ms, 43 tok/s, gsm8k 5-shot n=200 = 0.785 - Qwen3-8B-FP8: TPOT 21.6 ms, 45 tok/s, gsm8k 5-shot n=50 = 0.86 Depends on aiter PRs #3168 (gfx1201 gemm_a8w8 tuning configs) and #2578 (silu_mul_fused, already merged upstream).
…m_batched_tokens)
df15f7c to
a988bf8
Compare
a988bf8 to
f2b1551
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
gfx1201 (RX 9070 XT, RDNA4) support for ATOM. All-triton inference path — no prebuilt HIP kernels required.
Models verified:
What this PR adds:
NativeTritonAttentionBackend— pure-triton prefill/decode attention, KV cache write, auto-selected on gfx1201max_num_blocks_per_seqsized frommax(max_model_len, max_num_batched_tokens)ATOM_GFX1201_LM_HEAD_FP8=1)recipes/Ministral-3-8B.md,recipes/Qwen3-8B-FP8.mdDependencies:
Tested on: RX 9070 XT (gfx1201), single GPU,
rocm/atom-dev:latest