Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions tools/mtmd/mtmd-helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ int32_t mtmd_helper_decode_image_chunk(
llama_pos n_past,
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past) {
llama_pos * new_n_past,
mtmd_helper_post_decode_callback callback,
void * user_data) {
GGML_ASSERT(n_batch > 0);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
Expand Down Expand Up @@ -302,10 +304,23 @@ int32_t mtmd_helper_decode_image_chunk(
int32_t ret = llama_decode(lctx, batch_embd_view);
if (ret != 0) {
LOG_ERR("failed to decode %s\n", name);
llama_set_causal_attn(lctx, true); // restore causal attn
if (use_non_causal) {
llama_set_causal_attn(lctx, true);
}
return ret;
}

if (callback != nullptr) {
ret = callback(&batch_embd_view, user_data);
if (ret != 0) {
LOG_ERR("post-decode callback failed\n");
if (use_non_causal) {
llama_set_causal_attn(lctx, true);
}
return ret;
}
}

LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1);

i_batch++;
Expand All @@ -327,16 +342,20 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
llama_seq_id seq_id,
int32_t n_batch,
bool logits_last,
llama_pos * new_n_past) {
llama_pos * new_n_past,
mtmd_helper_post_decode_callback callback,
void * user_data) {
GGML_ASSERT(n_batch > 0);
int32_t ret;
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
auto chunk_type = mtmd_input_chunk_get_type(chunk);

if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens;
const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);

// LOG_INF("decoding text chunk, n_tokens = %zu\n", n_tokens);
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);

size_t i = 0;
while (i < n_tokens) { // split into batches
text_batch.n_tokens = 0; // clear the batch
Expand All @@ -360,9 +379,18 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
llama_batch_free(text_batch);
return ret;
}
if (callback != nullptr) {
ret = callback(&text_batch, user_data);
if (ret != 0) {
LOG_ERR("post-decode callback failed\n");
llama_batch_free(text_batch);
return ret;
}
}
*new_n_past += text_batch.n_tokens;
}

llama_batch_free(text_batch);
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
int64_t t0 = ggml_time_ms();
Expand All @@ -372,24 +400,21 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
ret = mtmd_encode_chunk(ctx, chunk);
if (ret != 0) {
LOG_ERR("failed to encode %s slice\n", name);
llama_batch_free(text_batch);
return ret;
}

LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0);

float * embd = mtmd_get_output_embd(ctx);
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past, callback, user_data);
if (ret != 0) {
LOG_ERR("failed to decode %s\n", name);
llama_batch_free(text_batch);
return ret;
}
} else {
GGML_ABORT("chunk type not supported");
}

llama_batch_free(text_batch);
return 0;
}

Expand All @@ -411,7 +436,7 @@ int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
bool chunk_logits_last = (i == n_chunks - 1) && logits_last;
auto chunk = mtmd_input_chunks_get(chunks, i);

int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past);
int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past, nullptr, nullptr);
if (res != 0) {
LOG_ERR("failed to eval chunk %zu\n", i);
return res;
Expand Down
12 changes: 10 additions & 2 deletions tools/mtmd/mtmd-helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,22 @@ MTMD_API int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
bool logits_last,
llama_pos * new_n_past);

typedef int32_t (*mtmd_helper_post_decode_callback)(struct llama_batch * batch, void * user_data);

// works like mtmd_helper_eval_chunks(), but only for a single chunk
// this function is NOT thread-safe
// callback invoked after each successful llama_decode() call within the mtmd helpers.
// returns 0 on success, non-zero to signal failure (which will abort the eval).
MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
struct llama_context * lctx,
const mtmd_input_chunk * chunk,
llama_pos n_past,
llama_seq_id seq_id,
int32_t n_batch,
bool logits_last,
llama_pos * new_n_past);
llama_pos * new_n_past,
mtmd_helper_post_decode_callback callback,
void * user_data);

// helper function to decode an image whose embeddings have already been calculated
// this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention)
Expand All @@ -101,7 +107,9 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
llama_pos n_past,
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past);
llama_pos * new_n_past,
mtmd_helper_post_decode_callback callback,
void * user_data);

//
// video input helpers (requires ffmpeg/ffprobe installed on the system)
Expand Down
17 changes: 13 additions & 4 deletions tools/server/server-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,16 @@ bool server_tokens::validate(const struct llama_context * ctx) const {
return true;
}

// post-decode callback: decode the batch on the draft context
static int32_t server_tokens_process_chunk_callback(struct llama_batch * batch, void * user_data) {
llama_context * ctx_dft = static_cast<llama_context *>(user_data);
return llama_decode(ctx_dft, *batch);
}

int32_t server_tokens::process_chunk(
llama_context * ctx,
mtmd_context * mctx,
llama_context * ctx_tgt,
llama_context * ctx_dft,
size_t idx,
llama_pos pos,
int32_t seq_id,
Expand All @@ -542,16 +549,18 @@ int32_t server_tokens::process_chunk(
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
? "image" : "audio";
SRV_INF("processing %s...\n", name);
int32_t n_batch = llama_n_batch(ctx);
int32_t n_batch = llama_n_batch(ctx_tgt);
int64_t t0 = ggml_time_ms();
llama_pos new_n_past; // unused for now
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx_tgt,
chunk.get(),
pos,
seq_id,
n_batch,
true, // logits last
&new_n_past);
&new_n_past,
ctx_dft ? server_tokens_process_chunk_callback : nullptr,
ctx_dft);
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
if (result != 0) {
LOG_ERR("mtmd_helper_eval failed with status %d", result);
Expand Down
6 changes: 4 additions & 2 deletions tools/server/server-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,12 @@ struct server_tokens {
// make sure all text tokens are within the vocab range
bool validate(const struct llama_context * ctx) const;

// encode and decode the image chunk
// encode and decode the image chunk.
// if ctx_dft is non-NULL, the batch is also decoded on the draft context.
int32_t process_chunk(
llama_context * ctx,
mtmd_context * mctx,
llama_context * ctx_tgt,
llama_context * ctx_dft,
size_t idx,
llama_pos pos,
int32_t seq_id,
Expand Down
20 changes: 6 additions & 14 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#include "ggml-cpp.h"

// TODO: tmp until the mtmd draft processing is refactored [TAG_MTMD_DRAFT_PROCESSING]
// TODO: tmp until the common_get_device_memory_data is refactored
#include "../../src/llama-ext.h"

#include <algorithm>
Expand Down Expand Up @@ -2984,27 +2984,19 @@ struct server_context_impl {

// check if we should process the image
while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image
// process the image; ctx_dft keeps the draft context in sync
size_t n_tokens_out = 0;
int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
int32_t res = input_tokens.process_chunk(
mctx, ctx_tgt, ctx_dft.get(),
slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id,
n_tokens_out);
if (res != 0) {
SLT_ERR(slot, "failed to process image, res = %d\n", res);
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
slot.release();
continue;
}

if (ctx_dft && llama_get_ctx_other(ctx_dft.get()) != ctx_tgt) {
// TODO: in the future, figure out how to infuse target embeddings to the images
// for now, we skip this for simplicity
// maybe we simply need to call `common_speculative_process()` on the mtmd batches in the `process_chunk` above?
// [TAG_MTMD_DRAFT_PROCESSING]
res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
if (res != 0) {
GGML_ABORT("failed to process multi-modal data on draft context\n");
}
}

slot.n_prompt_tokens_processed += n_tokens_out;

// add the image chunk to cache
Expand Down
Loading