[feat][ATOM-vLLM][Attention Refactor] Reconstruct the Attention Arch#750
[feat][ATOM-vLLM][Attention Refactor] Reconstruct the Attention Arch#750zejunchen-zejun wants to merge 10 commits into
Conversation
| from atom.plugin.prepare import is_plugin_mode | ||
| from atom.utils import CpuGpuBuffer | ||
| from atom.utils.block_convert import ( | ||
| block_table_convert_triton, |
| from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 | ||
|
|
| from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip | ||
| batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm, | ||
| ) | ||
|
|
There was a problem hiding this comment.
aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant imported but unused
| from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip | |
| batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm, | |
| ) |
| from functools import partial as functools_partial | ||
| from atom.model_ops.linear import use_triton_gemm |
| if use_triton_gemm(): | ||
| try: | ||
| from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import ( | ||
| fused_gemm_a8w8_blockscale_preshuffle_split_cat, |
| fused_gemm_a8w8_blockscale_preshuffle_split_cat, | ||
| ) | ||
| from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import ( | ||
| fused_gemm_afp4wfp4_preshuffle_split_cat, |
| from aiter import get_mla_metadata_v1 | ||
| from atom.utils.block_convert import kv_indices_generate_triton |
There was a problem hiding this comment.
| from atom.utils.block_convert import kv_indices_generate_triton | ||
| from atom.utils.forward_context import Context |
There was a problem hiding this comment.
| from atom.model_ops.attention_mla import MLAAttention, _MLA_MIN_HEADS | ||
|
|
| from atom.model_ops.attention_mla import MLAAttention, _MLA_MIN_HEADS | ||
|
|
| from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 | ||
|
|
| from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip | ||
| batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm, | ||
| ) | ||
| from aiter.mla import mla_decode_fwd |
There was a problem hiding this comment.
aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant imported but unused
| from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip | |
| batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm, | |
| ) | |
| from aiter.mla import mla_decode_fwd | |
| from aiter.mla import mla_decode_fwd |
| from aiter.mla import mla_decode_fwd | ||
| from aiter import ( |
| from aiter import ( | ||
| fused_qk_rope_concat_and_cache_mla, | ||
| cp_gather_indexer_k_quant_cache, | ||
| dtypes, | ||
| indexer_k_quant_and_cache, | ||
| top_k_per_row_decode, | ||
| top_k_per_row_prefill, | ||
| ) |
There was a problem hiding this comment.
aiter.fused_qk_rope_concat_and_cache_mla imported but unused
| from aiter import ( | |
| fused_qk_rope_concat_and_cache_mla, | |
| cp_gather_indexer_k_quant_cache, | |
| dtypes, | |
| indexer_k_quant_and_cache, | |
| top_k_per_row_decode, | |
| top_k_per_row_prefill, | |
| ) | |
| from aiter import ( | |
| cp_gather_indexer_k_quant_cache, | |
| dtypes, | |
| indexer_k_quant_and_cache, | |
| top_k_per_row_decode, | |
| ) |
| from aiter import ( | ||
| fused_qk_rope_concat_and_cache_mla, | ||
| cp_gather_indexer_k_quant_cache, | ||
| dtypes, | ||
| indexer_k_quant_and_cache, | ||
| top_k_per_row_decode, | ||
| top_k_per_row_prefill, | ||
| ) |
There was a problem hiding this comment.
aiter.top_k_per_row_prefill imported but unused
| from aiter import ( | |
| fused_qk_rope_concat_and_cache_mla, | |
| cp_gather_indexer_k_quant_cache, | |
| dtypes, | |
| indexer_k_quant_and_cache, | |
| top_k_per_row_decode, | |
| top_k_per_row_prefill, | |
| ) | |
| from aiter import ( | |
| cp_gather_indexer_k_quant_cache, | |
| dtypes, | |
| indexer_k_quant_and_cache, | |
| top_k_per_row_decode, | |
| ) |
| """vLLM-facing sparse MLA backend surface for ATOM attention layers.""" | ||
|
|
||
| @staticmethod | ||
| def get_builder_cls() -> Type["AiterMLASparseMetadataBuilder"]: |
| """vLLM-facing sparse MLA indexer backend surface.""" | ||
|
|
||
| @staticmethod | ||
| def get_builder_cls() -> Type["AiterMLASparseIndexerMetadataBuilder"]: |
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
divide the atom-vllm metadata from atom metadata Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
3b061e9 to
b832aee
Compare
This PR refactor the attention architecture for ATOM-vLLM
Here is the RFC: #758
Here is the validation result:
ATOM-vLLM CI
DeepSeek-R1-FP8 TP8 / atom-vllm CI: 0.9484457922668689 >= 0.93
gpt-oss-120b TP1 / atom-vllm CI: FAILED - aiter JIT header missing
Kimi-K2-Thinking-MXFP4 TP4 / atom-vllm CI: FAILED - server not ready
Qwen3.5-35B-A3B-FP8 TP2 / atom-vllm CI: 0.7862016679302501 >= 0.76
ATOM-vLLM nightly validation
Qwen3-235B-A22B-Instruct-2507-FP8 TP8+EP8 / atom-vllm nightly: FAILED - max_qlen=None
Qwen3-Next-80B-A3B-Instruct-FP8 TP1 / atom-vllm nightly: FAILED - tuple index error
Qwen3-Next-80B-A3B-Instruct-FP8 TP4 / atom-vllm nightly: FAILED - tuple index error
Qwen3.5-397B-A17B-FP8 TP8 / atom-vllm nightly: 0.8688400303260045 >= 0.83
Qwen3.5-397B-A17B TP8 / atom-vllm nightly: 0.8688400303260045 >= 0.83
Qwen3.5-397B-A17B-MXFP4 TP4 / atom-vllm nightly: 0.8468536770280516 >= 0.83
Meta-Llama-3.1-405B-Instruct-FP8 TP8 / atom-vllm nightly: FAILED - fp8 gemm dtype mismatch
Llama-3.1-8B-Instruct TP1 / atom-vllm nightly: FAILED - max_qlen=None
Kimi-K2-Thinking-MXFP4 TP8 / atom-vllm nightly: 0.931008339651251 >= 0.90
Kimi-K2.5-MXFP4 TP8 / atom-vllm nightly: FAILED - vision config missing hidden_size
DeepSeek-R1-FP8 TP8 / atom-vllm nightly: 0.9507202426080363 >= 0.93
DeepSeek-R1-0528-MXFP4 TP8 / atom-vllm nightly: 0.9370735405610311 >= 0.93
DeepSeek-V3.2-FP8 TP8 / atom-vllm nightly: 0.9492039423805914 >= 0.93
gpt-oss-120b TP1 / atom-vllm nightly: FAILED - aiter JIT header missing
gpt-oss-120b TP2 / atom-vllm nightly: FAILED - aiter JIT header missing
GLM-5.1-FP8 TP8 / atom-vllm nightly: 0.9423805913570887 >= 0.88
ATOM CI and nightly:
Meta-Llama-3-8B-Instruct / native atom: FAILED - local model missing
DeepSeek-R1-0528 / native atom: 0.9522365428354814 >= 0.94
DeepSeek-V4-Pro / native atom: 0.9552691432903715 >= 0.92
DeepSeek-R1-0528 MTP / native atom: 0.9461713419257013 >= 0.94
gpt-oss-120b / native atom: FAILED - no result JSON
Llama-3.3-70B-Instruct-MXFP4-Preview / native atom: FAILED - local model missing
DeepSeek-R1-0528-FP4 / native atom: 0.9492039423805914 >= 0.93
DeepSeek-R1-0528-FP4 MTP / native atom: 0.9401061410159212 >= 0.93
Qwen3-235B-A22B-Instruct-2507-FP8 / native atom: 0.8953752843062927 >= 0.87
Qwen3-Next-80B-A3B-Thinking / native atom: 0.6732373009855952 >= 0.65
gpt-oss-120b 2GPU / native atom: FAILED - no result JSON
Qwen3-235B-A22B-Instruct-2507-MXFP4 / native atom: 0.8764215314632298 >= 0.87
Kimi-K2.5-MXFP4 / native atom: 0.9423805913570887 >= 0.92
Kimi-K2.5-MXFP4 Eagle3 / native atom: 0.935557240333586 >= 0.91
GLM-5-FP8 / native atom: 0.9416224412433661 >= 0.93
GLM-5.1-FP8 / native atom: 0.8893100833965125 >= 0.875
GLM-5.1-MXFP4 MTP / native atom: 0.8809704321455648 >= 0.87
GLM-5.1-MXFP4 / native atom: 0.88855193328279 >= 0.87
Qwen3.5-397B-A17B-FP8 / native atom: 0.8786959818043972 >= 0.85
Qwen3.5-397B-A17B-FP8 MTP / native atom: 0.8673237300985596 >= 0.85
Qwen3.5-397B-A17B-MXFP4 / native atom: 0.8620166793025019 >= 0.835
Qwen3.5-397B-A17B-MXFP4 MTP / native atom: FAILED - 0.8339651250947687 < 0.835
MiniMax-M2.5 / native atom: 0.9317664897649734 >= 0.92
MiniMax-M2.5-MXFP4 / native atom: 0.9196360879454132 >= 0.91