Skip to content

[ATOM SGLang] SGL plugin Attention Refractory#863

Open
ZhiweiYan-96 wants to merge 11 commits into
ROCm:mainfrom
zejunchen-zejun:zhiwei/attn_refrac_integrated
Open

[ATOM SGLang] SGL plugin Attention Refractory#863
ZhiweiYan-96 wants to merge 11 commits into
ROCm:mainfrom
zejunchen-zejun:zhiwei/attn_refrac_integrated

Conversation

@ZhiweiYan-96
Copy link
Copy Markdown
Contributor

@ZhiweiYan-96 ZhiweiYan-96 commented May 21, 2026

ATOM SGLang Attention Refactor

Status

Image

Summary

This RFC proposes a staged refactor of the ATOM SGLang plugin attention stack. The goal is to make SGLang-specific runtime, model adaptation, and attention backend responsibilities explicit.

The current direction is:

  • Decouple generic SGLang full-attention backend code from model-specific DeepSeek MLA code.
  • Route DeepSeek MLA through SGLangDeepseekMLAAttention as an explicit model-level attention adapter.
  • Extract SGLang runtime state and ForwardBatch -> ATOM forward_context bridging into scoped runtime utilities.
  • Introduce a first model adapter registry via SGLangModelAdapterSpec so existing special cases are declared instead of hard-coded.
  • Split ATOMAttnBackendForSgl by backend lifecycle responsibility.
  • Keep shared metadata / kernel-call reuse experimental until interfaces are proven stable.

Background

The existing SGLang plugin support grew through several overlapping concerns:

  • ATOMAttnBackendForSgl handles metadata construction, cache writes, CUDA graph metadata, MHA/MLA dispatch, speculative modes, and kernel calls.
  • DeepSeek MLA model-specific logic used to live close to generic SGLang attention backend code.
  • base_model_wrapper.py collected generic wrapper logic, runtime state, model-specific flags, and forward-context bridging.
  • Some model adaptations were hard-coded by architecture name, for example DeepSeek patching and Qwen3.5 prepare-time config remapping.

Recent branches split these concerns:

This PR holds all the change from :

Goals

  1. Make file ownership and runtime ownership obvious.
  2. Keep generic SGLang full-attention backend code free of model-specific DeepSeek MLA semantics.
  3. Provide a consistent way to express model adaptation needs.
  4. Preserve existing supported model behavior.
  5. Create extension points for future V3.2 sparse indexer and V4 hybrid attention work.

Target architecture

Model adapter layer
  Qwen3.5 outer wrapper
  DeepSeek MLA semantic adapter
  DeepSeek MTP draft wrapper
  future V3.2 / V4 model-specific adapters

SGLang runtime layer
  SGLangForwardBatchMetadata
  SGLangPluginRuntime
  plugin_runtime_scope
  model adapter registry

SGLang framework attention layer
  RadixAttention
  ATOMAttnBackendForSgl
  SGLang token/KV pools
  decode / extend / graph lifecycle

Kernel interface layer
  ForwardMetadata / AttentionMetaData-like fields
  KV indices / indptr / page table layout
  aiter / triton kernel call interfaces

Refactor Tracks

Track 1: Attention File and Responsibility Decoupling

This track has two parts: first, separate generic SGLang full-attention files from DeepSeek-specific MLA files; second, split the remaining full-attention backend by responsibility instead of keeping all backend lifecycle logic in ATOMAttnBackendForSgl.

The first problem was file ownership. Generic SGLang full-attention backend code and DeepSeek-specific MLA helpers lived too close together. The refactor moves them apart:

atom/plugin/sglang/attention_backend/full_attention/
  full_attention_backend.py
  radix_attention.py

atom/plugin/sglang/models/
  deepseek_mla.py
  deepseek_mla_attention.py
  deepseek_mla_forward.py

This track is represented by attn_model_decouple. Its purpose is not to change runtime behavior. Its purpose is to establish ownership:

  • full_attention/ owns SGLang framework backend behavior.
  • models/deepseek_mla*.py owns DeepSeek model-specific MLA behavior.
  • RadixAttention remains the SGLang framework adapter.

This is the foundation for every later PR. Without this move, DeepSeek-specific logic would continue to leak into generic backend files.

The second problem is that ATOMAttnBackendForSgl still owns too many backend responsibilities after the file move. The refactor starts splitting it into focused helpers:

full_attention/
  full_attention_backend.py  # backend orchestrator and dispatch
  metadata.py                # ForwardMetadata
  kv_cache.py                # cache layout shuffle helpers
  pa_metadata.py             # PA persistent metadata helpers

Future splits can continue along the same responsibility boundary:

metadata_builder.py          # decode / extend metadata construction
cuda_graph.py                # CUDA graph capture/replay metadata
decode.py                    # decode dispatch
extend.py                    # extend/prefill dispatch

This split should not be top-level MHA backend vs MLA backend. MHA and MLA are dispatch cases, but metadata construction, KV cache layout, CUDA graph, PA metadata, and speculative modes cut across both.

Track 2: SGLangDeepseekMLAAttention

DeepSeek MLA cannot be treated like Qwen-style q/k/v attention. Its model forward passes latent MLA state:

hidden_states_or_q_c
kv_c_normed
k_pe
positions
q_scale

These are model-level semantic inputs, not backend-ready attention inputs. The refactor introduces SGLangDeepseekMLAAttention to own this lowering:

DeepseekV2MLAAttention.forward()
  -> self.mla_attn(...)
  -> SGLangDeepseekMLAAttention
  -> RadixAttention
  -> ATOMAttnBackendForSgl

This track is represented by attn_refrac_share_model.

The important design choice is that the wrapper sits above RadixAttention. RadixAttention is a SGLang framework adapter: it expects attention-ready tensors and a ForwardBatch. DeepSeek MLA, however, calls self.mla_attn(...) with
model-specific latent state. The wrapper is the place where that semantic gap is closed.

SGLangDeepseekMLAAttention is responsible for:

  • resolving forward_batch from explicit kwargs or current runtime context,
  • gathering scattered runtime inputs when SGLang TP communication scatters them,
  • projecting q_c to final query when needed,
  • splitting and applying RoPE to q/k RoPE components,
  • choosing absorbed vs non-absorbed MLA path,
  • staging latent KV into SGLang's KV pool,
  • calling the underlying RadixAttention / SGLang backend,
  • applying DeepSeek MLA V up-projection and output projection.

The absorbed path roughly lowers:

q_input + kv_c_normed + k_pe
  -> q projection
  -> q_nope absorbed BMM
  -> latent KV attention
  -> V up-projection
  -> o_proj

The non-absorbed path roughly lowers:

q_input + kv_c_normed + k_pe
  -> q projection
  -> kv_b_proj expands latent KV into K/V
  -> standard q/k/v-shaped attention
  -> o_proj

The wrapper should not own generic backend concerns such as page table construction, CUDA graph replay, or PA metadata buffers. Those stay under the SGLang framework backend.

It solves several problems:

  • avoids monkey-patching the entire DeepSeek attention forward path,
  • keeps absorbed / non-absorbed MLA dispatch near the model semantic boundary,
  • prevents generic SGLang full-attention backend code from needing to understand DeepSeek latent tensors,
  • gives future DeepSeek variants a clear place to attach model-level semantic adapters.

Track 3: SGLang Runtime Bridge

The SGLang wrapper must translate framework runtime state into what ATOM model code expects. This includes:

  • current ForwardBatch,
  • PP proxy tensors,
  • dummy / idle batch handling,
  • ATOM plugin framework/config globals,
  • ATOM forward_context,
  • target/draft wrapper state for speculative decoding.

The refactor extracts this into atom/plugin/sglang/runtime:

runtime/context.py
  SGLangForwardBatchMetadata
  get_current_forward_batch
  plugin_runtime_scope

runtime/forward_context.py
  SGLangPluginRuntime

runtime/model_arch.py
  model adapter registry

This track is represented by attn_refractory_runtime.

There are three distinct runtime problems:

1. Current SGLang Forward State

Some model-level adapters need access to the current SGLang ForwardBatch without threading it through every intermediate ATOM model call. The runtime package provides SGLangForwardBatchMetadata for this:

SGLangForwardBatchMetadata
  forward_batch
  pp_proxy_tensors
  save_kv_cache

It also keeps get_current_forward_batch() as a narrow compatibility path for adapters such as RadixAttention fallback lookup and DeepSeek MLA wrapper input resolution.

2. ATOM Plugin Global State

ATOM still has process-global plugin state:

atom.plugin.prepare._CURRENT_FRAMEWORK
atom.config._current_atom_config

SGLang target/draft model wrappers can coexist, especially under speculative decoding. plugin_runtime_scope() scopes those globals around construction, load, patch, and forward sections so one wrapper does not leak runtime state into another.

3. SGLang ForwardBatch to ATOM forward_context

Many ATOM model ops read atom.utils.forward_context.get_forward_context() for information such as:

  • positions,
  • prefill/decode mode,
  • dummy/idle run status,
  • graph batch size,
  • DP token distribution,
  • attention metadata used by MoE padding or auxiliary ops.

SGLangPluginRuntime is a scoped adapter for model wrappers:

with SGLangPluginRuntime(
    atom_config=atom_config,
    forward_batch=forward_batch,
    positions=positions,
    input_ids=input_ids,
    input_embeds=input_embeds,
) as runtime:
    hidden_states = model(
        input_ids=runtime.input_ids,
        positions=runtime.positions,
        inputs_embeds=runtime.input_embeds,
    )
    hidden_states = runtime.trim_output(hidden_states)

It owns:

  • binding the current ForwardBatch,
  • materializing ATOM-compatible dummy inputs for SGLang idle batches,
  • setting and resetting ATOM forward context,
  • resolving DP token counts for ATOM-side metadata,
  • trimming ATOM dummy outputs back to SGLang-visible token counts.

The important boundary is:

SGLang model wrapper -> ATOM model body

The runtime bridge is not for ATOMAttnBackendForSgl kernel dispatch. The full-attention backend should use SGLang ForwardBatch and backend metadata directly.

This separation prevents a common failure mode: pushing model-wrapper runtime concerns into the attention backend simply because both happen to see ForwardBatch.

Track 4: Model Adapter Interface

The current code already has multiple model adaptation patterns:

  • Qwen3 / Qwen3Moe use the default base wrapper.
  • Qwen3Next needs GDN runtime context binding.
  • Qwen3.5 keeps the upstream SGLang outer wrapper and swaps in an ATOM language model stack.
  • DeepSeek V3 MLA needs install-time attention adaptation.
  • DeepSeek MTP needs a draft wrapper, config override, layer-id retagging, and embed/head sharing.
  • Future V3.2 needs sparse indexer side cache / top-k buffer handling.
  • Future V4 needs hybrid state/cache/metadata ownership.

Using more booleans in ModelArchSpec does not scale. The first implementation step is SGLangModelAdapterSpec:

@dataclass(frozen=True)
class SGLangModelAdapterSpec:
    wrapper_binds_gdn_context: bool = False
    prepare_config: Callable[[Any, str], None] | None = None
    install_adapters: Callable[[Any], None] | None = None

This is intentionally small. It replaces hard-coded special cases without claiming to be a complete future-proof framework.

Current uses:

  • DeepseekV3ForCausalLM uses install_adapters=setup_deepseek_for_sglang.

  • Qwen3NextForCausalLM keeps wrapper_binds_gdn_context=True.

  • Qwen3_5ForConditionalGeneration and Qwen3_5MoeForConditionalGeneration use prepare_config=apply_prepare_model_adaptations.

Future lifecycle hooks may include:

  • construct_model,
  • load_weights,
  • post_load,
  • runtime_policy,
  • output_policy,
  • cache owner registration,
  • metadata adapter registration.

The key point is that new models should declare adaptation needs through a registry instead of adding new one-off branches in the generic wrapper.

This track is represented by sglang_model_adapter. It is intentionally a small first step: it codifies existing DeepSeek and Qwen3.5 special cases without trying to solve every future model family in one PR.

The intended lifecycle for future adapters is:

prepare_config
  Patch or remap config before ATOM model construction.

construct_model
  Optional custom construction for outer wrappers, draft models, or hybrid runtimes.

install_adapters
  Patch/wrap submodules after construction, such as DeepSeek MLA attention wrappers.

load_weights / post_load
  Optional custom checkpoint mapping, shared-weight binding, or post-load transforms.

runtime_policy
  Declare whether the model needs default runtime, GDN context, context-only forward,
  or a custom runtime bridge.

output_policy
  Declare how hidden states become SGLang-visible outputs.

The first PR only implements the two hooks that are already needed by existing
code:

prepare_config
install_adapters

It deliberately leaves the rest as design direction. That keeps review scope small while still moving away from boolean flags.

Existing mappings:

Qwen3 / Qwen3Moe
  default adapter

Qwen3Next
  wrapper_binds_gdn_context=True

Qwen3.5 / Qwen3.5-MoE
  prepare_config=apply_prepare_model_adaptations

DeepSeek V3 MLA
  install_adapters=setup_deepseek_for_sglang

Future mappings should be additive:

DeepSeek MTP
  construct_model + load_weights + runtime policy + embed/head sharing

DeepSeek V3.2
  install indexer adapter + cache owner hook + sparse metadata adapter

DeepSeek V4
  custom construction + state cache owner + V4 metadata/runtime adapter

The adapter registry is therefore a coordination point, not a replacement for model-specific modules. Complex models should still keep their logic in focused files such as deepseek_mla_attention.py, deepseek_nextn_wrapper.py, or a future deepseek_v4_adapter.py.

@ZhiweiYan-96 ZhiweiYan-96 marked this pull request as ready for review May 21, 2026 09:18
Copilot AI review requested due to automatic review settings May 21, 2026 09:18
ZhiweiYan-96 and others added 8 commits May 21, 2026 09:22
Move the SGLang DeepSeek MLA runtime entry from legacy forward glue into
SGLangDeepseekMLAAttention while keeping RadixAttention and the full-attention
backend as the host/backend layers. Shrink deepseek_mla_forward.py into a
helper module and clarify absorbed vs non-absorbed path naming.
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the ATOM SGLang plugin attention stack to make SGLang runtime state, model-level adaptation (e.g., DeepSeek MLA), and full-attention backend responsibilities explicit and better separated. It introduces a small model-adapter registry, moves runtime/forward-context bridging into a dedicated runtime package, and splits the previously monolithic backend helpers into focused modules while keeping behavior aligned with existing supported models.

Changes:

  • Introduces atom.plugin.sglang.runtime (scoped runtime globals, forward-context bridge, and model adapter registry) and updates wrappers to use it.
  • Decouples DeepSeek MLA model adaptation into atom/plugin/sglang/models/deepseek_mla* and removes the old monolithic sgl_attention_mla.py.
  • Splits the SGLang full-attention backend into helper modules (metadata, kv_cache, pa_metadata) and updates import paths across plugin and core ops.

Reviewed changes

Copilot reviewed 24 out of 24 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/plugin/test_sglang_register.py Updates mocks/imports for the renamed full-attention backend module and additional model imports.
tests/plugin/test_sglang_model_wrapper.py Updates DeepSeek MLA setup-hook import path to the new models.deepseek_mla module.
atom/plugin/sglang/runtime/model_arch.py Adds SGLangModelAdapterSpec + registry for prepare/install hooks and wrapper flags.
atom/plugin/sglang/runtime/forward_context.py Adds SGLangPluginRuntime to bridge ForwardBatch into ATOM forward_context and handle dummy/idle batches.
atom/plugin/sglang/runtime/context.py Adds scoped runtime utilities (plugin_runtime_scope, forward-batch ContextVars, metadata binding helpers).
atom/plugin/sglang/runtime/init.py Exposes the runtime utilities as a public package surface.
atom/plugin/sglang/models/qwen3_5.py Switches to runtime package import and updates comment to reference MODEL_ARCH_SPECS.
atom/plugin/sglang/models/deepseek_nextn_wrapper.py Migrates draft wrapper to SGLangPluginRuntime + plugin_runtime_scope.
atom/plugin/sglang/models/deepseek_mla.py Adds install-time DeepSeek MLA patch entrypoint (setup_deepseek_for_sglang) in a model-owned module.
atom/plugin/sglang/models/deepseek_mla_forward.py Extracts DeepSeek MLA shared helper functions (BMM paths, weight post-load processing, KV staging).
atom/plugin/sglang/models/deepseek_mla_attention.py Adds SGLangDeepseekMLAAttention model-level adapter to lower latent MLA inputs into backend-ready attention calls.
atom/plugin/sglang/models/base_model_wrapper.py Replaces embedded runtime/context logic with atom.plugin.sglang.runtime and adapter-driven hooks.
atom/plugin/sglang/attention_backend/sgl_attention_mla.py Removes the old monolithic DeepSeek MLA SGLang plugin module.
atom/plugin/sglang/attention_backend/full_attention/radix_attention.py Updates fallback get_current_forward_batch import to runtime package.
atom/plugin/sglang/attention_backend/full_attention/pa_metadata.py Adds helper module for PA persistent metadata buffer allocation/build.
atom/plugin/sglang/attention_backend/full_attention/metadata.py Adds ForwardMetadata dataclass in its own module.
atom/plugin/sglang/attention_backend/full_attention/kv_cache.py Moves KV layout shuffle kernel + helper into a dedicated module.
atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py Refactors backend to use extracted helper modules and updates naming/imports.
atom/plugin/sglang/attention_backend/full_attention/init.py Adds package exports for full-attention backend components.
atom/plugin/sglang/attention_backend/attention_gdn.py Updates import path for SGLangForwardBatchMetadata to runtime package.
atom/plugin/register.py Updates custom attention backend import path to the new full-attention backend module.
atom/plugin/prepare.py Routes model-specific config preparation via the new model adapter spec (get_model_arch_spec).
atom/model_ops/attentions/aiter_attention.py Updates RadixAttention import path to the new full-attention location.
atom/model_ops/init.py Updates RadixAttention import path to the new full-attention location.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +160 to 172
with SGLangPluginRuntime(
atom_config=self.atom_config,
forward_batch=forward_batch,
positions=positions,
input_ids=input_ids,
input_embeds=input_embeds,
):
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
hidden_states=forward_batch.spec_info.hidden_states,
inputs_embeds=input_embeds,
)
@ZhiweiYan-96 ZhiweiYan-96 force-pushed the zhiwei/attn_refrac_integrated branch from 86024e8 to 8de3516 Compare May 21, 2026 09:26
Co-authored-by: Cursor <cursoragent@cursor.com>
Copilot AI review requested due to automatic review settings May 21, 2026 09:31
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 2 comments.

Comment on lines +100 to +117
"""Fuse q/k RMSNorm and q quant using ATOM's DeepSeek-V2 path."""

(q_quantized, q_scale), q_normed, k_nope_normed, _ = _fuse_rmsnorm_quant(
q,
attn.q_a_layernorm.weight,
attn.q_a_layernorm.eps,
k_nope,
attn.kv_a_layernorm.weight,
attn.kv_a_layernorm.eps,
None,
dtype_quant=attn.quant_dtype,
shuffle=False,
scale_shuffle_padding=False,
group_size=128,
quant_type=_linear_quant_type_value(attn.q_b_proj),
output_unquantized_inp1=output_unquantized_q,
transpose_scale=True,
)
Comment on lines +160 to 172
with SGLangPluginRuntime(
atom_config=self.atom_config,
forward_batch=forward_batch,
positions=positions,
input_ids=input_ids,
input_embeds=input_embeds,
):
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
hidden_states=forward_batch.spec_info.hidden_states,
inputs_embeds=input_embeds,
)
ZhiweiYan-96 and others added 2 commits May 21, 2026 09:41
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Copilot AI review requested due to automatic review settings May 21, 2026 10:21
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 1 comment.

Comment on lines +166 to 172
):
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
hidden_states=forward_batch.spec_info.hidden_states,
inputs_embeds=input_embeds,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants