From ca97dde5c96fbdc2743c2045b570987fc05454c6 Mon Sep 17 00:00:00 2001 From: tdeburca Date: Sun, 31 May 2026 23:25:36 +0100 Subject: [PATCH 1/2] Add speculative draft Prometheus metrics --- common/speculative.cpp | 20 +++++++++++++ common/speculative.h | 16 ++++++++++ tools/server/README.md | 34 +++++++++++++++++++++ tools/server/server-context.cpp | 53 +++++++++++++++++++++++++++++++++ tools/server/server-task.h | 3 ++ 5 files changed, 126 insertions(+) diff --git a/common/speculative.cpp b/common/speculative.cpp index e786cd63ab24..1aae7aa4dc47 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -2099,3 +2099,23 @@ void common_speculative_print_stats(const common_speculative * spec) { str_perf.c_str()); } } + +std::vector common_speculative_get_stats(const common_speculative * spec) { + std::vector result; + if (spec == nullptr) { + return result; + } + + result.reserve(spec->impls.size()); + for (const auto & impl : spec->impls) { + result.push_back({ + common_speculative_type_to_str(impl->type), + (uint64_t) impl->n_gen_drafts, + (uint64_t) impl->n_acc_drafts, + (uint64_t) impl->n_gen_tokens, + (uint64_t) impl->n_acc_tokens, + }); + } + + return result; +} diff --git a/common/speculative.h b/common/speculative.h index 02fba8877f39..00fca4c430e1 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -3,8 +3,21 @@ #include "llama.h" #include "common.h" +#include +#include +#include + struct common_speculative; +struct common_speculative_stats { + std::string spec_type; + + uint64_t n_gen_drafts = 0; + uint64_t n_acc_drafts = 0; + uint64_t n_gen_tokens = 0; + uint64_t n_acc_tokens = 0; +}; + // comma separated list of all types std::string common_speculative_type_name_str(); @@ -67,3 +80,6 @@ void common_speculative_cancel(common_speculative * spec); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); + +// snapshot statistics about the speculative decoding +std::vector common_speculative_get_stats(const common_speculative * spec); diff --git a/tools/server/README.md b/tools/server/README.md index b924225a0fd0..082ce28523e2 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1029,6 +1029,40 @@ Available metrics: - `llamacpp:requests_processing`: Number of requests processing. - `llamacpp:requests_deferred`: Number of requests deferred. - `llamacpp:n_tokens_max`: High watermark of the context size observed. +- `llamacpp:speculative_drafts_generated_total{spec_type="..."}`: Number of speculative draft batches generated. +- `llamacpp:speculative_drafts_accepted_total{spec_type="..."}`: Number of speculative draft batches accepted at least partially. +- `llamacpp:speculative_draft_tokens_generated_total{spec_type="..."}`: Number of speculative draft tokens generated. +- `llamacpp:speculative_draft_tokens_accepted_total{spec_type="..."}`: Number of speculative draft tokens accepted by the target model. + +The speculative counters use the same source counters as the server's `statistics ` log line and are aggregated across slots. The `spec_type` label is the speculative implementation name, such as `mtp`, `nextn`, `draft`, `eagle3`, or an n-gram type. A server with no configured speculative implementation exports the metric metadata but no speculative series. + +Example Grafana/Prometheus expressions: + +```promql +rate(llamacpp:speculative_drafts_accepted_total[5m]) +/ +rate(llamacpp:speculative_drafts_generated_total[5m]) +``` + +```promql +rate(llamacpp:speculative_draft_tokens_accepted_total[5m]) +/ +rate(llamacpp:speculative_draft_tokens_generated_total[5m]) +``` + +To graph all speculative modes together, aggregate before dividing: + +```promql +sum(rate(llamacpp:speculative_drafts_accepted_total[5m])) +/ +sum(rate(llamacpp:speculative_drafts_generated_total[5m])) +``` + +Verify locally with: + +```bash +curl -s http://localhost:8080/metrics | rg 'speculative|draft' +``` ### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 65703c056106..b330275331d5 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1919,6 +1919,7 @@ struct server_context_impl { int n_idle_slots = 0; int n_processing_slots = 0; + std::map speculative_stats_by_type; for (server_slot & slot : slots) { json slot_data = slot.to_json(slots_debug == 0); @@ -1929,6 +1930,15 @@ struct server_context_impl { n_idle_slots++; } + for (const auto & stats : common_speculative_get_stats(slot.spec)) { + auto & agg = speculative_stats_by_type[stats.spec_type]; + agg.spec_type = stats.spec_type; + agg.n_gen_drafts += stats.n_gen_drafts; + agg.n_acc_drafts += stats.n_acc_drafts; + agg.n_gen_tokens += stats.n_gen_tokens; + agg.n_acc_tokens += stats.n_acc_tokens; + } + slots_data.push_back(slot_data); } SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); @@ -1955,6 +1965,9 @@ struct server_context_impl { res->n_decode_total = metrics.n_decode_total; res->n_busy_slots_total = metrics.n_busy_slots_total; + for (const auto & el : speculative_stats_by_type) { + res->speculative_stats.push_back(el.second); + } if (task.metrics_reset_bucket) { metrics.reset_bucket(); @@ -3645,6 +3658,46 @@ void server_routes::init_routes() { } } + struct speculative_metric_def { + const char * name; + const char * help; + uint64_t common_speculative_stats::* value; + }; + + static const speculative_metric_def speculative_metrics_def[] = { + { + "speculative_drafts_generated_total", + "Number of speculative draft batches generated.", + &common_speculative_stats::n_gen_drafts, + }, + { + "speculative_drafts_accepted_total", + "Number of speculative draft batches accepted at least partially.", + &common_speculative_stats::n_acc_drafts, + }, + { + "speculative_draft_tokens_generated_total", + "Number of speculative draft tokens generated.", + &common_speculative_stats::n_gen_tokens, + }, + { + "speculative_draft_tokens_accepted_total", + "Number of speculative draft tokens accepted by the target model.", + &common_speculative_stats::n_acc_tokens, + }, + }; + + for (const auto & metric_def : speculative_metrics_def) { + prometheus << "# HELP llamacpp:" << metric_def.name << " " << metric_def.help << "\n" + << "# TYPE llamacpp:" << metric_def.name << " counter\n"; + + for (const auto & stats : res_task->speculative_stats) { + prometheus << "llamacpp:" << metric_def.name + << "{spec_type=\"" << stats.spec_type << "\"} " + << stats.*(metric_def.value) << "\n"; + } + } + res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); res->content_type = "text/plain; version=0.0.4"; res->status = 200; diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 95f39207b18c..1aecb0a7aaaa 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -2,6 +2,7 @@ #include "common.h" #include "llama.h" +#include "speculative.h" #include #include @@ -526,6 +527,8 @@ struct server_task_result_metrics : server_task_result { uint64_t n_decode_total = 0; uint64_t n_busy_slots_total = 0; + std::vector speculative_stats; + // while we can also use std::vector this requires copying the slot object which can be quite messy // therefore, we use json to temporarily store the slot.to_json() result json slots_data = json::array(); From b7b5a3bf920a5c52aa62bf75c5868f4957ea28f6 Mon Sep 17 00:00:00 2001 From: tdeburca Date: Mon, 1 Jun 2026 18:30:33 +0100 Subject: [PATCH 2/2] Add a reserve-only Gemma4 MTP context --- docs/gemma4-mtp-multislot-crash-worklog.md | 167 +++++++++++++++++++++ src/llama-context.cpp | 21 +-- src/llama-kv-cache-iswa.cpp | 32 +++- src/llama-kv-cache-iswa.h | 10 ++ src/llama-kv-cache.h | 4 +- 5 files changed, 221 insertions(+), 13 deletions(-) create mode 100644 docs/gemma4-mtp-multislot-crash-worklog.md diff --git a/docs/gemma4-mtp-multislot-crash-worklog.md b/docs/gemma4-mtp-multislot-crash-worklog.md new file mode 100644 index 000000000000..63c54c7fb2a2 --- /dev/null +++ b/docs/gemma4-mtp-multislot-crash-worklog.md @@ -0,0 +1,167 @@ +# Gemma4 A4B MTP Multislot Crash Worklog + +## Scope + +Goal: fix the deterministic Gemma4 A4B MTP crash when `llama-server` runs with multiple slots. + +## Minimal Reproducer + +Patched branch was reproduced first with a smaller setup to reduce iteration time: + +```bash +CTX=4096 \ +PARALLEL=2 \ +BATCH=128 \ +UBATCH=64 \ +SPLIT_MODE=layer \ +KV_K=turbo4 \ +KV_V=turbo4 \ +REASONING_BUDGET=1024 \ +ENABLE_MTP=1 \ +PORT=8084 \ +NO_WARMUP=1 \ +~/scripts/local-opencode-llama/scripts/run-gemma4-26b-a4b-mtp.sh +``` + +Request: + +```bash +curl -sS -H 'Content-Type: application/json' \ + http://127.0.0.1:8084/v1/messages \ + -d '{ + "model": "gemma4-26b-a4b-mtp", + "max_tokens": 16, + "messages": [ + {"role": "user", "content": "hi"} + ] + }' +``` + +Before the fix this returned `curl: (52) Empty reply from server`. + +## Failing Backtrace + +The failure reproduced deterministically on the first `/v1/messages` request: + +```text +slot get_availabl: id 1 | task -1 | selected slot by LRU, t_last = -1 +... +/home/tdeburca/git/model-learning/atomic-llama-cpp-turboquant/ggml/src/ggml.c:3665: GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2) failed +... +#6 ggml_reshape_3d +#7 llm_build_gemma4_mtp::llm_build_gemma4_mtp(...) +#8 llama_model::build_graph(...) +#9 llama_context::ensure_sched_mtp() +#10 llama_context::decode_mtp_async(...) +#11 common_speculative_state_mtp::draft(...) +#12 common_speculative_draft(...) +#13 server_context_impl::update_slots() +``` + +## Root Cause + +The crash was in the MTP scheduler reserve path, not in the real worker compute path. + +`llama_context::ensure_sched_mtp()` reserved the MTP graph with: + +- a single-token `ubatch` +- but a **full KV context** from `memory->init_full()` + +For `PARALLEL=2`, the full KV context builds dummy `slot_info` spanning **all streams**: + +- `src/llama-kv-cache.cpp`, `llama_kv_cache_context(llama_kv_cache * kv)` +- `s0 = 0` +- `s1 = n_stream - 1` + +That changes the reserve graph topology from the single-stream shape expected by Gemma4 MTP into a multi-stream attention shape. The Gemma4 MTP builder later performs single-stream reshapes in `src/models/gemma4-assistant.cpp`, which then trip the `ggml_reshape_3d()` element-count assert during reserve. + +This is why: + +- `PARALLEL=1 ENABLE_MTP=1` worked +- `PARALLEL=2 ENABLE_MTP=0` worked +- `PARALLEL=2 ENABLE_MTP=1` crashed + +The issue is fundamentally the reserve context shape, not prompt size or memory pressure. + +## Patch + +Initial fix: + +- stop using `memory->init_full()` for MTP reserve +- reserve against a single-sequence / single-stream MTP topology instead + +Follow-up hardening: + +- added a dedicated reserve-only API on `llama_kv_cache_iswa`: + - `init_mtp_reserve(llama_ubatch ubatch)` +- updated `llama_context::ensure_sched_mtp()` to use `kv_iswa->init_mtp_reserve(ub)` +- kept real decode on `kv_iswa->init_mtp(seq_id, ub)` + +The reserve helper constructs a shape-only MTP memory context: + +- one stream +- one index +- one ubatch + +It does not depend on user `seq_id 0` existing or having KV state, so reserve no +longer borrows real decode semantics just to obtain the correct graph shape. + +## Files Changed + +- `src/llama-context.cpp` +- `src/llama-kv-cache-iswa.h` +- `src/llama-kv-cache-iswa.cpp` +- `src/llama-kv-cache.h` +- `docs/gemma4-mtp-multislot-crash-worklog.md` + +## Validation + +Build: + +```bash +cmake --build build-hip-rocwmma --target llama-server -j "$(nproc)" +``` + +Validated combinations: + +1. `PARALLEL=1 ENABLE_MTP=1 KV_K=turbo4 KV_V=turbo4` + - `/v1/messages` returned `200` + - generation completed + +2. `PARALLEL=2 ENABLE_MTP=0 KV_K=turbo4 KV_V=turbo4` + - `/v1/messages` returned `200` + - generation completed + +3. `PARALLEL=2 ENABLE_MTP=1 KV_K=turbo4 KV_V=turbo4` + - tiny `/v1/messages` request returned `200` + - normal Claude-style `/v1/messages` request with `system` + `messages` returned `200` + - generation completed + - no crash + +4. Spot check: `PARALLEL=2 ENABLE_MTP=1 KV_K=f16 KV_V=f16` + - `/v1/messages` returned `200` + - generation completed + +Metrics check with MTP enabled: + +```bash +curl -sS http://127.0.0.1:8084/metrics | rg 'speculative|draft' +``` + +Observed Prometheus speculative draft metrics including: + +- `llamacpp:speculative_drafts_generated_total{spec_type="mtp"}` +- `llamacpp:speculative_drafts_accepted_total{spec_type="mtp"}` +- `llamacpp:speculative_draft_tokens_generated_total{spec_type="mtp"}` +- `llamacpp:speculative_draft_tokens_accepted_total{spec_type="mtp"}` + +## Result + +`PARALLEL=2 ENABLE_MTP=1` now works and generates normally. + +Reserve-time MTP setup no longer uses `seq_id 0` as a placeholder decode path. +It now uses a dedicated shape-only memory context. + +## Remaining Risks + +- I did not change non-Gemma speculative paths. diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 822d97b829dc..6259c543ec6c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1253,7 +1253,6 @@ bool llama_context::ensure_sched_mtp() { return false; } - const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const size_t max_nodes = this->graph_max_nodes(n_tokens); @@ -1279,14 +1278,6 @@ bool llama_context::ensure_sched_mtp() { return false; } - llama_memory_context_ptr mctx = memory->init_full(); - if (!mctx) { - LLAMA_LOG_ERROR("%s: failed to init memory context for MTP reserve\n", __func__); - sched_mtp.reset(); - gf_res_prev_mtp.reset(); - return false; - } - const uint32_t n_bb = model.mtp_assistant->hparams.n_embd_backbone; auto data = std::make_shared(); data->token.resize(1); @@ -1321,6 +1312,18 @@ bool llama_context::ensure_sched_mtp() { ub.output = data->output.data(); ub.data = data; + // Reserve the MTP graph against a dedicated shape-only KV view. Using + // init_full() here would build dummy slot_info spanning every server stream; + // with n_seq_max > 1 that changes the MTP attention output topology during + // reserve and Gemma4's single-stream reshape path later asserts. + llama_memory_context_ptr mctx = kv_iswa->init_mtp_reserve(ub); + if (!mctx) { + LLAMA_LOG_ERROR("%s: failed to init MTP memory context for reserve\n", __func__); + sched_mtp.reset(); + gf_res_prev_mtp.reset(); + return false; + } + const uint32_t save_n_outputs = n_outputs; n_outputs = 1; diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index b69897173727..53a95e91e802 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -217,12 +217,15 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, return std::make_unique(this, lctx, optimize); } -llama_memory_context_ptr llama_kv_cache_iswa::init_mtp(llama_seq_id seq_id, llama_ubatch ubatch) { +llama_memory_context_ptr llama_kv_cache_iswa::init_mtp_with_slot_info( + llama_kv_cache::slot_info sinfo_base, + llama_kv_cache::slot_info sinfo_swa, + llama_ubatch ubatch) { llama_kv_cache::slot_info_vec_t sinfos_base; llama_kv_cache::slot_info_vec_t sinfos_swa; - sinfos_base.push_back(kv_base->mtp_slot_info(seq_id)); - sinfos_swa.push_back(kv_swa->mtp_slot_info(seq_id)); + sinfos_base.push_back(std::move(sinfo_base)); + sinfos_swa.push_back(std::move(sinfo_swa)); std::vector ubatches; ubatches.push_back(std::move(ubatch)); @@ -231,6 +234,29 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_mtp(llama_seq_id seq_id, llam this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); } +llama_memory_context_ptr llama_kv_cache_iswa::init_mtp(llama_seq_id seq_id, llama_ubatch ubatch) { + return init_mtp_with_slot_info( + kv_base->mtp_slot_info(seq_id), + kv_swa->mtp_slot_info(seq_id), + std::move(ubatch)); +} + +llama_memory_context_ptr llama_kv_cache_iswa::init_mtp_reserve(llama_ubatch ubatch) { + // Shape-only reserve context: one stream, one index, one ubatch. We intentionally + // avoid seq_id-dependent helpers here so the reserve path cannot accidentally claim + // to represent a real user sequence or read its KV placement. + llama_kv_cache::slot_info sinfo_base; + sinfo_base.s0 = 0; + sinfo_base.s1 = 0; + sinfo_base.strm = { 0 }; + sinfo_base.idxs.resize(1); + sinfo_base.idxs[0] = { 0 }; + + llama_kv_cache::slot_info sinfo_swa = sinfo_base; + + return init_mtp_with_slot_info(std::move(sinfo_base), std::move(sinfo_swa), std::move(ubatch)); +} + bool llama_kv_cache_iswa::get_can_shift() const { return kv_base->get_can_shift() && kv_swa->get_can_shift(); diff --git a/src/llama-kv-cache-iswa.h b/src/llama-kv-cache-iswa.h index 7803aaa29f82..7f6d565a9a79 100644 --- a/src/llama-kv-cache-iswa.h +++ b/src/llama-kv-cache-iswa.h @@ -81,7 +81,17 @@ class llama_kv_cache_iswa : public llama_memory_i { // for the same seq_id does not trigger eviction during an in-flight MTP request. llama_memory_context_ptr init_mtp(llama_seq_id seq_id, llama_ubatch ubatch); + // Reserve-only MTP context used to size/allocate the single-token MTP graph. This + // is shape-only: it does not correspond to a real sequence and must never be used + // for actual MTP decode or for reading user-visible KV state. + llama_memory_context_ptr init_mtp_reserve(llama_ubatch ubatch); + private: + llama_memory_context_ptr init_mtp_with_slot_info( + llama_kv_cache::slot_info sinfo_base, + llama_kv_cache::slot_info sinfo_swa, + llama_ubatch ubatch); + const llama_hparams & hparams; const bool unified; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 747e16720952..3961083800f7 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -193,7 +193,9 @@ class llama_kv_cache : public llama_memory_i { // return empty slot_info on failure slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; - // Gemma4 MTP: one-token slot_info pointing at the last populated cell for seq_id (read-only graphs). + // Gemma4 MTP real-decode path: one-token slot_info pointing at the last populated + // cell for seq_id (read-only graphs). Reserve-only callers should use the + // dedicated shape-only context in llama_kv_cache_iswa instead. slot_info mtp_slot_info(llama_seq_id seq_id) const; // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]