Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
"Starcoder2ForCausalLM": "starcoder",
"Step3p5ForCausalLM": "step3",
"StepVLForConditionalGeneration": "step3",
"Step3p7ForConditionalGeneration": "step3",
"T5EncoderModel": "t5",
"T5ForConditionalGeneration": "t5",
"T5WithLMHeadModel": "t5",
Expand Down Expand Up @@ -279,6 +280,7 @@
"Sarashina2VisionForCausalLM": "sarashina2",
"SmolVLMForConditionalGeneration": "smolvlm",
"StepVLForConditionalGeneration": "step3",
"Step3p7ForConditionalGeneration": "step3",
"UltravoxModel": "ultravox",
"VoxtralForConditionalGeneration": "ultravox",
"YoutuVLForConditionalGeneration": "youtuvl",
Expand Down
4 changes: 2 additions & 2 deletions conversion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Ca
name, gen = item

# Skip multimodal tensors
if name.startswith(("mlp", "vit.", "vpm.", "siglip2.", "conformer.", "merger.", "resampler.", "sound_encoder.", "sound_projection.", "speech_embeddings.")) \
if name.startswith(("mlp", "vit.", "vit_large_projector.", "vpm.", "siglip2.", "conformer.", "merger.", "resampler.", "sound_encoder.", "sound_projection.", "speech_embeddings.")) \
or "visual." in name or "vision." in name or "audio." in name or "talker." in name \
or "vision_" in name or "audio_" in name or "sam_model" in name \
or "token2wav." in name or "code2wav." in name \
Expand Down Expand Up @@ -2552,7 +2552,7 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
# Step3-VL keeps text config under text_config but uses a custom top-level architecture.
# For text conversion we route to a dedicated text-only class.
# TODO: refactor this later to avoid adding exception here
if model_type == ModelType.TEXT and arch in ("StepVLForConditionalGeneration", "Sarashina2VisionForCausalLM"):
if model_type == ModelType.TEXT and arch in ("StepVLForConditionalGeneration", "Sarashina2VisionForCausalLM", "Step3p7ForConditionalGeneration"):
return arch

# if "architectures" is found in the sub-config, use that instead
Expand Down
85 changes: 68 additions & 17 deletions conversion/step3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .qwen import Qwen3Model


@ModelBase.register("StepVLForConditionalGeneration")
@ModelBase.register("StepVLForConditionalGeneration", "Step3p7ForConditionalGeneration")
class Step3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -95,10 +95,23 @@ class Step3VLTextModel(Qwen3Model):
model_arch = gguf.MODEL_ARCH.QWEN3


@ModelBase.register("Step3p5ForCausalLM")
@ModelBase.register("Step3p5ForCausalLM", "Step3p7ForConditionalGeneration")
class Step35Model(TextModel):
model_arch = gguf.MODEL_ARCH.STEP35

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Step-3.7 ships NextN/MTP heads (num_nextn_predict_layers > 0) after the
# main transformer stack. We expose them as extra blocks (blk.N..blk.N+K-1)
# so the model loader can find their tensors under blk.%d.nextn.* and the
# final dense MLP / shared head tensors.
nextn = int(self.hparams.get("num_nextn_predict_layers", 0))
self._nextn_predict_layers = nextn
self._n_main_layers = int(self.hparams["num_hidden_layers"])
if nextn > 0:
self.block_count = self._n_main_layers + nextn
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_gguf_parameters(self):
rope_theta = self.hparams.get("rope_theta")
if isinstance(rope_theta, list):
Expand All @@ -109,8 +122,11 @@ def set_gguf_parameters(self):

super().set_gguf_parameters()

layer_types = self.hparams.get("layer_types") or []
partial_rotary_factors = self.hparams.get("partial_rotary_factors") or []
nextn = self._nextn_predict_layers
n_main = self._n_main_layers

layer_types = list(self.hparams.get("layer_types") or [])
partial_rotary_factors = list(self.hparams.get("partial_rotary_factors") or [])
attn_other = self.hparams.get("attention_other_setting") or {}

n_head_base = self.hparams["num_attention_heads"]
Expand All @@ -119,9 +135,19 @@ def set_gguf_parameters(self):
n_head_swa = attn_other.get("num_attention_heads", n_head_base)
n_kv_swa = attn_other.get("num_attention_groups", n_kv_base)

layer_types = layer_types[: self.block_count]
partial_rotary_factors = partial_rotary_factors[: self.block_count]
# Trim the HF lists to the main transformer length first; the upstream
# config sometimes includes entries for the MTP heads, sometimes not.
layer_types = layer_types[:n_main]
partial_rotary_factors = partial_rotary_factors[:n_main]
assert [1.0 if lt == "sliding_attention" else 0.5 for lt in layer_types] == partial_rotary_factors

# MTP heads are full-attention only and use the full-attention rope branch
# (half rope dims, base rope_theta). Extend per-layer arrays accordingly so
# the GGUF carries one entry per block.
if nextn > 0:
layer_types += ["full_attention"] * nextn
partial_rotary_factors += [0.5] * nextn

head_arr = [n_head_swa if lt == "sliding_attention" else n_head_base for lt in layer_types]
kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types]
swa_pat = [lt == "sliding_attention" for lt in layer_types]
Expand Down Expand Up @@ -157,30 +183,39 @@ def set_gguf_parameters(self):

self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))

# Optional per-layer SwiGLU clamps.
# NextN/MTP heads — Step-3.7 ships num_nextn_predict_layers dense MTP
# blocks after the main transformer (model.layers.N..N+K-1 in HF).
if self._nextn_predict_layers > 0:
self.gguf_writer.add_nextn_predict_layers(self._nextn_predict_layers)

# Optional per-layer SwiGLU clamps. Pad with 0.0 for the MTP blocks
# (MTP heads use a dense MLP without clamping), so the array length
# matches block_count.
if (limits := self.hparams.get("swiglu_limits")) is not None:
limits_f = [0.0 if v is None else float(v) for v in limits[: self.block_count]]
limits_f = [0.0 if v is None else float(v) for v in limits[: self._n_main_layers]]
limits_f += [0.0] * self._nextn_predict_layers
self.gguf_writer.add_swiglu_clamp_exp(limits_f)
if (limits_shared := self.hparams.get("swiglu_limits_shared")) is not None:
limits_shared_f = [0.0 if v is None else float(v) for v in limits_shared[: self.block_count]]
limits_shared_f = [0.0 if v is None else float(v) for v in limits_shared[: self._n_main_layers]]
limits_shared_f += [0.0] * self._nextn_predict_layers
self.gguf_writer.add_swiglu_clamp_shexp(limits_shared_f)

@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item

# Map router bias (expert selection bias) to a GGUF bias tensor
if name.endswith(".moe.router_bias"):
name += ".bias"

return super().filter_tensors((name, gen))

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
# remove mtp layers
# Step-3.7 MTP heads live at model.layers.{N..N+K-1}.{eh_proj,enorm,hnorm,...}
# We keep them when nextn_predict_layers > 0 (mapped via NEXTN_* tensors)
# and drop them otherwise to preserve backward compatibility with text-only conversion.
if (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None:
il = int(m.group(1))
n_main = int(self.hparams.get("num_hidden_layers", self.block_count))
if il >= n_main:
if il >= self._n_main_layers and self._nextn_predict_layers == 0:
return
if name.endswith("norm.weight"):
data_torch += 1.0
Expand All @@ -203,11 +238,23 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
if isinstance(rope_theta, list):
rope_theta = rope_theta[0]
base = float(rope_theta)
if (dim := self.hparams.get("head_dim")) is None:
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
dim = int(dim)

freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if (storage_dim := self.hparams.get("head_dim")) is None:
storage_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
storage_dim = int(storage_dim)

# Llama 3 factors apply only to the rotary dims used by full_attention layers
# (partial_rotary_factor * head_dim). Remaining slots are padded with 1.0 so
# sliding_attention layers remain unaffected. set_gguf_parameters already
# guarantees at least one full_attention layer.
layer_types = (self.hparams.get("layer_types") or [])[: self.block_count]
partial_rotary_factors = (self.hparams.get("partial_rotary_factors") or [])[: self.block_count]
full_attention_factor = next(
float(f) for lt, f in zip(layer_types, partial_rotary_factors) if lt == "full_attention"
)
rotary_dim = int(storage_dim * full_attention_factor)

freqs = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim))

factor = float(rope_params.get("factor", 8.0))
low_freq_factor = float(rope_params.get("low_freq_factor", 1.0))
Expand All @@ -228,4 +275,8 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
rope_factors.append(1.0 / ((1.0 - smooth) / factor + smooth))

# Pad to head_dim/2 with 1.0 so non-scaled layers remain neutral.
if len(rope_factors) < storage_dim // 2:
rope_factors.extend([1.0] * (storage_dim // 2 - len(rope_factors)))

yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
6 changes: 6 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3937,6 +3937,12 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
# NextN/MTP heads (Step-3.7 num_nextn_predict_layers > 0)
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.LLAMA_EMBED: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2273,10 +2273,12 @@ class TensorNameMap:

MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: (
"model.layers.{bid}.shared_head.head",
"model.layers.{bid}.transformer.shared_head.output", # step3.7
),

MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: (
"model.layers.{bid}.shared_head.norm",
"model.layers.{bid}.transformer.shared_head.norm", # step3.7
),
}

Expand Down
7 changes: 7 additions & 0 deletions src/models/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -1900,5 +1900,12 @@ struct llama_model_step35 : public llama_model_base {
graph(const llama_model & model, const llm_graph_params & params);
};

// NextN/MTP draft head used by --spec-type draft-mtp.
// Steps the AR draft loop one position ahead using the pre-norm hidden
// state from the trunk and the embedding of the previous draft token.
struct graph_mtp : public llm_graph_context {
graph_mtp(const llama_model & model, const llm_graph_params & params);
};

std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
Loading