diff --git a/atom/plugin/attention_mla.py b/atom/plugin/attention_mla.py index 84dedf069..d4d0e596c 100644 --- a/atom/plugin/attention_mla.py +++ b/atom/plugin/attention_mla.py @@ -714,26 +714,20 @@ def forward_impl_plugin_mode( decode_only = has_decode and not has_prefill - if not decode_only: - if not hasattr(self, "_has_fused_rope_cache"): - self._has_fused_rope_cache = hasattr( - ops, "concat_and_cache_mla_rope_fused" - ) - if kv_cache.numel() > 0 and self._has_fused_rope_cache: - ops.concat_and_cache_mla_rope_fused( - positions, - q[..., self.qk_nope_head_dim :], - k_pe.squeeze(1), - k_c_normed, - self.rotary_emb_cos_sin_cache, - self.rotary_emb.is_neox_style, - attn_metadata.plugin_metadata.slot_mapping, - kv_cache, - self.kv_cache_dtype, - layer._k_scale, - ) - else: - self.rotary_emb(positions, q[..., self.qk_nope_head_dim :], k_pe) + # NoPE-only MLA models (e.g., Kimi-Linear) construct ``MLAModules`` + # with ``rotary_emb=None`` upstream because no rope is applied. Detect + # that case once here and route around every ``self.rotary_emb`` / + # ``rotary_emb_cos_sin_cache`` access below; the kv_cache write and + # decode_q construction otherwise follow the same code paths as the + # rope case, just without the rope step. + nope_only_mla = self.rotary_emb is None + + if not decode_only or nope_only_mla: + if nope_only_mla: + # No rope to apply; write kv_cache plainly. ``k_pe`` may have + # zero size along the last dim for fully NoPE-only models, in + # which case ``aiter.concat_and_cache_mla`` still writes the + # ``k_c_normed`` (kv_lora) portion correctly. if kv_cache.numel() > 0: aiter.concat_and_cache_mla( k_c_normed, @@ -743,6 +737,35 @@ def forward_impl_plugin_mode( kv_cache_dtype=self.kv_cache_dtype, scale=layer._k_scale, ) + else: + if not hasattr(self, "_has_fused_rope_cache"): + self._has_fused_rope_cache = hasattr( + ops, "concat_and_cache_mla_rope_fused" + ) + if kv_cache.numel() > 0 and self._has_fused_rope_cache: + ops.concat_and_cache_mla_rope_fused( + positions, + q[..., self.qk_nope_head_dim :], + k_pe.squeeze(1), + k_c_normed, + self.rotary_emb_cos_sin_cache, + self.rotary_emb.is_neox_style, + attn_metadata.plugin_metadata.slot_mapping, + kv_cache, + self.kv_cache_dtype, + layer._k_scale, + ) + else: + self.rotary_emb(positions, q[..., self.qk_nope_head_dim :], k_pe) + if kv_cache.numel() > 0: + aiter.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.plugin_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) if fp8_attention: kv_cache = kv_cache.view(current_platform.fp8_dtype()) @@ -795,7 +818,7 @@ def forward_impl_plugin_mode( transpose_bm=True, ) - if decode_only: + if decode_only and not nope_only_mla: decode_q = torch.empty( ( decode_ql_nope.shape[0], @@ -828,6 +851,11 @@ def forward_impl_plugin_mode( is_nope_first=True, ) else: + # Either mixed prefill+decode batch (kv_cache already written + # above) OR NoPE-only decode-only batch (kv_cache also written + # above via the ``nope_only_mla`` branch). In both cases we + # only need to construct ``decode_q`` here; rope was either + # applied above or not needed at all. if fp8_attention: assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]