Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1279,14 +1279,6 @@ bool llama_context::ensure_sched_mtp() {
return false;
}

llama_memory_context_ptr mctx = memory->init_full();
if (!mctx) {
LLAMA_LOG_ERROR("%s: failed to init memory context for MTP reserve\n", __func__);
sched_mtp.reset();
gf_res_prev_mtp.reset();
return false;
}

const uint32_t n_bb = model.mtp_assistant->hparams.n_embd_backbone;
auto data = std::make_shared<llama_ubatch::data_t>();
data->token.resize(1);
Expand Down Expand Up @@ -1321,6 +1313,14 @@ bool llama_context::ensure_sched_mtp() {
ub.output = data->output.data();
ub.data = data;

llama_memory_context_ptr mctx = kv_iswa->init_mtp(0, ub);
if (!mctx) {
LLAMA_LOG_ERROR("%s: failed to init memory context for MTP reserve\n", __func__);
sched_mtp.reset();
gf_res_prev_mtp.reset();
return false;
}

const uint32_t save_n_outputs = n_outputs;
n_outputs = 1;

Expand Down
4 changes: 2 additions & 2 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,7 @@ void llm_graph_result::set_params(const llm_graph_params & params) {

llm_graph_context::llm_graph_context(const llm_graph_params & params) :
arch (params.arch),
gtype (params.gtype),
hparams (params.hparams),
cparams (params.cparams),
ubatch (params.ubatch),
Expand Down Expand Up @@ -1899,7 +1900,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
const bool v_trans = v->nb[1] > v->nb[2];

// split the batch into streams if needed
const auto n_stream = k->ne[3];
const auto n_stream = (gtype == LLM_GRAPH_TYPE_MTP) ? 1 : k->ne[3];

q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);

Expand Down Expand Up @@ -1930,7 +1931,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
if (v->type == GGML_TYPE_F32) {
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
}

cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
Expand Down
1 change: 1 addition & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_

struct llm_graph_context {
const llm_arch arch;
const llm_graph_type gtype;

const llama_hparams & hparams;
const llama_cparams & cparams;
Expand Down
2 changes: 1 addition & 1 deletion src/models/gemma4-assistant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ static void gemma4_mtp_build_one_step(
ggml_tensor * Qcur = gctx.build_lora_mm(mtp.layers[il].wq, cur);
cb(Qcur, "Qcur", il);

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, 1);

Qcur = gctx.build_norm(Qcur, mtp.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Expand Down