Skip to content

feat(step3.7): support NextN/MTP heads for speculative decoding#1

Open
eauchs wants to merge 6 commits into
stepfun-ai:step3.7from
eauchs:step3.7-mtp
Open

feat(step3.7): support NextN/MTP heads for speculative decoding#1
eauchs wants to merge 6 commits into
stepfun-ai:step3.7from
eauchs:step3.7-mtp

Conversation

@eauchs

@eauchs eauchs commented May 31, 2026

Copy link
Copy Markdown

Summary

Wire up Step-3.7-Flash's num_nextn_predict_layers MTP heads end-to-end so that --spec-type draft-mtp becomes usable on step35 GGUFs. Today the converter silently drops layers ≥ num_hidden_layers and the runtime graph never declares an MTP draft head, so the 3 dense MTP blocks that ship with Step-3.7-Flash sit unused.

This patch follows the existing qwen35moe single-block MTP convention but adapts to Step-3.7 specifics:

  • MTP blocks use a dense SwiGLU MLP (mlp.{gate,up,down}_proj), not MoE.
  • Per-MTP-block shared LM head (transformer.shared_head.{norm,output}).
  • Step35 attention semantics (head-wise sigmoid g_proj, q/k RMSNorm, partial RoPE — n_rot = head_dim / 2 for full-attention).
  • Reuses the main tok_embd (Step-3.7 does not ship per-block embed_tokens).

Changes

File What
conversion/step3.py Extend block_count by num_nextn_predict_layers; stop filtering HF layers ≥ num_hidden_layers when MTP is enabled; emit step35.nextn_predict_layers; pad per-layer arrays (layer_types, partial_rotary_factors, swiglu_limits[_shared]) so the MTP blocks are marked full-attention with no SwiGLU clamp.
gguf-py/gguf/constants.py Register NEXTN_{EH_PROJ, ENORM, HNORM, SHARED_HEAD_HEAD, SHARED_HEAD_NORM} on MODEL_ARCH.STEP35.
gguf-py/gguf/tensor_mapping.py Map Step-3.7's transformer.shared_head.{norm,output} to NEXTN_SHARED_HEAD_{NORM,HEAD}.
src/models/step35.cpp Read nextn_predict_layers in load_arch_hparams, force trailing blocks to full-attention. Split tensor loading into trunk (MoE + shared expert + Step35 attn) for [0, n_main) and MTP heads (dense SwiGLU + nextn.* + per-block shared head) for [n_main, n_layer). Trim main forward to n_transformer_layers and expose res->t_h_pre_norm so the draft head can seed AR steps. Implement llama_model_step35::graph_mtp with Step35 attention + dense MLP + per-block shared head.
src/models/models.h Declare llama_model_step35::graph_mtp.

Architecture notes

Step-3.7-Flash ships:

  • Main stack: 45 layers (3 leading dense + 42 MoE, iSWA 3:1 pattern, window 512).
  • MTP heads: 3 dense blocks at HF layers 45, 46, 47, each with eh_proj, enorm, hnorm, full Step35 attention, dense SwiGLU MLP, and transformer.shared_head.{norm, output}.

The MTP graph implements the DeepSeek-V3 / Qwen3.5 NextN pattern:

h_norm  = RMSNorm_h(h_input)               # nextn.hnorm
e_norm  = RMSNorm_e(embed(prev_token))     # nextn.enorm
x       = nextn.eh_proj(concat(e_norm, h_norm, dim=0))   # [n_embd, n_tokens]
inpSA   = x
x       = input_layernorm(x)               # attn_norm
attn_out= step35_self_attn(x)              # g_proj sigmoid gate + q/k norm + partial RoPE
x       = inpSA + attn_out
x_fres  = x
x       = post_attention_layernorm(x)      # ffn_norm
x       = dense_swiglu_mlp(x)              # mlp.{gate,up,down}_proj
h_next  = x_fres + x
logits  = nextn.shared_head_head(nextn.shared_head_norm(h_next))

h_next is exposed as res->t_h_pre_norm so the AR draft loop can chain MTP steps.

For now graph_mtp only uses the first MTP block (lowest index), matching the existing qwen35moe single-block convention. The driver can iterate the draft graph multiple times manually to exploit all 3 heads; multi-block chained MTP can land in a follow-up.

Test plan

  • Build: cmake --build build -j clean on macOS arm64 (Metal + Accelerate + BLAS).
  • test-llama-archs passes for step35 on Apple M3 Max / Accelerate / Meta backends.
  • Tensor-mapping smoke test: all 51 MTP tensors from the published Step-3.7-Flash safetensors index (3 blocks × 17 tensors) resolve to blk.{45,46,47}.{nextn.*, attn_*, ffn_*} via the updated tensor map.
  • End-to-end conversion from the upstream HF Step-3.7-Flash safetensors (requires the original BF16 release; out of reach for this PR author).
  • Logit parity vs. StepFun's reference MTP forward (vLLM / SGLang). The HF modeling_step3p7.py does not expose an MTP forward, so the graph here follows the DeepSeek-V3 / Qwen3.5 NextN convention. Maintainer review against your internal reference is the validation that will close the loop.

Notes

  • No changes are needed in stepfun-ai/Step-3.7-Flash; the MTP weights are already published in the HF safetensors.
  • cache_reuse remains unsupported on step35 (per-layer RoPE dims break K-shift); that's a separate issue.

forforever73 and others added 6 commits May 27, 2026 11:28
Step-3.7-Flash ships num_nextn_predict_layers (3) dense MTP blocks after
the main transformer (HF model.layers.{N..N+K-1}). The current converter
silently drops them and the runtime arch graph never declares an MTP
draft head, so `--spec-type draft-mtp` is unavailable for step35 GGUFs.

This change wires the full chain end-to-end:

conversion/step3.py
  - Extend block_count by num_nextn_predict_layers.
  - Stop filtering HF layers >= num_hidden_layers when MTP is enabled.
  - Emit `step35.nextn_predict_layers` GGUF metadata.
  - Pad per-layer arrays (layer_types, partial_rotary_factors,
    swiglu_limits[_shared]) for the MTP blocks (full-attention, no clamp).

gguf-py/gguf/constants.py
  - Register the NEXTN_* tensors on MODEL_ARCH.STEP35.

gguf-py/gguf/tensor_mapping.py
  - Map Step-3.7's `transformer.shared_head.{norm,output}` to
    NEXTN_SHARED_HEAD_{NORM,HEAD}.

src/models/step35.cpp + src/models/models.h
  - Read `nextn_predict_layers` in load_arch_hparams; force the trailing
    blocks to full-attention.
  - Split tensor loading: trunk (MoE + shared expert + Step35 attn) for
    [0, n_main) and MTP heads (dense SwiGLU MLP + nextn.* + per-block
    shared head) for [n_main, n_layer).
  - Trim the main forward to n_transformer_layers and expose
    res->t_h_pre_norm so the draft head can seed AR steps.
  - Implement llama_model_step35::graph_mtp following the Qwen3.5
    single-block convention but with Step35 attention semantics
    (head-wise sigmoid gate, q/k norm, partial rotary) and a dense
    MLP (Step-3.7 MTP heads use mlp.{gate,up,down}_proj, not MoE).
@eauchs

eauchs commented May 31, 2026

Copy link
Copy Markdown
Author

Quick heads-up on follow-up validation: I'm capped at 128 GB unified RAM on an M3 Max, which isn't enough to do the next step myself (downloading the ~400 GB BF16 release, re-converting with this patch, and running the model with the 3 MTP heads loaded alongside the main 198B/IQ4_XS stack + KV cache + working set).

So I can confirm the conversion pipeline emits the expected tensor names and that the runtime graph builds + loads against a fresh GGUF, but:

  • Re-converting from the upstream HF safetensors
  • Loading the resulting GGUF and exercising --spec-type draft-mtp
  • Checking logit parity against your internal MTP reference (vLLM / SGLang)

…will need to happen on hardware with enough memory, ideally with someone who has the StepFun reference handy. Happy to iterate on review comments either way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants