Skip to content

Latest commit

 

History

History
101 lines (67 loc) · 4.44 KB

File metadata and controls

101 lines (67 loc) · 4.44 KB

Custom MLX Primitive (Optional)

This document describes gather_qmm_swiglu, a custom C++ Metal primitive implemented in mlx_local/ as an extension to MLX. It is not part of released MLX and must be built locally.

What this is

mlx_local/ is a local checkout of upstream MLX (ml-explore/mlx) at the commit pinned by MLX_REF in integrations/mlx_local_integration/setup_mlx_local.sh (default: 185b06d9...) with ~800 lines of custom C++ and Metal shader code adding the GatherQMMSwiGLU primitive. This fuses gate projection + up projection + SwiGLU activation for quantized MoE experts into a single GPU dispatch, eliminating multiple kernel launches per expert per layer during decode.

The primitive is exposed as mx.gather_qmm_swiglu() in Python when the custom build is active.

What it does

During MoE decode, each active expert normally requires separate kernel launches for:

  1. Dequantize + matmul (gate projection)
  2. Dequantize + matmul (up projection)
  3. SiLU activation
  4. Elementwise multiply (gate * up)

gather_qmm_swiglu fuses all four into a single Metal kernel launch per expert. At decode (M=1), where dispatch overhead dominates over compute, this reduces per-layer latency.

When to use it

  • If you want MoE decode speedups on GLM-4.7-Flash or Qwen3-30B-A3B (models where ZMLX auto-skips on stock MLX).
  • If you are prototyping fused MLX primitives for potential upstream contribution.

On stock MLX (pip install mlx), ZMLX auto-detects that gather_qmm_swiglu is unavailable and skips the fused paths. No action needed.

Set up mlx_local/

mlx_local/ is not shipped as part of ZMLX; it is intended as a local-only directory (gitignored) created by cloning MLX and applying a patch.

Recommended:

bash integrations/mlx_local_integration/setup_mlx_local.sh

Manual (equivalent):

git clone https://github.com/ml-explore/mlx.git mlx_local
cd mlx_local
git checkout 185b06d9efc1c869540eccfb5baff853fff3659d
git apply <REPO_ROOT>/integrations/mlx_local_integration/gather_qmm_swiglu.patch

Build

cd mlx_local
python3 setup.py build_ext --inplace
# Limit CPU usage during build if desired:
# CMAKE_BUILD_PARALLEL_LEVEL=4 python3 setup.py build_ext --inplace

Then make sure mlx_local/python is on your Python path before the stock MLX:

export PYTHONPATH=<REPO_ROOT>/mlx_local/python:<REPO_ROOT>/src:$PYTHONPATH
python3 -c "import mlx.core as mx; print(hasattr(mx, 'gather_qmm_swiglu'))"  # should print True

Validate

python3 -m zmlx.validate mlx-community/GLM-4.7-Flash-4bit --max-tokens 128 --runs 5
python3 -m zmlx.validate mlx-community/Qwen3-30B-A3B-Instruct-2507-4bit --max-tokens 128 --runs 5

Remove mlx_local/python from PYTHONPATH to revert to stock MLX.

Measured results (M4 Max 36 GB)

Model Decode (baseline -> patched) Change Fidelity Capsule
GLM-4.7-Flash-4bit 76.8 -> 83.5 tok/s +8.8% 15/15 configs identical benchmarks/repro_capsules/glm_stress_m4_20260204.json
Qwen3-30B-A3B-4bit 106.6 -> 115.0 tok/s +7.9% 200/200 tokens identical benchmarks/repro_capsules/qwen3_a3b_moe_mlp_m4max_20260205.json

Note: the GLM row is a 15-config stress suite (5 prompts × 3 lengths); the 76.8/83.5 values are the mean of per-config median decode tok/s.

Additional GLM capsule (shared shared_experts SwiGLU fusion, 200 tokens, 3 runs): benchmarks/repro_capsules/glm47_flash_shared_experts_swiglu_m4max_20260205_1d9ee0e.json.

Upstream plan

See UPSTREAM_PLAN.md. The intent is to contribute gather_qmm_swiglu to upstream MLX once it has been validated across more models and hardware.

Known constraints

  • N must be divisible by 8, K by 512.
  • Only transpose=True and mode='affine' are implemented.
  • CPU fallback exists but is not optimized (Metal GPU path only).

Experimental Router Flags (Qwen)

These are off by default and intended for controlled benchmarks only.

  • ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1
    • uses Qwen argpartition(logits) + top-k softmax routing path
  • ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS_TOPK=1
    • enables fused Metal top-k softmax on top of the argpartition(logits) path
    • requires ZMLX_QWEN_ROUTER_ARGPARTITION_LOGITS=1