From ac63658fc90c93df674d8af91f61232f8388828f Mon Sep 17 00:00:00 2001 From: chuali Date: Thu, 14 May 2026 11:53:30 -0700 Subject: [PATCH] [Plugin][MLA] Tolerate rotary_emb=None for NoPE-only MLA models Plugin-mode MLA in atom/plugin/attention_mla.py::forward_impl_plugin_mode unconditionally accesses self.rotary_emb.{cos_cache, sin_cache, is_neox_style} and calls self.rotary_emb(...), which fails for NoPE-only MLA models that upstream vLLM correctly constructs with rotary_emb=None in MLAModules(...). The visible failure surface today is Kimi-Linear-48B-A3B-Instruct, which crashes at cudagraph capture under TP=2/4 with: AttributeError: 'NoneType' object has no attribute 'cos_cache' _mla_plugin_mode_init already gates the precomputed rotary_emb_cos_sin_cache buffer registration on `self.rotary_emb is not None`, so handling rotary_emb=None at runtime is an extension of an existing pattern, not a new contract. Detect `nope_only_mla = self.rotary_emb is None` once at the top of forward_impl_plugin_mode and route around the three rotary_emb access sites: 1. Mixed prefill+decode batch (not decode_only): if NoPE-only, write kv_cache via aiter.concat_and_cache_mla (no-rope path) and skip concat_and_cache_mla_rope_fused / self.rotary_emb(...). 2. Decode-only batch (decode_only): if NoPE-only, kv_cache is now written by branch (1)'s `not decode_only or nope_only_mla` predicate; fall through to the existing manual-concat / fp8-quant decode_q construction path used today for mixed batches. For NoPE-only models qk_rope_head_dim is 0; splits/concats degenerate naturally so no further shape adaptation is needed. The kv_cache write path used by the new branch (aiter.concat_and_cache_mla) is already exercised today by the non-fused-rope branch in the same function. Discovery context: vllm-project/vllm#40697 (Kimi-Linear FlyDSL KDA decode wiring) is currently blocked at cudagraph capture on this same code path. AI-assistance disclosure: prepared with AI assistance (Cursor + Claude); every changed line was reviewed by a human submitter and syntax-checked. Runtime validation against Kimi-Linear is the next step. Co-authored-by: Claude Signed-off-by: chuali Co-authored-by: Cursor --- atom/plugin/attention_mla.py | 70 +++++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 21 deletions(-) diff --git a/atom/plugin/attention_mla.py b/atom/plugin/attention_mla.py index 84dedf069..d4d0e596c 100644 --- a/atom/plugin/attention_mla.py +++ b/atom/plugin/attention_mla.py @@ -714,26 +714,20 @@ def forward_impl_plugin_mode( decode_only = has_decode and not has_prefill - if not decode_only: - if not hasattr(self, "_has_fused_rope_cache"): - self._has_fused_rope_cache = hasattr( - ops, "concat_and_cache_mla_rope_fused" - ) - if kv_cache.numel() > 0 and self._has_fused_rope_cache: - ops.concat_and_cache_mla_rope_fused( - positions, - q[..., self.qk_nope_head_dim :], - k_pe.squeeze(1), - k_c_normed, - self.rotary_emb_cos_sin_cache, - self.rotary_emb.is_neox_style, - attn_metadata.plugin_metadata.slot_mapping, - kv_cache, - self.kv_cache_dtype, - layer._k_scale, - ) - else: - self.rotary_emb(positions, q[..., self.qk_nope_head_dim :], k_pe) + # NoPE-only MLA models (e.g., Kimi-Linear) construct ``MLAModules`` + # with ``rotary_emb=None`` upstream because no rope is applied. Detect + # that case once here and route around every ``self.rotary_emb`` / + # ``rotary_emb_cos_sin_cache`` access below; the kv_cache write and + # decode_q construction otherwise follow the same code paths as the + # rope case, just without the rope step. + nope_only_mla = self.rotary_emb is None + + if not decode_only or nope_only_mla: + if nope_only_mla: + # No rope to apply; write kv_cache plainly. ``k_pe`` may have + # zero size along the last dim for fully NoPE-only models, in + # which case ``aiter.concat_and_cache_mla`` still writes the + # ``k_c_normed`` (kv_lora) portion correctly. if kv_cache.numel() > 0: aiter.concat_and_cache_mla( k_c_normed, @@ -743,6 +737,35 @@ def forward_impl_plugin_mode( kv_cache_dtype=self.kv_cache_dtype, scale=layer._k_scale, ) + else: + if not hasattr(self, "_has_fused_rope_cache"): + self._has_fused_rope_cache = hasattr( + ops, "concat_and_cache_mla_rope_fused" + ) + if kv_cache.numel() > 0 and self._has_fused_rope_cache: + ops.concat_and_cache_mla_rope_fused( + positions, + q[..., self.qk_nope_head_dim :], + k_pe.squeeze(1), + k_c_normed, + self.rotary_emb_cos_sin_cache, + self.rotary_emb.is_neox_style, + attn_metadata.plugin_metadata.slot_mapping, + kv_cache, + self.kv_cache_dtype, + layer._k_scale, + ) + else: + self.rotary_emb(positions, q[..., self.qk_nope_head_dim :], k_pe) + if kv_cache.numel() > 0: + aiter.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.plugin_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) if fp8_attention: kv_cache = kv_cache.view(current_platform.fp8_dtype()) @@ -795,7 +818,7 @@ def forward_impl_plugin_mode( transpose_bm=True, ) - if decode_only: + if decode_only and not nope_only_mla: decode_q = torch.empty( ( decode_ql_nope.shape[0], @@ -828,6 +851,11 @@ def forward_impl_plugin_mode( is_nope_first=True, ) else: + # Either mixed prefill+decode batch (kv_cache already written + # above) OR NoPE-only decode-only batch (kv_cache also written + # above via the ``nope_only_mla`` branch). In both cases we + # only need to construct ``decode_q`` here; rope was either + # applied above or not needed at all. if fp8_attention: assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]