Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion atom/model_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .paged_attention import PagedAttention
from atom.plugin.sglang.attention_backend.radix_attention import RadixAttention
from atom.plugin.sglang.attention_backend.full_attention.radix_attention import (
RadixAttention,
)

# This global class is used to construct the attention op in model,
# it can be assigned to different attention ops.
Expand Down
4 changes: 3 additions & 1 deletion atom/model_ops/attentions/aiter_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import atom.model_ops as ops
from atom.model_ops.paged_attention import PagedAttention
from atom.model_ops.attention_mha import PagedAttentionImpl
from atom.plugin.sglang.attention_backend.radix_attention import RadixAttention
from atom.plugin.sglang.attention_backend.full_attention.radix_attention import (
RadixAttention,
)
from atom.utils.forward_context import AttentionMetaData, Context

from .backends import AttentionBackend, CommonAttentionBuilder
Expand Down
2 changes: 1 addition & 1 deletion atom/plugin/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _register_custom_attention_to_sglang() -> None:
from sglang.srt.layers.attention.attention_registry import (
register_attention_backend,
)
from atom.plugin.sglang.attention_backend.sgl_attn_backend import (
from atom.plugin.sglang.attention_backend.full_attention.full_attention_backend import (
ATOMAttnBackendForSgl,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .radix_attention import RadixAttention
from .full_attention_backend import ATOMAttnBackendForSgl, ForwardMetadata

__all__ = [
"RadixAttention",
"ATOMAttnBackendForSgl",
"ForwardMetadata",
]
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

# sglang-specific attention backend replacing sglang's built-in AiterAttnBackend.
# Shared by ALL models (DeepSeek, Qwen3, etc.) — handles KV cache writes,
# page-table fixup, pa_persistent_fwd decode path, and MLA prefill kernels.
# Sits at the lowest layer of the attention stack: sglang's RadixAttention
# delegates the actual kernel dispatch here.
# SGLang full-attention backend replacing sglang's built-in AiterAttnBackend.
# Shared by ALL full-attention models (DeepSeek, Qwen3, etc.) — handles KV
# cache writes, page-table fixup, pa_persistent_fwd decode path, and MLA
# prefill kernels. Sits at the lowest layer of the attention stack:
# sglang's RadixAttention delegates the actual kernel dispatch here.
#
# TODO: rewrite this file once sglang's attention flow is unified into ATOM's
# attention layer — KV cache management and attention kernel dispatch will then
Expand Down Expand Up @@ -47,7 +47,7 @@
except ImportError as e:
raise ImportError(
"Failed to import 'aiter', which provides AMD-specific attention kernels "
"required by sgl_attn_backend. Please ensure 'aiter' is installed and "
"required by full_attention_backend. Please ensure 'aiter' is installed and "
f"available on your AMD system. Original import error: {e}"
) from e

Expand Down
2 changes: 1 addition & 1 deletion atom/plugin/sglang/models/base_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def __init__(
# Apply ds model-specific sglang patches (attn dispatch, weight hooks, etc.)
# TODO: will remove this after sglang supports atom attention backend
if self.model_arch_spec.apply_deepseek_patch:
from atom.plugin.sglang.attention_backend.sgl_attention_mla import (
from atom.plugin.sglang.models.deepseek_mla import (
setup_deepseek_for_sglang,
)

Expand Down
75 changes: 75 additions & 0 deletions atom/plugin/sglang/models/deepseek_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

"""Model-level DeepSeek MLA patching for SGLang plugin mode.

This module owns the monkey-patch entrypoints that adapt DeepSeek MLA models to
SGLang plugin mode. The heavy DeepSeek-specific forward and weight helpers live
in `atom.plugin.sglang.models.deepseek_mla_forward`.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch

from atom.plugin.sglang.models.deepseek_mla_forward import (
forward_sgl_plugin_mode,
init_sgl_attrs,
process_mla_kv_b_proj_after_loading,
)

if TYPE_CHECKING:
from atom.models.deepseek_v2 import DeepseekV2MLAAttention


def setup_deepseek_for_sglang(model) -> None:
"""Patch a DeepSeek V2/V3 model for SGLang plugin mode."""
config = model.config

# Store atom_config for the OOT wrapper before install-time hooks run.
if not hasattr(model, "atom_config"):
from atom.config import get_current_atom_config

model.atom_config = get_current_atom_config()

kv_cache_dtype = model.atom_config.kv_cache_dtype

# Initialise SGLang's MLA TP context before patching per-layer forwards.
from sglang.srt.configs.model_config import is_deepseek_nsa
from sglang.srt.layers.communicator import get_attn_tp_context

get_attn_tp_context().init_context(config.q_lora_rank, is_deepseek_nsa(config))

from atom.models.deepseek_v2 import DeepseekV2MLAAttention

for module in model.modules():
if isinstance(module, DeepseekV2MLAAttention):
_patch_mla_attention_for_sglang(module, config, kv_cache_dtype)


def _patch_mla_attention_for_sglang(
attn: "DeepseekV2MLAAttention",
config: Any,
kv_cache_dtype: str = "bf16",
) -> None:
"""Patch one DeepSeek MLA layer for SGLang plugin mode."""
init_sgl_attrs(attn, config, kv_cache_dtype)

def patched_forward(
positions: torch.Tensor,
hidden_states: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
from atom.plugin.sglang.models.base_model_wrapper import (
get_current_forward_batch,
)

kwargs["forward_batch"] = get_current_forward_batch()
return forward_sgl_plugin_mode(attn, positions, hidden_states, **kwargs)

attn.forward = patched_forward
attn.process_weights_after_loading = lambda: process_mla_kv_b_proj_after_loading(
attn
)
Loading
Loading