Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions tools/mtmd/mtmd-helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ int32_t mtmd_helper_decode_image_chunk(
llama_pos n_past,
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past) {
llama_pos * new_n_past,
mtmd_helper_post_decode_callback callback,
void * user_data) {
GGML_ASSERT(n_batch > 0);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
Expand Down Expand Up @@ -302,10 +304,23 @@ int32_t mtmd_helper_decode_image_chunk(
int32_t ret = llama_decode(lctx, batch_embd_view);
if (ret != 0) {
LOG_ERR("failed to decode %s\n", name);
llama_set_causal_attn(lctx, true); // restore causal attn
if (use_non_causal) {
llama_set_causal_attn(lctx, true);
}
return ret;
}

if (callback != nullptr) {
ret = callback(batch_embd_view, user_data);
if (ret != 0) {
LOG_ERR("post-decode callback failed\n");
if (use_non_causal) {
llama_set_causal_attn(lctx, true);
}
return ret;
}
}

LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1);

i_batch++;
Expand Down Expand Up @@ -379,7 +394,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0);

float * embd = mtmd_get_output_embd(ctx);
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past, nullptr, nullptr);
if (ret != 0) {
LOG_ERR("failed to decode %s\n", name);
llama_batch_free(text_batch);
Expand Down
6 changes: 5 additions & 1 deletion tools/mtmd/mtmd-helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
bool logits_last,
llama_pos * new_n_past);

typedef int32_t (*mtmd_helper_post_decode_callback)(struct llama_batch batch, void * user_data);

// helper function to decode an image whose embeddings have already been calculated
// this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention)
// ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure
Expand All @@ -101,7 +103,9 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
llama_pos n_past,
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past);
llama_pos * new_n_past,
mtmd_helper_post_decode_callback callback,
void * user_data);

//
// video input helpers (requires ffmpeg/ffprobe installed on the system)
Expand Down
31 changes: 0 additions & 31 deletions tools/server/server-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,37 +539,6 @@ bool server_tokens::validate(const struct llama_context * ctx) const {
return true;
}

int32_t server_tokens::process_chunk(
llama_context * ctx,
mtmd_context * mctx,
size_t idx,
llama_pos pos,
int32_t seq_id,
size_t & n_tokens_out) const {
const auto & chunk = find_chunk(idx);
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
? "image" : "audio";
SRV_INF("processing %s...\n", name);
int32_t n_batch = llama_n_batch(ctx);
int64_t t0 = ggml_time_ms();
llama_pos new_n_past; // unused for now
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
chunk.get(),
pos,
seq_id,
n_batch,
true, // logits last
&new_n_past);
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
if (result != 0) {
LOG_ERR("mtmd_helper_eval failed with status %d", result);
n_tokens_out = 0;
return result;
}
n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
return 0;
}

server_tokens server_tokens::clone() const {
server_tokens res;
res.has_mtmd = has_mtmd;
Expand Down
9 changes: 0 additions & 9 deletions tools/server/server-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,6 @@ struct server_tokens {
// make sure all text tokens are within the vocab range
bool validate(const struct llama_context * ctx) const;

// encode and decode the image chunk
int32_t process_chunk(
llama_context * ctx,
mtmd_context * mctx,
size_t idx,
llama_pos pos,
int32_t seq_id,
size_t & n_tokens_out) const;

server_tokens clone() const;
};

Expand Down
104 changes: 27 additions & 77 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
#include "mtmd.h"
#include "mtmd-helper.h"

#include "ggml-cpp.h"

// TODO: tmp until the mtmd draft processing is refactored [TAG_MTMD_DRAFT_PROCESSING]
#include "../../src/llama-ext.h"

#include <algorithm>
#include <cstddef>
#include <cinttypes>
Expand Down Expand Up @@ -81,7 +76,6 @@ struct server_slot {
// multimodal
mtmd_context * mctx = nullptr;
mtmd::batch_ptr mbatch = nullptr;
std::array<llama_context *, 2> mtgt = {nullptr, nullptr}; // [0] for main context, [1] for optional draft context

// speculative decoding
common_speculative * spec;
Expand Down Expand Up @@ -244,15 +238,6 @@ struct server_slot {

// clear multimodal state
mbatch.reset();
mtgt[0] = ctx_tgt;
mtgt[1] = nullptr;
if (ctx_dft && llama_get_ctx_other(ctx_dft) != ctx_tgt) {
// TODO: in the future, figure out how to infuse target embeddings to the images
// for now, we re-decode the same chunk in both ctx_tgt and ctx_dft
// maybe we simply need to call `common_speculative_process()` ?
// [TAG_MTMD_DRAFT_PROCESSING]
mtgt[1] = ctx_dft;
}
}

void init_sampler() const {
Expand Down Expand Up @@ -598,32 +583,38 @@ struct server_slot {
int process_mtmd_chunk(size_t idx, size_t & n_tokens_out) {
GGML_ASSERT(mctx);
const auto & input_tokens = task->tokens;
auto & chunk = input_tokens.find_chunk(idx);
const auto & chunk = input_tokens.find_chunk(idx);
int32_t res = 0;

auto try_decode = [&]() -> int32_t {
if (mbatch) {
float * embd = mtmd_batch_get_output_embd(mbatch.get(), chunk.get());
if (embd) {
for (auto * lctx : mtgt) {
if (lctx == nullptr) {
continue;
}
llama_pos new_n_past; // unused for now
res = mtmd_helper_decode_image_chunk(
mctx,
lctx,
chunk.get(),
embd,
prompt.tokens.pos_next(),
id,
llama_n_batch(lctx),
&new_n_past
);
if (res != 0) {
SLT_ERR(*this, "failed to decode mtmd chunk, idx = %zu, res = %d\n", idx, res);
return -1;
void * cb_data = spec;
static auto cb = [](llama_batch batch, void * user_data) {
common_speculative * spec = static_cast<common_speculative *>(user_data);
if (!common_speculative_process(spec, batch)) {
return 1;
}
return 0;
};

llama_pos new_n_past; // unused for now
res = mtmd_helper_decode_image_chunk(
mctx,
ctx_tgt,
chunk.get(),
embd,
prompt.tokens.pos_next(),
id,
llama_n_batch(ctx_tgt),
&new_n_past,
cb,
cb_data
);
if (res != 0) {
SLT_ERR(*this, "failed to decode mtmd chunk, idx = %zu, res = %d\n", idx, res);
return -1;
}
n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
return 0; // success
Expand All @@ -636,7 +627,8 @@ struct server_slot {
res = try_decode();
if (res == 0) {
return 0;
} else if (res < 0) {
}
if (res < 0) {
// fatal error
return res;
}
Expand Down Expand Up @@ -3350,48 +3342,6 @@ struct server_context_impl {
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
// for now, always re-evaluate for simplicity
// ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384
//
// | spec type | need re-eval |
// | --- | --- |
// | draft model | no | because the draft model does not use embeddings from the target
// | MTP (std) | yes |
// | MTP Gemma4 | no | because the KV cache is shared
// | Eagle3 | yes |
// | DFlash | yes | https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4405406982
//
// note: this logic is now moved in `common_speculative_process()`
// keeping the sketch here until for a bit, until the logic is finalized
//
//if (ctx_dft) {
// // TODO: update as needed for MTP, Eagle3, etc.
// const bool need_tgt_embd = false;

// if (need_tgt_embd) {
// llama_synchronize(ctx_tgt);
// }

// // the logic here varies depending on the speculative decoding method
// // - some draft contexts require embeddings from the target context, others don't
// // - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings
// // TODO: extract this in a function ?
// {
// // TODO: hook the embeddings from the last target batch here
// if (llama_model_has_encoder(model_dft.get())) {
// //llama_encode(ctx_dft, ...);

// GGML_ABORT("not implemented yet\n");
// }

// const int ret = llama_decode(ctx_dft.get(), batch_view);

// if (ret != 0) {
// SRV_ERR("failed to decode draft batch, ret = %d\n", ret);

// // TODO: handle error
// break;
// }
// }
//}
if (!common_speculative_process(spec.get(), batch_view)) {
SRV_ERR("%s", "failed to process speculative batch\n");

Expand Down