feat(step3.7): support NextN/MTP heads for speculative decoding#1
Open
eauchs wants to merge 6 commits into
Open
feat(step3.7): support NextN/MTP heads for speculative decoding#1eauchs wants to merge 6 commits into
eauchs wants to merge 6 commits into
Conversation
Step-3.7-Flash ships num_nextn_predict_layers (3) dense MTP blocks after
the main transformer (HF model.layers.{N..N+K-1}). The current converter
silently drops them and the runtime arch graph never declares an MTP
draft head, so `--spec-type draft-mtp` is unavailable for step35 GGUFs.
This change wires the full chain end-to-end:
conversion/step3.py
- Extend block_count by num_nextn_predict_layers.
- Stop filtering HF layers >= num_hidden_layers when MTP is enabled.
- Emit `step35.nextn_predict_layers` GGUF metadata.
- Pad per-layer arrays (layer_types, partial_rotary_factors,
swiglu_limits[_shared]) for the MTP blocks (full-attention, no clamp).
gguf-py/gguf/constants.py
- Register the NEXTN_* tensors on MODEL_ARCH.STEP35.
gguf-py/gguf/tensor_mapping.py
- Map Step-3.7's `transformer.shared_head.{norm,output}` to
NEXTN_SHARED_HEAD_{NORM,HEAD}.
src/models/step35.cpp + src/models/models.h
- Read `nextn_predict_layers` in load_arch_hparams; force the trailing
blocks to full-attention.
- Split tensor loading: trunk (MoE + shared expert + Step35 attn) for
[0, n_main) and MTP heads (dense SwiGLU MLP + nextn.* + per-block
shared head) for [n_main, n_layer).
- Trim the main forward to n_transformer_layers and expose
res->t_h_pre_norm so the draft head can seed AR steps.
- Implement llama_model_step35::graph_mtp following the Qwen3.5
single-block convention but with Step35 attention semantics
(head-wise sigmoid gate, q/k norm, partial rotary) and a dense
MLP (Step-3.7 MTP heads use mlp.{gate,up,down}_proj, not MoE).
Author
|
Quick heads-up on follow-up validation: I'm capped at 128 GB unified RAM on an M3 Max, which isn't enough to do the next step myself (downloading the ~400 GB BF16 release, re-converting with this patch, and running the model with the 3 MTP heads loaded alongside the main 198B/IQ4_XS stack + KV cache + working set). So I can confirm the conversion pipeline emits the expected tensor names and that the runtime graph builds + loads against a fresh GGUF, but:
…will need to happen on hardware with enough memory, ideally with someone who has the StepFun reference handy. Happy to iterate on review comments either way. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Wire up Step-3.7-Flash's
num_nextn_predict_layersMTP heads end-to-end so that--spec-type draft-mtpbecomes usable onstep35GGUFs. Today the converter silently drops layers ≥num_hidden_layersand the runtime graph never declares an MTP draft head, so the 3 dense MTP blocks that ship with Step-3.7-Flash sit unused.This patch follows the existing
qwen35moesingle-block MTP convention but adapts to Step-3.7 specifics:mlp.{gate,up,down}_proj), not MoE.transformer.shared_head.{norm,output}).g_proj, q/k RMSNorm, partial RoPE —n_rot = head_dim / 2for full-attention).tok_embd(Step-3.7 does not ship per-blockembed_tokens).Changes
conversion/step3.pyblock_countbynum_nextn_predict_layers; stop filtering HF layers ≥num_hidden_layerswhen MTP is enabled; emitstep35.nextn_predict_layers; pad per-layer arrays (layer_types,partial_rotary_factors,swiglu_limits[_shared]) so the MTP blocks are marked full-attention with no SwiGLU clamp.gguf-py/gguf/constants.pyNEXTN_{EH_PROJ, ENORM, HNORM, SHARED_HEAD_HEAD, SHARED_HEAD_NORM}onMODEL_ARCH.STEP35.gguf-py/gguf/tensor_mapping.pytransformer.shared_head.{norm,output}toNEXTN_SHARED_HEAD_{NORM,HEAD}.src/models/step35.cppnextn_predict_layersinload_arch_hparams, force trailing blocks to full-attention. Split tensor loading into trunk (MoE + shared expert + Step35 attn) for[0, n_main)and MTP heads (dense SwiGLU +nextn.*+ per-block shared head) for[n_main, n_layer). Trim main forward ton_transformer_layersand exposeres->t_h_pre_normso the draft head can seed AR steps. Implementllama_model_step35::graph_mtpwith Step35 attention + dense MLP + per-block shared head.src/models/models.hllama_model_step35::graph_mtp.Architecture notes
Step-3.7-Flashships:eh_proj,enorm,hnorm, full Step35 attention, dense SwiGLU MLP, andtransformer.shared_head.{norm, output}.The MTP graph implements the DeepSeek-V3 / Qwen3.5 NextN pattern:
h_nextis exposed asres->t_h_pre_normso the AR draft loop can chain MTP steps.For now
graph_mtponly uses the first MTP block (lowest index), matching the existingqwen35moesingle-block convention. The driver can iterate the draft graph multiple times manually to exploit all 3 heads; multi-block chained MTP can land in a follow-up.Test plan
cmake --build build -jclean on macOS arm64 (Metal + Accelerate + BLAS).test-llama-archspasses forstep35on Apple M3 Max / Accelerate / Meta backends.blk.{45,46,47}.{nextn.*, attn_*, ffn_*}via the updated tensor map.modeling_step3p7.pydoes not expose an MTP forward, so the graph here follows the DeepSeek-V3 / Qwen3.5 NextN convention. Maintainer review against your internal reference is the validation that will close the loop.Notes
stepfun-ai/Step-3.7-Flash; the MTP weights are already published in the HF safetensors.cache_reuseremains unsupported onstep35(per-layer RoPE dims breakK-shift); that's a separate issue.