[Plugin][MLA] Tolerate rotary_emb=None for NoPE-only MLA models (Kimi-Linear)#792
Open
ChuanLi1101 wants to merge 1 commit into
Open
[Plugin][MLA] Tolerate rotary_emb=None for NoPE-only MLA models (Kimi-Linear)#792ChuanLi1101 wants to merge 1 commit into
ChuanLi1101 wants to merge 1 commit into
Conversation
Plugin-mode MLA in atom/plugin/attention_mla.py::forward_impl_plugin_mode
unconditionally accesses self.rotary_emb.{cos_cache, sin_cache,
is_neox_style} and calls self.rotary_emb(...), which fails for NoPE-only
MLA models that upstream vLLM correctly constructs with rotary_emb=None
in MLAModules(...).
The visible failure surface today is Kimi-Linear-48B-A3B-Instruct, which
crashes at cudagraph capture under TP=2/4 with:
AttributeError: 'NoneType' object has no attribute 'cos_cache'
_mla_plugin_mode_init already gates the precomputed
rotary_emb_cos_sin_cache buffer registration on
`self.rotary_emb is not None`, so handling rotary_emb=None at runtime is
an extension of an existing pattern, not a new contract.
Detect `nope_only_mla = self.rotary_emb is None` once at the top of
forward_impl_plugin_mode and route around the three rotary_emb access
sites:
1. Mixed prefill+decode batch (not decode_only): if NoPE-only, write
kv_cache via aiter.concat_and_cache_mla (no-rope path) and skip
concat_and_cache_mla_rope_fused / self.rotary_emb(...).
2. Decode-only batch (decode_only): if NoPE-only, kv_cache is now
written by branch (1)'s `not decode_only or nope_only_mla`
predicate; fall through to the existing manual-concat / fp8-quant
decode_q construction path used today for mixed batches.
For NoPE-only models qk_rope_head_dim is 0; splits/concats degenerate
naturally so no further shape adaptation is needed. The kv_cache write
path used by the new branch (aiter.concat_and_cache_mla) is already
exercised today by the non-fused-rope branch in the same function.
Discovery context: vllm-project/vllm#40697 (Kimi-Linear FlyDSL KDA
decode wiring) is currently blocked at cudagraph capture on this same
code path.
AI-assistance disclosure: prepared with AI assistance (Cursor + Claude);
every changed line was reviewed by a human submitter and syntax-checked.
Runtime validation against Kimi-Linear is the next step.
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: chuali <chuali@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
5 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Plugin-mode MLA in
atom/plugin/attention_mla.py::forward_impl_plugin_modeunconditionally accessesself.rotary_emb.{cos_cache, sin_cache, is_neox_style}and callsself.rotary_emb(...), which fails for NoPE-only MLA models that upstream vLLM correctly constructs withrotary_emb=NoneinMLAModules(...).The visible failure surface today is Kimi-Linear-48B-A3B-Instruct, which crashes at cudagraph capture under TP=2 / TP=4 with:
_mla_plugin_mode_initalready gates the precomputedrotary_emb_cos_sin_cachebuffer registration onself.rotary_emb is not None(currentmain, around line 932), so handlingrotary_emb=Noneat runtime is an extension of an existing pattern, not a new contract.Change
Detect
nope_only_mla = self.rotary_emb is Noneonce at the top offorward_impl_plugin_modeand route around the threerotary_embaccess sites:not decode_only): if NoPE-only, writekv_cacheviaaiter.concat_and_cache_mla(no-rope path) and skipops.concat_and_cache_mla_rope_fused/self.rotary_emb(...).decode_only): if NoPE-only,kv_cacheis now written by branch (1)'snot decode_only or nope_only_mlapredicate; fall through to the existing manual-concat / fp8-quantdecode_qconstruction path that's already used today for mixed batches.For NoPE-only models
qk_rope_head_dim == 0, so the splits/concats degenerate naturally (decode_q_peis empty along the last dim) and no further shape adaptation is needed.Why this is safe
self.rotary_emb is not None.aiter.concat_and_cache_mla) is already exercised today by the non-fused-rope branch a few lines above in the same function._mla_plugin_mode_initalready doesif self.rotary_emb is not None:for therotary_emb_cos_sin_cachebuffer registration. This PR just extends the same pattern to the forward path.Trade-off note
For Kimi-Linear, the NoPE-only
decode_onlybatch now goes through a two-step "write kv_cache → build decode_q" path instead of the one-stepaiter.fused_qk_rope_concat_and_cache_mlaop. This costs one extra kernel launch on the decode hot path but does no rope work, and Kimi-Linear MLA is only 7 of 27 layers in the model, so the impact is small relative to unblocking the model entirely. A follow-up could add anaiter.concat_and_cache_mla_no_rope-style fused op if profiling shows the extra launch matters.Test plan
python -c 'import ast; ast.parse(open(...).read())').rocm/atom-dev:vllm-v0.19.0-nightly_*image built with this fix.rotary_emb is not None.Related
AI assistance disclosure
This PR was prepared with AI assistance (Cursor + Claude). Every changed line was reviewed by a human submitter and syntax-checked locally. Runtime validation against Kimi-Linear is queued as the next step.