Skip to content

Consolidate MLA kv_b_proj sanitize/shard into shared helpers#1324

Open
scyyh11 wants to merge 1 commit into
ml-explore:mainfrom
scyyh11:mla-sanitize-shard-helpers
Open

Consolidate MLA kv_b_proj sanitize/shard into shared helpers#1324
scyyh11 wants to merge 1 commit into
ml-explore:mainfrom
scyyh11:mla-sanitize-shard-helpers

Conversation

@scyyh11
Copy link
Copy Markdown

@scyyh11 scyyh11 commented May 29, 2026

Summary

Several MLA models (DeepSeek-V3, DeepSeek-V3.2, GLM-4 MoE Lite, Kimi Linear, LongCat Flash) carry near-identical copies of two pieces of logic:

  • sanitize() — splitting the fused kv_b_proj weight into the absorbed embed_q / unembed_out (MultiLinear) projections, including the quantized-checkpoint path.
  • shard() — slicing those per-head projections across a distributed group.

This PR extracts both into shared helpers in mlx_lm/models/mla.py (split_kv_b_proj_weights and shard_mla_projections) and updates the five models to call them, removing ~230 lines of duplicated code.

Behavior

Behavior-preserving refactor — no functional change for existing checkpoints:

  • The splitter dequantizes → splits → re-quantizes the projections exactly as the per-model code did, so quantized checkpoints still load as QuantizedMultiLinear (no silent fall back to full precision).
  • It additionally tolerates affine checkpoints saved without explicit biases.
  • shard_mla_projections() raises a clear error when the head count is not divisible across ranks, instead of silently producing wrong shapes.

Forward passes, the KV cache, LoRA/DoRA, AWQ and DeepSeek-V2 are intentionally left untouched.

Tests

  • Added test_mla_split_requantizes_affine_weight_without_biases (covers the quantized split round-trip and the no-biases case).
  • Added test_mla_sharding_requires_divisible_heads.
  • Existing MLA model tests (test_deepseek_v3, test_deepseek_v32, …) pass.
python -m unittest discover tests/

Extract the duplicated kv_b_proj-splitting (sanitize) and head-sharding
logic from deepseek_v3, deepseek_v32, glm4_moe_lite, kimi_linear and
longcat_flash into split_kv_b_proj_weights() and shard_mla_projections()
in models/mla.py.

Behavior-preserving: the splitter dequantizes, splits and re-quantizes the
projections exactly as the per-model code did, so quantized checkpoints
still load as QuantizedMultiLinear. It additionally tolerates affine
weights saved without biases. shard_mla_projections() raises a clear error
when num_heads is not divisible across ranks.

Signed-off-by: Bvicii <yizhanhuang2002@gmail.com>
@scyyh11 scyyh11 force-pushed the mla-sanitize-shard-helpers branch from c584703 to fb1a981 Compare May 29, 2026 00:19
@scyyh11
Copy link
Copy Markdown
Author

scyyh11 commented May 29, 2026

@angeloskath would you mind taking a look when you get a chance?

This is a small, behavior-preserving refactor that pulls the duplicated MLA kv_b_proj sanitize/shard logic into shared helpers. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant