Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
47106bd
fork: drop pre-norm Qwen MTP (#149) ahead of upstream MTP lineage
TheTom Jun 8, 2026
19c7616
llama: avoid copying logits during prompt decode in MTP (#23198)
am17an May 17, 2026
672261e
llama : MTP clean-up (#23269)
ggerganov May 19, 2026
41aef76
model : clarify MTP layer comment in qwen35.cpp [no ci] (#23338)
danbev May 19, 2026
1c9f2df
common/speculative : fix nullptr crash in get_devices_str (#23386)
ggerganov May 20, 2026
697d19c
Move to backend sampling for MTP draft path (#23287)
gaugarg-nv May 20, 2026
57cffb2
llama-graph: fix null-buffer crash in llm_graph_input_attn_kv_iswa fo…
ssfdre38 May 21, 2026
676b3ed
mtp: use inp_out_ids for skipping logit computation (#23433)
am17an May 21, 2026
46cc11a
model : add NVFP4 MTP scale tensors (#23563)
michaelw9999 May 23, 2026
ab11a71
llama: add llm_graph_input_mtp (#23643)
am17an May 29, 2026
337e04c
speculative : fix n_outputs_max and remove draft-simple auto-enable (…
ggerganov Jun 1, 2026
7560617
tests : add support for qwen3 SSM archs (#24031)
ggerganov Jun 3, 2026
0c809f0
qwen35: use post-norm hidden state for MTP (#24025)
am17an Jun 3, 2026
d55c844
hparams : refactor `hparams.n_layer` (#24060)
ggerganov Jun 5, 2026
6d9a4a8
spec : fix vocab compatibility check (#24256)
CISC Jun 7, 2026
d1e70aa
llama : add Gemma4 MTP (#23398)
am17an Jun 7, 2026
2f756e6
fork: reconcile MTP lineage with TurboQuant+ KV cache
TheTom Jun 8, 2026
de389e0
vulkan: fix f16vec4 casts in cm1 FA quantized K/V decode paths
TheTom Jun 9, 2026
469c9c4
mtp: fix standalone all-nextn draft KV cache (gemma4-assistant)
TheTom Jun 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
const bool skip = (arg == "--spec-type");

if (!skip) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
}
auto & tmp = arg_to_options[arg];
auto opt = *tmp.first;
Expand Down Expand Up @@ -903,7 +907,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
const bool skip = (arg == "--spec-type");

if (!skip) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
}
auto opt = *arg_to_options[arg];
std::string val;
Expand Down Expand Up @@ -1037,11 +1045,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
// we define here to make sure it's included in llama-gen-docs
if (ex == LLAMA_EXAMPLE_COMPLETION) {
params.use_jinja = false; // disable jinja by default

} else if (ex == LLAMA_EXAMPLE_MTMD) {
params.use_jinja = false; // disable jinja by default
params.sampling.temp = 0.2; // lower temp by default for better quality

} else if (ex == LLAMA_EXAMPLE_SERVER) {
params.n_parallel = -1; // auto by default
}
Expand All @@ -1062,7 +1068,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
sampler_type_names.pop_back(); // remove last semicolon
}


/**
* filter options by example
* rules:
Expand All @@ -1076,7 +1081,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
};


add_opt(common_arg(
{"-h", "--help", "--usage"},
"print usage and exit",
Expand Down Expand Up @@ -3606,6 +3610,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.draft.p_min = std::stof(value);
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN"));
add_opt(common_arg(
{"--spec-draft-backend-sampling"},
{"--no-spec-draft-backend-sampling"},
string_format("offload draft sampling to the backend (default: %s)",
params.speculative.draft.backend_sampling ? "enabled" : "disabled"),
[](common_params & params, bool value) {
params.speculative.draft.backend_sampling = value;
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_BACKEND_SAMPLING"));
add_opt(common_arg(
{"--spec-draft-device", "-devd", "--device-draft"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
Expand Down Expand Up @@ -4141,6 +4154,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.ngram_mod.n_match = 24;
params.speculative.ngram_mod.n_min = 48;
params.speculative.ngram_mod.n_max = 64;

// TODO: not sure if this is a good config - explore more settings and potentially enable it
//params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
//params.speculative.ngram_map_k4v.size_n = 8;
//params.speculative.ngram_map_k4v.size_m = 24;
//params.speculative.ngram_map_k4v.min_hits = 2;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));

Expand Down
24 changes: 1 addition & 23 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1258,29 +1258,6 @@ common_init_result::common_init_result(common_params & params) :
cparams.n_samplers = pimpl->samplers_seq_config.size();
}

// [TAG_RS_STATE_ROLLBACK_SUPPORT]
// TODO: ngram speculative methods require checkpointing in addition to partial RS rollback
// currently this is not supported. so we disable the partial rollback
if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) {
auto & types = params.speculative.types;

for (int i = 0; i < (int) types.size(); i++) {
if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) {
continue;
}
if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) {
continue;
}

cparams.n_rs_seq = 0;

LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__,
common_speculative_type_to_str(types[i]).c_str());

break;
}
}

llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
Expand Down Expand Up @@ -1562,6 +1539,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &

cparams.n_ctx = params.n_ctx;
cparams.n_seq_max = params.n_parallel;
cparams.n_outputs_max = params.n_outputs_max;
cparams.n_rs_seq = params.speculative.need_n_rs_seq();
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch;
Expand Down
11 changes: 7 additions & 4 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,13 @@ struct common_params_model {

// draft-model-based speculative decoding parameters
struct common_params_speculative_draft {
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
int32_t n_max = 3; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding

float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.0f; // minimum speculative decoding probability (greedy)

bool backend_sampling = true; // offload draft sampling to the backend (default: on)

common_params_model mparams;

Expand Down Expand Up @@ -428,6 +430,7 @@ struct common_params {
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_outputs_max = 0; // max outputs supported by the context (0 = derive)
int32_t n_sequences = 1; // number of sequences to decode
int32_t grp_attn_n = 1; // group-attention factor
int32_t grp_attn_w = 512; // group-attention width
Expand Down
2 changes: 1 addition & 1 deletion common/ngram-map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ void common_ngram_map_draft(common_ngram_map & map,
draft.push_back(inp[match_pos + n + i]);
}

LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
LOG_DBG("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
key_offset, slot_max,
curr_key.key_num, draft.size());

Expand Down
Loading
Loading