diff --git a/common/speculative.cpp b/common/speculative.cpp index 349d23dcd201..76a724146303 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -201,6 +201,12 @@ struct common_speculative_state { // Optional hook: drain any in-flight async work (prepare_next) and discard. virtual void cancel() {} + + // Phase C.2.1 — cold-restart hook (foundational, no behavior change here). + // Stronger than cancel(): clears all per-iteration state accumulated during a generation. + // Default is cancel(); MTP overrides to also zero h_idx + adaptive-skip counters + + // cached spec params. + virtual void reset() { cancel(); } }; struct common_speculative_state_draft : public common_speculative_state { @@ -845,6 +851,26 @@ struct common_speculative_state_mtp : public common_speculative_state { mtp_drain_pending_discard(); } + // Phase C.2.1 — cold-restart MTP state at a known boundary (e.g. image-encoding → text continuation). + // Drains any in-flight draft (like cancel) AND zeroes h_idx + adaptive-skip counters + + // cached spec params. Post-condition: next begin()/draft() pair behaves as if MTP was + // just constructed. KV memory and embeddings setting on the target are untouched — + // the host owns those. + void reset() override { + // 1. drain in-flight async draft and clear the one-shot skip flag (cancel semantics) + skip_streak_last_draft = false; + mtp_drain_pending_discard(); + + // 2. zero per-iteration h_prev pointer + adaptive-skip tracking + h_idx = -1; + prev_n_acc_drafts = 0; + zero_accept_streak = 0; + + // 3. forget cached spec params from prior draft() call so the next draft re-computes + // n_steps from scratch when the host passes fresh params. + last_spec_params = common_params_speculative{}; + } + void prepare_next(llama_token id_last) override { // Kill switch for A/B testing depth-2 vs sync. static const bool depth2_disabled = []() { @@ -1569,6 +1595,15 @@ void common_speculative_cancel(common_speculative * spec) { } } +void common_speculative_reset(common_speculative * spec) { + if (spec == nullptr) { + return; + } + for (auto & impl : spec->impls) { + impl->reset(); + } +} + void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index 839237f19d4d..5cea8d3ca4b7 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -58,5 +58,22 @@ void common_speculative_prepare_next(common_speculative * spec, llama_token id_l // snapshot (e.g. slot stop / release / new request seq_rm). Safe no-op when nothing is pending. void common_speculative_cancel(common_speculative * spec); +// Phase C.2.1 — Cold-restart the speculative state machine (foundational API, no behavior change here). +// +// Stronger than cancel(): in addition to draining any in-flight draft, this clears all +// per-iteration state accumulated during a generation — h_idx is reset to its default +// (-1 = "last output"), draft-history counters used by adaptive skip (prev_n_acc_drafts, +// zero_accept_streak, skip_streak_last_draft) are zeroed, and any cached spec params from +// the previous draft() call are forgotten. After reset(), the implementation behaves as +// if begin() had just been called on a fresh prompt. +// +// Intended use: at known state-boundaries that are NOT prompt boundaries but DO invalidate +// the assistant's hidden-state assumptions — e.g. when a slot transitions from image-encoding +// (where MTP was gated off) back to text continuation (where MTP should re-engage from a clean +// slate). The next few text tokens incur the usual warmup cost but state desync is avoided. +// +// Safe no-op for non-MTP implementations. +void common_speculative_reset(common_speculative * spec); + // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 51a47b209bba..44106df9261b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -151,6 +151,10 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API (when building with shared libraries) llama_build_and_test(test-sampling.cpp) llama_build_and_test(test-speculative-mtp.cpp) + # Phase C.2.0 — server_tokens coexistence APIs unit tests + llama_build_and_test(test-server-tokens.cpp) + target_include_directories(test-server-tokens PRIVATE ${PROJECT_SOURCE_DIR}/tools/server ${PROJECT_SOURCE_DIR}/tools/mtmd) + target_link_libraries(test-server-tokens PRIVATE server-context) llama_build_and_test(test-reasoning-budget.cpp) llama_build_and_test(test-grammar-parser.cpp) llama_build_and_test(test-grammar-integration.cpp) diff --git a/tests/test-server-tokens.cpp b/tests/test-server-tokens.cpp new file mode 100644 index 000000000000..bf17943e86b7 --- /dev/null +++ b/tests/test-server-tokens.cpp @@ -0,0 +1,121 @@ +// Phase C.2.0 — unit tests for server_tokens coexistence APIs introduced for MTP+mmproj dispatch. +// +// Scope: +// - is_pure_text_continuation(from_idx) +// - last_image_end_idx() +// - get_text_tokens_post_media() +// +// These APIs are foundational and do not change runtime behavior; they expose information that +// future per-batch dispatch will use. This file covers what is testable WITHOUT loading a model +// or running the full mtmd pipeline: +// - non-multimodal (has_mtmd=false) buffers +// - empty multimodal (has_mtmd=true, no chunks) buffers +// - empty buffer (size==0) edge cases +// +// The WITH-image cases require a real mtmd_input_chunk (which goes through the mtmd public API +// requiring an image file + clip model). Those are covered by integration tests in C.2.4 once +// the dispatch behavior is wired up. + +#include "server-common.h" + +#include +#include + +#define CHECK(cond, msg) \ + do { \ + if (!(cond)) { \ + std::fprintf(stderr, "FAIL %s:%d %s (cond: %s)\n", __FILE__, __LINE__, msg, #cond); \ + std::exit(1); \ + } \ + } while (0) + +static void test_non_mtmd_empty_buffer() { + llama_tokens t; + server_tokens st(t, /*has_mtmd*/ false); + + CHECK(st.size() == 0, "empty size"); + CHECK(st.empty(), "empty()"); + CHECK(st.last_image_end_idx() == 0, "last_image_end_idx empty"); + CHECK(st.is_pure_text_continuation(0), "pure-text @ 0 (empty)"); + CHECK(st.is_pure_text_continuation(100), "pure-text @ 100 (empty/past-end)"); + + llama_tokens out = st.get_text_tokens_post_media(); + CHECK(out.empty(), "post-media tail empty for empty buffer"); +} + +static void test_non_mtmd_text_only() { + llama_tokens t = {1, 2, 3, 4, 5}; + server_tokens st(t, /*has_mtmd*/ false); + + CHECK(st.size() == 5, "size==5"); + CHECK(!st.empty(), "!empty"); + CHECK(st.last_image_end_idx() == 0, "last_image_end_idx text-only -> 0"); + + // is_pure_text_continuation always true when !has_mtmd + CHECK(st.is_pure_text_continuation(0), "pure-text @ 0"); + CHECK(st.is_pure_text_continuation(3), "pure-text @ 3"); + CHECK(st.is_pure_text_continuation(5), "pure-text @ 5 (at end)"); + CHECK(st.is_pure_text_continuation(999), "pure-text @ 999 (past end)"); + + // For non-mtmd, get_text_tokens_post_media returns all tokens (no NULL stripped because none present). + llama_tokens out = st.get_text_tokens_post_media(); + CHECK(out.size() == 5, "post-media tail size matches buffer"); + for (size_t i = 0; i < out.size(); ++i) { + CHECK(out[i] == t[i], "post-media tail token matches"); + } + + // get_text_tokens() must still return the canonical reference for non-mtmd path. + const llama_tokens & ref = st.get_text_tokens(); + CHECK(ref.size() == 5, "get_text_tokens() size"); + CHECK(ref.data() != out.data(), "post-media tail is a distinct copy"); +} + +static void test_mtmd_empty_chunks() { + // server_tokens with has_mtmd=true but no media chunks added: same observable behavior as non-mtmd + // for the new APIs (per-API contract: empty map → return as text-only). + // We construct via the llama_tokens ctor + force has_mtmd=true via the public mutable field + // (server_tokens exposes has_mtmd as public — see server-common.h:126). + llama_tokens t = {10, 20, 30}; + server_tokens st(t, /*has_mtmd*/ false); + st.has_mtmd = true; // simulate mtmd-enabled buffer with no chunks yet + + CHECK(st.last_image_end_idx() == 0, "mtmd+empty-map: last_image_end_idx==0"); + CHECK(st.is_pure_text_continuation(0), "mtmd+empty-map: pure @ 0"); + CHECK(st.is_pure_text_continuation(3), "mtmd+empty-map: pure @ 3"); + CHECK(st.is_pure_text_continuation(999), "mtmd+empty-map: pure @ past-end"); + + llama_tokens out = st.get_text_tokens_post_media(); + CHECK(out.size() == 3, "mtmd+empty-map: tail returns all text"); + CHECK(out[0] == 10 && out[1] == 20 && out[2] == 30, "mtmd+empty-map: tail content matches"); +} + +static void test_pure_text_continuation_semantics() { + // The contract: is_pure_text_continuation(from_idx) returns true iff there is NO image chunk + // extending past from_idx. We can verify the non-mtmd / empty-mtmd branches here (the + // with-image branch is exercised by integration tests once mtmd is wired up). + llama_tokens t = {7, 8, 9}; + server_tokens st(t, false); + + CHECK(st.is_pure_text_continuation(0), "from_idxsize: true (past end)"); + CHECK(st.is_pure_text_continuation(SIZE_MAX), "from_idx=SIZE_MAX: true (past end)"); +} + +int main() { + test_non_mtmd_empty_buffer(); + std::printf("[server_tokens] non_mtmd_empty_buffer OK\n"); + + test_non_mtmd_text_only(); + std::printf("[server_tokens] non_mtmd_text_only OK\n"); + + test_mtmd_empty_chunks(); + std::printf("[server_tokens] mtmd_empty_chunks OK\n"); + + test_pure_text_continuation_semantics(); + std::printf("[server_tokens] pure_text_continuation_semantics OK\n"); + + std::printf("ALL PASS — 4 test groups, server_tokens C.2.0 foundational API\n"); + return 0; +} diff --git a/tests/test-speculative-mtp.cpp b/tests/test-speculative-mtp.cpp index 54f4c34051a3..a50a83e09769 100644 --- a/tests/test-speculative-mtp.cpp +++ b/tests/test-speculative-mtp.cpp @@ -1,4 +1,5 @@ #include "llama.h" +#include "speculative.h" #include #include @@ -10,6 +11,13 @@ // Set env vars to run non-skip paths; otherwise exits 0. int main() { + // Phase C.2.1 — contract smoke: common_speculative_reset / common_speculative_cancel + // must be safe no-ops on a null spec (matches the documented contract in speculative.h). + // Runs unconditionally — no model files required. + common_speculative_cancel(nullptr); + common_speculative_reset(nullptr); + std::cout << "[common_speculative] null-spec cancel + reset OK\n"; + const char * path_tgt = std::getenv("LLAMA_MTP_TEST_TARGET"); const char * path_head = std::getenv("LLAMA_MTP_TEST_HEAD"); const char * path_bad = std::getenv("LLAMA_MTP_TEST_BAD_ARCH"); diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index e3f24390233b..d53ca16a4c3f 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -388,6 +388,54 @@ const llama_tokens & server_tokens::get_text_tokens() const { return tokens; } +// Phase C.2.0 — coexistence APIs (see header for contract). + +size_t server_tokens::last_image_end_idx() const { + if (!has_mtmd || map_idx_to_media.empty()) { + return 0; + } + // map_idx_to_media is std::map sorted by start idx; rbegin() is O(1). + auto last = map_idx_to_media.rbegin(); + const size_t start_idx = last->first; + const size_t n_tokens = mtmd_input_chunk_get_n_tokens(last->second.get()); + return start_idx + n_tokens; +} + +bool server_tokens::is_pure_text_continuation(size_t from_idx) const { + if (!has_mtmd || map_idx_to_media.empty()) { + return true; + } + return from_idx >= last_image_end_idx(); +} + +llama_tokens server_tokens::get_text_tokens_post_media() const { + if (!has_mtmd || map_idx_to_media.empty()) { + // Defensive: even in pure-text mode the buffer should not contain LLAMA_TOKEN_NULL, + // but strip just in case to keep the post-condition invariant uniform. + llama_tokens out; + out.reserve(tokens.size()); + for (const auto & t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + out.push_back(t); + } + } + return out; + } + const size_t start = last_image_end_idx(); + llama_tokens out; + if (start >= tokens.size()) { + return out; + } + out.reserve(tokens.size() - start); + for (size_t i = start; i < tokens.size(); ++i) { + const llama_token t = tokens[i]; + if (t != LLAMA_TOKEN_NULL) { + out.push_back(t); + } + } + return out; +} + void server_tokens::set_token(llama_pos pos, llama_token id) { GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled tokens[pos] = id; diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 440ebc597af7..5470083129d7 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -189,6 +189,30 @@ struct server_tokens { // for compatibility with speculative decoding, ctx shift, slot save/load const llama_tokens & get_text_tokens() const; + // Phase C.2.0 — coexistence APIs for MTP + mmproj dispatch (foundational, no behavior change here). + // + // is_pure_text_continuation(from_idx) — O(log n) oracle: + // "if a caller decodes starting at position from_idx, will all tokens through end-of-buffer + // be pure text (no image chunks remaining)?" + // Used by the server to gate per-batch MTP draft dispatch when mmproj is also loaded. + // - !has_mtmd → always true + // - map empty → always true + // - from_idx >= last_image_end_idx() → true (we're past every image chunk) + // - otherwise → false (an image chunk still extends past from_idx) + bool is_pure_text_continuation(size_t from_idx) const; + + // End-exclusive idx of the last image/audio chunk in the buffer (start + n_tokens). + // Returns 0 if there are no media chunks. !has_mtmd → 0. + size_t last_image_end_idx() const; + + // Returns the suffix of text tokens after the last media chunk. + // - !has_mtmd → returns a copy of all tokens + // - map empty → returns a copy of all tokens + // - otherwise → tokens[last_image_end_idx() ..] with any LLAMA_TOKEN_NULL stripped + // Returned by value because the underlying buffer may interleave images and the suffix is + // not a contiguous slice. Callers typically bind to a const ref of the temporary. + llama_tokens get_text_tokens_post_media() const; + // for compatibility with speculative decoding void set_token(llama_pos pos, llama_token id);