diff --git a/atom/plugin/sglang/attention_backend/sgl_attention_mla.py b/atom/plugin/sglang/attention_backend/sgl_attention_mla.py index a13ce5a6a..8b190acdb 100644 --- a/atom/plugin/sglang/attention_backend/sgl_attention_mla.py +++ b/atom/plugin/sglang/attention_backend/sgl_attention_mla.py @@ -236,6 +236,43 @@ def mla_absorbed_bmm( return torch.bmm(inp.transpose(0, 1), weight).transpose(0, 1) +def mla_v_up_proj( + attn: DeepseekV2MLAAttention, + inp: torch.Tensor, + weight: torch.Tensor, + weight_scale: Optional[torch.Tensor], + weight_scale_k: Optional[torch.Tensor], + out_dim: int, +) -> torch.Tensor: + """Project MLA decode output to a flat o_proj input.""" + if _is_hip and ( + (_use_aiter_gfx95 and weight.dtype == torch.float8_e4m3fn) + or (get_is_capture_mode() and weight.dtype == torch.float8_e4m3fnuz) + ): + x = inp.transpose(0, 1) + out = torch.empty( + (inp.shape[0], attn.num_local_heads * out_dim), + device=inp.device, + dtype=torch.bfloat16, + ) + out_3d = out.view(inp.shape[0], attn.num_local_heads, out_dim) + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=x, + WQ=weight.transpose(-1, -2), + w_scale=weight_scale, + group_size=128, + YQ=out_3d, + transpose_bm=True, + transpose_bm_in=False, + dtype=torch.bfloat16, + ) + return out + + return mla_absorbed_bmm( + attn, inp, weight, weight_scale, weight_scale_k, out_dim + ).flatten(1, 2) + + # Forward: prepare → core def forward_sgl_prepare( attn: DeepseekV2MLAAttention, @@ -420,9 +457,9 @@ def forward_sgl_core( attn_output = attn_output.view(-1, attn.num_local_heads, attn.kv_lora_rank) # up-proj by w_vc - attn_bmm_output = mla_absorbed_bmm( + attn_bmm_output = mla_v_up_proj( attn, attn_output, attn.w_vc, attn.w_scale, attn.w_scale_v, attn.v_head_dim - ).flatten(1, 2) + ) return attn.o_proj(attn_bmm_output)