Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
93e6013
models: add Mistral3 (Ministral 3) text-only support
carlushuang May 9, 2026
4f848a9
attentions: scaffold a torch-native backend for gfx1201
carlushuang May 9, 2026
ba82ba9
wip: notes for next session on torch-native attn backend
carlushuang May 9, 2026
c983d98
attentions: torch-native impl + RMSNorm fallback for gfx1201
carlushuang May 10, 2026
1ebd66e
wip: NEXT_SESSION update — RMSNorm fallback in, FP8 GEMM is next blocker
carlushuang May 10, 2026
e2a0e1b
model_ops: gfx1201 fallbacks for FP8 GEMM, SiLU+Mul, sampler
carlushuang May 10, 2026
8f4099f
wip: NEXT_SESSION v3 — prefill works end-to-end; decode is the only p…
carlushuang May 10, 2026
b277edc
attn: full prefill + decode + KV cache; FP8 dequant uint8 reinterpret
carlushuang May 10, 2026
7c29b78
recipes: Ministral-3-8B on gfx1201 with torch-native attention; remov…
carlushuang May 10, 2026
eb05533
linear: route gfx1201 FP8 GEMM through aiter triton gemm_a8w8
carlushuang May 10, 2026
962c31b
recipes: update Ministral-3-8B with triton FP8 GEMM perf + gsm8k 76.5…
carlushuang May 10, 2026
ddd8c5e
attn: wire aiter triton paged_attention_decode + pivot KV layout
carlushuang May 10, 2026
4e9d262
recipes: document triton paged_attention_decode (~20% e2e win)
carlushuang May 10, 2026
c8c7e18
attn: wire aiter triton context_attention_fwd into prefill
carlushuang May 10, 2026
8a830e0
recipes: triton context_attention_fwd prefill (gsm8k 78.5% n=200)
carlushuang May 10, 2026
2402b21
attn: triton kv-cache write kernel; linear: BF16 unquantized gfx1201 …
carlushuang May 10, 2026
367f006
recipes: triton kv-write kernel + BF16 linear fallback
carlushuang May 10, 2026
53324e9
attn/norm/act: triton kernels + cudagraph foundation
carlushuang May 10, 2026
2db0e08
recipes: triton RMSNorm + SiLU+Mul
carlushuang May 10, 2026
f84cfc1
attn: enable cudagraph at decode bs<=2 (24% TPOT, 3.3x TTFT)
carlushuang May 10, 2026
52697ce
model_ops: delete torch reference fallbacks (triton-only on gfx1201)
carlushuang May 10, 2026
afe6c55
attn/cudagraph: warmup on capture stream, twice (SGLang pattern)
carlushuang May 10, 2026
f6e8ae0
linear: fuse dynamic FP8 quant + cache per-channel weight scale
carlushuang May 10, 2026
28ce324
recipes: document fused FP8 quant win + TP=2 host blocker
carlushuang May 10, 2026
620c65f
attentions: rename torch_native_attn -> gfx1201_triton_attn
carlushuang May 11, 2026
ada3fd0
recipes: full bisection results + TP=2 dual-path failure analysis
carlushuang May 11, 2026
97482a9
attentions: rename gfx1201_triton_attn -> native_triton_attn
carlushuang May 11, 2026
83aaa7d
recipes: roofline analysis + cross-GPU comparison
carlushuang May 11, 2026
0597938
linear: hand-tuned gemm_a8w8 config per (M,N,K) for gfx1201
carlushuang May 11, 2026
d3badad
Merge remote-tracking branch 'origin/main' into carhuang/support_gfx1…
carlushuang May 11, 2026
cb27314
moe: guard newer-aiter imports for older rocm/atom-dev:latest builds
carlushuang May 11, 2026
bd1311b
attn: fix bs >= 3 cudagraph corruption (NaN-from-padding in pa_decode)
carlushuang May 11, 2026
cc6648a
linear: switch dynamic FP8 quant to per-token (1 kernel vs 2)
carlushuang May 11, 2026
f22eeef
style: black + ruff cleanup for CI Pre Checkin
carlushuang May 11, 2026
a412f74
qwen3: enable Qwen3-8B-FP8 (block-128) on gfx1201 via gemm_a16w8_bloc…
carlushuang May 12, 2026
419ef32
gfx1201: ship aiter-config setup script + document required-setup ste…
carlushuang May 12, 2026
18d01ff
layernorm: drop replicated triton rmsnorm kernels, call aiter triton …
carlushuang May 12, 2026
17f939a
Merge remote-tracking branch 'origin/main' into carhuang/support_gfx1…
carlushuang May 12, 2026
15db101
gfx1201: re-tune down_proj gemm_a8w8 config + add reusable bench script
carlushuang May 12, 2026
eaf492e
gfx1201: speed up native triton decode path
chuanbowang2026 May 13, 2026
aa1f776
recipes: document ATOM_GFX1201_LM_HEAD_FP8 + perf table after the spe…
carlushuang May 13, 2026
390cd79
Merge remote-tracking branch 'origin/main' into carhuang/support_gfx1…
carlushuang May 13, 2026
c009ef6
gfx1201: drop _silu_mul_triton + _gfx1201_gemm_a8w8_config, depend on…
carlushuang May 13, 2026
fe04782
fix: replace _is_gfx1201 hasattr-cached detection with module-level c…
chuanbowang2026 May 14, 2026
42d90aa
style: black format gfx1201 detection cleanup
chuanbowang2026 May 14, 2026
24681e1
Merge branch 'main' into carhuang/support_gfx1201_mistral3
chuanbowang2026 May 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ def _remap_layer_name(name: str) -> list[str]:
"kimi_k25": "text_config",
"qwen3_5": "text_config",
"qwen3_5_moe": "text_config",
"mistral3": "text_config",
}

# multimodal models fully supported by plugin mode
Expand Down
2 changes: 2 additions & 0 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
"KimiK25ForConditionalGeneration": "atom.models.kimi_k25.KimiK25ForCausalLM",
"MiniMaxM2ForCausalLM": "atom.models.minimax_m2.MiniMaxM2ForCausalLM",
"MiMoV2FlashForCausalLM": "atom.models.mimo_v2_flash.MiMoV2FlashForCausalLM",
"Mistral3ForConditionalGeneration": "atom.models.mistral3.Mistral3TextOnly",
"MistralForCausalLM": "atom.models.mistral3.Mistral3ForCausalLM",
}
# seed = 34567
# np.random.seed(seed)
Expand Down
38 changes: 31 additions & 7 deletions atom/model_ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,28 @@
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import torch
from typing import Optional
from torch import nn
import torch.nn.functional as F
from aiter import silu_and_mul
from atom.config import QuantizationConfig
from atom.quant_spec import LayerQuantConfig
from aiter.jit.utils.torch_guard import torch_compile_guard

from aiter import (
QuantType,
silu_and_mul,
)
from aiter.jit.utils.torch_guard import torch_compile_guard
from atom.config import QuantizationConfig
from atom.quant_spec import LayerQuantConfig
from torch import nn
from typing import Optional


def _detect_gfx1201() -> bool:
try:
return (torch.cuda.get_device_properties(0).gcnArchName or "").startswith(
"gfx1201"
)
except Exception:
return False


_IS_GFX1201: bool = _detect_gfx1201()


def mxfp4_act_mul_quant_fuse_fake(
Expand Down Expand Up @@ -84,6 +95,19 @@ def forward_native(
def forward(
self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# gfx1201 (RDNA4): aiter prebuilt silu_and_mul HIP kernel has no
# gfx1201 code object (CDNA-only v_pk_mul_f32). Use the portable
# triton silu_and_mul added in aiter PR #3168 (which mirrors the
# HIP signature out=fn(x)).
if _IS_GFX1201:
from aiter.ops.triton.activation import (
silu_and_mul as _aiter_silu_mul_triton,
)

half = x.shape[-1] // 2
out = torch.empty((*x.shape[:-1], half), dtype=x.dtype, device=x.device)
_aiter_silu_mul_triton(out, x)
return out
# fp8 quantization
if x_scale is not None and self.fused_quant:
from aiter.ops.triton.fused_fp8_quant import (
Expand Down
Loading
Loading