diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 822d97b829dc..2228e260f994 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1279,14 +1279,6 @@ bool llama_context::ensure_sched_mtp() { return false; } - llama_memory_context_ptr mctx = memory->init_full(); - if (!mctx) { - LLAMA_LOG_ERROR("%s: failed to init memory context for MTP reserve\n", __func__); - sched_mtp.reset(); - gf_res_prev_mtp.reset(); - return false; - } - const uint32_t n_bb = model.mtp_assistant->hparams.n_embd_backbone; auto data = std::make_shared(); data->token.resize(1); @@ -1321,6 +1313,14 @@ bool llama_context::ensure_sched_mtp() { ub.output = data->output.data(); ub.data = data; + llama_memory_context_ptr mctx = kv_iswa->init_mtp(0, ub); + if (!mctx) { + LLAMA_LOG_ERROR("%s: failed to init memory context for MTP reserve\n", __func__); + sched_mtp.reset(); + gf_res_prev_mtp.reset(); + return false; + } + const uint32_t save_n_outputs = n_outputs; n_outputs = 1; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e8db32545c75..1d985f3fce87 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -947,6 +947,7 @@ void llm_graph_result::set_params(const llm_graph_params & params) { llm_graph_context::llm_graph_context(const llm_graph_params & params) : arch (params.arch), + gtype (params.gtype), hparams (params.hparams), cparams (params.cparams), ubatch (params.ubatch), @@ -1899,7 +1900,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( const bool v_trans = v->nb[1] > v->nb[2]; // split the batch into streams if needed - const auto n_stream = k->ne[3]; + const auto n_stream = (gtype == LLM_GRAPH_TYPE_MTP) ? 1 : k->ne[3]; q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0); @@ -1930,7 +1931,6 @@ ggml_tensor * llm_graph_context::build_attn_mha( if (v->type == GGML_TYPE_F32) { v = ggml_cast(ctx0, v, GGML_TYPE_F16); } - cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); cb(cur, LLAMA_TENSOR_NAME_FATTN, il); diff --git a/src/llama-graph.h b/src/llama-graph.h index a77b2033a4de..dd007c8d3570 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -742,6 +742,7 @@ using llm_graph_get_rows_fn = std::function