Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 49 additions & 21 deletions atom/plugin/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,26 +714,20 @@ def forward_impl_plugin_mode(

decode_only = has_decode and not has_prefill

if not decode_only:
if not hasattr(self, "_has_fused_rope_cache"):
self._has_fused_rope_cache = hasattr(
ops, "concat_and_cache_mla_rope_fused"
)
if kv_cache.numel() > 0 and self._has_fused_rope_cache:
ops.concat_and_cache_mla_rope_fused(
positions,
q[..., self.qk_nope_head_dim :],
k_pe.squeeze(1),
k_c_normed,
self.rotary_emb_cos_sin_cache,
self.rotary_emb.is_neox_style,
attn_metadata.plugin_metadata.slot_mapping,
kv_cache,
self.kv_cache_dtype,
layer._k_scale,
)
else:
self.rotary_emb(positions, q[..., self.qk_nope_head_dim :], k_pe)
# NoPE-only MLA models (e.g., Kimi-Linear) construct ``MLAModules``
# with ``rotary_emb=None`` upstream because no rope is applied. Detect
# that case once here and route around every ``self.rotary_emb`` /
# ``rotary_emb_cos_sin_cache`` access below; the kv_cache write and
# decode_q construction otherwise follow the same code paths as the
# rope case, just without the rope step.
nope_only_mla = self.rotary_emb is None

if not decode_only or nope_only_mla:
if nope_only_mla:
# No rope to apply; write kv_cache plainly. ``k_pe`` may have
# zero size along the last dim for fully NoPE-only models, in
# which case ``aiter.concat_and_cache_mla`` still writes the
# ``k_c_normed`` (kv_lora) portion correctly.
if kv_cache.numel() > 0:
aiter.concat_and_cache_mla(
k_c_normed,
Expand All @@ -743,6 +737,35 @@ def forward_impl_plugin_mode(
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
else:
if not hasattr(self, "_has_fused_rope_cache"):
self._has_fused_rope_cache = hasattr(
ops, "concat_and_cache_mla_rope_fused"
)
if kv_cache.numel() > 0 and self._has_fused_rope_cache:
ops.concat_and_cache_mla_rope_fused(
positions,
q[..., self.qk_nope_head_dim :],
k_pe.squeeze(1),
k_c_normed,
self.rotary_emb_cos_sin_cache,
self.rotary_emb.is_neox_style,
attn_metadata.plugin_metadata.slot_mapping,
kv_cache,
self.kv_cache_dtype,
layer._k_scale,
)
else:
self.rotary_emb(positions, q[..., self.qk_nope_head_dim :], k_pe)
if kv_cache.numel() > 0:
aiter.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.plugin_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)

if fp8_attention:
kv_cache = kv_cache.view(current_platform.fp8_dtype())
Expand Down Expand Up @@ -795,7 +818,7 @@ def forward_impl_plugin_mode(
transpose_bm=True,
)

if decode_only:
if decode_only and not nope_only_mla:
decode_q = torch.empty(
(
decode_ql_nope.shape[0],
Expand Down Expand Up @@ -828,6 +851,11 @@ def forward_impl_plugin_mode(
is_nope_first=True,
)
else:
# Either mixed prefill+decode batch (kv_cache already written
# above) OR NoPE-only decode-only batch (kv_cache also written
# above via the ``nope_only_mla`` branch). In both cases we
# only need to construct ``decode_q`` here; rope was either
# applied above or not needed at all.
if fp8_attention:
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
Expand Down
Loading