From b1297e7f4bc41534fbc085a7a597be5d763451e1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jun 2026 17:18:25 +0300 Subject: [PATCH 1/3] server : unify mtmd image processing with post-decode callback Add mtmd_helper_post_decode_callback to mtmd_helper_eval_chunk_single and mtmd_helper_decode_image_chunk, invoked after each successful llama_decode(). The server uses this callback to run common_speculative_process on the batch, keeping the draft context in sync with the target context during multi-modal prompt processing. This eliminates the need for a second process_chunk call on ctx_dft and removes the llama-ext.h workaround include. Assisted-by: pi:llama.cpp/Qwen3.6-27B --- tools/mtmd/mtmd-helper.cpp | 35 ++++++++++++++++++++++++++++----- tools/mtmd/mtmd-helper.h | 12 +++++++++-- tools/server/server-common.cpp | 8 ++++++-- tools/server/server-common.h | 8 ++++++-- tools/server/server-context.cpp | 29 ++++++++++++++------------- 5 files changed, 67 insertions(+), 25 deletions(-) diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 2d11a33804a3..6e0f45a7d69f 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -247,7 +247,9 @@ int32_t mtmd_helper_decode_image_chunk( llama_pos n_past, llama_seq_id seq_id, int32_t n_batch, - llama_pos * new_n_past) { + llama_pos * new_n_past, + mtmd_helper_post_decode_callback callback, + void * user_data) { GGML_ASSERT(n_batch > 0); auto chunk_type = mtmd_input_chunk_get_type(chunk); const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; @@ -302,10 +304,23 @@ int32_t mtmd_helper_decode_image_chunk( int32_t ret = llama_decode(lctx, batch_embd_view); if (ret != 0) { LOG_ERR("failed to decode %s\n", name); - llama_set_causal_attn(lctx, true); // restore causal attn + if (use_non_causal) { + llama_set_causal_attn(lctx, true); + } return ret; } + if (callback != nullptr) { + ret = callback(&batch_embd_view, user_data); + if (ret != 0) { + LOG_ERR("post-decode callback failed\n"); + if (use_non_causal) { + llama_set_causal_attn(lctx, true); + } + return ret; + } + } + LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1); i_batch++; @@ -327,7 +342,9 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, llama_seq_id seq_id, int32_t n_batch, bool logits_last, - llama_pos * new_n_past) { + llama_pos * new_n_past, + mtmd_helper_post_decode_callback callback, + void * user_data) { GGML_ASSERT(n_batch > 0); int32_t ret; llama_batch text_batch = llama_batch_init(n_batch, 0, 1); @@ -360,6 +377,14 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, llama_batch_free(text_batch); return ret; } + if (callback != nullptr) { + ret = callback(&text_batch, user_data); + if (ret != 0) { + LOG_ERR("post-decode callback failed\n"); + llama_batch_free(text_batch); + return ret; + } + } *new_n_past += text_batch.n_tokens; } @@ -379,7 +404,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0); float * embd = mtmd_get_output_embd(ctx); - ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past); + ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past, callback, user_data); if (ret != 0) { LOG_ERR("failed to decode %s\n", name); llama_batch_free(text_batch); @@ -411,7 +436,7 @@ int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, bool chunk_logits_last = (i == n_chunks - 1) && logits_last; auto chunk = mtmd_input_chunks_get(chunks, i); - int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past); + int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past, nullptr, nullptr); if (res != 0) { LOG_ERR("failed to eval chunk %zu\n", i); return res; diff --git a/tools/mtmd/mtmd-helper.h b/tools/mtmd/mtmd-helper.h index 164b7c6689d9..5c557e4d14f6 100644 --- a/tools/mtmd/mtmd-helper.h +++ b/tools/mtmd/mtmd-helper.h @@ -80,8 +80,12 @@ MTMD_API int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, bool logits_last, llama_pos * new_n_past); +typedef int32_t (*mtmd_helper_post_decode_callback)(struct llama_batch * batch, void * user_data); + // works like mtmd_helper_eval_chunks(), but only for a single chunk // this function is NOT thread-safe +// callback invoked after each successful llama_decode() call within the mtmd helpers. +// returns 0 on success, non-zero to signal failure (which will abort the eval). MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, struct llama_context * lctx, const mtmd_input_chunk * chunk, @@ -89,7 +93,9 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, llama_seq_id seq_id, int32_t n_batch, bool logits_last, - llama_pos * new_n_past); + llama_pos * new_n_past, + mtmd_helper_post_decode_callback callback, + void * user_data); // helper function to decode an image whose embeddings have already been calculated // this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention) @@ -101,7 +107,9 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx, llama_pos n_past, llama_seq_id seq_id, int32_t n_batch, - llama_pos * new_n_past); + llama_pos * new_n_past, + mtmd_helper_post_decode_callback callback, + void * user_data); // // video input helpers (requires ffmpeg/ffprobe installed on the system) diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 9f3caac8f723..78d3cfb5affd 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -537,7 +537,9 @@ int32_t server_tokens::process_chunk( size_t idx, llama_pos pos, int32_t seq_id, - size_t & n_tokens_out) const { + size_t & n_tokens_out, + mtmd_helper_post_decode_callback callback, + void * user_data) const { const auto & chunk = find_chunk(idx); const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; @@ -551,7 +553,9 @@ int32_t server_tokens::process_chunk( seq_id, n_batch, true, // logits last - &new_n_past); + &new_n_past, + callback, + user_data); SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); if (result != 0) { LOG_ERR("mtmd_helper_eval failed with status %d", result); diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 249b97c2fadb..e3012d98f160 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -5,6 +5,7 @@ #include "llama.h" #include "chat.h" #include "mtmd.h" +#include "mtmd-helper.h" #define JSON_ASSERT GGML_ASSERT #include @@ -217,14 +218,17 @@ struct server_tokens { // make sure all text tokens are within the vocab range bool validate(const struct llama_context * ctx) const; - // encode and decode the image chunk + // encode and decode the image chunk. + // if callback is non-NULL, it is invoked after each successful llama_decode() call. int32_t process_chunk( llama_context * ctx, mtmd_context * mctx, size_t idx, llama_pos pos, int32_t seq_id, - size_t & n_tokens_out) const; + size_t & n_tokens_out, + mtmd_helper_post_decode_callback callback, + void * user_data) const; server_tokens clone() const; }; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 95a2a6ed9956..b2fe18cdfc85 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -17,7 +17,7 @@ #include "ggml-cpp.h" -// TODO: tmp until the mtmd draft processing is refactored [TAG_MTMD_DRAFT_PROCESSING] +// TODO: tmp until the common_get_device_memory_data is refactored #include "../../src/llama-ext.h" #include @@ -2984,9 +2984,21 @@ struct server_context_impl { // check if we should process the image while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { - // process the image + // process the image; the post-decode callback keeps the draft context in sync + static auto mtmd_post_decode_callback = [](struct llama_batch * batch, void * user_data) -> int32_t { + server_context_impl * impl = static_cast(user_data); + if (impl->spec && !common_speculative_process(impl->spec.get(), *batch)) { + return -1; + } + return 0; + }; + size_t n_tokens_out = 0; - int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + int32_t res = input_tokens.process_chunk( + ctx_tgt, mctx, + slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, + n_tokens_out, + mtmd_post_decode_callback, this); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); @@ -2994,17 +3006,6 @@ struct server_context_impl { continue; } - if (ctx_dft && llama_get_ctx_other(ctx_dft.get()) != ctx_tgt) { - // TODO: in the future, figure out how to infuse target embeddings to the images - // for now, we skip this for simplicity - // maybe we simply need to call `common_speculative_process()` on the mtmd batches in the `process_chunk` above? - // [TAG_MTMD_DRAFT_PROCESSING] - res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); - if (res != 0) { - GGML_ABORT("failed to process multi-modal data on draft context\n"); - } - } - slot.n_prompt_tokens_processed += n_tokens_out; // add the image chunk to cache From 30d49c022e2dfe36d4bbfb62e11c2825c1c11c4c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jun 2026 17:28:24 +0300 Subject: [PATCH 2/3] mtmd : narrow-down batch scope --- tools/mtmd/mtmd-helper.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 6e0f45a7d69f..91af1448f0c5 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -347,13 +347,15 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, void * user_data) { GGML_ASSERT(n_batch > 0); int32_t ret; - llama_batch text_batch = llama_batch_init(n_batch, 0, 1); auto chunk_type = mtmd_input_chunk_get_type(chunk); if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { size_t n_tokens; const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + // LOG_INF("decoding text chunk, n_tokens = %zu\n", n_tokens); + llama_batch text_batch = llama_batch_init(n_batch, 0, 1); + size_t i = 0; while (i < n_tokens) { // split into batches text_batch.n_tokens = 0; // clear the batch @@ -388,6 +390,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, *new_n_past += text_batch.n_tokens; } + llama_batch_free(text_batch); } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; int64_t t0 = ggml_time_ms(); @@ -397,7 +400,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, ret = mtmd_encode_chunk(ctx, chunk); if (ret != 0) { LOG_ERR("failed to encode %s slice\n", name); - llama_batch_free(text_batch); return ret; } @@ -407,14 +409,12 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past, callback, user_data); if (ret != 0) { LOG_ERR("failed to decode %s\n", name); - llama_batch_free(text_batch); return ret; } } else { GGML_ABORT("chunk type not supported"); } - llama_batch_free(text_batch); return 0; } From f4f142a293e135549f1bed299bf2972ef9315c95 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jun 2026 17:42:28 +0300 Subject: [PATCH 3/3] server : move mtmd post-decode callback into process_chunk Move the post-decode callback construction inside server_tokens::process_chunk() so that server-common.h no longer depends on mtmd-helper.h. The caller now passes ctx_tgt and ctx_dft directly. Assisted-by: pi:llama.cpp/Qwen3.6-27B --- tools/server/server-common.cpp | 21 +++++++++++++-------- tools/server/server-common.h | 10 ++++------ tools/server/server-context.cpp | 15 +++------------ 3 files changed, 20 insertions(+), 26 deletions(-) diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 78d3cfb5affd..81641dcb041d 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -531,31 +531,36 @@ bool server_tokens::validate(const struct llama_context * ctx) const { return true; } +// post-decode callback: decode the batch on the draft context +static int32_t server_tokens_process_chunk_callback(struct llama_batch * batch, void * user_data) { + llama_context * ctx_dft = static_cast(user_data); + return llama_decode(ctx_dft, *batch); +} + int32_t server_tokens::process_chunk( - llama_context * ctx, mtmd_context * mctx, + llama_context * ctx_tgt, + llama_context * ctx_dft, size_t idx, llama_pos pos, int32_t seq_id, - size_t & n_tokens_out, - mtmd_helper_post_decode_callback callback, - void * user_data) const { + size_t & n_tokens_out) const { const auto & chunk = find_chunk(idx); const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; SRV_INF("processing %s...\n", name); - int32_t n_batch = llama_n_batch(ctx); + int32_t n_batch = llama_n_batch(ctx_tgt); int64_t t0 = ggml_time_ms(); llama_pos new_n_past; // unused for now - int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, + int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx_tgt, chunk.get(), pos, seq_id, n_batch, true, // logits last &new_n_past, - callback, - user_data); + ctx_dft ? server_tokens_process_chunk_callback : nullptr, + ctx_dft); SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); if (result != 0) { LOG_ERR("mtmd_helper_eval failed with status %d", result); diff --git a/tools/server/server-common.h b/tools/server/server-common.h index e3012d98f160..bc4e0b5a3be9 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -5,7 +5,6 @@ #include "llama.h" #include "chat.h" #include "mtmd.h" -#include "mtmd-helper.h" #define JSON_ASSERT GGML_ASSERT #include @@ -219,16 +218,15 @@ struct server_tokens { bool validate(const struct llama_context * ctx) const; // encode and decode the image chunk. - // if callback is non-NULL, it is invoked after each successful llama_decode() call. + // if ctx_dft is non-NULL, the batch is also decoded on the draft context. int32_t process_chunk( - llama_context * ctx, mtmd_context * mctx, + llama_context * ctx_tgt, + llama_context * ctx_dft, size_t idx, llama_pos pos, int32_t seq_id, - size_t & n_tokens_out, - mtmd_helper_post_decode_callback callback, - void * user_data) const; + size_t & n_tokens_out) const; server_tokens clone() const; }; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index b2fe18cdfc85..2238b950c47c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2984,21 +2984,12 @@ struct server_context_impl { // check if we should process the image while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { - // process the image; the post-decode callback keeps the draft context in sync - static auto mtmd_post_decode_callback = [](struct llama_batch * batch, void * user_data) -> int32_t { - server_context_impl * impl = static_cast(user_data); - if (impl->spec && !common_speculative_process(impl->spec.get(), *batch)) { - return -1; - } - return 0; - }; - + // process the image; ctx_dft keeps the draft context in sync size_t n_tokens_out = 0; int32_t res = input_tokens.process_chunk( - ctx_tgt, mctx, + mctx, ctx_tgt, ctx_dft.get(), slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, - n_tokens_out, - mtmd_post_decode_callback, this); + n_tokens_out); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); send_error(slot, "failed to process image", ERROR_TYPE_SERVER);