Skip to content

[Plugin][MLA] Tolerate rotary_emb=None for NoPE-only MLA models (Kimi-Linear)#792

Open
ChuanLi1101 wants to merge 1 commit into
mainfrom
chuali/mla-tolerate-rotary-emb-none
Open

[Plugin][MLA] Tolerate rotary_emb=None for NoPE-only MLA models (Kimi-Linear)#792
ChuanLi1101 wants to merge 1 commit into
mainfrom
chuali/mla-tolerate-rotary-emb-none

Conversation

@ChuanLi1101
Copy link
Copy Markdown
Collaborator

Summary

Plugin-mode MLA in atom/plugin/attention_mla.py::forward_impl_plugin_mode unconditionally accesses self.rotary_emb.{cos_cache, sin_cache, is_neox_style} and calls self.rotary_emb(...), which fails for NoPE-only MLA models that upstream vLLM correctly constructs with rotary_emb=None in MLAModules(...).

The visible failure surface today is Kimi-Linear-48B-A3B-Instruct, which crashes at cudagraph capture under TP=2 / TP=4 with:

File "/app/ATOM/atom/plugin/attention_mla.py", line 825, in forward_impl_plugin_mode
    self.rotary_emb.cos_cache,
AttributeError: 'NoneType' object has no attribute 'cos_cache'

_mla_plugin_mode_init already gates the precomputed rotary_emb_cos_sin_cache buffer registration on self.rotary_emb is not None (current main, around line 932), so handling rotary_emb=None at runtime is an extension of an existing pattern, not a new contract.

Change

Detect nope_only_mla = self.rotary_emb is None once at the top of forward_impl_plugin_mode and route around the three rotary_emb access sites:

  1. Mixed prefill+decode batch (not decode_only): if NoPE-only, write kv_cache via aiter.concat_and_cache_mla (no-rope path) and skip ops.concat_and_cache_mla_rope_fused / self.rotary_emb(...).
  2. Decode-only batch (decode_only): if NoPE-only, kv_cache is now written by branch (1)'s not decode_only or nope_only_mla predicate; fall through to the existing manual-concat / fp8-quant decode_q construction path that's already used today for mixed batches.

For NoPE-only models qk_rope_head_dim == 0, so the splits/concats degenerate naturally (decode_q_pe is empty along the last dim) and no further shape adaptation is needed.

Why this is safe

  • Pure additive routing: every existing rope code path is byte-for-byte unchanged when self.rotary_emb is not None.
  • No new dependency: the kv_cache write used by the new branch (aiter.concat_and_cache_mla) is already exercised today by the non-fused-rope branch a few lines above in the same function.
  • Mirrors an existing guard: _mla_plugin_mode_init already does if self.rotary_emb is not None: for the rotary_emb_cos_sin_cache buffer registration. This PR just extends the same pattern to the forward path.

Trade-off note

For Kimi-Linear, the NoPE-only decode_only batch now goes through a two-step "write kv_cache → build decode_q" path instead of the one-step aiter.fused_qk_rope_concat_and_cache_mla op. This costs one extra kernel launch on the decode hot path but does no rope work, and Kimi-Linear MLA is only 7 of 27 layers in the model, so the impact is small relative to unblocking the model entirely. A follow-up could add an aiter.concat_and_cache_mla_no_rope-style fused op if profiling shows the extra launch matters.

Test plan

Related

AI assistance disclosure

This PR was prepared with AI assistance (Cursor + Claude). Every changed line was reviewed by a human submitter and syntax-checked locally. Runtime validation against Kimi-Linear is queued as the next step.

Plugin-mode MLA in atom/plugin/attention_mla.py::forward_impl_plugin_mode
unconditionally accesses self.rotary_emb.{cos_cache, sin_cache,
is_neox_style} and calls self.rotary_emb(...), which fails for NoPE-only
MLA models that upstream vLLM correctly constructs with rotary_emb=None
in MLAModules(...).

The visible failure surface today is Kimi-Linear-48B-A3B-Instruct, which
crashes at cudagraph capture under TP=2/4 with:

  AttributeError: 'NoneType' object has no attribute 'cos_cache'

_mla_plugin_mode_init already gates the precomputed
rotary_emb_cos_sin_cache buffer registration on
`self.rotary_emb is not None`, so handling rotary_emb=None at runtime is
an extension of an existing pattern, not a new contract.

Detect `nope_only_mla = self.rotary_emb is None` once at the top of
forward_impl_plugin_mode and route around the three rotary_emb access
sites:

 1. Mixed prefill+decode batch (not decode_only): if NoPE-only, write
    kv_cache via aiter.concat_and_cache_mla (no-rope path) and skip
    concat_and_cache_mla_rope_fused / self.rotary_emb(...).
 2. Decode-only batch (decode_only): if NoPE-only, kv_cache is now
    written by branch (1)'s `not decode_only or nope_only_mla`
    predicate; fall through to the existing manual-concat / fp8-quant
    decode_q construction path used today for mixed batches.

For NoPE-only models qk_rope_head_dim is 0; splits/concats degenerate
naturally so no further shape adaptation is needed. The kv_cache write
path used by the new branch (aiter.concat_and_cache_mla) is already
exercised today by the non-fused-rope branch in the same function.

Discovery context: vllm-project/vllm#40697 (Kimi-Linear FlyDSL KDA
decode wiring) is currently blocked at cudagraph capture on this same
code path.

AI-assistance disclosure: prepared with AI assistance (Cursor + Claude);
every changed line was reviewed by a human submitter and syntax-checked.
Runtime validation against Kimi-Linear is the next step.

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: chuali <chuali@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants