diff --git a/atom/model_ops/__init__.py b/atom/model_ops/__init__.py index 3883dd694..4670a5cdb 100644 --- a/atom/model_ops/__init__.py +++ b/atom/model_ops/__init__.py @@ -1,5 +1,7 @@ from .paged_attention import PagedAttention -from atom.plugin.sglang.attention_backend.radix_attention import RadixAttention +from atom.plugin.sglang.attention_backend.full_attention.radix_attention import ( + RadixAttention, +) # This global class is used to construct the attention op in model, # it can be assigned to different attention ops. diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index e4e7cc47e..ba2bec968 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -16,7 +16,9 @@ import atom.model_ops as ops from atom.model_ops.paged_attention import PagedAttention from atom.model_ops.attention_mha import PagedAttentionImpl -from atom.plugin.sglang.attention_backend.radix_attention import RadixAttention +from atom.plugin.sglang.attention_backend.full_attention.radix_attention import ( + RadixAttention, +) from atom.utils.forward_context import AttentionMetaData, Context from .backends import AttentionBackend, CommonAttentionBuilder diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 86f06623d..043f14133 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -45,7 +45,7 @@ def _register_custom_attention_to_sglang() -> None: from sglang.srt.layers.attention.attention_registry import ( register_attention_backend, ) - from atom.plugin.sglang.attention_backend.sgl_attn_backend import ( + from atom.plugin.sglang.attention_backend.full_attention.full_attention_backend import ( ATOMAttnBackendForSgl, ) diff --git a/atom/plugin/sglang/attention_backend/full_attention/__init__.py b/atom/plugin/sglang/attention_backend/full_attention/__init__.py new file mode 100644 index 000000000..2b9f6f272 --- /dev/null +++ b/atom/plugin/sglang/attention_backend/full_attention/__init__.py @@ -0,0 +1,8 @@ +from .radix_attention import RadixAttention +from .full_attention_backend import ATOMAttnBackendForSgl, ForwardMetadata + +__all__ = [ + "RadixAttention", + "ATOMAttnBackendForSgl", + "ForwardMetadata", +] diff --git a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py similarity index 99% rename from atom/plugin/sglang/attention_backend/sgl_attn_backend.py rename to atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py index a8ebc1122..2b58f6fed 100644 --- a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py @@ -1,10 +1,10 @@ from __future__ import annotations -# sglang-specific attention backend replacing sglang's built-in AiterAttnBackend. -# Shared by ALL models (DeepSeek, Qwen3, etc.) — handles KV cache writes, -# page-table fixup, pa_persistent_fwd decode path, and MLA prefill kernels. -# Sits at the lowest layer of the attention stack: sglang's RadixAttention -# delegates the actual kernel dispatch here. +# SGLang full-attention backend replacing sglang's built-in AiterAttnBackend. +# Shared by ALL full-attention models (DeepSeek, Qwen3, etc.) — handles KV +# cache writes, page-table fixup, pa_persistent_fwd decode path, and MLA +# prefill kernels. Sits at the lowest layer of the attention stack: +# sglang's RadixAttention delegates the actual kernel dispatch here. # # TODO: rewrite this file once sglang's attention flow is unified into ATOM's # attention layer — KV cache management and attention kernel dispatch will then @@ -47,7 +47,7 @@ except ImportError as e: raise ImportError( "Failed to import 'aiter', which provides AMD-specific attention kernels " - "required by sgl_attn_backend. Please ensure 'aiter' is installed and " + "required by full_attention_backend. Please ensure 'aiter' is installed and " f"available on your AMD system. Original import error: {e}" ) from e diff --git a/atom/plugin/sglang/attention_backend/radix_attention.py b/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py similarity index 100% rename from atom/plugin/sglang/attention_backend/radix_attention.py rename to atom/plugin/sglang/attention_backend/full_attention/radix_attention.py diff --git a/atom/plugin/sglang/models/base_model_wrapper.py b/atom/plugin/sglang/models/base_model_wrapper.py index 3f2c743b1..dbace2e64 100644 --- a/atom/plugin/sglang/models/base_model_wrapper.py +++ b/atom/plugin/sglang/models/base_model_wrapper.py @@ -386,7 +386,7 @@ def __init__( # Apply ds model-specific sglang patches (attn dispatch, weight hooks, etc.) # TODO: will remove this after sglang supports atom attention backend if self.model_arch_spec.apply_deepseek_patch: - from atom.plugin.sglang.attention_backend.sgl_attention_mla import ( + from atom.plugin.sglang.models.deepseek_mla import ( setup_deepseek_for_sglang, ) diff --git a/atom/plugin/sglang/models/deepseek_mla.py b/atom/plugin/sglang/models/deepseek_mla.py new file mode 100644 index 000000000..a8d4deb1a --- /dev/null +++ b/atom/plugin/sglang/models/deepseek_mla.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Model-level DeepSeek MLA patching for SGLang plugin mode. + +This module owns the monkey-patch entrypoints that adapt DeepSeek MLA models to +SGLang plugin mode. The heavy DeepSeek-specific forward and weight helpers live +in `atom.plugin.sglang.models.deepseek_mla_forward`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import torch + +from atom.plugin.sglang.models.deepseek_mla_forward import ( + forward_sgl_plugin_mode, + init_sgl_attrs, + process_mla_kv_b_proj_after_loading, +) + +if TYPE_CHECKING: + from atom.models.deepseek_v2 import DeepseekV2MLAAttention + + +def setup_deepseek_for_sglang(model) -> None: + """Patch a DeepSeek V2/V3 model for SGLang plugin mode.""" + config = model.config + + # Store atom_config for the OOT wrapper before install-time hooks run. + if not hasattr(model, "atom_config"): + from atom.config import get_current_atom_config + + model.atom_config = get_current_atom_config() + + kv_cache_dtype = model.atom_config.kv_cache_dtype + + # Initialise SGLang's MLA TP context before patching per-layer forwards. + from sglang.srt.configs.model_config import is_deepseek_nsa + from sglang.srt.layers.communicator import get_attn_tp_context + + get_attn_tp_context().init_context(config.q_lora_rank, is_deepseek_nsa(config)) + + from atom.models.deepseek_v2 import DeepseekV2MLAAttention + + for module in model.modules(): + if isinstance(module, DeepseekV2MLAAttention): + _patch_mla_attention_for_sglang(module, config, kv_cache_dtype) + + +def _patch_mla_attention_for_sglang( + attn: "DeepseekV2MLAAttention", + config: Any, + kv_cache_dtype: str = "bf16", +) -> None: + """Patch one DeepSeek MLA layer for SGLang plugin mode.""" + init_sgl_attrs(attn, config, kv_cache_dtype) + + def patched_forward( + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + from atom.plugin.sglang.models.base_model_wrapper import ( + get_current_forward_batch, + ) + + kwargs["forward_batch"] = get_current_forward_batch() + return forward_sgl_plugin_mode(attn, positions, hidden_states, **kwargs) + + attn.forward = patched_forward + attn.process_weights_after_loading = lambda: process_mla_kv_b_proj_after_loading( + attn + ) diff --git a/atom/plugin/sglang/attention_backend/sgl_attention_mla.py b/atom/plugin/sglang/models/deepseek_mla_forward.py similarity index 86% rename from atom/plugin/sglang/attention_backend/sgl_attention_mla.py rename to atom/plugin/sglang/models/deepseek_mla_forward.py index 1dae9349e..25f1ef79a 100644 --- a/atom/plugin/sglang/attention_backend/sgl_attention_mla.py +++ b/atom/plugin/sglang/models/deepseek_mla_forward.py @@ -1,17 +1,19 @@ -"""Sglang-specific MLA forward and weight processing for DeepseekV2/V3. +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -DeepSeek MLA (Multi-Latent Attention) forward logic for sglang plugin mode: +"""Model-specific DeepSeek MLA helpers for SGLang plugin mode. + +DeepSeek MLA (Multi-Latent Attention) forward logic for SGLang plugin mode: absorbed BMM computation, MHA/MLA path dispatch (prefill -> MHA, decode -> MLA), -kv_b_proj weight splitting (w_kc/w_vc), and monkey-patch setup via -setup_deepseek_for_sglang(). +and kv_b_proj weight splitting (w_kc/w_vc). -This module is lazily imported from base_model_wrapper.py only when running in -sglang plugin mode (``is_sglang() == True``). Keeping all sglang-dependent -imports here avoids crashing when sglang is not installed. +This module lives under ``atom.plugin.sglang.models`` because the logic is +DeepSeek-model-specific rather than a generic SGLang attention backend. TODO: rewrite this file once sglang's attention flow is unified into ATOM's -attention layer — the MLA absorbed path and MHA dispatch will then be handled -natively by ATOM's attention ops, making this sglang-specific module unnecessary. +attention layer - the MLA absorbed path and MHA dispatch will then be handled +natively by ATOM's attention ops, making this sglang-specific module +unnecessary. """ from __future__ import annotations @@ -29,7 +31,6 @@ from atom.models.utils import maybe_prefix from atom.models.deepseek_v2 import _fuse_rmsnorm_quant -# sglang imports from sglang.srt.layers.communicator import AttentionInputs, get_attn_tp_context from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode @@ -59,7 +60,6 @@ from atom.models.deepseek_v2 import DeepseekV2MLAAttention -# bmm_fp8 custom-op wrapper (adapted from sglang forward_mla.py) if _is_cuda: from sgl_kernel import bmm_fp8 as _raw_bmm_fp8 from sglang.srt.utils.custom_op import register_custom_op @@ -90,7 +90,6 @@ def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): raise RuntimeError("bmm_fp8 requires CUDA (sgl_kernel)") -# NamedTuple for prepare → core data flow class SglPrepareResult(NamedTuple): q_pe: torch.Tensor k_pe: torch.Tensor @@ -182,7 +181,6 @@ def _prepare_weight_for_bmm( ) -# Init helpers def init_sgl_attrs( attn: DeepseekV2MLAAttention, config, @@ -216,7 +214,6 @@ def init_sgl_attrs( attn.attn_mha.attn.kv_b_proj = None -# Absorbed batched-matmul (shared by prepare and core) def mla_absorbed_bmm( attn: DeepseekV2MLAAttention, inp: torch.Tensor, @@ -225,12 +222,7 @@ def mla_absorbed_bmm( weight_scale_k: Optional[torch.Tensor], out_dim: int, ) -> torch.Tensor: - """Batched matmul for MLA absorbed weights (w_kc / w_vc). - - Handles deep_gemm, mxfp4, fp8-triton, fp8-cublas, and bf16 fallback paths. - inp: (num_tokens, num_heads, in_dim) — token-major - Returns: (num_tokens, num_heads, out_dim) — token-major - """ + """Batched matmul for MLA absorbed weights (w_kc / w_vc).""" effective_weight_scale = ( weight_scale_k if weight_scale_k is not None else weight_scale ) @@ -299,7 +291,6 @@ def mla_absorbed_bmm( ) return out.transpose(0, 1) - # CUDA fp8 path if weight.dtype == torch.float8_e4m3fn: val, scale = per_tensor_quant_mla_fp8( inp.transpose(0, 1), @@ -308,7 +299,6 @@ def mla_absorbed_bmm( out = bmm_fp8(val, weight, scale, effective_weight_scale, torch.bfloat16) return out.transpose(0, 1) - # bf16 fallback return torch.bmm(inp.transpose(0, 1), weight).transpose(0, 1) @@ -352,14 +342,13 @@ def mla_v_up_proj( ).flatten(1, 2) -# Forward: prepare → core def forward_sgl_prepare( attn: DeepseekV2MLAAttention, positions: torch.Tensor, hidden_states: torch.Tensor, **model_kwargs, ) -> SglPrepareResult: - """Prepare QKV for sglang MLA attention (adapted from sglang forward_absorb_prepare).""" + """Prepare QKV for sglang MLA attention.""" hidden_states_scale = None if isinstance(hidden_states, tuple): hidden_states, hidden_states_scale = hidden_states @@ -401,9 +390,6 @@ def forward_sgl_prepare( k_nope = latent_cache[..., : attn.kv_lora_rank] q_scale = None - # Reuse native ATOM gating for q/k RMSNorm fusion. Quant fusion is used - # when DeepSeek enables qknorm-quant; otherwise keep the non-quant fused - # path aligned with native ATOM before falling back to plain layernorm. if getattr(attn, "fuse_qknorm_quant", False): q, q_scale, q_lora, k_nope = _fuse_qk_rmsnorm_and_q_quant( attn, @@ -413,7 +399,6 @@ def forward_sgl_prepare( ) elif getattr(attn, "fuse_qknorm", False): q, k_nope = _fuse_qk_rmsnorm(attn, q, k_nope) - # Otherwise keep the original overlap path for unfused qk norm. elif attn.alt_stream is not None and get_is_capture_mode(): current_stream = torch.cuda.current_stream() attn.alt_stream.wait_stream(current_stream) @@ -425,11 +410,9 @@ def forward_sgl_prepare( q = attn.q_a_layernorm(q) k_nope = attn.kv_a_layernorm(k_nope) - if attn.use_nsa: - if q_lora is None: - q_lora = q + if attn.use_nsa and q_lora is None: + q_lora = q - # overlap q_b_proj and indexer during decode if ( attn.alt_stream is not None and get_is_capture_mode() @@ -506,7 +489,7 @@ def forward_sgl_core( attn: DeepseekV2MLAAttention, prepared: SglPrepareResult, ) -> torch.Tensor: - """Core MLA attention computation for sglang (adapted from sglang forward_absorb_core).""" + """Core MLA attention computation for sglang.""" save_kv_cache = True if attn.use_fused_qk_rope_concat_and_cache_mla: @@ -545,7 +528,6 @@ def forward_sgl_core( is_neox=attn.rotary_emb.is_neox_style, is_nope_first=True, ) - # Decode/speculative MLA consumes q plus packed MLA cache directly. k = None v = None save_kv_cache = False @@ -571,7 +553,6 @@ def forward_sgl_core( ) attn_output = attn_output.view(-1, attn.num_local_heads, attn.kv_lora_rank) - # up-proj by w_vc attn_bmm_output = mla_v_up_proj( attn, attn_output, attn.w_vc, attn.w_scale, attn.w_scale_v, attn.v_head_dim ) @@ -580,15 +561,7 @@ def forward_sgl_core( def _dispatch_sgl_plugin_attn_path(forward_batch) -> str: - """Decide the attention algorithm for this batch based on forward_mode. - - Returns "mha" for extend/prefill (uses standard Q×K×V with flash_attn) - or "mla" for decode (uses absorbed weights + mla_decode_fwd). - - This is the per-batch *routing* decision, distinct from - ``_can_run_sgl_mha_now`` which is a *capability* gate checking whether - the model configuration supports the MHA path at all. - """ + """Decide the attention algorithm for this batch based on forward_mode.""" if forward_batch.forward_mode.is_extend_without_speculative(): return "mha" return "mla" @@ -635,12 +608,7 @@ def _set_mla_kv_buffer_for_mha( def _can_run_sgl_mha_now(attn: DeepseekV2MLAAttention, forward_batch) -> bool: - """Check if the model configuration supports the MHA attention path. - - This is a *capability* gate — NSA models and MXFP4-quantised weights - (uint8) cannot use the MHA path. Distinct from - ``_dispatch_sgl_plugin_attn_path`` which routes each batch. - """ + """Check if the model configuration supports the MHA attention path.""" del forward_batch if attn.use_nsa: return False @@ -819,8 +787,6 @@ def prepare_qkv_latent( hidden_states, hidden_states_scale = hidden_states qkv_lora = attn.fused_qkv_a_proj(hidden_states, hidden_states_scale) - # Fallback: when communicator does not enable input_scattered gather, - # force qkv latent token dimension to align with positions. expected_tokens = 0 if hasattr(forward_batch, "positions") and forward_batch.positions is not None: expected_tokens = int(forward_batch.positions.shape[0]) @@ -843,7 +809,6 @@ def prepare_qkv_latent( return qkv_lora -# Top-level forward entry point def forward_sgl_plugin_mode( attn: DeepseekV2MLAAttention, positions: torch.Tensor, @@ -884,7 +849,6 @@ def forward_sgl_plugin_mode( raise ValueError(f"Unsupported plugin attention path: {attn_path}") -# Weight post-processing: decomposed into sub-functions def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: """Read kv_b_proj weight, handling AWQ and fnuz dtypes.""" if hasattr(attn.kv_b_proj, "qweight"): @@ -901,8 +865,6 @@ def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: else: w = attn.kv_b_proj.weight - # On ROCm, ATOM creates parameters with fnuz dtype but loads fn bytes. - # View-cast back to fn so the normalize path works correctly. if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fnuz: w = w.view(torch.float8_e4m3fn) @@ -926,10 +888,7 @@ def _process_fp8_weight( w: torch.Tensor, weight_block_size: Optional[list[int]], ) -> tuple[torch.Tensor, bool, Optional[torch.Tensor]]: - """Process FP8 weights for kv_b_proj. - - Returns (w, use_deep_gemm_bmm, block_scale). - """ + """Process FP8 weights for kv_b_proj.""" from atom.model_ops.utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import ( block_quant_dequant, @@ -1061,7 +1020,6 @@ def _split_and_assign_kc_vc( [attn.qk_nope_head_dim, attn.v_head_dim], dim=1 ) - # quark fp4 special path quant_method = getattr(attn.kv_b_proj, "quant_method", None) quant_config = getattr(quant_method, "quant_config", None) if ( @@ -1084,8 +1042,6 @@ def _split_and_assign_kc_vc( w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) w_vc = w_vc.contiguous().transpose(1, 2) - # Align bf16 kv_b_proj post-load handling with vLLM: split first, then - # quantize kc/vc independently for the fp8 BMM path. if w.dtype == torch.bfloat16 and (_is_hip or _is_cuda): w_kc, w_scale_k = dynamic_per_batched_tensor_quant(w_kc, dtype=dtypes.fp8) w_vc, w_scale_v = dynamic_per_batched_tensor_quant(w_vc, dtype=dtypes.fp8) @@ -1117,89 +1073,19 @@ def _split_and_assign_kc_vc( def process_mla_kv_b_proj_after_loading(attn: DeepseekV2MLAAttention) -> None: - """Process kv_b_proj weights after loading for sglang MLA mode. - - Orchestrates reading, quantization handling, and splitting of - kv_b_proj into absorbed w_kc / w_vc weights. - """ + """Process kv_b_proj weights after loading for sglang MLA mode.""" w = _read_kv_b_proj_weight(attn) weight_block_size = _get_weight_block_size(attn) use_deep_gemm_bmm = False block_scale = None - # fp8 path if w.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): w, use_deep_gemm_bmm, block_scale = _process_fp8_weight( attn, w, weight_block_size ) - # int8 path if w.dtype == torch.int8: w = _process_int8_weight(attn, w, weight_block_size) - # split and assign kc/vc _split_and_assign_kc_vc(attn, w, use_deep_gemm_bmm, block_scale, weight_block_size) - - -# One-time model setup (called from base_model_wrapper.py) -def setup_deepseek_for_sglang(model) -> None: - """Patch a DeepseekV2/V3 model for sglang plugin mode. - - - Initialises sglang TP context - - Patches each MLAAttention.forward to dispatch to the sglang MLA path - - Registers process_weights_after_loading hooks - - Stores atom_config on the model - """ - config = model.config - - # Store atom_config (needed by load_weights in the OOT wrapper) - if not hasattr(model, "atom_config"): - from atom.config import get_current_atom_config - - model.atom_config = get_current_atom_config() - - kv_cache_dtype = model.atom_config.kv_cache_dtype - - # Initialise sglang TP context for MLA gather/scatter - from sglang.srt.configs.model_config import is_deepseek_nsa - from sglang.srt.layers.communicator import get_attn_tp_context - - get_attn_tp_context().init_context(config.q_lora_rank, is_deepseek_nsa(config)) - - # Patch each MLAAttention instance - from atom.models.deepseek_v2 import DeepseekV2MLAAttention - - for module in model.modules(): - if isinstance(module, DeepseekV2MLAAttention): - _patch_mla_attention_for_sglang(module, config, kv_cache_dtype) - - -def _patch_mla_attention_for_sglang(attn, config, kv_cache_dtype: str = "bf16") -> None: - """Patch a single DeepseekV2MLAAttention for sglang plugin mode. - - We patch attn.forward (rather than relying solely on ops.Attention = - RadixAttention) because MLA's absorbed-weight forward path replaces the - *entire* forward method — including RoPE, and absorbed - BMM — not just the attention backend. ops.Attention = RadixAttention - handles the backend layer (flash_attn / paged_attn dispatch) and is - already set via set_attn_cls(); this patch sits above that layer. - """ - init_sgl_attrs(attn, config, kv_cache_dtype) - - def patched_forward( - positions: torch.Tensor, - hidden_states: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - from atom.plugin.sglang.models.base_model_wrapper import ( - get_current_forward_batch, - ) - - kwargs["forward_batch"] = get_current_forward_batch() - return forward_sgl_plugin_mode(attn, positions, hidden_states, **kwargs) - - attn.forward = patched_forward - attn.process_weights_after_loading = lambda: process_mla_kv_b_proj_after_loading( - attn - ) diff --git a/tests/plugin/test_sglang_model_wrapper.py b/tests/plugin/test_sglang_model_wrapper.py index e4015ed9d..20d0e0792 100644 --- a/tests/plugin/test_sglang_model_wrapper.py +++ b/tests/plugin/test_sglang_model_wrapper.py @@ -54,8 +54,7 @@ def _make_fake_modules(*, is_last_rank: bool, setup_hook=None) -> dict[str, Modu forward_batch_mod.ForwardBatch = object forward_batch_mod.PPProxyTensors = object - attn_backend_pkg = _package("atom.plugin.sglang.attention_backend") - mla_mod = ModuleType("atom.plugin.sglang.attention_backend.sgl_attention_mla") + mla_mod = ModuleType("atom.plugin.sglang.models.deepseek_mla") mla_mod.setup_deepseek_for_sglang = setup_hook or (lambda model: None) return { @@ -68,8 +67,7 @@ def _make_fake_modules(*, is_last_rank: bool, setup_hook=None) -> dict[str, Modu "sglang.srt.layers.quantization.base_config": quant_base_mod, "sglang.srt.model_executor": model_executor_pkg, "sglang.srt.model_executor.forward_batch_info": forward_batch_mod, - "atom.plugin.sglang.attention_backend": attn_backend_pkg, - "atom.plugin.sglang.attention_backend.sgl_attention_mla": mla_mod, + "atom.plugin.sglang.models.deepseek_mla": mla_mod, } diff --git a/tests/plugin/test_sglang_register.py b/tests/plugin/test_sglang_register.py index 562aaf8bc..1adb20fb4 100644 --- a/tests/plugin/test_sglang_register.py +++ b/tests/plugin/test_sglang_register.py @@ -324,10 +324,13 @@ def __init__(self, runner): "atom.models.qwen3_moe": ModuleType("atom.models.qwen3_moe"), "atom.models.glm4_moe": ModuleType("atom.models.glm4_moe"), "atom.models.deepseek_v2": ModuleType("atom.models.deepseek_v2"), + "atom.models.minimax_m2": ModuleType("atom.models.minimax_m2"), + "atom.models.qwen3_next": ModuleType("atom.models.qwen3_next"), + "atom.models.qwen3_5": ModuleType("atom.models.qwen3_5"), "atom.config": ModuleType("atom.config"), "atom.plugin.prepare": fake_prepare_mod, - "atom.plugin.sglang.attention_backend.sgl_attn_backend": ModuleType( - "atom.plugin.sglang.attention_backend.sgl_attn_backend" + "atom.plugin.sglang.attention_backend.full_attention.full_attention_backend": ModuleType( + "atom.plugin.sglang.attention_backend.full_attention.full_attention_backend" ), } fake_modules["atom.models.qwen3"].Qwen3ForCausalLM = type( @@ -342,9 +345,21 @@ def __init__(self, runner): fake_modules["atom.models.deepseek_v2"].DeepseekV3ForCausalLM = type( "DeepseekV3ForCausalLM", (), {} ) + fake_modules["atom.models.minimax_m2"].MiniMaxM2ForCausalLM = type( + "MiniMaxM2ForCausalLM", (), {} + ) + fake_modules["atom.models.qwen3_next"].Qwen3NextForCausalLM = type( + "Qwen3NextForCausalLM", (), {} + ) + fake_modules["atom.models.qwen3_5"].Qwen3_5ForCausalLM = type( + "Qwen3_5ForCausalLM", (), {} + ) + fake_modules["atom.models.qwen3_5"].Qwen3_5MoeForCausalLM = type( + "Qwen3_5MoeForCausalLM", (), {} + ) fake_modules["atom.config"].Config = type("Config", (), {}) fake_modules[ - "atom.plugin.sglang.attention_backend.sgl_attn_backend" + "atom.plugin.sglang.attention_backend.full_attention.full_attention_backend" ].ATOMAttnBackendForSgl = _FakeBackend with patch.dict(sys.modules, fake_modules):