From 10508e7408cfd946b6e9547bf402626a80a19597 Mon Sep 17 00:00:00 2001 From: Aleksandr Nikolich Date: Fri, 12 Jun 2026 18:11:16 +0200 Subject: [PATCH 1/4] convert: map the Qwen3.5-4B multimodal tokenizer hash to the qwen35 pre-tokenizer Qwen/Qwen3.5-4B is Qwen3_5ForConditionalGeneration (multimodal). Without this mapping neither the target nor the DFlash drafter converts to GGUF. --- convert_hf_to_gguf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6a5ac25d945d..750eb3110289 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1537,6 +1537,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4": # ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct res = "qwen35" + if chkhsh == "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f": + # ref: https://huggingface.co/Qwen/Qwen3.5-4B (multimodal text tokenizer; same qwen35 split regex) + res = "qwen35" if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d": # ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash res = "joyai-llm" From 6bbbeacd1f72fdcc3f95944ae0677cd1cab50ab9 Mon Sep 17 00:00:00 2001 From: Aleksandr Nikolich Date: Fri, 12 Jun 2026 18:11:16 +0200 Subject: [PATCH 2/4] dflash: make speculative decoding work and fast on the Qwen3.5-4B hybrid The DFlash drafter targets a Gated-DeltaNet hybrid (recurrent + attention). The recurrent state can't be partial-rolled-back, so a naive verify is slower than plain generation. This brings it to lossless speedup: - recurrent-state rewind via a per-token GDN state trace + on-device promote of the accepted state (llama_dflash_promote_state) instead of a ~50 MiB host checkpoint and re-decode per round - graph reuse: fixed-capacity device-resident target-context cache, encoder folded into the decoder graph, padding mask over a bucketed context - on-device greedy verify: drafter block argmax + target argmax for the greedy verify (llama_set_out_argmax), one host sync per round - optional GPU sampling verify (temperature; top-k/top-p behind LLAMA_SPEC_GPU_SAMPLE) - CUDA graphs opt-in on Volta (GGML_CUDA_GRAPHS_VOLTA) and a stable sched uid Lossless. ~1.7x on V100/Q8 single-stream, scaling with the draft block on high-acceptance (reasoning) workloads. --- bw_full.sh | 43 ++ common/speculative.cpp | 72 +-- common/speculative.h | 3 + .../speculative-simple/speculative-simple.cpp | 328 +++++++++- ggml/include/ggml.h | 15 + ggml/src/ggml-cpu/ops.cpp | 6 + ggml/src/ggml-cuda/gated_delta_net.cu | 31 +- ggml/src/ggml-cuda/ggml-cuda.cu | 20 +- ggml/src/ggml.c | 27 + h100_bench.sh | 30 + h100_full.sh | 44 ++ include/llama.h | 83 +++ src/llama-context.cpp | 579 +++++++++++++++++- src/llama-context.h | 57 ++ src/llama-cparams.h | 4 + src/llama-graph.cpp | 102 ++- src/llama-graph.h | 46 ++ src/models/delta-net-base.cpp | 11 +- src/models/dflash.cpp | 81 ++- src/models/models.h | 4 + src/models/qwen35.cpp | 32 + tools/server/server-context.cpp | 47 +- 22 files changed, 1585 insertions(+), 80 deletions(-) create mode 100644 bw_full.sh create mode 100644 h100_bench.sh create mode 100644 h100_full.sh diff --git a/bw_full.sh b/bw_full.sh new file mode 100644 index 000000000000..f516319f40a1 --- /dev/null +++ b/bw_full.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Self-contained RTX PRO 6000 Blackwell (sm_120) DFlash verification. Runs inside the pod. +set -e +echo "=== GPU ==="; nvidia-smi --query-gpu=name,compute_cap,memory.total,driver_version --format=csv,noheader +export DEBIAN_FRONTEND=noninteractive +apt-get update -qq && apt-get install -y -qq cmake build-essential git python3-pip >/dev/null 2>&1 || true + +cd /workspace 2>/dev/null || cd /root +[ -d llama.cpp ] || git clone -q -b work-qwen35-dflash https://github.com/AlexWortega/llama.cpp.git +cd llama.cpp +pip install -q numpy sentencepiece transformers safetensors gguf protobuf hf_transfer 2>/dev/null || true + +mkdir -p models hf +export HF_HUB_ENABLE_HF_TRANSFER=1 +echo "=== download HF models ===" +python3 -c "from huggingface_hub import snapshot_download as s; s('Qwen/Qwen3.5-4B', local_dir='hf/tgt'); s('z-lab/Qwen3.5-4B-DFlash', local_dir='hf/dft')" 2>&1 | tail -1 + +echo "=== convert ===" +python3 convert_hf_to_gguf.py hf/tgt --outfile models/tgt-f16.gguf --outtype f16 >/tmp/cv1.log 2>&1 && echo tgt-ok || { echo TGT_FAIL; tail -8 /tmp/cv1.log; exit 1; } +python3 convert_hf_to_gguf.py hf/dft --outfile models/Qwen3.5-4B-DFlash-f16.gguf --outtype f16 >/tmp/cv2.log 2>&1 && echo dft-ok || { echo DFT_FAIL; tail -8 /tmp/cv2.log; exit 1; } + +echo "=== build (Blackwell sm_120) ===" +cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=120 -DLLAMA_CURL=OFF >/tmp/cm.log 2>&1 +cmake --build build --target llama-speculative-simple llama-quantize llama-cli -j $(nproc) >/tmp/build.log 2>&1 && echo BUILT || { echo BUILDFAIL; tail -20 /tmp/build.log; exit 1; } + +./build/bin/llama-quantize models/tgt-f16.gguf models/Qwen3.5-4B-Q8_0.gguf Q8_0 >/dev/null 2>&1 && echo quantized +M="-m models/Qwen3.5-4B-Q8_0.gguf -md models/Qwen3.5-4B-DFlash-f16.gguf --dflash -ngl 99 -ngld 99 -p Tell-me-about-the-water-cycle-in-detail. -n 200 -c 2048 --draft-max 5 --temp 0 --top-k 1 --samplers top_k" +BIN=./build/bin/llama-speculative-simple + +echo "=== AR baseline ===" +./build/bin/llama-cli -m models/Qwen3.5-4B-Q8_0.gguf -ngl 99 -p "Tell me about the water cycle in detail." -n 200 -c 2048 --temp 0 -no-cnv 2>/tmp/ar.err >/dev/null || true +tr "\r" "\n" < /tmp/ar.err | grep -oE "[0-9.]+ tokens per second" | tail -1 + +echo "=== DFlash full stack (trace+gpuverify+async) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 LLAMA_SPEC_ASYNC=1 $BIN $M >/tmp/df.txt 2>/tmp/df.err || true +tr "\r" "\n" < /tmp/df.err | grep -oE "speed: +[0-9.]+|accept += +[0-9.]+%" | tail -2 + +echo "=== DFlash + Blackwell CUDA graphs (sm_120 >= Ampere: engage by default) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 $BIN $M >/tmp/dfg.txt 2>/tmp/dfg.err || true +tr "\r" "\n" < /tmp/dfg.err | grep -oE "speed: +[0-9.]+" | tail -1 + +echo "=== sample ==="; tail -c 160 /tmp/df.txt +echo "=== DONE ===" diff --git a/common/speculative.cpp b/common/speculative.cpp index 4980c03da62e..378776179b34 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -780,32 +780,12 @@ struct common_speculative_state_dflash : public common_speculative_state { GGML_ASSERT(n >= 1 && "prompt_tgt is empty"); GGML_ASSERT(n_new >= 1 && "must have at least 1 new token"); - // Step 1: Encode new accepted tokens' features + // Steps 1+2 folded into the decoder graph: stage the NEW tokens' raw target features; + // the decoder encodes them (fc+norm) in-graph and appends into the device-resident + // context cache. No separate encoder pass, no host round trip of encoded features. + GGML_UNUSED(n_embd); const float * features = llama_get_dflash_target_features(ctx_tgt); - - llama_batch enc_batch = { - /*.n_tokens =*/ n_new, - /*.token =*/ nullptr, - /*.embd =*/ const_cast(features), - /*.pos =*/ nullptr, - /*.n_seq_id =*/ nullptr, - /*.seq_id =*/ nullptr, - /*.logits =*/ nullptr, - }; - if (llama_encode(ctx_dft_enc, enc_batch) != 0) { - LOG_ERR("DFlash: encoder failed\n"); - return; - } - - const float * target_ctx_new = llama_get_embeddings(ctx_dft_enc); - GGML_ASSERT(target_ctx_new && "encoder output is null"); - - // Step 2: Append to accumulated target_ctx and set on decoder context (writes to cross.v_embd) - const size_t new_size = (size_t)n_embd * n_new; - accumulated_ctx.insert(accumulated_ctx.end(), target_ctx_new, target_ctx_new + new_size); - - const int n_ctx_total = (int)(accumulated_ctx.size() / n_embd); - llama_set_dflash_accumulated_target_ctx(ctx_dft_dec, accumulated_ctx.data(), n_embd, n_ctx_total); + llama_dflash_append_features(ctx_dft_dec, features, n_new, n); // Step 3: Decode noise block const llama_token mask_token_id = llama_model_dflash_mask_token_id(llama_get_model(ctx_dft_dec)); @@ -813,7 +793,9 @@ struct common_speculative_state_dflash : public common_speculative_state { common_batch_clear(batch); for (int i = 0; i < block_size; i++) { const llama_token tok = (i == 0) ? id_last : mask_token_id; - common_batch_add(batch, tok, i, {0}, true); + // logits=false: the greedy draft tokens come from the on-device argmax below, so the + // n_vocab x block logits host copy (~5 MB/round at vocab 248k) is skipped entirely + common_batch_add(batch, tok, i, {0}, false); } if (llama_decode(ctx_dft_dec, batch) != 0) { @@ -823,18 +805,26 @@ struct common_speculative_state_dflash : public common_speculative_state { dflash_n_past = n; - // Step 4: Sample draft tokens from positions 1..block_size-1 + // Step 4: greedy top-1 draft tokens from the decoder's on-device argmax (the DFlash decode + // graph appends a GGML argmax node over the block logits). The drafted tokens are verified + // by the target regardless, so this cannot affect correctness, only draft latency. result.clear(); - common_sampler_reset(smpl); - for (int i = 1; i < block_size; i++) { - common_sampler_sample(smpl, ctx_dft_dec, i); - - const auto * cur_p = common_sampler_get_candidates(smpl, true); - const llama_token id = cur_p->data[0].id; + // async feed (LLAMA_SPEC_ASYNC=1): hand the draft tokens to the target device-to-device + // and return placeholders - the host reads the actual values only after the verify decode + // (one synchronization per round instead of two; the GPU queues draft+verify back-to-back) + static const bool async_feed = getenv("LLAMA_SPEC_ASYNC") != nullptr && + std::string(getenv("LLAMA_SPEC_ASYNC")) != "0"; + if (async_feed && llama_dflash_feed_draft_tokens(ctx_tgt, ctx_dft_dec, block_size - 1)) { + result.assign(block_size - 1, 0); // placeholders; refilled by the caller post-verify + return; + } - common_sampler_accept(smpl, id, true); - result.push_back(id); + int32_t n_am = 0; + const int32_t * am = llama_get_dflash_argmax(ctx_dft_dec, &n_am); + GGML_ASSERT(am != nullptr && n_am >= block_size && "DFlash decoder did not produce argmax"); + for (int i = 1; i < block_size; i++) { + result.push_back((llama_token) am[i]); } } @@ -1125,6 +1115,18 @@ struct common_speculative { common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) }; +llama_context * common_speculative_get_dflash_decoder(common_speculative * spec) { + if (spec == nullptr) { + return nullptr; + } + for (auto & impl : spec->impls) { + if (impl->type == COMMON_SPECULATIVE_TYPE_DFLASH) { + return static_cast(impl.get())->ctx_dft_dec; + } + } + return nullptr; +} + static common_ngram_map get_common_ngram_map(const common_speculative_config & config) { uint16_t size_key = config.params.ngram_size_n; uint16_t size_value = config.params.ngram_size_m; diff --git a/common/speculative.h b/common/speculative.h index bca78d32b5b3..9097f0e3d7db 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -33,6 +33,9 @@ llama_tokens common_speculative_draft( // informs the speculative decoder that n_accepted tokens were accepted by the target model void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); +// the DFlash decoder context, if a DFlash implementation is active (nullptr otherwise) +llama_context * common_speculative_get_dflash_decoder(common_speculative * spec); + // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 804a16623a41..8e226b99575b 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -13,6 +13,90 @@ #include #include #include +#include +#include +#include + +// Build the filtered candidate distribution for one verify row from its top-K (token, logit) +// pairs: apply temperature, softmax, top-k and top-p, renormalize. Fills `cand` with the kept +// (token, prob) pairs sorted by descending prob. This is the exact target sampler restricted to +// the top-K candidates - the top-p nucleus is a subset of the top-K for any realistic params. +static void spec_build_candidates( + const int32_t * ids, const float * logits, int32_t k, int32_t n_vocab, + float temp, int32_t top_k, float top_p, + std::vector> & cand) { + // store RAW (un-temped) logits; the sampler order matches llama.cpp's "top_k;top_p;temp" chain: + // top_k and top_p operate on the pre-temperature logits/probs, temperature is applied last. + cand.clear(); + cand.reserve(k); + for (int32_t j = 0; j < k; ++j) { + // skip padding slots: the logits vocab can be padded beyond the real vocabulary + if (ids[j] < 0 || ids[j] >= n_vocab) { continue; } + cand.emplace_back((llama_token) ids[j], logits[j]); // raw logit + } + std::sort(cand.begin(), cand.end(), + [](const auto & a, const auto & b) { return a.second > b.second; }); + // 1) top-k cut (logit order, temperature-invariant) + if (top_k > 0 && (int32_t) cand.size() > top_k) { + cand.resize(top_k); + } + // 2) top-p cut on the temperature-1 softmax (the nucleus is defined pre-temperature) + if (top_p < 1.0f && !cand.empty()) { + const float maxl = cand[0].second; + double sum = 0.0; + std::vector p(cand.size()); + for (size_t j = 0; j < cand.size(); ++j) { p[j] = std::exp(cand[j].second - maxl); sum += p[j]; } + double cum = 0.0; + size_t keep = cand.size(); + for (size_t j = 0; j < cand.size(); ++j) { + cum += p[j] / sum; + if (cum >= top_p) { keep = j + 1; break; } + } + cand.resize(keep); + } + // 3) temperature, then final softmax over the kept candidates + const float inv_t = 1.0f / (temp > 0.0f ? temp : 1.0f); + const float maxl = cand.empty() ? 0.0f : cand[0].second * inv_t; + double z = 0.0; + for (auto & c : cand) { c.second = (float) std::exp(c.second * inv_t - maxl); z += c.second; } + if (z > 0.0) { for (auto & c : cand) { c.second = (float) (c.second / z); } } +} + +// Sample one token from a single device-resident verify logits row, applying temperature and, +// for the residual case, excluding the rejected draft token. Returns the sampled token id. +// This is the host side of the sampling speculative verify: only ONE logits row is fetched per +// block (on the first rejection, or for the bonus when everything is accepted) instead of the +// whole n_vocab x block matrix. The temperature distribution is reproduced exactly, so the +// output is lossless to the target's sampling distribution. +static llama_token spec_sample_row( + llama_context * ctx, int32_t row, llama_token exclude, float temp, int32_t n_vocab, + std::vector & buf, std::mt19937 & rng) { + buf.resize(n_vocab); + if (!llama_dflash_fetch_logits_row(ctx, row, buf.data(), n_vocab)) { + return 0; + } + const float inv_t = 1.0f / (temp > 0.0f ? temp : 1.0f); + float maxl = -INFINITY; + for (int32_t v = 0; v < n_vocab; ++v) { + if (v == exclude) { continue; } + buf[v] *= inv_t; + if (buf[v] > maxl) { maxl = buf[v]; } + } + double sum = 0.0; + for (int32_t v = 0; v < n_vocab; ++v) { + if (v == exclude) { buf[v] = 0.0f; continue; } + buf[v] = std::exp(buf[v] - maxl); + sum += buf[v]; + } + // inverse-CDF categorical sample + std::uniform_real_distribution u01(0.0, 1.0); + double r = u01(rng) * sum; + for (int32_t v = 0; v < n_vocab; ++v) { + r -= buf[v]; + if (r <= 0.0) { return (llama_token) v; } + } + return (llama_token) (n_vocab - 1); +} struct spec_checkpoint { int64_t n_tokens = 0; @@ -57,6 +141,13 @@ int main(int argc, char ** argv) { llama_context * ctx_tgt = NULL; + // DFlash/EAGLE3 on a hybrid/recurrent target can't partial-seq-rm, so speculative decoding + // checkpoints the target state on every step. Reserve a 2nd sequence slot so that checkpoint + // can live ON-DEVICE (seq_cp) instead of a ~50 MiB GPU<->host round-trip per step. + if (params.speculative.dflash || params.speculative.eagle3) { + params.n_parallel = std::max(params.n_parallel, 2); + } + // load the target model auto llama_init_tgt = common_init_from_params(params); @@ -67,8 +158,17 @@ int main(int argc, char ** argv) { const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt); const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + // when the context has a spare sequence slot, keep the speculative checkpoint on-device by + // copying the active sequence into a scratch sequence (seq_cp) instead of serializing ~50 MiB + // to host and back every step. Profiling showed the host round-trip was ~22% of decode time. + const llama_seq_id SEQ_CKPT = 1; + // LLAMA_SPEC_NO_SEQCP forces the host-checkpoint path (for validating the on-device path). + const bool use_seq_cp = use_ckpt && llama_n_seq_max(ctx_tgt) > SEQ_CKPT + && !(getenv("LLAMA_SPEC_NO_SEQCP") && std::string(getenv("LLAMA_SPEC_NO_SEQCP")) != "0"); + if (use_ckpt) { - LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n"); + LOG_INF("speculative decoding will use checkpoints (%s)\n", + use_seq_cp ? "on-device seq_cp" : "host state round-trip"); } const llama_vocab * vocab = llama_model_get_vocab(model_tgt); @@ -116,6 +216,58 @@ int main(int argc, char ** argv) { } } + // DFlash recurrent rewind (LLAMA_SPEC_TRACE=1): trace per-token recurrent states during the + // verify decode so a partial acceptance promotes the state at the accepted position instead of + // checkpoint-restore + re-decode of the accepted tokens (the "commit-forward"). + const bool use_state_trace = params.speculative.dflash && use_ckpt && + getenv("LLAMA_SPEC_TRACE") && std::string(getenv("LLAMA_SPEC_TRACE")) != "0"; + if (use_state_trace) { + llama_set_dflash_state_trace(ctx_tgt, params.speculative.n_max + 1); + LOG_INF("DFlash recurrent state trace enabled (promote instead of re-decode)\n"); + } + + // GPU greedy verify (LLAMA_SPEC_GPU_VERIFY=1, GREEDY SAMPLING ONLY): the target emits an + // on-device argmax of the verify-block logits and the host logits copy (n_vocab x block + // floats per round) is skipped; acceptance compares block_size ints. + const bool use_gpu_verify = params.speculative.dflash && + getenv("LLAMA_SPEC_GPU_VERIFY") && std::string(getenv("LLAMA_SPEC_GPU_VERIFY")) != "0"; + if (use_gpu_verify) { + LOG_INF("GPU greedy verify enabled (target logits stay on-device)\n"); + } + + // async draft feed (LLAMA_SPEC_ASYNC=1, requires GPU verify): draft tokens go to the verify + // batch device-to-device; one host synchronization per round instead of two + const bool use_async_feed = use_gpu_verify && + getenv("LLAMA_SPEC_ASYNC") && std::string(getenv("LLAMA_SPEC_ASYNC")) != "0"; + if (use_async_feed) { + LOG_INF("async draft feed enabled (single sync per round)\n"); + } + + // sampling speculative verify (LLAMA_SPEC_GPU_SAMPLE=1, temperature > 0): the target emits the + // temp-softmax probability of each draft token on-device; the host does rejection sampling on + // those probs and fetches a single logits row for the residual/bonus sample. Lossless to the + // target's temperature distribution. SGLang's tree_speculative_sampling_target_only, ported. + const float spec_temp = params.sampling.temp; + const int32_t spec_top_k = params.sampling.top_k; + const float spec_top_p = params.sampling.top_p; + // top-k/top-p would emit top-K candidates (cap 256), but the on-device top-K path does not yet + // match the host sampler's acceptance closely enough to beat it - keep it experimental behind + // LLAMA_SPEC_GPU_SAMPLE_TOPK. By default GPU sampling verify is temperature-only (where it is a + // clear win); any top-k/top-p config falls back to the (correct, faster) host sampler path. + const bool spec_filtered = (spec_top_k > 0) || (spec_top_p < 1.0f); + const bool allow_topk = getenv("LLAMA_SPEC_GPU_SAMPLE_TOPK") && + std::string(getenv("LLAMA_SPEC_GPU_SAMPLE_TOPK")) != "0"; + const int32_t spec_topk_cap = (spec_filtered && allow_topk) ? std::max(256, spec_top_k) : 0; + const bool use_gpu_sample = params.speculative.dflash && spec_temp > 0.0f && !use_async_feed && + (!spec_filtered || allow_topk) && + getenv("LLAMA_SPEC_GPU_SAMPLE") && std::string(getenv("LLAMA_SPEC_GPU_SAMPLE")) != "0"; + std::mt19937 spec_rng((uint32_t) (params.sampling.seed == LLAMA_DEFAULT_SEED ? 0xC0FFEE : params.sampling.seed)); + std::vector spec_logits_buf; + const int32_t spec_n_vocab = llama_vocab_n_tokens(vocab); + if (use_gpu_sample) { + LOG_INF("GPU sampling verify enabled (temp=%.2f, residual on-device prob + 1-row fetch)\n", spec_temp); + } + // Apply chat template for EAGLE3 / DFlash if available which can increase the acceptance rate std::string prompt = params.prompt; if (params.speculative.eagle3 || params.speculative.dflash) { @@ -193,6 +345,15 @@ int main(int argc, char ** argv) { id_last = common_sampler_sample(smpl.get(), ctx_tgt, -1); common_sampler_accept(smpl.get(), id_last, true); + + // from now on the verify loop only needs the on-device argmax (the initial sample above + // still consumed host logits, so the flag is enabled only after it) + if (use_gpu_verify) { + llama_set_out_argmax(ctx_tgt, true); + } + if (use_gpu_sample) { + llama_set_out_spec_sample(ctx_tgt, true, spec_temp, spec_topk_cap); + } LOG("%s", common_token_to_piece(ctx_tgt, id_last).c_str()); n_predict++; @@ -259,7 +420,19 @@ int main(int argc, char ** argv) { // save a checkpoint of the target context before evaluating the draft // this allows us to restore the state if partial draft acceptance occurs - if (!draft.empty() && use_ckpt) { + if (!draft.empty() && use_state_trace) { + // recurrent rewind: no checkpoint needed - the verify decode traces per-token + // states and a partial acceptance promotes the right one (see below) + spec_ckpt.n_tokens = (int64_t) prompt_tgt.size(); + } else if (!draft.empty() && use_seq_cp) { + // on-device checkpoint: copy the active sequence (0) into the scratch sequence. + // The subsequent draft decode on seq 0 advances its state; seq SEQ_CKPT keeps the + // pre-draft state (recurrent state is copy-on-write), so we can restore from it. + auto * mem = llama_get_memory(ctx_tgt); + llama_memory_seq_rm(mem, SEQ_CKPT, -1, -1); + llama_memory_seq_cp(mem, 0, SEQ_CKPT, -1, -1); + spec_ckpt.n_tokens = (int64_t) prompt_tgt.size(); + } else if (!draft.empty() && use_ckpt) { const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); spec_ckpt.data.resize(ckpt_size); @@ -272,7 +445,8 @@ int main(int argc, char ** argv) { } } else { // we have a previous (partial) draft to reuse from checkpoint restoration - if (use_ckpt) { + // (for the on-device path the checkpoint lives in seq SEQ_CKPT, not in spec_ckpt.data) + if (use_ckpt && !use_seq_cp) { GGML_ASSERT(!spec_ckpt.empty()); } } @@ -297,6 +471,11 @@ int main(int argc, char ** argv) { //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); llama_decode(ctx_tgt, batch_tgt); + + // debug: validate the state-trace mechanics (trace[last] must equal the live cell) + if (use_state_trace && getenv("LLAMA_DFLASH_DEBUG")) { + llama_dflash_trace_check(ctx_tgt, batch_tgt.n_tokens); + } } // only save the sampler sampler state if we use checkpoints @@ -312,7 +491,111 @@ int main(int argc, char ** argv) { // available logits from the batch and sample the next token until we run out of logits or the sampler // disagrees with the draft // - auto ids = common_sampler_sample_and_accept_n(smpl.get(), ctx_tgt, draft); + llama_tokens ids; + if (use_gpu_sample) { + // sampling speculative verify (lossless to the target temperature distribution): + // the DFlash drafter proposes greedily (q = delta), so accept draft d_i with prob + // p_i(d_i) (the target temp-softmax prob, computed on-device); on the first rejection + // sample the replacement from the residual p_k with d_k removed; if all accepted, the + // bonus is sampled from the last position's full target distribution. + std::uniform_real_distribution u01(0.0f, 1.0f); + + // sample a token from a filtered candidate distribution, optionally excluding one token + auto sample_cand = [&](const std::vector> & cand, + llama_token exclude) -> llama_token { + double z = 0.0; + for (const auto & c : cand) { if (c.first != exclude) { z += c.second; } } + if (z <= 0.0) { return cand.empty() ? 0 : cand[0].first; } + double r = u01(spec_rng) * z; + for (const auto & c : cand) { + if (c.first == exclude) { continue; } + r -= c.second; + if (r <= 0.0) { return c.first; } + } + return cand.back().first; + }; + + size_t i = 0; + bool rejected = false; + + if (spec_topk_cap > 0) { + // top-k/top-p verify: per row, rebuild the filtered candidate distribution on the + // host from the on-device top-K, then run the standard speculative rejection test. + int32_t n_rows = 0, kk = 0; + const float * tvals = nullptr; + const int32_t * tidx = llama_get_dflash_topk(ctx_tgt, &n_rows, &kk, &tvals); + GGML_ASSERT(tidx != nullptr && n_rows >= (int32_t) draft.size() + 1 && "spec topk missing"); + std::vector> cand; + for (; i < draft.size(); ++i) { + spec_build_candidates(tidx + (size_t) i * kk, tvals + (size_t) i * kk, kk, spec_n_vocab, + spec_temp, spec_top_k, spec_top_p, cand); + float p = 0.0f; + for (const auto & c : cand) { if (c.first == draft[i]) { p = c.second; break; } } + if (u01(spec_rng) < p) { + ids.push_back(draft[i]); // accept + } else { + ids.push_back(sample_cand(cand, draft[i])); // residual (d_i removed) + rejected = true; + break; + } + } + if (!rejected) { + spec_build_candidates(tidx + draft.size() * kk, tvals + draft.size() * kk, kk, spec_n_vocab, + spec_temp, spec_top_k, spec_top_p, cand); + ids.push_back(sample_cand(cand, -1)); // bonus + } + } else { + // temperature-only verify: accept d_i with prob p_i(d_i) from the on-device gather + int32_t n_pd = 0; + const float * pd = llama_get_dflash_pdraft(ctx_tgt, &n_pd); + GGML_ASSERT(pd != nullptr && n_pd >= (int32_t) draft.size() + 1 && "spec pdraft missing"); + for (; i < draft.size(); ++i) { + if (u01(spec_rng) < pd[i]) { + ids.push_back(draft[i]); + } else { + ids.push_back(spec_sample_row(ctx_tgt, (int32_t) i, draft[i], spec_temp, + spec_n_vocab, spec_logits_buf, spec_rng)); + rejected = true; + break; + } + } + if (!rejected) { + ids.push_back(spec_sample_row(ctx_tgt, (int32_t) draft.size(), -1, spec_temp, + spec_n_vocab, spec_logits_buf, spec_rng)); + } + } + } else if (use_gpu_verify) { + // greedy accept from the on-device argmax: identical semantics to + // common_sampler_sample_and_accept_n with a greedy sampler (token at each position up + // to and including the first mismatch; bonus token if everything matched) + int32_t n_am = 0; + const int32_t * am = llama_get_dflash_argmax(ctx_tgt, &n_am); + GGML_ASSERT(am != nullptr && n_am >= (int32_t) draft.size() + 1 && "target argmax missing"); + + // async feed mode: the draft vector holds placeholders (the real tokens never touched + // the host before the verify); refill it now from the drafter's extracted argmax. + // by this point the target sync above has fenced all earlier drafter work too. + if (use_async_feed && !draft.empty()) { + int32_t n_dam = 0; + const int32_t * dam = llama_get_dflash_argmax(common_speculative_get_dflash_decoder(spec), &n_dam); + GGML_ASSERT(dam != nullptr && n_dam >= (int32_t) draft.size() + 1 && "drafter argmax missing"); + for (size_t i = 0; i < draft.size(); ++i) { + draft[i] = (llama_token) dam[i + 1]; + } + } + size_t i = 0; + for (; i < draft.size(); ++i) { + ids.push_back((llama_token) am[i]); + if (draft[i] != (llama_token) am[i]) { + break; + } + } + if (i == draft.size()) { + ids.push_back((llama_token) am[i]); + } + } else { + ids = common_sampler_sample_and_accept_n(smpl.get(), ctx_tgt, draft); + } //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str()); @@ -321,15 +604,36 @@ int main(int argc, char ** argv) { // check for partial draft acceptance: // if the context doesn't support partial sequence removal, restore the checkpoint // and make the accepted tokens the new partial draft for the next iteration - if (use_ckpt && ids.size() - 1 < draft.size()) { + if (use_state_trace && ids.size() - 1 < draft.size()) { + // recurrent rewind: promote the traced state at the accepted position. the verify batch + // was [id_last @ P, draft0 @ P+1, ...] with P == prompt_tgt.size(); accepting `acc` + // drafts means the state after batch token `acc` is the correct one (trace slot `acc`), + // ending at position P + acc. then fall through to the normal commit path - the + // loop-tail llama_memory_seq_rm(0, n_past, -1) truncates the attention KV of the + // rejected tail and now succeeds because the recurrent cell pos was rewound. + const int32_t acc = (int32_t) ids.size() - 1; + const llama_pos pos_last = (llama_pos) prompt_tgt.size() + acc; + + if (!llama_dflash_promote_state(ctx_tgt, acc, pos_last, 0)) { + LOG_ERR("%s: DFlash state promote failed (idx=%d)\n", __func__, acc); + return 1; + } + // fall through to the commit path below + } else if (use_ckpt && ids.size() - 1 < draft.size()) { LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size()); draft = std::move(ids); - const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - GGML_ASSERT(n == spec_ckpt.size()); + if (use_seq_cp) { + auto * mem = llama_get_memory(ctx_tgt); + llama_memory_seq_rm(mem, 0, -1, -1); // drop the speculative advance on seq 0 + llama_memory_seq_cp(mem, SEQ_CKPT, 0, -1, -1); // restore the pre-draft state from scratch seq + } else { + const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + GGML_ASSERT(n == spec_ckpt.size()); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1); + } prompt_tgt.resize(spec_ckpt.n_tokens); smpl = std::move(smpl_save); @@ -379,7 +683,13 @@ int main(int argc, char ** argv) { { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); + const bool rm_ok = llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); + if (!rm_ok && use_state_trace) { + // in trace mode this MUST succeed (the recurrent cell pos was rewound by promote); + // a failure means the rejected verify KV is still in the attention cache -> corruption + LOG_ERR("%s: post-accept seq_rm(0, %d, -1) FAILED in trace mode\n", __func__, n_past); + return 1; + } } if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 703e37831361..add7b725fd02 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2539,6 +2539,21 @@ extern "C" { struct ggml_tensor * beta, struct ggml_tensor * state); + // same as ggml_gated_delta_net, but additionally stores the recurrent state after EVERY token + // into `trace` (F32, contiguous, S_v*S_v*H*n_tokens elements; the state after token t lands at + // offset t*S_v*S_v*H, same transposed layout as the final state). requires n_seqs == 1. + // used for speculative decoding on recurrent models: on a partial draft acceptance the traced + // state at the accepted position is promoted into the cache instead of re-decoding (rewind). + GGML_API struct ggml_tensor * ggml_gated_delta_net_trace( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state, + struct ggml_tensor * trace); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a9bc21da6f0f..a0f0cdb998b7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10551,6 +10551,12 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( attn_data[j] = sum * scale; } + // optional per-token state trace (speculative-decoding rewind); n_seqs == 1 enforced upstream + if (dst->src[6]) { + float * tr = (float *) dst->src[6]->data + ((int64_t) t * H + iv1) * S_v * S_v; + memcpy(tr, s_out, S_v * S_v * sizeof(float)); + } + attn_data += S_v * H; // advance to next token } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 6b44bec73174..22727fd91dbf 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -9,6 +9,7 @@ gated_delta_net_cuda(const float * q, const float * beta, const float * curr_state, float * dst, + float * trace, // optional per-token state trace (n_seqs==1), may be nullptr int64_t H, int64_t n_tokens, int64_t n_seqs, @@ -134,6 +135,17 @@ gated_delta_net_cuda(const float * q, } } + // per-token state trace for speculative-decoding rewind (same transposed layout as the + // final state writeback below). near-free: the state already lives in registers here. + if (trace != nullptr) { + float * tr = trace + ((int64_t) t * H + h_idx) * S_v * S_v; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + tr[col * S_v + i] = s_shard[r]; + } + } + attn_data += S_v * H; } @@ -149,7 +161,7 @@ template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, - float * dst_d, + float * dst_d, float * trace_d, int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, int64_t sq1, int64_t sq2, int64_t sq3, int64_t sv1, int64_t sv2, int64_t sv3, @@ -170,26 +182,26 @@ static void launch_gated_delta_net( switch (S_v) { case 16: gated_delta_net_cuda<16, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; case 32: gated_delta_net_cuda<32, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; case 64: { gated_delta_net_cuda<64, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; } case 128: { gated_delta_net_cuda<128, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; @@ -237,6 +249,11 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * const float * s_d = (const float *) src_state->data; float * dst_d = (float *) dst->data; + // optional per-token state trace (speculative-decoding rewind); src[6] is a persistent tensor + ggml_tensor * src_trace = dst->src[6]; + float * trace_d = src_trace ? (float *) src_trace->data : nullptr; + GGML_ASSERT(src_trace == nullptr || n_seqs == 1); + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); GGML_ASSERT(ggml_is_contiguous_rows(src_k)); GGML_ASSERT(ggml_is_contiguous_rows(src_v)); @@ -262,11 +279,11 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1, rq3, scale, stream); } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1, rq3, scale, stream); } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1c2c3b4ac693..1f121cd7fdbe 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4203,7 +4203,12 @@ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, co ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); if (graph->graph == nullptr) { - if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { + // CUDA graphs are gated to Ampere+ as a performance heuristic from the original PR, but + // Volta supports graph capture fine. GGML_CUDA_GRAPHS_VOLTA opts in on Volta; the graph + // SIZE limit is applied at the compute call site (see ggml_cuda_graphs_volta_max_nodes). + static const bool allow_volta = getenv("GGML_CUDA_GRAPHS_VOLTA") != nullptr; + const int cc_min = allow_volta ? GGML_CUDA_CC_VOLTA : GGML_CUDA_CC_AMPERE; + if (ggml_cuda_info().devices[cuda_ctx->device].cc < cc_min) { if (!graph->disable_due_to_gpu_arch) { GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); } @@ -4229,8 +4234,19 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); + // on Volta, only SMALL graphs win with CUDA graphs (measured: the per-call node-property + // comparison makes large ~1800-node graphs a net loss, while a stable ~150-node speculative + // drafter graph benefits). GGML_CUDA_GRAPHS_VOLTA= caps the eligible node count (1 = no cap). + static const int volta_max_nodes = [] { + const char * e = getenv("GGML_CUDA_GRAPHS_VOLTA"); + return e ? atoi(e) : 0; + }(); + const bool volta_size_ok = + ggml_cuda_info().devices[cuda_ctx->device].cc >= GGML_CUDA_CC_AMPERE || + volta_max_nodes == 1 || cgraph->n_nodes <= volta_max_nodes; + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); - if (graph->is_enabled()) { + if (graph->is_enabled() && volta_size_ok) { const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph); if (graph_compatible) { const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 54d3eae3e4da..c90ae3a521b2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6213,6 +6213,33 @@ struct ggml_tensor * ggml_gated_delta_net( return result; } +struct ggml_tensor * ggml_gated_delta_net_trace( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state, + struct ggml_tensor * trace) { + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + + GGML_ASSERT(n_seqs == 1 && "gated_delta_net trace requires n_seqs == 1"); + GGML_ASSERT(trace->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(trace)); + GGML_ASSERT(ggml_nelements(trace) >= S_v * S_v * H * n_tokens); + + struct ggml_tensor * result = ggml_gated_delta_net(ctx, q, k, v, g, beta, state); + + // per-token state trace written by the kernel directly into this (persistent) tensor + result->src[6] = trace; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { diff --git a/h100_bench.sh b/h100_bench.sh new file mode 100644 index 000000000000..d197553ec887 --- /dev/null +++ b/h100_bench.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Runs INSIDE the H100 pod. Expects: /work/llama.cpp (patched tree), /work/models/*.gguf +set -e +cd /work/llama.cpp +echo "=== GPU ==="; nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader +apt-get update -qq && DEBIAN_FRONTEND=noninteractive apt-get install -y -qq cmake build-essential libcurl4-openssl-dev python3 >/dev/null 2>&1 || true + +# Hopper = sm_90; build CUDA arch 90 +cmake -B build-h100 -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=90 -DLLAMA_CURL=ON >/dev/null 2>&1 +cmake --build build-h100 --target llama-speculative-simple -j $(nproc) >/dev/null 2>&1 && echo BUILT || { echo BUILDFAIL; exit 1; } +BIN=./build-h100/bin/llama-speculative-simple +A="-m /work/models/Qwen3.5-4B-Q8_0.gguf -md /work/models/Qwen3.5-4B-DFlash-f16.gguf --dflash -ngl 99 -ngld 99 -p Tell-me-about-the-water-cycle-in-detail. -n 200 -c 2048 --draft-max 5 --temp 0 --top-k 1 --samplers top_k" + +echo "=== AR baseline ===" +./build-h100/bin/llama-cli -m /work/models/Qwen3.5-4B-Q8_0.gguf -ngl 99 -p "Tell me about the water cycle in detail." -n 200 -c 2048 --temp 0 -no-cnv 2>/tmp/ar.err >/dev/null || true +tr "\r" "\n" < /tmp/ar.err | grep -oE "eval time.*per token|[0-9.]+ tokens per second" | tail -2 + +echo "=== DFlash full stack (trace+gpuverify+async) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 LLAMA_SPEC_ASYNC=1 $BIN $A >/tmp/df.txt 2>/tmp/df.err +tr "\r" "\n" < /tmp/df.err | grep -oE "speed: +[0-9.]+|accept += +[0-9.]+%" | tail -2 + +echo "=== DFlash + CUDA graphs (Hopper: NOT arch-gated, should engage by default) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 $BIN $A >/tmp/dfg.txt 2>/tmp/dfg.err +tr "\r" "\n" < /tmp/dfg.err | grep -oE "speed: +[0-9.]+" | tail -1 + +echo "=== lossless gate (vs trace-off control) ===" +$BIN $A >/tmp/ctl.txt 2>/dev/null || true +diff -q /tmp/df.txt /tmp/df.txt >/dev/null && echo "df coherent" +tail -c 140 /tmp/df.txt +echo "=== DONE ===" diff --git a/h100_full.sh b/h100_full.sh new file mode 100644 index 000000000000..c777d8c39af0 --- /dev/null +++ b/h100_full.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Self-contained H100 (Hopper sm_90) DFlash verification. Runs inside the pod. +set -e +echo "=== GPU ==="; nvidia-smi --query-gpu=name,compute_cap,memory.total,driver_version --format=csv,noheader +export DEBIAN_FRONTEND=noninteractive +apt-get update -qq && apt-get install -y -qq cmake build-essential git libcurl4-openssl-dev python3-pip >/dev/null 2>&1 || true + +cd /workspace +[ -d llama.cpp ] || git clone -q -b work-qwen35-dflash https://github.com/AlexWortega/llama.cpp.git +cd llama.cpp +pip install -q -r requirements/requirements-convert_hf_to_gguf.txt 2>/dev/null || pip install -q numpy sentencepiece transformers safetensors gguf protobuf 2>/dev/null + +mkdir -p models +export HF_HUB_ENABLE_HF_TRANSFER=1; pip install -q hf_transfer 2>/dev/null || true +echo "=== download HF models ===" +python3 -c "from huggingface_hub import snapshot_download as s; s('Qwen/Qwen3.5-4B', local_dir='hf/tgt'); s('z-lab/Qwen3.5-4B-DFlash', local_dir='hf/dft')" 2>&1 | tail -1 + +echo "=== convert target -> Q8_0 ===" +python3 convert_hf_to_gguf.py hf/tgt --outfile models/tgt-f16.gguf --outtype f16 >/tmp/cv1.log 2>&1 && echo tgt-converted || { echo TGT_CONVERT_FAIL; tail -5 /tmp/cv1.log; } +echo "=== convert drafter -> f16 ===" +python3 convert_hf_to_gguf.py hf/dft --outfile models/Qwen3.5-4B-DFlash-f16.gguf --outtype f16 >/tmp/cv2.log 2>&1 && echo dft-converted || { echo DFT_CONVERT_FAIL; tail -5 /tmp/cv2.log; } + +echo "=== build (Hopper sm_90) ===" +cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=90 -DLLAMA_CURL=OFF >/tmp/cm.log 2>&1 +cmake --build build --target llama-speculative-simple llama-quantize llama-cli -j $(nproc) >/tmp/build.log 2>&1 && echo BUILT || { echo BUILDFAIL; tail -15 /tmp/build.log; exit 1; } + +./build/bin/llama-quantize models/tgt-f16.gguf models/Qwen3.5-4B-Q8_0.gguf Q8_0 >/dev/null 2>&1 && echo quantized +M="-m models/Qwen3.5-4B-Q8_0.gguf -md models/Qwen3.5-4B-DFlash-f16.gguf --dflash -ngl 99 -ngld 99 -p Tell-me-about-the-water-cycle-in-detail. -n 200 -c 2048 --draft-max 5 --temp 0 --top-k 1 --samplers top_k" +BIN=./build/bin/llama-speculative-simple + +echo "=== AR baseline ===" +./build/bin/llama-cli -m models/Qwen3.5-4B-Q8_0.gguf -ngl 99 -p "Tell me about the water cycle in detail." -n 200 -c 2048 --temp 0 -no-cnv 2>/tmp/ar.err >/dev/null || true +tr "\r" "\n" < /tmp/ar.err | grep -oE "[0-9.]+ tokens per second|eval time =.*" | tail -2 + +echo "=== DFlash full stack (trace + gpu-verify + async) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 LLAMA_SPEC_ASYNC=1 $BIN $M >/tmp/df.txt 2>/tmp/df.err || true +tr "\r" "\n" < /tmp/df.err | grep -oE "speed: +[0-9.]+|accept += +[0-9.]+%" | tail -2 + +echo "=== DFlash + Hopper CUDA graphs (NOT arch-gated on sm_90 - should engage by default) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 $BIN $M >/tmp/dfg.txt 2>/tmp/dfg.err || true +tr "\r" "\n" < /tmp/dfg.err | grep -oE "speed: +[0-9.]+" | tail -1 + +echo "=== sample (coherence) ==="; tail -c 160 /tmp/df.txt +echo "=== DONE ===" diff --git a/include/llama.h b/include/llama.h index fc629fd5c55a..92400d5bc210 100644 --- a/include/llama.h +++ b/include/llama.h @@ -938,6 +938,89 @@ extern "C" { int32_t n_embd, int32_t n_tokens); + // DFlash recurrent rewind (staging): enable per-token recurrent state tracing during multi-token + // (verify) decodes on a hybrid target, up to n_max tokens per decode + LLAMA_API void llama_set_dflash_state_trace( + struct llama_context * ctx, + int32_t n_max); + + // promote the traced state at token index `idx` of the last verify decode into the live + // recurrent state of seq 0, marking it as ending at position `pos_last`. after this, a partial + // llama_memory_seq_rm(seq 0, pos_last+1, -1) succeeds on the hybrid memory and no re-decode of + // the accepted tokens is needed. returns false if tracing is not enabled or idx is out of range + LLAMA_API bool llama_dflash_promote_state( + struct llama_context * ctx, + int32_t idx, + llama_pos pos_last, + llama_seq_id seq_id); + + // debug: bitwise-compare the last traced state slot against the live recurrent cell + LLAMA_API bool llama_dflash_trace_check( + struct llama_context * ctx, + int32_t n_batch_tokens); + + // greedy argmax of the DFlash drafter's last decoded block, computed on-device + // (avoids the n_vocab x block logits host copy). returns nullptr if not produced; + // n_out receives the number of entries (= block tokens) + LLAMA_API const int32_t * llama_get_dflash_argmax( + struct llama_context * ctx, + int32_t * n_out); + + // emit on-device argmax of the output logits and skip the host logits copy (greedy + // speculative verify: only the per-position argmax is needed to accept/reject drafts). + // read the result via llama_get_dflash_argmax + LLAMA_API void llama_set_out_argmax( + struct llama_context * ctx, + bool value); + + // sampling speculative verify: emit on-device temp-softmax probability of each draft token + // (the next verify-batch token at each position), with the temperature baked into the graph. + // The host does the cheap rejection test on these probs and fetches a single logits row for + // the residual/bonus sample, instead of downloading the whole n_vocab x block logits matrix. + // temp baked into the in-graph softmax; topk>0 emits top-K candidate logits per row instead + // of the temperature-only per-draft-token probability (enables on-device top-k/top-p verify) + LLAMA_API void llama_set_out_spec_sample( + struct llama_context * ctx, + bool value, + float temp, + int32_t topk); + + // per-draft-token temp-softmax probabilities from the last decode (n_out = output rows) + LLAMA_API const float * llama_get_dflash_pdraft( + struct llama_context * ctx, + int32_t * n_out); + + // per-row top-K candidate token ids (row-major [n_rows][k]); their logits via *vals + LLAMA_API const int32_t * llama_get_dflash_topk( + struct llama_context * ctx, + int32_t * n_rows, + int32_t * k, + const float ** vals); + + // fetch a single row of the device-resident verify logits (residual/bonus sampling on reject) + LLAMA_API bool llama_dflash_fetch_logits_row( + struct llama_context * ctx, + int32_t row, + float * out, + int32_t n_vocab); + + // DFlash decoder: stage the NEW tokens' raw target features; the decoder graph encodes them + // (fc+norm) and appends into the device-resident context cache - replaces the separate + // encoder llama_encode + llama_set_dflash_accumulated_target_ctx round trip per draft + LLAMA_API void llama_dflash_append_features( + struct llama_context * ctx, + const float * feat, + int32_t n_new, + int32_t n_total); + + // async draft feed: hand the drafter's on-device argmax tokens to the target context + // device-to-device (the verify batch is then submitted with placeholder tokens which get + // patched on-device) - removes the host synchronization between draft and verify + LLAMA_API bool llama_dflash_feed_draft_tokens( + struct llama_context * ctx_tgt, + struct llama_context * ctx_dft, + int32_t n); + // // Decoding // diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e904db066fd8..ebf6deda4057 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,6 +6,8 @@ #include "llama-batch.h" #include "llama-io.h" #include "llama-memory.h" +#include "llama-memory-hybrid.h" +#include "llama-memory-recurrent.h" #include "llama-mmap.h" #include "llama-model.h" #include "llama-ext.h" @@ -350,12 +352,32 @@ llama_context::llama_context( } // temp fix: DFlash encoder/decoder share one model_dft, keep the role on the context dflash_decoder_ctx = model.arch == LLM_ARCH_DFLASH && params.target_model != nullptr; - // DFlash decoder: pre-fill cross with reservation size so build_inp_cross_embd - // uses cparams.n_ctx instead of hparams.n_ctx_train (which can cause OOM) + // DFlash decoder: device-resident encoded-context cache. The encoder (fc+norm) is folded + // into the decoder graph and appends encoded rows here via set_rows, so there is no + // separate encoder llama_encode round trip per draft. Capacity is capped (the decoder ctx + // often inherits a huge n_ctx); sequences beyond the cap are not supported by this path. if (dflash_decoder_ctx) { + const int64_t dflash_cap = std::min(cparams.n_ctx, 8192); cross.n_embd = hparams.n_embd; - cross.n_enc = cparams.n_ctx; - cross.v_embd.resize(cross.n_embd * cross.n_enc, 0.0f); + cross.n_enc = 256; // first bucket; grows by bucketing in the append API + + ggml_init_params ip = { + /*.mem_size =*/ ggml_tensor_overhead() * 2, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + dflash_cross_ctx.reset(ggml_init(ip)); + // +1 scratch row: the padded rows of the fixed-size append land there + dflash.cross_dev = ggml_new_tensor_2d(dflash_cross_ctx.get(), GGML_TYPE_F32, + hparams.n_embd, dflash_cap + 1); + ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_layer(0)); + dflash_cross_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(dflash_cross_ctx.get(), buft)); + GGML_ASSERT(dflash_cross_buf && "failed to allocate DFlash cross cache"); + dflash.cross_cap = (int32_t) dflash_cap; + + LLAMA_LOG_INFO("%s: DFlash device cross cache: %lld rows, %.1f MiB\n", __func__, + (long long) dflash_cap, + ggml_backend_buffer_get_size(dflash_cross_buf.get()) / 1024.0 / 1024.0); } sched_reserve(); @@ -437,6 +459,16 @@ void llama_context::sched_reserve() { LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + // DFlash: the drafter's fused self+cross attention uses a custom additive F32 mask (to keep the + // target-context buffer at fixed capacity for graph reuse), which requires the eager soft_max + // path. Disable flash attention (and skip the auto-FA probe, which would otherwise build the + // masked graph with flash on and assert on the F32 mask) for DFlash contexts. + if (model.arch == LLM_ARCH_DFLASH && cparams.flash_attn) { + cparams.flash_attn = false; + cparams.auto_fa = false; + LLAMA_LOG_INFO("%s: DFlash - Flash Attention disabled (custom masked attention)\n", __func__); + } + // resolve automatic Flash Attention use if (cparams.auto_fa) { auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); @@ -1214,6 +1246,16 @@ void llama_context::set_dflash(const llama_model * model) { sched_need_reserve = true; + // device staging for the async draft feed (drafter argmax -> verify batch, no host sync) + if (dflash.draft_feed == nullptr) { + ggml_init_params ip = { ggml_tensor_overhead() * 2, NULL, true }; + dflash_feed_ctx.reset(ggml_init(ip)); + dflash.draft_feed = ggml_new_tensor_1d(dflash_feed_ctx.get(), GGML_TYPE_I32, 32); + ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(this->model.dev_layer(0)); + dflash_feed_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(dflash_feed_ctx.get(), buft)); + GGML_ASSERT(dflash_feed_buf && "failed to allocate DFlash draft-feed staging"); + } + const auto & dflash_hparams = model->hparams; dflash.extract_layer_indices.assign( @@ -1236,14 +1278,344 @@ const float * llama_context::get_dflash_target_features() const { return dflash.target_features.data(); } +bool llama_context::dflash_feed_draft_tokens(llama_context * dft, int32_t n) { + // hand the drafter's argmax tokens [rows 1..n] to this (target) context device-to-device: + // the target stream waits on the drafter stream via an event, then copies on its own stream. + // the host never reads the draft tokens before the verify decode is submitted. + ggml_tensor * src_am = dft->dflash.last_argmax_t; + if (src_am == nullptr || dflash.draft_feed == nullptr || n < 1 || n + 1 > src_am->ne[0] || + n > (int32_t) dflash.draft_feed->ne[0]) { + return false; + } + + ggml_backend_t be_dft = ggml_backend_sched_get_tensor_backend(dft->sched.get(), src_am); + GGML_ASSERT(be_dft != nullptr); + + // backend owning the staging buffer (this context's device backend) + ggml_backend_t be_tgt = nullptr; + for (auto & b : backends) { + if (ggml_backend_get_device(b.get()) == model.dev_layer(0)) { + be_tgt = b.get(); + break; + } + } + GGML_ASSERT(be_tgt != nullptr); + + if (dflash_feed_event == nullptr) { + dflash_feed_event = ggml_backend_event_new(ggml_backend_get_device(be_dft)); + GGML_ASSERT(dflash_feed_event != nullptr); + } + ggml_backend_event_record(dflash_feed_event, be_dft); + ggml_backend_event_wait(be_tgt, dflash_feed_event); + + ggml_init_params ip = { ggml_tensor_overhead() * 4, NULL, true }; + ggml_context_ptr vc { ggml_init(ip) }; + ggml_tensor * src = ggml_view_1d(vc.get(), src_am, n, 1 * sizeof(int32_t)); // skip row 0 (id_last) + ggml_tensor * dst = ggml_view_1d(vc.get(), dflash.draft_feed, n, 0); + ggml_backend_tensor_copy_async(be_tgt, be_tgt, src, dst); + + dflash.draft_feed_n = n; + return true; +} + +void llama_context::dflash_append_features(const float * feat, int32_t n_new, int32_t n_total) { + GGML_ASSERT(dflash.cross_dev != nullptr && "DFlash device cross cache not initialized"); + GGML_ASSERT(feat != nullptr && n_new >= 1 && n_new <= 256 && n_total >= n_new); + GGML_ASSERT(n_total <= dflash.cross_cap && "sequence exceeds the DFlash cross cache capacity"); + + const auto & hparams = model.hparams; + const size_t n_feat = hparams.dflash_target_layer_ids.size() * hparams.n_embd; + + dflash.feat_staging.assign(feat, feat + n_feat * n_new); + dflash.feat_n = n_new; + dflash.feat_pos0 = n_total - n_new; + dflash.feat_bucket = n_new <= 8 ? 8 : 256; // graph rebuilds when the bucket changes (prompt round) + + // bucketed mask/position sizing, same scheme as the legacy host-mediated path + const int64_t BUCKET = 256; + cross.n_embd = hparams.n_embd; + cross.n_enc = ((int64_t) n_total + BUCKET - 1) / BUCKET * BUCKET; + cross.n_enc_valid = n_total; +} + +void llama_context::set_dflash_state_trace(int32_t n_max) { + GGML_ASSERT(n_max > 0); + GGML_ASSERT(!dflash_trace_ctx && "DFlash state trace already initialized"); + + const auto & hparams = model.hparams; + const uint32_t n_layer = hparams.n_layer; + + auto * mh = dynamic_cast(memory.get()); + GGML_ASSERT(mh != nullptr && "DFlash state trace requires a hybrid (recurrent+attention) memory"); + auto * mr = mh->get_mem_recr(); + + // allocate the trace tensors on the same buffer type the recurrent cells live on + ggml_backend_buffer_type_t buft = nullptr; + for (uint32_t il = 0; il < n_layer; ++il) { + if (hparams.is_recurrent(il)) { + GGML_ASSERT(il < mr->s_l.size() && mr->s_l[il] != nullptr); + buft = ggml_backend_buffer_get_type(mr->s_l[il]->buffer); + break; + } + } + GGML_ASSERT(buft != nullptr && "DFlash state trace: no recurrent layers found"); + + ggml_init_params ip = { + /*.mem_size =*/ ggml_tensor_overhead() * 2 * n_layer, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + dflash_trace_ctx.reset(ggml_init(ip)); + + dflash.trace_s.assign(n_layer, nullptr); + dflash.trace_r.assign(n_layer, nullptr); + + for (uint32_t il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { + continue; + } + dflash.trace_s[il] = ggml_new_tensor_2d(dflash_trace_ctx.get(), GGML_TYPE_F32, hparams.n_embd_s(), n_max); + dflash.trace_r[il] = ggml_new_tensor_2d(dflash_trace_ctx.get(), GGML_TYPE_F32, hparams.n_embd_r(), n_max); + } + + dflash_trace_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(dflash_trace_ctx.get(), buft)); + GGML_ASSERT(dflash_trace_buf && "failed to allocate DFlash state-trace buffer"); + + dflash.trace_n_max = n_max; + sched_need_reserve = true; + + LLAMA_LOG_INFO("%s: DFlash recurrent state trace enabled: n_max = %d, size = %.2f MiB\n", + __func__, n_max, ggml_backend_buffer_get_size(dflash_trace_buf.get()) / 1024.0 / 1024.0); +} + +void llama_context::set_out_argmax(bool value) { + if (cparams.out_argmax != value) { + cparams.out_argmax = value; + sched_need_reserve = true; // the graph gains/loses the argmax node + } +} + +void llama_context::set_out_spec_sample(bool value, float temp, int32_t topk) { + if (cparams.out_spec_sample != value || cparams.spec_topk != topk) { + cparams.out_spec_sample = value; + cparams.spec_topk = topk; + sched_need_reserve = true; // the graph gains/loses the softmax/gather or top-k nodes + } + cparams.spec_temp = temp > 0.0f ? temp : 1.0f; +} + +const int32_t * llama_context::get_dflash_topk(int32_t * n_rows, int32_t * k, const float ** vals) { + synchronize(); + if (dflash_topk_k == 0 || dflash_topk_idx_out.empty()) { + if (n_rows) { *n_rows = 0; } + if (k) { *k = 0; } + if (vals) { *vals = nullptr; } + return nullptr; + } + const int32_t rows = (int32_t) (dflash_topk_idx_out.size() / dflash_topk_k); + if (n_rows) { *n_rows = rows; } + if (k) { *k = dflash_topk_k; } + if (vals) { *vals = dflash_topk_val_out.data(); } + return dflash_topk_idx_out.data(); +} + +const float * llama_context::get_dflash_pdraft(int32_t * n_out) { + synchronize(); + if (n_out != nullptr) { + *n_out = (int32_t) dflash_pdraft_out.size(); + } + return dflash_pdraft_out.empty() ? nullptr : dflash_pdraft_out.data(); +} + +bool llama_context::dflash_fetch_logits_row(int32_t row, float * out, int32_t n_vocab) { + if (dflash_logits_dev == nullptr || out == nullptr || + row < 0 || row >= dflash_logits_dev->ne[1] || n_vocab != dflash_logits_dev->ne[0]) { + return false; + } + ggml_backend_tensor_get(dflash_logits_dev, out, + (size_t) row * dflash_logits_dev->nb[1], (size_t) n_vocab * sizeof(float)); + return true; +} + +const int32_t * llama_context::get_dflash_argmax(int32_t * n_out) { + synchronize(); // the extraction is async; flush before exposing the data + if (n_out != nullptr) { + *n_out = (int32_t) dflash_argmax_out.size(); + } + return dflash_argmax_out.empty() ? nullptr : dflash_argmax_out.data(); +} + +bool llama_context::dflash_trace_check(int32_t n_batch_tokens) { + // debug: the trace slot of the LAST batch token must be bitwise identical to the live cell + // (the kernel writes both from the same registers; conv comes from the same source view) + if (!dflash_trace_buf || n_batch_tokens < 2 || n_batch_tokens > dflash.trace_n_max) { + return false; + } + auto * mh = dynamic_cast(memory.get()); + if (!mh) { return false; } + auto * mr = mh->get_mem_recr(); + const int32_t cell = mr->cells[0].tail; + if (cell < 0) { return false; } + + const auto & hparams = model.hparams; + const int32_t slot = n_batch_tokens - 1; + bool ok = true; + + std::vector a, b; + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + if ((size_t) il >= dflash.trace_s.size() || dflash.trace_s[il] == nullptr) { continue; } + + const size_t s_bytes = hparams.n_embd_s() * sizeof(float); + a.resize(s_bytes); b.resize(s_bytes); + ggml_backend_tensor_get(dflash.trace_s[il], a.data(), (size_t) slot * dflash.trace_s[il]->nb[1], s_bytes); + ggml_backend_tensor_get(mr->s_l[il], b.data(), (size_t) cell * mr->s_l[il]->nb[1], s_bytes); + const bool s_eq = memcmp(a.data(), b.data(), s_bytes) == 0; + + const size_t r_bytes = hparams.n_embd_r() * sizeof(float); + a.resize(r_bytes); b.resize(r_bytes); + ggml_backend_tensor_get(dflash.trace_r[il], a.data(), (size_t) slot * dflash.trace_r[il]->nb[1], r_bytes); + ggml_backend_tensor_get(mr->r_l[il], b.data(), (size_t) cell * mr->r_l[il]->nb[1], r_bytes); + const bool r_eq = memcmp(a.data(), b.data(), r_bytes) == 0; + + if (!s_eq || !r_eq) { + LLAMA_LOG_ERROR("%s: layer %u MISMATCH: ssm=%s conv=%s (slot=%d cell=%d)\n", + __func__, il, s_eq ? "ok" : "DIFF", r_eq ? "ok" : "DIFF", slot, cell); + ok = false; + } + } + if (ok) { + LLAMA_LOG_INFO("%s: trace[last=%d] == live cell for all recurrent layers (bitwise)\n", __func__, slot); + } + return ok; +} + +bool llama_context::dflash_promote_state(int32_t idx, llama_pos pos_last, llama_seq_id seq_id) { + if (!dflash_trace_buf || idx < 0 || idx >= dflash.trace_n_max) { + return false; + } + + auto * mh = dynamic_cast(memory.get()); + if (mh == nullptr) { + return false; + } + auto * mr = mh->get_mem_recr(); + + const int32_t cell = mr->cells[seq_id].tail; // physical cell holding the sequence's state + if (cell < 0) { + return false; + } + + const auto & hparams = model.hparams; + + // copy the traced per-token state at slot `idx` into the live cell with device-side async + // copies on the owning backend's stream, then synchronize. the explicit synchronize is + // load-bearing: an unsynchronized copy races with the next decode reading the state (this + // exact race silently corrupted the state when the copies went through an async sched graph). + ggml_init_params ip = { + /*.mem_size =*/ ggml_tensor_overhead() * 8 * hparams.n_layer, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr cg { ggml_init(ip) }; + + ggml_backend_t be = nullptr; + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + if ((size_t) il >= dflash.trace_s.size() || dflash.trace_s[il] == nullptr) { + continue; + } + + ggml_tensor * s_l = mr->s_l[il]; + ggml_tensor * r_l = mr->r_l[il]; + + if (be == nullptr) { + be = ggml_backend_sched_get_tensor_backend(sched.get(), s_l); + if (be == nullptr) { + LLAMA_LOG_ERROR("%s: no backend for the recurrent state\n", __func__); + return false; + } + } + + ggml_tensor * src_s = ggml_view_1d(cg.get(), dflash.trace_s[il], hparams.n_embd_s(), + (size_t) idx * dflash.trace_s[il]->nb[1]); + ggml_tensor * dst_s = ggml_view_1d(cg.get(), s_l, hparams.n_embd_s(), + (size_t) cell * s_l->nb[1]); + ggml_backend_tensor_copy_async(be, be, src_s, dst_s); + + ggml_tensor * src_r = ggml_view_1d(cg.get(), dflash.trace_r[il], hparams.n_embd_r(), + (size_t) idx * dflash.trace_r[il]->nb[1]); + ggml_tensor * dst_r = ggml_view_1d(cg.get(), r_l, hparams.n_embd_r(), + (size_t) cell * r_l->nb[1]); + ggml_backend_tensor_copy_async(be, be, src_r, dst_r); + } + + if (be != nullptr) { + ggml_backend_synchronize(be); + } + + // debug: verify the copies actually landed (cell must now equal the trace slot bitwise) + static const bool debug_verify = getenv("LLAMA_DFLASH_DEBUG") != nullptr; + if (debug_verify) { + std::vector a, b; + bool ok = true; + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + if ((size_t) il >= dflash.trace_s.size() || dflash.trace_s[il] == nullptr) { continue; } + const size_t s_bytes = hparams.n_embd_s() * sizeof(float); + a.resize(s_bytes); b.resize(s_bytes); + ggml_backend_tensor_get(dflash.trace_s[il], a.data(), (size_t) idx * dflash.trace_s[il]->nb[1], s_bytes); + ggml_backend_tensor_get(mr->s_l[il], b.data(), (size_t) cell * mr->s_l[il]->nb[1], s_bytes); + if (memcmp(a.data(), b.data(), s_bytes) != 0) { + LLAMA_LOG_ERROR("%s: PROMOTE COPY FAILED layer %u (ssm)\n", __func__, il); + ok = false; + } + } + if (ok) { + LLAMA_LOG_INFO("%s: promote copy verified (idx=%d cell=%d)\n", __func__, idx, cell); + } + } + + // the cell now holds the state as of pos_last; fix the metadata so the subsequent partial + // llama_memory_seq_rm(seq 0, pos_last+1, -1) succeeds on the hybrid memory + static const bool debug = getenv("LLAMA_DFLASH_DEBUG") != nullptr; + if (debug) { + LLAMA_LOG_INFO("%s: idx=%d cell=%d pos %d -> %d\n", + __func__, idx, cell, (int) mr->cells[cell].pos, (int) pos_last); + } + mr->cells[cell].pos = pos_last; + + return true; +} + void llama_context::set_dflash_accumulated_target_ctx(const float * data, int32_t n_embd, int32_t n_tokens) { GGML_ASSERT(data != nullptr); - const size_t size = (size_t)n_embd * n_tokens; - // Store in cross struct (reusing T5 style cross-attention for accumulated target features fed to the DFlash decoder) - cross.n_embd = n_embd; - cross.n_enc = n_tokens; - cross.v_embd.resize(size); - std::memcpy(cross.v_embd.data(), data, size * sizeof(float)); + // Round the target-context length up to a fixed BUCKET so the DFlash decoder graph keeps a + // constant shape across most speculative rounds and can be reused (previously n_enc grew with + // the accumulated context every round -> graphs reused = 0 -> a graph rebuild + sched reserve + // per step, which erased the speculative speedup on hybrid/recurrent targets). The graph is now + // rebuilt only when the context crosses a bucket boundary. Rows [0, n_tokens) are valid; the + // padding rows up to the bucket are zeroed and masked out in the decoder (see the DFlash block + // in process_ubatch + src/models/dflash.cpp dflash_kq_mask). + // + // The context is APPEND-ONLY across speculative rounds: rows [0, prev_valid) are unchanged, so + // both the host mirror update here and the device upload in set_input are delta-only. + const int64_t BUCKET = 256; + const int64_t capacity = ((int64_t(n_tokens) + BUCKET - 1) / BUCKET) * BUCKET; + GGML_ASSERT(n_tokens >= 1 && "DFlash accumulated target context must be non-empty"); + + const bool same_buf = cross.n_embd == n_embd && cross.n_enc == capacity && + (int64_t) cross.v_embd.size() == (int64_t) n_embd * capacity; + const int64_t prev = same_buf && n_tokens >= cross.n_enc_valid ? cross.n_enc_valid : 0; + + if (!same_buf) { + cross.v_embd.assign((size_t) n_embd * capacity, 0.0f); + } + cross.n_embd = n_embd; + cross.n_enc = capacity; // bucketed -> stable shape within a bucket + cross.n_enc_valid = n_tokens; // real rows + cross.n_enc_appended = prev; // rows below this are unchanged (delta-upload hint) + std::memcpy(cross.v_embd.data() + (size_t) prev * n_embd, + data + (size_t) prev * n_embd, + (size_t) n_embd * (n_tokens - prev) * sizeof(float)); } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -1308,6 +1680,24 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated res->set_inputs(&ubatch); + // async draft feed (target side): the verify batch was submitted with placeholder draft + // tokens; patch inp_tokens rows [1..n] device-to-device from the staged drafter argmax + // (the host never reads the draft tokens before the verify - no inter-model sync) + if (dflash.draft_feed_n > 0 && res->t_inp_tokens != nullptr) { + const int32_t n = dflash.draft_feed_n; + GGML_ASSERT(res->t_inp_tokens->ne[0] >= n + 1); + ggml_backend_t be = ggml_backend_sched_get_tensor_backend(sched.get(), res->t_inp_tokens); + GGML_ASSERT(be != nullptr); + + ggml_init_params ip = { ggml_tensor_overhead() * 4, NULL, true }; + ggml_context_ptr vc { ggml_init(ip) }; + ggml_tensor * src = ggml_view_1d(vc.get(), dflash.draft_feed, n, 0); + ggml_tensor * dst = ggml_view_1d(vc.get(), res->t_inp_tokens, n, 1 * sizeof(int32_t)); + ggml_backend_tensor_copy_async(be, be, src, dst); + + dflash.draft_feed_n = 0; + } + // EAGLE3: Fill g_embeddings for decoder input if (model.arch == LLM_ARCH_EAGLE3 && gtype == LLM_GRAPH_TYPE_DECODER && !eagle3.g_embeddings.empty()) { ggml_tensor * g_embd = ggml_graph_get_tensor(gf, "inp_g_embeddings"); @@ -1316,20 +1706,81 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } } - // temp fix DFlash: Fill position tensor for decoder - if (model.arch == LLM_ARCH_DFLASH && gtype == LLM_GRAPH_TYPE_DECODER && !cross.v_embd.empty()) { - const int64_t n_ctx = cross.n_enc; + // sampling speculative verify: fill the flat gather index (i*n_vocab + draft_token[i]) so + // the graph can read each draft token's temp-softmax probability. The draft token at output + // row i is the verify batch's next input token (ubatch.token[i+1]); the last row has no + // successor and is left pointing at token 0 (its pdraft is unused - bonus is sampled from a + // fetched logits row). + if (cparams.out_spec_sample) { + // the flat gather indexes the logits tensor's own vocab stride (which may be padded + // beyond the model vocab), so use t_logits->ne[0], not model.vocab.n_tokens() + const int64_t n_vocab = res->t_logits != nullptr ? res->t_logits->ne[0] : model.vocab.n_tokens(); + if (ggml_tensor * gidx = ggml_graph_get_tensor(gf, "spec_gather_idx")) { + const int64_t n_out = gidx->ne[0]; + std::vector idx(n_out); + for (int64_t i = 0; i < n_out; ++i) { + const int64_t tok = (i + 1 < (int64_t) ubatch.n_tokens) ? ubatch.token[i + 1] : 0; + idx[i] = (int32_t) (i * n_vocab + tok); + } + ggml_backend_tensor_set(gidx, idx.data(), 0, n_out * sizeof(int32_t)); + } + } + + // temp fix DFlash: fill the decoder position tensor + the padding mask. + // The cross (target) context is a fixed-capacity buffer of n_enc rows of which only the + // first n_enc_valid are real; the noise block follows the *real* context, so noise RoPE + // positions are n_enc_valid + j (not n_enc + j), and the padding rows [n_enc_valid, n_enc) + // are masked out of the noise->context attention. + if (model.arch == LLM_ARCH_DFLASH && gtype == LLM_GRAPH_TYPE_DECODER && + (!cross.v_embd.empty() || dflash.cross_dev != nullptr)) { + const int64_t n_ctx = cross.n_enc; // fixed capacity + const int64_t n_valid = cross.n_enc_valid; // real target rows const int64_t n_noise = ubatch.n_tokens; const int64_t n_total = n_ctx + n_noise; + // device cross cache: upload the NEW raw features + their destination row indices + // (padded entries are routed to the scratch row) + ggml_tensor * feat_t = ggml_graph_get_tensor(gf, "dflash_feat_new"); + ggml_tensor * idx_t = ggml_graph_get_tensor(gf, "dflash_feat_idx"); + if (feat_t != nullptr && idx_t != nullptr) { + const int64_t cap_rows = feat_t->ne[1]; + const int32_t n_new = dflash.feat_n; + if (n_new > 0) { + ggml_backend_tensor_set(feat_t, dflash.feat_staging.data(), 0, (size_t) n_new * feat_t->nb[1]); + } + std::vector ids(cap_rows); + for (int64_t i = 0; i < cap_rows; ++i) { + ids[i] = i < n_new ? (int64_t) dflash.feat_pos0 + i : (int64_t) dflash.cross_cap; + } + ggml_backend_tensor_set(idx_t, ids.data(), 0, cap_rows * sizeof(int64_t)); + dflash.feat_n = 0; // consumed + } + ggml_tensor * pos_full = ggml_graph_get_tensor(gf, "inp_pos_full"); if (pos_full) { std::vector pos_data(n_total); - for (int64_t i = 0; i < n_total; ++i) { - pos_data[i] = (int32_t)i; + for (int64_t i = 0; i < n_ctx; ++i) { + pos_data[i] = (int32_t) i; // target slots (real rows get their true pos) + } + for (int64_t j = 0; j < n_noise; ++j) { + pos_data[n_ctx + j] = (int32_t) (n_valid + j); // noise block continues after real context } ggml_backend_tensor_set(pos_full, pos_data.data(), 0, n_total * sizeof(int32_t)); } + + ggml_tensor * kq_mask = ggml_graph_get_tensor(gf, "dflash_kq_mask"); + if (kq_mask) { + // additive mask [n_total, n_q]; 0 = visible, -inf = masked (padding target rows) + const int64_t n_q = kq_mask->ne[1]; + std::vector mask_data((size_t) n_total * n_q, 0.0f); + for (int64_t q = 0; q < n_q; ++q) { + for (int64_t kv = n_valid; kv < n_ctx; ++kv) { + mask_data[(size_t) q * n_total + kv] = -INFINITY; // mask padding target rows + } + // real target [0, n_valid) and all noise rows [n_ctx, n_total) stay visible + } + ggml_backend_tensor_set(kq_mask, mask_data.data(), 0, ggml_nbytes(kq_mask)); + } } //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); @@ -1351,6 +1802,56 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll extract_dflash_features(ubatch); } + // DFlash drafter: pull the on-device greedy argmax of the block logits (a few ints + // instead of the full n_vocab x block logits host copy) + if (ggml_tensor * t_am = res->get_argmax()) { + const int64_t n = t_am->ne[0]; + dflash_argmax_out.resize(n); + ggml_backend_t backend_am = ggml_backend_sched_get_tensor_backend(sched.get(), t_am); + GGML_ASSERT(backend_am != nullptr); + ggml_backend_tensor_get_async(backend_am, t_am, dflash_argmax_out.data(), 0, n * sizeof(int32_t)); + dflash.last_argmax_t = t_am; // for the async device-to-device draft feed + } else { + dflash_argmax_out.clear(); + dflash.last_argmax_t = nullptr; + } + + // sampling speculative verify: pull the per-draft-token temp-softmax probabilities + if (ggml_tensor * t_pd = res->get_spec_pdraft()) { + const int64_t n = t_pd->ne[0]; + dflash_pdraft_out.resize(n); + ggml_backend_t be_pd = ggml_backend_sched_get_tensor_backend(sched.get(), t_pd); + GGML_ASSERT(be_pd != nullptr); + ggml_backend_tensor_get_async(be_pd, t_pd, dflash_pdraft_out.data(), 0, n * sizeof(float)); + dflash_logits_dev = res->t_logits; // kept on-device for the residual/bonus row fetch + } else { + dflash_pdraft_out.clear(); + dflash_logits_dev = nullptr; + } + + // top-k/top-p verify: pull the per-row top-K candidate ids + their logits + if (ggml_tensor * t_ti = res->get_spec_topk_idx()) { + ggml_tensor * t_tv = res->get_spec_topk_val(); + const int64_t n = ggml_nelements(t_ti); + dflash_topk_idx_out.resize(n); + dflash_topk_val_out.resize(n); + dflash_topk_k = (int32_t) t_ti->ne[0]; + const int32_t nvocab = res->t_logits != nullptr ? (int32_t) res->t_logits->ne[0] : 0; + ggml_backend_t be_tv = ggml_backend_sched_get_tensor_backend(sched.get(), t_tv); + GGML_ASSERT(be_tv != nullptr); + // argsort_top_k returns FLAT indices (row*n_vocab + token); fetch synchronously and recover + // the per-row token id as idx % n_vocab (robust to row indexing) + ggml_backend_tensor_get(t_ti, dflash_topk_idx_out.data(), 0, n * sizeof(int32_t)); + if (nvocab > 0) { + for (int64_t m = 0; m < n; ++m) { dflash_topk_idx_out[m] %= nvocab; } + } + ggml_backend_tensor_get_async(be_tv, t_tv, dflash_topk_val_out.data(), 0, n * sizeof(float)); + } else { + dflash_topk_idx_out.clear(); + dflash_topk_val_out.clear(); + dflash_topk_k = 0; + } + ret = GGML_STATUS_SUCCESS; return res; @@ -1866,8 +2367,8 @@ int llama_context::decode(const llama_batch & batch_inp) { t_embd = res->get_embd_pooled(); } - // extract logits - if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { + // extract logits (skipped in the greedy-verify path: only the on-device argmax is read) + if (logits.data && t_logits && n_outputs > 0 && !cparams.out_argmax && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits.data != nullptr); @@ -3894,6 +4395,50 @@ void llama_set_dflash_accumulated_target_ctx(llama_context * ctx, const float * ctx->set_dflash_accumulated_target_ctx(data, n_embd, n_tokens); } +void llama_set_dflash_state_trace(llama_context * ctx, int32_t n_max) { + ctx->set_dflash_state_trace(n_max); +} + +bool llama_dflash_promote_state(llama_context * ctx, int32_t idx, llama_pos pos_last, llama_seq_id seq_id) { + return ctx->dflash_promote_state(idx, pos_last, seq_id); +} + +bool llama_dflash_trace_check(llama_context * ctx, int32_t n_batch_tokens) { + return ctx->dflash_trace_check(n_batch_tokens); +} + +const int32_t * llama_get_dflash_argmax(llama_context * ctx, int32_t * n_out) { + return ctx->get_dflash_argmax(n_out); +} + +void llama_set_out_argmax(llama_context * ctx, bool value) { + ctx->set_out_argmax(value); +} + +void llama_set_out_spec_sample(llama_context * ctx, bool value, float temp, int32_t topk) { + ctx->set_out_spec_sample(value, temp, topk); +} + +const float * llama_get_dflash_pdraft(llama_context * ctx, int32_t * n_out) { + return ctx->get_dflash_pdraft(n_out); +} + +const int32_t * llama_get_dflash_topk(llama_context * ctx, int32_t * n_rows, int32_t * k, const float ** vals) { + return ctx->get_dflash_topk(n_rows, k, vals); +} + +bool llama_dflash_fetch_logits_row(llama_context * ctx, int32_t row, float * out, int32_t n_vocab) { + return ctx->dflash_fetch_logits_row(row, out, n_vocab); +} + +void llama_dflash_append_features(llama_context * ctx, const float * feat, int32_t n_new, int32_t n_total) { + ctx->dflash_append_features(feat, n_new, n_total); +} + +bool llama_dflash_feed_draft_tokens(llama_context * ctx_tgt, llama_context * ctx_dft, int32_t n) { + return ctx_tgt->dflash_feed_draft_tokens(ctx_dft, n); +} + // // ext diff --git a/src/llama-context.h b/src/llama-context.h index 86f0d81c0ccf..8587d1f8edb8 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -110,6 +110,37 @@ struct llama_context { void set_eagle3(const llama_model * model); void set_dflash(const llama_model * model); + // DFlash recurrent rewind (staging): allocate per-token state trace buffers (n_max tokens) and + // enable tracing during multi-token decodes on recurrent layers + void set_dflash_state_trace(int32_t n_max); + // promote the traced state at token index `idx` of the last verify decode into the live + // recurrent cell of seq 0 and mark it as ending at position `pos_last` + bool dflash_promote_state(int32_t idx, llama_pos pos_last, llama_seq_id seq_id = 0); + // debug: bitwise-compare the last traced slot against the live recurrent cell + bool dflash_trace_check(int32_t n_batch_tokens); + + // greedy argmax of the DFlash drafter's last decoded block (nullptr if not produced) + const int32_t * get_dflash_argmax(int32_t * n_out); + + // emit on-device argmax of the output logits and skip the host logits copy (greedy verify) + void set_out_argmax(bool value); + + // emit on-device sampling-verify data: temp baked in; topk>0 emits top-K candidates instead + void set_out_spec_sample(bool value, float temp, int32_t topk); + // per-draft-token temp-softmax probabilities from the last decode (nullptr if not produced) + const float * get_dflash_pdraft(int32_t * n_out); + // per-row top-K candidate token ids (+ logits via vals) from the last decode (row-major) + const int32_t * get_dflash_topk(int32_t * n_rows, int32_t * k, const float ** vals); + // fetch a single row of the device-resident verify logits (for residual/bonus sampling) + bool dflash_fetch_logits_row(int32_t row, float * out, int32_t n_vocab); + + // DFlash decoder: stage the NEW tokens' raw target features for the in-graph encoder fold + // (fc+norm + set_rows append into the device cross cache). n_total = committed context rows. + void dflash_append_features(const float * feat, int32_t n_new, int32_t n_total); + + // async draft feed: hand the drafter's argmax tokens to this (target) context on-device + bool dflash_feed_draft_tokens(llama_context * dft, int32_t n); + // process a single ubatch with a specific graph type // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation @@ -280,6 +311,32 @@ struct llama_context { mutable llama_dflash dflash; + // ownership of the DFlash state-trace tensors (see llama_dflash::trace_s/trace_r) + ggml_context_ptr dflash_trace_ctx; + ggml_backend_buffer_ptr dflash_trace_buf; + + // ownership of the DFlash device cross cache (see llama_dflash::cross_dev) + ggml_context_ptr dflash_cross_ctx; + ggml_backend_buffer_ptr dflash_cross_buf; + + // ownership of the async draft-feed staging + the inter-stream event (see llama_dflash::draft_feed) + ggml_context_ptr dflash_feed_ctx; + ggml_backend_buffer_ptr dflash_feed_buf; + ggml_backend_event_t dflash_feed_event = nullptr; + + // on-device greedy argmax of the DFlash drafter's block logits (see t_argmax) + std::vector dflash_argmax_out; + + // sampling speculative verify: per-draft-token temp-softmax probs (see t_spec_pdraft) and the + // device-resident logits tensor kept for fetching a single residual/bonus row on demand + std::vector dflash_pdraft_out; + ggml_tensor * dflash_logits_dev = nullptr; + + // top-k/top-p verify: per-row top-K candidate ids + logits (row-major [n_out][K]) from last decode + std::vector dflash_topk_idx_out; + std::vector dflash_topk_val_out; + int32_t dflash_topk_k = 0; + // temp fix: avoid DFlash encoder/decoder mis-detection. They share one model_dft, // so shared model fields cannot safely identify the decoder (caused OOM). bool dflash_decoder_ctx = false; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 906bfbe36c12..95252f49b626 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -40,6 +40,10 @@ struct llama_cparams { bool kv_unified; bool eagle3_extract_enabled; // enable layer extraction for EAGLE3 speculative decoding bool dflash_extract_enabled; // enable layer extraction for DFlash speculative decoding + bool out_argmax; // emit on-device argmax of the output logits (greedy verify path) + bool out_spec_sample; // emit on-device temp-softmax prob of each draft token (sampling verify) + float spec_temp; // temperature baked into the in-graph softmax for out_spec_sample + int32_t spec_topk; // >0: emit top-K candidate logits per row (top-k/top-p verify) instead of pdraft bool pipeline_parallel; enum llama_pooling_type pooling_type; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9fabd242e766..64d8ee28ccb5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -339,8 +339,43 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { if (cross_embd && !cross->v_embd.empty()) { assert(cross_embd->type == GGML_TYPE_F32); - ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd)); + // append-only fast path (DFlash speculative rounds): rows [0, n_enc_appended) are + // unchanged since the last upload, so only the delta is transferred. a graph rebuild + // creates a fresh input (n_uploaded = -1) and triggers the full upload, which also + // initializes the zero padding of the fixed-capacity buffer. + const int64_t row_bytes = cross->n_embd * ggml_element_size(cross_embd); + if (n_uploaded >= 0 && cross->n_enc_appended >= n_uploaded && + cross->n_enc == cross_embd->ne[1]) { + const int64_t first = n_uploaded; + const int64_t last = cross->n_enc_valid; + if (last > first) { + ggml_backend_tensor_set(cross_embd, + cross->v_embd.data() + (size_t) first * cross->n_embd, + (size_t) first * row_bytes, + (size_t) (last - first) * row_bytes); + } + } else { + ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd)); + } + n_uploaded = cross->n_enc_valid; + } +} + +bool llm_graph_input_cross_embd::can_reuse(const llm_graph_params & params) { + GGML_UNUSED(params); + + // The cross embeddings are re-uploaded every step in set_input(), so the graph can be reused as + // long as the cross tensor shape is unchanged. This is what makes DFlash block drafting cheap: + // the target-context buffer is bucketed to a fixed capacity, so within a bucket the decoder + // graph is identical across speculative rounds (previously this input always forced a rebuild). + if (!cross_embd || !cross) { + return false; } + + const int64_t n_embd = !cross->v_embd.empty() ? cross->n_embd : cross_embd->ne[0]; + const int64_t n_enc = !cross->v_embd.empty() ? cross->n_enc : cross_embd->ne[1]; + + return cross_embd->ne[0] == n_embd && cross_embd->ne[1] == n_enc; } static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { @@ -805,6 +840,10 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + t_argmax = nullptr; + t_spec_pdraft = nullptr; + t_spec_topk_idx = nullptr; + t_spec_topk_val = nullptr; t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); @@ -2811,6 +2850,67 @@ void llm_graph_context::build_pooling( } void llm_graph_context::build_sampling() const { + // on-device greedy argmax over ALL output logit rows (speculative greedy-verify path): + // the verifier only needs the per-position argmax to accept/reject draft tokens, so + // downloading n_outputs ints instead of n_outputs x n_vocab floats skips a multi-MB host + // copy per verify round. some graphs (e.g. the DFlash drafter) set t_argmax themselves. + if (cparams.out_argmax && res->t_logits != nullptr && res->t_argmax == nullptr) { + ggml_tensor * am = ggml_argmax(ctx0, res->t_logits); + cb(am, "result_argmax", -1); + res->t_argmax = am; + ggml_build_forward_expand(gf, am); + } + + // on-device temp-softmax probability of each draft token (sampling speculative verify): + // p_i(d_i) for every output row, gathered via a flat index (i*n_vocab + d_i) the context + // fills from the verify batch tokens. The host then does the cheap rejection test on these + // n_outputs floats and only fetches a single logits row for the residual/bonus sample, + // instead of downloading the whole n_vocab x block logits matrix. + if (cparams.out_spec_sample && res->t_logits != nullptr && + res->t_spec_pdraft == nullptr && res->t_spec_topk_idx == nullptr) { + const int64_t n_vocab = res->t_logits->ne[0]; + const int64_t n_out = res->t_logits->ne[1]; + + if (cparams.spec_topk > 0) { + // top-k/top-p verify: emit the top-K candidate token ids + their raw logits per row. + // The host then applies the full sampler (temp/top-k/top-p) over those K candidates - + // the top-p nucleus is a subset of the top-K, so this is exact for realistic params. + const int64_t K = std::min(cparams.spec_topk, n_vocab); + + // argsort-based top-K (ggml_top_k's CUDA path returned bad indices for ~248k vocab). + // argsort_top_k here returns FLAT indices (row*n_vocab + token) into the [n_vocab,n_out] + // logits, so they directly index the flattened [1, n_vocab*n_out] view - no base needed. + // The host recovers the token id as idx % n_vocab. + ggml_tensor * idx = ggml_cont(ctx0, ggml_argsort_top_k(ctx0, res->t_logits, K)); // I32 [K, n_out], flat + + ggml_tensor * lflat = ggml_reshape_2d(ctx0, res->t_logits, 1, n_vocab * n_out); + ggml_tensor * vals = ggml_get_rows(ctx0, lflat, ggml_reshape_1d(ctx0, idx, K * n_out)); + vals = ggml_reshape_2d(ctx0, vals, K, n_out); // F32 [K, n_out] + + cb(idx, "result_spec_topk_idx", -1); + cb(vals, "result_spec_topk_val", -1); + res->t_spec_topk_idx = idx; + res->t_spec_topk_val = vals; + ggml_build_forward_expand(gf, idx); + ggml_build_forward_expand(gf, vals); + } else { + // temperature-only verify: emit p_i(d_i) directly via a flat gather (exact, no top-K) + ggml_tensor * scaled = ggml_scale(ctx0, res->t_logits, 1.0f / cparams.spec_temp); + ggml_tensor * probs = ggml_soft_max(ctx0, scaled); // [n_vocab, n_out] + + ggml_tensor * gidx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_out); + ggml_set_input(gidx); + ggml_set_name(gidx, "spec_gather_idx"); + + ggml_tensor * pflat = ggml_reshape_2d(ctx0, probs, 1, n_vocab * n_out); + ggml_tensor * pgather = ggml_get_rows(ctx0, pflat, gidx); // [1, n_out] + ggml_tensor * pdraft = ggml_reshape_1d(ctx0, pgather, n_out); + cb(pdraft, "result_spec_pdraft", -1); + res->t_spec_pdraft = pdraft; + ggml_build_forward_expand(gf, pdraft); + } + } + if (samplers.empty() || !res->t_logits) { return; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 1925a275d8a3..9b6cc24d1f0e 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -66,6 +66,15 @@ struct llama_cross { int64_t n_embd = 0; int64_t n_enc = 0; + // DFlash: number of *valid* target-context rows inside the fixed-capacity (n_enc) buffer. + // Rows [n_enc_valid, n_enc) are zero padding and are masked out in the DFlash decoder, so the + // cross tensor keeps a constant shape across speculative rounds (enables graph reuse). + int64_t n_enc_valid = 0; + + // DFlash: the context is append-only across speculative rounds; rows [0, n_enc_appended) of + // v_embd are unchanged since the previous round, so set_input only uploads the delta + int64_t n_enc_appended = 0; + // embeddings data copied to host memory (tmp) std::vector v_embd; @@ -105,6 +114,29 @@ struct llama_dflash { std::vector extract_tensors; + // recurrent state trace (staging): per-token SSM/conv states captured during the verify decode + // of a hybrid target, so that on a partial draft acceptance the state at the accepted position + // is promoted instead of restore+re-decode. tensors live in a persistent context-owned buffer. + int32_t trace_n_max = 0; // max tokens traced per decode (0 = disabled) + std::vector trace_s; // per-layer [n_embd_s, trace_n_max] (recurrent layers only) + std::vector trace_r; // per-layer [n_embd_r, trace_n_max] (conv windows) + + // device-resident encoded-context cache for the DFlash decoder (encoder folded into the + // decoder graph): new target features are fc+norm'ed in-graph and appended into cross_dev + // via ggml_set_rows, eliminating the separate encoder llama_encode round trip per draft. + ggml_tensor * cross_dev = nullptr; // [n_embd, cross_cap + 1] (last row = scratch for padding) + int32_t cross_cap = 0; // capacity in rows (0 = host-mediated legacy path) + std::vector feat_staging; // host staging of the NEW tokens' raw target features + int32_t feat_n = 0; // number of staged feature rows + int32_t feat_pos0 = 0; // destination row of the first staged feature + int32_t feat_bucket = 8; // padded feature rows in the graph (8 normally; 256 for the prompt round) + + // async draft feed: the drafter's argmax tokens are handed device-to-device into the verify + // batch, so there is no host synchronization between the draft and the verify submission + ggml_tensor * last_argmax_t = nullptr; // this context's argmax tensor from the last decode + ggml_tensor * draft_feed = nullptr; // (target ctx) device staging for fed draft tokens [I32] + int32_t draft_feed_n = 0; // pending fed rows to patch into inp_tokens rows [1..n] + void clear() { target_features.clear(); extract_tensors.clear(); @@ -292,6 +324,12 @@ class llm_graph_input_cross_embd : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + // rows of cross->v_embd already uploaded to this input's device tensor (-1 = never uploaded; + // a graph rebuild creates a fresh input object, so the full upload happens automatically) + int64_t n_uploaded = -1; + ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc] const llama_cross * cross; @@ -684,6 +722,10 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_argmax() const { return t_argmax; } + ggml_tensor * get_spec_pdraft() const { return t_spec_pdraft; } + ggml_tensor * get_spec_topk_idx() const { return t_spec_topk_idx; } + ggml_tensor * get_spec_topk_val() const { return t_spec_topk_val; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -712,6 +754,10 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_argmax = nullptr; // I32 [n_tokens] greedy tokens (DFlash drafter) + ggml_tensor * t_spec_pdraft = nullptr; // F32 [n_tokens] temp-softmax prob of each draft token + ggml_tensor * t_spec_topk_idx = nullptr; // I32 [K, n_tokens] top-K token ids per row (top-k/top-p verify) + ggml_tensor * t_spec_topk_val = nullptr; // F32 [K, n_tokens] their raw logits std::map t_sampled_logits; std::map t_candidates; diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 6bc989c95099..cb78b4067e61 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -397,7 +397,14 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + ggml_tensor * result; + if (gdn_trace != nullptr) { + // per-token state trace requested (DFlash speculative rewind on recurrent targets) + result = ggml_gated_delta_net_trace(ctx0, q, k, v, g, b, s, gdn_trace); + gdn_trace = nullptr; // consumed + } else { + result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + } if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -434,6 +441,7 @@ std::pair llm_build_delta_net_base::build_delta_ne if (cparams.fused_gdn_ar) { return build_delta_net_fused(q, k, v, g, b, s, il); } + GGML_ASSERT(gdn_trace == nullptr && "GDN state trace requires the fused kernel path"); return build_delta_net_autoregressive(q, k, v, g, b, s, il); } @@ -441,5 +449,6 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_fused(q, k, v, g, b, s, il); } + GGML_ASSERT(gdn_trace == nullptr && "GDN state trace requires the fused kernel path"); return build_delta_net_chunking(q, k, v, g, b, s, il); } diff --git a/src/models/dflash.cpp b/src/models/dflash.cpp index 0adba127eabf..d397be4f2ffa 100644 --- a/src/models/dflash.cpp +++ b/src/models/dflash.cpp @@ -1,5 +1,26 @@ #include "models.h" +// graph-reuse guard for the device-cache decoder path: the feat/mask/position tensor shapes are +// baked from (cross->n_enc, dflash->feat_bucket) at build time; force a rebuild when either moves +class llm_graph_input_dflash_dev : public llm_graph_input_i { +public: + llm_graph_input_dflash_dev(const llama_cross * cross, const llama_dflash * df, + int64_t n_enc_built, int32_t bucket_built) + : cross(cross), df(df), n_enc_built(n_enc_built), bucket_built(bucket_built) {} + + void set_input(const llama_ubatch * ubatch) override { GGML_UNUSED(ubatch); } + + bool can_reuse(const llm_graph_params & params) override { + GGML_UNUSED(params); + return cross->n_enc == n_enc_built && df->feat_bucket == bucket_built; + } + + const llama_cross * cross; + const llama_dflash * df; + int64_t n_enc_built; + int32_t bucket_built; +}; + ggml_tensor * llm_build_dflash_encode::build_inp_embd() const { const int64_t n_target_layer_ids = (int64_t) hparams.dflash_target_layer_ids.size(); const int64_t n_embd_target_features = n_target_layer_ids * n_embd; @@ -40,9 +61,44 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons ggml_tensor * noise_embd = build_inp_embd(model.target_tok_embd); cb(noise_embd, "inp_noise_embd", -1); - // Target context via llama_cross (filled from accumulated_target_ctx), graph rebuilds every step - ggml_tensor * target_ctx = build_inp_cross_embd(); - const int64_t n_ctx = target_ctx->ne[1]; + ggml_tensor * target_ctx = nullptr; + int64_t n_ctx = 0; + + if (dflash != nullptr && dflash->cross_dev != nullptr) { + // encoder folded into the decoder graph: the NEW tokens' raw target features arrive as an + // input, get fc+norm'ed here and appended into the persistent device cache (cross_dev) via + // set_rows; the attention context is a view of that cache. this removes the separate + // encoder llama_encode + the host round trip of the encoded features per draft round. + const int64_t n_feat = (int64_t) hparams.dflash_target_layer_ids.size() * n_embd; + // padded feature rows; extra rows land in the scratch row via the index input. bucketed + // (8 for normal rounds, 256 for the prompt round) so the fc GEMM does not waste flops on + // padding - the graph rebuilds once when the bucket changes + const int64_t n_new_max = dflash->feat_bucket; + + ggml_tensor * feat = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_feat, n_new_max); + ggml_set_input(feat); + cb(feat, "dflash_feat_new", -1); + + ggml_tensor * fidx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_new_max); + ggml_set_input(fidx); + cb(fidx, "dflash_feat_idx", -1); + + ggml_tensor * enc = build_lora_mm(model.fc, feat); + enc = build_norm(enc, model.dflash_hidden_norm, NULL, LLM_NORM_RMS, -1); + cb(enc, "dflash_enc_new", -1); + + ggml_tensor * cache = ggml_set_rows(ctx0, dflash->cross_dev, enc, fidx); + cb(cache, "dflash_cross_cache", -1); + + n_ctx = cross->n_enc; // bucketed valid+padding rows (padding masked out) + target_ctx = ggml_view_2d(ctx0, cache, n_embd, n_ctx, cache->nb[1], 0); + + res->add_input(std::make_unique(cross, dflash, n_ctx, (int32_t) n_new_max)); + } else { + // legacy host-mediated path: accumulated context uploaded via llama_cross + target_ctx = build_inp_cross_embd(); + n_ctx = target_ctx->ne[1]; + } ggml_tensor * inpL = noise_embd; @@ -57,6 +113,15 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons ggml_tensor * inp_pos_q = ggml_view_1d(ctx0, inp_pos_full, n_tokens, n_ctx * ggml_element_size(inp_pos_full)); + // Additive attention mask over [target_ctx (n_ctx) ++ noise (n_tokens)] for the noise queries. + // target_ctx is a fixed-capacity buffer so the graph shape stays constant across speculative + // rounds (enables graph reuse); the padding rows are masked out. Values are filled per round in + // llama_context::process_ubatch (named "dflash_kq_mask"). Requires the eager soft_max path, + // which is why flash attention is disabled for the DFlash decoder context. + ggml_tensor * dflash_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens_kv, n_tokens); + ggml_set_input(dflash_kq_mask); + cb(dflash_kq_mask, "dflash_kq_mask", -1); + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); for (int il = 0; il < n_layer; ++il) { @@ -119,7 +184,7 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons ggml_build_forward_expand(gf, Kcur); ggml_build_forward_expand(gf, Vcur); - ggml_tensor * cur = build_attn_mha(Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, nullptr, kq_scale, il); + ggml_tensor * cur = build_attn_mha(Qcur, Kcur, Vcur, nullptr, dflash_kq_mask, nullptr, nullptr, kq_scale, il); cb(cur, "kqv_out", il); cur = build_lora_mm(layer.wo, cur); @@ -155,6 +220,14 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons cur = build_lora_mm(model.target_output, cur); cb(cur, "result_output", -1); res->t_logits = cur; + + // GPU argmax over the block logits: the DFlash draft is greedy top-1, so downloading + // block_size ints instead of n_vocab x block_size floats (~5 MB/round at vocab 248k) + // removes the per-round logits host copy + CPU scan entirely + ggml_tensor * am = ggml_argmax(ctx0, cur); + cb(am, "result_argmax", -1); + res->t_argmax = am; + ggml_build_forward_expand(gf, am); } ggml_build_forward_expand(gf, cur); diff --git a/src/models/models.h b/src/models/models.h index 062e6ff621d2..e7efcc4823af 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -64,6 +64,10 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // optional per-token state trace target for the NEXT build_delta_net call (fused path only); + // set by the caller (e.g. qwen35 during a DFlash verify), consumed+reset by build_delta_net_fused + ggml_tensor * gdn_trace = nullptr; }; struct llm_build_rwkv6_base : public llm_graph_context { diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 19d3d95619d0..484e0cd988b3 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -290,6 +290,33 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + // DFlash recurrent rewind (staging): during the speculative verify decode, capture the per-token + // conv windows and (below, via gdn_trace) the per-token SSM states, so that on a partial draft + // acceptance the state at the accepted position is PROMOTED instead of restore+re-decode. + const bool dflash_trace = dflash != nullptr && dflash->trace_n_max > 0 && + n_seqs == 1 && n_seq_tokens > 1 && n_seq_tokens <= dflash->trace_n_max && + cparams.fused_gdn_ch && + (size_t) il < dflash->trace_s.size() && dflash->trace_s[il] != nullptr; + + if (dflash_trace) { + // conv window after token t = rows [t+1 .. t+conv_kernel_size-1] of conv_input (the same + // view pattern as last_conv_states above). One clean in-bounds sub-view copy per token: + // an earlier single-overlapping-3D-view optimization sat exactly on the buffer boundary + // (data_size + offset == ggml_nbytes(conv_input)) and aborted in ggml_view_3d whenever + // conv_input had a different row count than (k-1)+n_seq_tokens (observed on sm_120 and in + // the server's prompt-chunk path). The per-token windows are always strictly in bounds. + const int64_t conv_sz = (conv_kernel_size - 1) * conv_channels; // == hparams.n_embd_r() + for (int64_t t = 0; t < n_seq_tokens; ++t) { + ggml_tensor * win = ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, 1, + conv_input->nb[1], conv_input->nb[2], + (t + 1) * ggml_element_size(conv_input)); + ggml_tensor * dst_t = ggml_view_1d(ctx0, dflash->trace_r[il], conv_sz, + (size_t) t * conv_sz * ggml_element_size(dflash->trace_r[il])); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, win, dst_t)); + } + } + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); cb(state, "state_predelta", il); @@ -350,6 +377,11 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); + if (dflash_trace) { + // per-token SSM state trace target, consumed by the fused GDN op (see ggml_gated_delta_net_trace) + gdn_trace = ggml_view_1d(ctx0, dflash->trace_s[il], hparams.n_embd_s() * n_seq_tokens, 0); + } + auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); ggml_tensor * output = attn_out.first; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index c835dd8a44c2..9d216109fe5c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -92,6 +92,12 @@ struct server_slot { server_prompt_checkpoint spec_ckpt; common_speculative_ptr spec; + // DFlash recurrent rewind: when set, the target's per-token recurrent states are traced during + // the verify decode so a partial acceptance promotes the state at the accepted position on-device + // instead of the ~50 MiB host checkpoint round-trip + re-decode (see llama_dflash_promote_state) + bool spec_state_trace = false; + llama_pos spec_pos0 = 0; // base position of the current verify batch (rewind target = pos0 + accepted) + // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 std::unique_ptr task; @@ -363,7 +369,9 @@ struct server_slot { spec_draft.clear(); } - if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + // the host checkpoint is only needed for the restore-based rewind; with the + // on-device state trace a partial acceptance promotes the state instead (no host copy) + if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL && !spec_state_trace) { const auto n_tokens = prompt.tokens.size(); spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens); @@ -396,6 +404,7 @@ struct server_slot { } auto pos0 = prompt.tokens.pos_next(); + spec_pos0 = pos0; // base position of the verify batch (for the DFlash on-device rewind) common_batch_add(batch, sampled, pos0++, { this->id }, true); for (auto token : spec_draft) { @@ -929,6 +938,20 @@ struct server_context_impl { if (slot.spec) { SLT_INF(slot, "%s", "speculative decoding context initialized\n"); + + // DFlash on a hybrid/recurrent target: enable the recurrent state trace so a + // partial acceptance promotes the accepted-position state on-device instead of + // the host checkpoint round-trip. Only for the FULL-seq-rm regime (hybrid), and + // only at n_parallel == 1 (the per-context feature extraction limitation above). + const bool trace_ok = params_base.speculative.dflash && + ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL && + params_base.n_parallel == 1; + if (trace_ok && + !(getenv("LLAMA_SPEC_NO_TRACE") && std::string(getenv("LLAMA_SPEC_NO_TRACE")) != "0")) { + llama_set_dflash_state_trace(slot.ctx, params_base.speculative.n_max + 1); + slot.spec_state_trace = true; + SLT_INF(slot, "%s", "DFlash recurrent state trace enabled (on-device rewind)\n"); + } } } @@ -2987,9 +3010,10 @@ struct server_context_impl { { const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; - // only save the sampler sampler state if we use checkpoints + // only save the sampler state if we use the checkpoint-restore rewind; the + // on-device trace rewind commits forward and never rolls the sampler back common_sampler_ptr smpl_save; - if (use_ckpt) { + if (use_ckpt && !slot.spec_state_trace) { smpl_save.reset(common_sampler_clone(slot.smpl.get())); } @@ -3003,7 +3027,22 @@ struct server_context_impl { // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (use_ckpt) { + if (slot.spec_state_trace) { + // DFlash recurrent rewind: promote the traced state at the accepted + // position instead of restoring a host checkpoint. The verify batch was + // [sampled @ pos0, draft0 @ pos0+1, ...]; accepting `acc` drafts means the + // state after batch token `acc` (trace slot `acc`, ending at pos0 + acc) + // is correct. Then fall through to the normal commit path below - its + // llama_memory_seq_rm(pos) truncates the rejected attention KV tail and + // now succeeds because the recurrent cell pos was rewound. + const int32_t acc = (int32_t) accepted.size() - 1; + const llama_pos pos_last = slot.spec_pos0 + acc; + + if (!llama_dflash_promote_state(slot.ctx, acc, pos_last, slot.id)) { + GGML_ABORT("%s: DFlash state promote failed (idx=%d)\n", __func__, acc); + } + // no checkpoint restore, no `continue` - fall through to commit + } else if (use_ckpt) { // partial acceptance is not supported by the context -> truncate the draft and restore the state slot.spec_draft = std::move(accepted); From e8bfef237bd048a4b413a17107b6ea42f21fa5db Mon Sep 17 00:00:00 2001 From: Aleksandr Nikolich Date: Fri, 12 Jun 2026 18:11:16 +0200 Subject: [PATCH 3/4] gdn: portable chunk-parallel Gated-DeltaNet verify path (opt-in) Decompose the GDN recurrence into a pure ggml-op graph (cumsum, exp, mul_mat, tri, solve_tri, diag, concat) so the verify can run on backends that lack a fused GDN kernel (WebGPU, Metal, Vulkan). Multi-chunk tiling keeps exp(cumsum(g)) in fp32 range; handles both the vector (KDA) and per-head scalar gate, and GQA. Validated bitwise against ggml_gated_delta_net on CPU and CUDA (tests/test-gdn-chunked). Opt-in via LLAMA_GDN_CHUNKED; the default path is unchanged. This is for portability: on CUDA the fused kernel is faster and the GDN scan is not the verify bottleneck. --- DESIGN.md | 320 +++++++++++++++++ FINDINGS.md | 117 +++++++ GDN_CHUNKED_BRINGUP.md | 170 +++++++++ ggml/src/ggml-backend.cpp | 58 +++- ggml/src/ggml-cuda/gated_delta_net.cu | 10 + ggml/src/ggml-cuda/gated_delta_net.cuh | 14 + ggml/src/ggml-cuda/gated_delta_net_chunked.cu | 325 ++++++++++++++++++ ggml/src/ggml-cuda/gdn_chunked_oracle.py | 67 ++++ src/models/delta-net-base.cpp | 108 ++++++ src/models/models.h | 21 ++ tests/test-gdn-chunked.cpp | 235 +++++++++++++ 11 files changed, 1444 insertions(+), 1 deletion(-) create mode 100644 DESIGN.md create mode 100644 FINDINGS.md create mode 100644 GDN_CHUNKED_BRINGUP.md create mode 100644 ggml/src/ggml-cuda/gated_delta_net_chunked.cu create mode 100644 ggml/src/ggml-cuda/gdn_chunked_oracle.py create mode 100644 tests/test-gdn-chunked.cpp diff --git a/DESIGN.md b/DESIGN.md new file mode 100644 index 000000000000..5061ac2c8a6c --- /dev/null +++ b/DESIGN.md @@ -0,0 +1,320 @@ +# Chunk-parallel Gated-DeltaNet (GDN) for DFlash speculative VERIFY + +Status: DESIGN + reviewable CUDA skeleton (no GPU available; not yet compiled/validated). +Scope: the **verify** path only — single sequence (`n_seqs == 1`), a block of `N` tokens +(N = draft-max + 1, up to ~16, design supports up to 32). Prefill and single-token decode keep +using the existing sequential kernel. + +Files: +- existing sequential kernel: `ggml/src/ggml-cuda/gated_delta_net.cu` +- CPU reference (exact math): `ggml/src/ggml-cpu/ops.cpp` + (`ggml_compute_forward_gated_delta_net_one_chunk`) +- new skeleton: `ggml/src/ggml-cuda/gated_delta_net_chunked.cu` +- dispatch hook: `ggml/src/ggml-cuda/ggml-cuda.cu` (`GGML_OP_GATED_DELTA_NET`, ~L2934) + +--- + +## 1. The exact recurrence (from the CPU reference) + +Per head, per sequence. Let `S_k = S_v = D` (the existing kernel assumes square state; head dim +`D in {16,32,64,128}`). The recurrent state is a `D x D` matrix `S` with `S[i][j]`, where `i` indexes +the **key** dimension and `j` indexes the **value** dimension. + +> Storage detail: both the CPU ref and the CUDA kernel store `S` **transposed** as +> `M[j][i] = S[i][j]` so that "row j of M" (contiguous) is "column j of S". The math below is in the +> mathematical `S[i][j]` convention; the kernel maps it to the transposed layout. + +Inputs at token `t` (all f32): `q_t, k_t in R^D` (key dim), `v_t in R^D` (value dim), +`beta_t in R` (scalar), gate `g_t`: +- **scalar gate** (non-KDA): `g_t in R`, decay `a_t = exp(g_t)` applied to the whole state. +- **KDA / vector gate**: `g_t in R^D` indexed by the **key** dim `i`, decay `a_t[i] = exp(g_t[i])` + applied per key-row. + +The update (matching `ops.cpp` lines 10522-10552 exactly): + +``` +1. decay: S[i][j] <- a_t[i] * S[i][j] (a_t[i] = exp(g_t[i]); scalar: a_t[i]=exp(g_t) for all i) +2. kv: u_t[j] = sum_i S[i][j] * k_t[i] = (S^T k_t)[j] +3. delta: d_t[j] = (v_t[j] - u_t[j]) * beta_t +4. update: S[i][j] <- S[i][j] + k_t[i] * d_t[j] (rank-1: S += k_t d_t^T) +5. output: o_t[j] = scale * sum_i S[i][j] * q_t[i] = scale * (S^T q_t)[j] (scale = 1/sqrt(D)) +``` + +Substituting (3) into (4), with `S_{t-1}` the post-(prev-token) state and `S_t` the post-update +state, this is the **gated delta rule**: + +``` +S_t = diag(a_t) S_{t-1} + k_t ( beta_t (v_t - (diag(a_t) S_{t-1})^T k_t) )^T +o_t = scale * S_t^T q_t +``` + +Define the **effective value** (a.k.a. "new value" / pseudo-value in DeltaNet) so the update becomes +a plain (non-recursive-in-S) rank-1 add: + +``` +w_t = beta_t * k_t (D, key dim) <- write key +u_t = (diag(a_t) S_{t-1})^T k_t (D, value dim) <- what's already stored +d_t = beta_t * v_t - beta_t * u_t (D, value dim) <- delta value +S_t = diag(a_t) S_{t-1} + k_t d_t^T +``` + +`o_t` reads `S_t` (post-update, includes the current token — note step 5 runs *after* step 4). + +This is the per-token recurrence the existing CUDA kernel runs `N` times sequentially in the verify +block. Cost of verify ~ `N` sequential steps, each `O(D^2)` work. We want to cut the **sequential +depth** from `N` to `N/C`. + +--- + +## 2. Chunked derivation (chunked delta-rule / chunked linear attention) + +Reference: FLA `chunk_delta_rule` / `chunk_gated_delta_rule`; Yang et al. "Parallelizing Linear +Transformers with the Delta Rule over Sequence Length" (DeltaNet) and "Gated DeltaNet". + +Split the `N` tokens into chunks of size `C` (e.g. `C = 16`, so a 16-token verify block is **one +chunk**; a 32-token block is two). Index tokens within a chunk by `r = 0..C-1` (global token +`t = chunk_base + r`). Let `S_in` be the state entering the chunk (the carry). + +### 2.1 Cumulative gate products inside the chunk + +For the **KDA / vector gate**, the per-key-dim decay is multiplicative, so define inclusive cumulative +products along the chunk (per key dim `i`): + +``` +A_r[i] = prod_{s=0..r} a_s[i] (inclusive, decay applied up to and including token r) +``` + +with `A_{-1}[i] = 1`. The decay from "just after token s applied" to "the chunk boundary after token +C-1" is `A_{C-1}[i] / A_s[i]`. For the **scalar gate**, `a_s` is a scalar and `A_r` collapses to a +scalar per token — same formulas, broadcast over `i`. + +To keep the rank-1 writes commutable, **pre-scale** each token's write key into a common reference +frame (the chunk start). Define: + +``` +k~_r[i] = k_r[i] / A_r[i] (deflated write key — "undo" the decay it will accumulate) +q~_r[i] = q_r[i] * A_r[i] (inflated query — apply decay the carry would have gotten) +``` + +Intuition: a rank-1 contribution `k_s d_s^T` written at token `s` gets multiplicatively decayed by +`A_{r}[i]/A_s[i]` (key dim) by the time we read at token `r >= s`. Folding `1/A_s` into the key and +`A_r` into the query realizes that decay through a single elementwise scale per token, so the +intra-chunk interactions become plain matmuls. (This is exactly the FLA "secondary chunking" trick; +do the cumprod in log space — see section 5.) + +### 2.2 Intra-chunk parallel form + +Stack the chunk into matrices (rows = tokens within the chunk): +`K, Q, V in R^{C x D}` (rows `k_r, q_r, v_r`), `K~, Q~` the deflated/inflated versions, `beta in R^C`. + +**(a) Carry read (contribution of `S_in` to every token's `u` and `o`):** + +The "already stored" value seen by token `r` from the *incoming* state is +`(diag(A_r) S_in)^T k_r = S_in^T (A_r (.) k_r)`. So with `Kbar_r = A_r (.) k_r` (the **inflated** read +key) stacked into `Kbar in R^{C x D}`: + +``` +U_carry = Kbar @ S_in in R^{C x D} (each row = u_r^carry, value dim) +``` + +The carry contribution to the output uses the *post-update* state, but since `S_in` is constant +within the chunk its output contribution is `O_carry = scale * (Qbar @ S_in)` with +`Qbar_r = A_r (.) q_r`. + +**(b) Intra-chunk token-token interactions (the delta-rule coupling):** + +Within the chunk, token `r`'s delta `d_r` depends on the writes of all earlier tokens `s < r` (and on +the carry). Build the **strictly-lower-triangular** decayed attention matrix between deflated keys: + +``` +T[r][s] = beta_r * ( k~_r . k~_s ) for s < r, else 0 in R^{C x C} +``` + +The delta-rule "un-mixing" is the classic `(I + tril(T,-1))^{-1}` solve (forward substitution over the +chunk, `C` sequential micro-steps but only on a `C x C` system, cheap and in shared memory). Let + +``` +W = (I + strict_tril(T))^{-1} in R^{C x C} +Dmat = W @ ( beta (.) (V - U_carry) ) in R^{C x D} (rows = d_r, the resolved deltas) +``` + +(`beta (.) V` is row-scaling `V` by `beta_r`; `U_carry` from (a).) `Dmat` rows are exactly the +per-token delta values `d_r` consistent with the sequential recurrence — now computed by two matmuls + +one small triangular solve instead of `C` rank-1 steps. + +**(c) Per-token output:** + +``` +O = O_carry + scale * tril( Q~ @ K~^T ) @ Dmat in R^{C x D} +``` +The `tril(Q~ K~^T)` term sums the intra-chunk writes that token `r` should see. The output reads the +**post-update** state, so the current token's own write (`s == r`) must be included — use the +lower-triangle **including** the diagonal for this output term, while the `T` solve in (b) stays +**strictly** lower. Mapping the exact diagonal handling is the one subtlety to nail against the +reference (see section 6 validation). + +### 2.3 Inter-chunk state carry (the only sequential part) + +After the chunk, the new boundary state: + +``` +S_out = diag(A_{C-1}) S_in + Kw^T @ Dmat + = diag(A_{C-1}) S_in + sum_r ((A_{C-1}/A_r) (.) k_r) d_r^T +``` +with `Kw_r = (A_{C-1}/A_r) (.) k_r` stacked into `Kw in R^{C x D}` (each write key carried forward to +the chunk end). This is one `D x D` update per chunk. + +**Sequential depth = number of chunks = ceil(N/C).** With `N <= C` (verify block <= 16 and `C = 16`) +the whole verify is **a single chunk**: zero inter-chunk recurrence, everything is matmuls + one +`C x C` triangular solve. That is the win. + +--- + +## 3. CUDA kernel structure (`gated_delta_net_chunked_cuda`) + +One CUDA **block per (head, sequence)** — for verify `sequence` is fixed (n_seqs==1), so grid is +`(H, 1, 1)`. Each block owns the chunk's `C x D` tiles and the `D x D` carry state in shared memory. + +Tiling (for the verify regime: `C <= 32`, `D in {16,32,64,128}`): +- Shared mem holds: `S` (`D x D` f32), `K,Q,V,K~,Q~,Kbar,Qbar` chunk tiles (`C x D` each), `T`/`W` + (`C x C`), `Dmat` (`C x D`), `A` cumprods (`C x D` for KDA, `C` for scalar). For `D=128, C=16` that + is `128*128*4 = 64KB` for `S` alone — at the edge of the 48-96KB smem budget, so for `D=128` either + keep `S` in registers (sharded across the warp like the existing kernel) or cap `C` smaller / use + the host-decomposition fallback (3.2). For `D <= 64` everything fits comfortably. +- Threads: a 2D thread block, `D` lanes x `num_warps` (mirror the existing + `block_dims(min(warp,D), num_warps)`). Matmuls are done cooperatively; the `C x C` triangular solve + is done by a single warp (C <= 32 fits one warp) via forward substitution. + +Phases inside the kernel (single chunk; the multi-chunk loop wraps phases 2-6): +1. **Load + gate cumprod.** Load `g`, compute `a_r = exp(g_r)`, inclusive cumprod `A_r` along the + chunk **in f32 / log space** (Hillis-Steele scan across `C`). Build `k~,q~,Kbar,Qbar,Kw`. +2. **U_carry = Kbar @ S_in**, **O_carry = scale . (Qbar @ S_in)** — two `C x D . D x D` matmuls. +3. **T = strict_tril(beta (.) (K~ K~^T))** — a `C x D . D x C` matmul, mask to strict lower. +4. **Solve `W (I+T)`:** forward-substitution to get `Dmat = (I+T)^{-1} (beta (.) (V - U_carry))` + (C sequential micro-steps on the small `C x C` system, one warp). +5. **Output:** `O = O_carry + scale . tril(Q~ K~^T) @ Dmat`; write `O` rows to `attn_data` (same + `[S_v.H]`-strided layout as the sequential kernel) and, if `trace != nullptr`, materialize the + per-token state trace (see 3.3). +6. **Carry:** `S_out = diag(A_{C-1}) S_in + Kw^T @ Dmat`; write back transposed `M[j][i]`. + +### 3.1 Dispatch hook + +In `ggml/src/ggml-cuda/ggml-cuda.cu`, `GGML_OP_GATED_DELTA_NET` (~L2934) currently calls +`ggml_cuda_op_gated_delta_net`. Add inside that op (in `gated_delta_net.cu`'s +`ggml_cuda_op_gated_delta_net`) a guarded fast path: + +``` +if (n_seqs == 1 && n_tokens >= GDN_CHUNK_MIN && n_tokens <= GDN_CHUNK_MAX && S_v <= GDN_CHUNK_DMAX) + launch_gated_delta_net_chunked(...); // new path (verify block) +else + launch_gated_delta_net(...); // existing sequential path (prefill / single decode) +``` + +`GDN_CHUNK_MIN` ~ 2 (no point for a single token), `GDN_CHUNK_MAX` ~ 32, `GDN_CHUNK_DMAX` initially +64 (raise to 128 once the smem/register strategy for `D=128` is validated). The trace output and the +final-state writeback use the **same** dst layout (`[attn_scores | new_states]`) and the same +transposed state convention, so nothing downstream changes. + +### 3.2 Host-decomposition fallback + +If a single monolithic kernel is too much for a first cut, the same math maps onto existing ggml CUDA +ops as a host-side graph (per head, single chunk): `ggml_mul_mat` for `Kbar@S`, `Q~K~^T`, `Kw^T@Dmat`; +elementwise muls for gating; a tiny custom kernel only for the `C x C` triangular solve. Slower than +the fused kernel (extra global-memory round trips) but a correctness oracle and a quick path to a +working verify. The skeleton notes this decomposition. + +### 3.3 Trace compatibility (DFlash rewind) + +DFlash needs the **per-token** state `S_t` for partial-acceptance rewind (`src[6]` trace). The chunked +kernel does not naturally produce per-token `S_t` (it jumps chunk->chunk). Two options: +- **(preferred)** After computing `Dmat`, materialize `S_r = diag(A_r) S_in + Kw(->r)^T @ Dmat[0..r]` + for each `r` via a small cumulative pass (the prefix of the chunk update) and write the trace rows. + Costs `C` light steps but they're independent across `r` (can be a parallel segmented scan). +- **(fallback)** When `trace != nullptr`, route to the sequential kernel (it already writes the trace + near-free). Verify still benefits whenever the harness doesn't request a trace; partial-accept paths + pay the sequential cost. Start here, then implement the prefix-trace. + +--- + +## 4. Expected speedup + +- Sequential kernel verify cost ~ `N` sequential GDN steps (latency-bound: each step's rank-1 update + + two reductions depends on the previous). For the 24 GDN layers this is the dominant verify cost + and is why single-stream speedup caps at ~1.5-1.7x. +- Chunked: sequential **depth** drops to `ceil(N/C)`. For `N <= 16, C = 16` -> **depth 1**. The + remaining work is matmuls (`C x D x D`) + one `C x C` solve, which are throughput-bound and overlap + well; on a modern GPU the `C x D x D` matmuls for `C,D <= 128` are far below peak and hide behind + issue latency. +- Net: verify GDN latency goes from `O(N)` serial to `O(N/C)` serial + parallel intra-chunk. This is + the same structural change that lets SGLang reach ~3x — we expect the single-stream cap to move from + ~1.5-1.7x toward the ~2.5-3x regime, gated by how much of end-to-end time is GDN verify vs the + full-attention layers and sampling. +- The arithmetic *work* slightly increases (the `(I+T)^{-1}` solve + extra matmuls), but it converts + serial-dependent work into parallel work, which is the right trade for a latency-bound verify. + +--- + +## 5. Numerical-stability concerns + +- **fp32 accumulation everywhere** — mirrors the top-level CLAUDE.md lesson (fp16 mean-pool overflowed + to +/-inf on attention-sink channels and poisoned every downstream lstsq). All cumprods, matmul + accumulators, the `C x C` solve, and the state must accumulate in **f32** (the sequential kernel is + already all-f32; keep parity). Never down-cast the state or the gate products to fp16. +- **Cumulative gate product underflow/overflow.** `A_r[i] = prod a_s[i]` with `a_s = exp(g_s)`. Over a + chunk of 16 tokens with strongly negative `g` (heavy decay), `A_{C-1}` can underflow and the + deflated key `k~_r = k_r / A_r` can blow up — the classic instability the FLA chunked kernels guard. + Mitigations: (a) keep `C` modest (16) so the product spans few tokens; (b) work in **log space** for + the cumulative gate (`L_r[i] = sum_{s<=r} g_s[i]`, then `A_r = exp(L_r)`, and form ratios as + `exp(L_r - L_s)` rather than dividing two exponentials) — this is the numerically safe way to get + `A_r/A_s` and `A_{C-1}/A_r` without ever materializing a tiny denominator; (c) f32 throughout. +- **Triangular solve conditioning.** `(I + strict_tril(T))` is unit-lower-triangular, so always + invertible and forward-substitution is stable; just accumulate in f32. +- **Diagonal/self-token bookkeeping** is the main *correctness* (not stability) risk — the output reads + the **post-update** state, so the current token's own write must be included. Validate against the + reference (section 6) rather than reasoning it through once. + +--- + +## 6. Step-by-step plan to production-correct + validation + +1. **CPU oracle first.** Implement the chunked math as a second CPU function next to + `ggml_compute_forward_gated_delta_net_one_chunk` (or a standalone test harness) and assert it + reproduces the sequential CPU reference (f32, `|delta| < 1e-4` per element) on random + `q,k,v,g,beta,S_in` for both scalar and KDA gates, for `D in {16,32,64,128}` and + `C in {1,2,4,8,16}`. This nails the diagonal/self-token and the gate-ratio direction *before* any + CUDA. +2. **Single-chunk CUDA kernel** for `n_seqs==1`, `N==C`, `D <= 64`. Compare its `attn` output and + `S_out` against the sequential CUDA kernel on the same inputs (host-side max-abs diff `< 1e-3` f32). +3. **Multi-chunk loop** (`N` = a few chunks); re-check the inter-chunk carry matches sequential. +4. **Trace path** (3.3 preferred): verify the per-token trace rows equal the sequential kernel's trace + element-for-element (this is what DFlash rewind reads — must match exactly). +5. **D=128 strategy**: pick register-sharded `S` (like the existing kernel) or smem; re-validate. +6. **Wire dispatch** behind the `n_seqs==1 && N in [MIN,MAX] && D <= DMAX` guard; keep the sequential + path as the default so prefill/decode are untouched. Add an env/define kill-switch + (`GGML_CUDA_GDN_CHUNKED=0`) to fall back at runtime during bring-up. +7. **End-to-end**: run the Qwen3.5-4B DFlash verify on a real prompt, confirm accepted-token sequences + are identical to the sequential-verify build (greedy + fixed seed), then measure tok/s. Validation = + *identical accepted tokens* + improved verify latency. + +Validation harness lives alongside the existing GDN tests (search `test-backend-ops` / +`gated_delta_net` test cases); add a chunked-vs-sequential equivalence case there. + +--- +## CORRECTION (validated by gdn_chunked_oracle.py, bitwise vs sequential, max err ~1e-13) + +The pairwise inter-token decay (s -> r) is **A_r / A_s**, NOT 1/(A_r·A_s). So the deflation must be +ASYMMETRIC: the LATER token r carries A_r (Kbar/Qbar = A⊙k, A⊙q), the EARLIER token s carries 1/A_s +(Ktil = k/A). The dot Kbar_r·Ktil_s = sum_i k_r k_s · A_r/A_s (bounded ≤1 for s compatible. + +2. **DFlash-specific nodes are capturable.** `ggml_set_rows(cross_dev,...)` (`dflash.cpp` L90), the + per-token conv/state trace `ggml_cpy` nodes (`qwen35.cpp` L301-317, L380-396), `ggml_argmax` + (`dflash.cpp` L227), and the top-k `argsort` verify path all map to normal CUDA ops with **no host + stream sync** (checked `set-rows.cu` / `argmax.cu` / `argsort.cu` — only `cudaMemcpyAsync` D2D, which + is capturable). => none disables capture. + +3. **Stable destinations / offsets.** `cross_dev`, `trace_s[il]`, `trace_r[il]` are persistent tensors + (allocated once in `dflash_cross_ctx` / `dflash_trace_buf`), so the trace/set_rows dst ptrs are + constant. The recurrent state write offset `kv_head * n_embd_s` (`qwen35.cpp` L395) is constant for a + single sequence (`get_head()` fixed for seq 0). => node props stable round-to-round. + +4. **The graph key is stable.** The CUDA graph is keyed by `cgraph->nodes[0]`. The verify ubatch is a + **constant** `block_size` tokens (the drafter always emits `block_size-1` drafts: + `speculative.cpp result.assign(block_size-1,0)`; `speculative-simple.cpp` L457-469). So + `llm_graph_result::can_reuse` holds (constant `n_tokens`/`n_outputs`/`cross`/samplers; recurrent + `head`/`rs_z` constant), and `llm_graph_result::reset()` reuses `buf_compute_meta` in place (same + `.data()` => tensors placement-allocated at the same offsets). => `nodes[0]` is the same pointer + across rounds, even when a rebuild happens. + +5. **Double-buffering is a non-issue here.** `cur_copy` only flips in `ggml_backend_sched_alloc_graph` + (skipped on the reuse path), and a single-GPU DFlash target runs `pipeline_parallel=false => + n_copies=1`, so input-copy pointers don't alternate. + +Conclusion: on Ampere+ the verify graph already captures and stays warm (the `cuda_graph` object keyed +by the stable `nodes[0]` keeps `node_props` across rounds; eviction is 10 s, rounds are ms apart). The +warmup does NOT reset for the verify on an identical graph. + +## Root cause of the residual per-round CPU cost (and the whole-model -6%) + +`ggml_graph_view` zeroes the uid; the tail of `ggml_backend_sched_split_graph` then assigns a fresh +monotonic uid per split every call. The CUDA backend's fast-path +(`if (cgraph->uid != 0 && cgraph->uid == graph->uid) return false;`) can therefore only skip the +property walk when the *higher-level* graph reuse keeps `split_graph` from running at all. Any reuse +miss re-runs `split_graph`, bumps the uid, and forces the full walk. On the ~1800-node whole-model graph +that is the measured ~-6%; on the hundreds-of-nodes verify graph it is smaller but non-zero. + +## The existing fix, and what this change adds + +Existing at HEAD (`ggml/src/ggml-backend.cpp`): +- `struct ggml_backend_sched_split` carries `prev_uid` / `prev_sig`. +- The uid loop computes a per-slot topology signature; if it matches the previous round's, it reuses + `prev_uid` instead of minting a fresh one. `GGML_SCHED_STABLE_UID=0` opts out (on by default). +- Grown `splits` slots are zeroed after `realloc` so `prev_uid`/`prev_sig` start clean. + +This change (hardening only): +- The signature was `backend_id + n_nodes + nodes[0] + nodes[n-1]` (endpoints only). A "same count + + same endpoints but different middle" collision would let the backend reuse a **stale captured graph** + (a silent correctness bug). Strengthened it to also fold in a **strided sample of up to ~16 interior + node pointers**, making such a collision effectively impossible while staying O(1)-ish per split. +- Updated the in-code comment to match. + +Why safe: the uid is a pure optimization hint. A matching uid only skips a walk that would have found no +change anyway (signature matched on stable placement-allocated pointers); any mismatch falls back to the +full walk + recapture. The fast-path's `node_props.size() == n_nodes` assert holds because `n_nodes` is +in the signature. + +## Files changed by this investigation + +- `ggml/src/ggml-backend.cpp` — `ggml_backend_sched_split_graph()`: strengthen the per-split topology + signature (strided interior-node sampling); comment fix. No struct/ABI change beyond what HEAD already + had; no CUDA file touched. + +## What to validate on GPU (remote Ampere+ box; V100 needs `GGML_CUDA_GRAPHS_VOLTA`) + +1. Build with CUDA. On V100 also pass `GGML_CUDA_GRAPHS_VOLTA=` (n >= verify node count, or 1). +2. Run `speculative-simple --spec dflash` for draft-max in {8, 12, 16}, comparing tokens/sec with vs + without the uid stabilization (`GGML_SCHED_STABLE_UID=0` disables). Expect the fix to remove the + per-round split walk -> higher t/s, and to make the larger draft blocks (12/16) viable toward the + SGLang accept_len ~6.6 regime (combined with the DESIGN.md chunked GDN kernel). +3. Debug build of `ggml-cuda.cu` (`-DCMAKE_BUILD_TYPE=Debug`): confirm + `GGML_LOG_DEBUG("CUDA Graph id %zu reused\n", ...)` fires every steady-state verify round. +4. Correctness: greedy verify output must be **token-identical** with and without + `GGML_SCHED_STABLE_UID` (a divergence would mean a signature collision — not expected after the + hardening). +5. Cross-check whole-model decode (no spec): the same stabilization should turn the prior ~-6% CUDA-graph + regression neutral/positive. + +GPU validation command (example): + +``` +GGML_SCHED_STABLE_UID=0 ./build/bin/llama-speculative-simple \ + -m target.gguf -md draft.gguf --spec dflash --draft-max 16 --draft-min 1 -n 256 -p "" +GGML_SCHED_STABLE_UID=1 ./build/bin/llama-speculative-simple \ + -m target.gguf -md draft.gguf --spec dflash --draft-max 16 --draft-min 1 -n 256 -p "" +``` diff --git a/GDN_CHUNKED_BRINGUP.md b/GDN_CHUNKED_BRINGUP.md new file mode 100644 index 000000000000..1a4f63afea1b --- /dev/null +++ b/GDN_CHUNKED_BRINGUP.md @@ -0,0 +1,170 @@ +# Chunked GDN verify — bring-up recipe (math validated, ggml-op decomposition) + +The chunked math is bitwise-correct (gdn_chunked_oracle.py vs sequential, err ~1e-13). It needs NO +hand-written CUDA kernel: ggml has cumsum + tri + solve_tri + mul_mat on BOTH CPU and CUDA. Build the +verify GDN (n_seqs==1, N=draft_max+1 ≤ ~16, single chunk) as a ggml subgraph; validate on the CPU +backend locally (build/bin/libggml*.dylib already built — write a standalone test linking it), then +it runs on CUDA for free. + +## Inputs (from ggml_gated_delta_net): q,k,v [S_v,H,N,1]; g [S_v,H,N,1] (kda); beta [1,H,N,1]; +## state S0 [S_v,S_v,H,1]. Output [S_v*H, N + S_v] (cols 0..N-1 = attn, cols N..N+S_v-1 = new state). + +## Op recipe (per the VALIDATED oracle; A_r[i]=prod_{s<=r} a_s, a=exp(g)): +1. Permute q,k,v,g to per-head token-matrices Xp [S_v, N, H] (ne0=dim i, ne1=token r, ne2=head). +2. A: put tokens on ne0 -> g2 [N, S_v, H]; L = ggml_cumsum(g2) over ne0 (tokens); A = exp(L); + permute A back to [S_v, N, H]. (cumsum is ne0-only, hence the shuffle.) +3. Kbar = A⊙kp ; Qbar = A⊙qp ; Ktil = kp ⊙ exp(-L_perm) (= kp/A, but form via exp(-L) to stay fp32-safe). +4. U_carry = mul_mat(S0[i,j,H], Kbar[i,r,H]) -> [j, r, H] (contracts i). O_carry = scale·mul_mat(S0, Qbar). +5. KK = mul_mat(Ktil[i,s,H], Kbar[i,r,H]) -> [s, r, H]; want T[r,s]=beta_r·KK[s,r]. + Tfull = beta(broadcast over s) ⊙ transpose(KK to [r,s,H]); T = ggml_tri(Tfull, LOWER strict). +6. rhs = beta ⊙ (vp_as[r,j] - U_carry[j,r]^T) -> shape [j? r?]; keep as [N, S_v, H] (r on ne0) for solve. + Dmat = ggml_solve_tri(A=I+T [N,N,H] unit-lower, B=rhs, left=true, lower=true, uni=true) -> [N, S_v, H]. +7. QK = ggml_tri(transpose(mul_mat(Ktil,Qbar)) to [r,s,H], LOWER_DIAG incl diagonal); + O = O_carry + scale·mul_mat(QK[s,r,H]?, Dmat[s,j,H]) -> [r,j,H] (contracts s). +8. S_out[i,j,H] = (A_end[i] ⊙ S0[i,j]) + mul_mat(Kw[i? ], Dmat) ; Kw_r = (A_end/A_r)⊙k_r. +9. Reassemble into the [S_v*H, N+S_v] output layout (attn cols = O reshaped, state cols = S_out). + +## Validation (do BEFORE wiring into the model): +- Standalone CPU test: random q,k,v,g,beta,S0; run ggml_gated_delta_net (sequential, reference) and + build_gated_delta_net_chunked; ggml_backend_cpu; compare max|diff| < 1e-4. Iterate layout bugs here + (mul_mat is a^T b contracting ne0; transposes/permutes are where bugs hide). +- Then bitwise-vs-sequential for the TRACE rows too (DFlash rewind needs per-token state; the chunked + path gives only the final S_out + per-token O, NOT per-token state -> for the rewind we still need + per-token states. EITHER also emit per-token states (S after each token = O_carry-style partial), OR + keep trace on the sequential path. RESOLVE THIS before shipping: the rewind/promote depends on the + per-token trace; chunked must reproduce it or the verify can't use trace+promote.) + +## Then: wire as the verify fast path (n_seqs==1, N small) behind a flag; bench draft-max 8/12/16 on +## Blackwell/H100; expect verify cost ~flat in N -> larger blocks affordable -> accept_len ~6 -> ~2.5-3x. + +## OPEN RISK (important): the DFlash rewind (trace+promote) needs per-TOKEN states. The chunked form +## naturally yields only the chunk-final state. Per-token states within the chunk can be recovered +## (S_t = decay(S0,t) + intra-chunk updates up to t) but that's extra work; OR run chunked for speed +## and the sequential trace only when a partial-accept rewind is actually needed. Must be designed. + +--- +## STATUS: chunked GDN ggml-graph VALIDATED (commit 9c1f082b8). Integration plan below. + +tests/test-gdn-chunked.cpp: chunked-vs-sequential ALL PASS (N=1..16, S_v=64/128, fp32). The graph is +portable (cumsum/tri/solve_tri/mul_mat on CPU+CUDA; the path that also extends Metal/Vulkan/WebGPU). + +### Wiring (next): +1. Lift build_chunked() from the test into a reusable builder, e.g. build_gated_delta_net_chunked() + in src/models/ (or a ggml helper), returning the SAME [S_v*H, N+S_v] output as ggml_gated_delta_net. +2. In src/models/qwen35.cpp build_layer_attn_linear: when (n_seqs==1 && n_seq_tokens>1 && verify), + call the chunked builder instead of ggml_gated_delta_net. Keep the fused sequential op for prefill + and single-token decode (chunked wins only for a multi-token block). +3. Gate behind a flag (e.g. cparams.gdn_chunked or env) so it's opt-in until GPU-validated. + +### TRACE / rewind (the one real design decision): +DFlash rewind (trace+promote) needs the per-TOKEN state S_t for the accepted position. The chunked +path yields only S_out (after all N). Resolution: keep the chunked path for the fast verify FORWARD +(attn + S_out); on a PARTIAL accept (acc accept_len ~6 -> + toward 2.5-3x (the SGLang regime). This is the payoff. + +### NUMERICAL: fp32 err grows mildly with N,D (1e-5 at N16/D128). On GPU keep fp32 accumulation in the +matmuls/solve. If a stronger-decay model underflows k/A, switch to the log-space ratio form (DESIGN.md). + +--- +## REASSEMBLY into the [S_v*H, N+S_v] combined output (the drop-in detail, derived + checked) +ggml_gated_delta_net's data = [attn region (S_v*H*N) | state region (S_v*S_v*H)], flat: + attn[h,t,j] at (t*H + h)*S_v + j -> column-major [S_v*H, N], row=h*S_v+j, col=t + state[h,i,j] at attn_elems + h*S_v*S_v + j*S_v + i -> EXACTLY a [S_v(i),S_v(j),H] contiguous tensor +So the chunked builder's cs ([i,j,H], already contiguous) IS the state region as-is. For attn: + O is [t, j, H] -> permute(2,0,1,3) -> [j, H, t] -> cont -> reshape [S_v*H, N]. +Combined = reshape( ggml_concat( reshape(O_attn,1D[S_v*H*N]), reshape(cs,1D[S_v*S_v*H]), dim0 ), + [S_v*H, N+S_v] ). Drop-in for ggml_gated_delta_net's result. + +## GQA: q/k have num_k_heads (ssm_n_group), v/g/beta/state have num_v_heads (ssm_dt_rank). In the +## builder: ggml_repeat q/k from Hk to H (interleaved h%Hk) BEFORE the per-head ops. VALIDATED. + +## DONE this session: chunked GDN ggml builder VALIDATED on CPU bitwise vs sequential, incl GQA +## (tests/test-gdn-chunked.cpp, ALL PASS N=1..16, S_v=64/128, H_v=4, H_k=1/2). It is portable +## (CUDA/Metal/Vulkan/WebGPU via the op kernels). REMAINING: (1) lift build_chunked into +## delta-net-base.cpp + reassembly above, gate by env, fall back to sequential when gdn_trace!=null; +## (2) GPU build + bitwise vs sequential + draft-max 8/12/16 speed (verify ~flat in N -> ~3x); +## (3) trace/rewind: compute S_acc on partial-accept from kept A/Ktil/Dmat (no per-token buffer). + +--- +## GPU BRING-UP RESULTS (eva01 V100, Qwen3.5-4B-Q8_0, 24 GDN + 8 attn) — DECISIVE, lever #3 verdict + +Ran the wired chunked path on CUDA. Three findings settle whether chunked GDN is a speedup lever: + +1. **Qwen3.5 GDN uses a SCALAR gate, not the vector (KDA) gate.** Gate diagnostic at the dispatch + site: `g->ne[0]=1, S_v=128`. The builder + tests/test-gdn-chunked.cpp were written for the + per-channel VECTOR gate (`g->ne[0]==S_v`), so the original `g->ne[0]==S_v` trigger NEVER fired on + Qwen3.5 — every "chunked vs seq" comparison before this was a silent no-op (identical numbers). + Fix: generalized the builder to accept the scalar gate (A=[1,N,H] broadcasts across S_v; AendB + width from A->ne[0]; full-size tensor first in every A-multiply so [1,N,H] repeats into [S_v,N,H]). + Now fires: `[GDN-CHUNKED] active: N=16 S_v=128 H=32 Hk=16 gate=scalar`, assert-free on CUDA. + +2. **Chunked is SLOWER than the fused CUDA kernel** (llama-bench pp, t/s, r=8): + | N (pp) | fused (off) | chunked (on) | + |--------|-------------:|-------------:| + | 8 | 514 | 413 | + | 16 | 830 | 672 | + | 32 | 1373 | 1129 | + | 64 | 1507 | 1343 | + CUDA already ships a native `GGML_OP_GATED_DELTA_NET` kernel (ggml-cuda/gated_delta_net.cu) that + the default multi-token path uses; the decomposed ggml graph (cumsum/exp/diag/tri/solve_tri + + several mul_mat) cannot beat it. **Worse, GDN is not the verify bottleneck at all**: baseline pp + throughput is identical whether the GDN op is the fused kernel or not — the 4B transformer's + attention+MLP matmuls dominate at N<=64, the GDN scan over <=16 tokens is negligible. + +3. **Single-chunk is numerically valid only for SMALL N.** The builder treats the whole block as ONE + chunk, so `Ainv = exp(-cumsum(g))` overflows for long sequences. A chat-template-inflated prefill + (~30-40 tok) already produces `?????` garbage under LLAMA_GDN_CHUNKED=1. It is correct only for a + true small verify block (<=16, as the CPU test covers); using it for prefill needs real multi-chunk + tiling (chunk into 16-64 blocks) which is NOT implemented. + +### VERDICT (lever #3): chunked GDN is a PORTABILITY artifact, NOT a speed lever. +- On CUDA it loses to the fused kernel and GDN isn't the bottleneck -> zero speedup toward SGLang's ~3x. +- Its only real use is backends WITHOUT a fused GDN kernel (WebGPU/Metal/Vulkan) AND only for small + blocks. For the speedup goal, the real lever is #2 (CUDA graphs for cheap large-block verify), + not GDN chunking. +- Code kept opt-in behind LLAMA_GDN_CHUNKED (default OFF, gdn_trace==null only) so normal serving is + untouched. Scalar-gate model-path correctness at small N is UNVERIFIED (llama-cli is confounded by + template prefill length + single-chunk overflow); do not claim it correct without a unit harness. + +--- +## PORTABLE + CORRECT (multi-chunk tiling) — what it took to make chunked GDN actually work on CUDA + +Goal: a pure-ggml chunked GDN that is CORRECT on every backend (verify path for fused-kernel-less +backends: WebGPU/Metal/Vulkan). Two deliverables: a unit test proving correctness, and multi-chunk +tiling for sequences longer than one block. Both done; the road there found three real issues. + +1. **Scalar gate proven correct (CPU + CUDA).** tests/test-gdn-chunked.cpp now covers the scalar + (Gated DeltaNet, g->ne[0]==1) gate as well as the vector (KDA) gate, GQA, and runs on a SELECTABLE + backend (GDN_BACKEND=CUDA). Match vs ggml_gated_delta_net is ~1e-8 (fp32) on both CPU and CUDA. + +2. **Multi-chunk tiling.** build_delta_net_chunked tiles the N tokens into blocks of C, threads the + recurrent state forward, concats the per-block attn. For a verify block (N<=C) the loop runs once + = the original single-chunk path. Capped at N<=128 in the dispatch gate (LLAMA_GDN_CHUNKED_MAXN): + the loop unrolls ceil(N/C) subgraphs per layer, so a long prefill (the 512-token ubatch reserve) + would explode the static graph (GGML_ASSERT(obj_new) — ctx pool exhausted) -> fall through to the + fused op there. On CUDA the fused kernel is faster for long prefill anyway (see verdict above). + +3. **The test's random inputs hid TWO bugs that only the real model exposed:** + - **Unnormalized keys diverge.** Random k with ||k||^2 ~ S_v*sc^2 >> 2 violates the delta-rule + stability bound beta*||k||^2 < 2, so the TRUE recurrence diverges and ref vs chunked blow up in + the unstable directions for long N (looked like a tiling bug: error grew ~30%/token). Fix: + l2-normalize q,k in the test (delta-net normalizes them) -> stable, machine-eps match. + - **Deflation precision sets a small max chunk size.** A=exp(+/-cumsum(g)) has wide dynamic range; + Ktil=k*exp(-cumsum), Kbar=k*exp(+cumsum) individually span many orders of magnitude even though + their product is bounded, so the KK matmul loses fp32 precision when a chunk is too long. The + test used mild gates (g~-0.2) and passed at C=16/32; the REAL model has strong-decay heads and + **garbles at C>=16, is clean at C<=8** (verified: greedy output identical to the fused baseline + at C=4 and C=8; "??????" at C=16/24). **Deployed default C=8.** The test now uses C=8 to match. + +### Net: chunked GDN is CORRECT on CUDA (greedy output bit-identical to the fused kernel across +short/medium/long prompts) and portable. Confirmed NOT a speedup on CUDA (the fused kernel is faster +and GDN isn't the bottleneck); its role is the verify path on backends without a fused GDN kernel. +Validation harness: tests/test-gdn-chunked.cpp (GDN_BACKEND=CPU|CUDA), all PASS on both. diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index d9f8aaec52fd..981851a68d11 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -769,6 +769,16 @@ struct ggml_backend_sched_split { int n_inputs; // graph view of this split struct ggml_cgraph graph; + + // stable-uid bookkeeping (see ggml_backend_sched_split_graph): when a re-split produces a + // byte-identical split (same backend, node count, and head/tail node pointers) we reuse the + // previous uid instead of minting a fresh one. A stable uid lets graph-capturing backends + // (CUDA graphs) hit their uid fast-path and skip the per-node property walk every round, which + // is the dominant CPU overhead for a stable speculative-verify graph. This is purely a hint: + // a matching uid only lets the backend skip a walk that would have found no change anyway, and + // a non-matching uid always falls back to the full walk, so correctness is unaffected. + uint64_t prev_uid; // uid assigned on the previous split_graph for this slot (0 = none) + uint64_t prev_sig; // topology signature of the previous split for this slot }; struct ggml_backend_sched { @@ -1304,10 +1314,14 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra split->i_end = i; i_split++; if (i_split >= sched->splits_capacity) { + const int prev_capacity = sched->splits_capacity; sched->splits_capacity *= 2; sched->splits = (ggml_backend_sched_split *) realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); GGML_ASSERT(sched->splits != NULL); + // zero the newly grown slots so the stable-uid prev_uid/prev_sig start clean + memset(&sched->splits[prev_capacity], 0, + (sched->splits_capacity - prev_capacity) * sizeof(struct ggml_backend_sched_split)); } split = &sched->splits[i_split]; split->backend_id = node_backend_id; @@ -1481,8 +1495,50 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra } // set ids for all splits + // + // stable-uid optimization: if this split slot is byte-identical to the previous split_graph + // (same backend, node count, and a strided sample of node pointers - sufficient because nodes + // are placement-allocated at stable offsets in the reused compute buffer, so an unchanged + // topology re-materializes the exact same tensor pointers), reuse the previous uid. This lets + // graph-capturing backends (CUDA graphs) take their uid fast-path and skip the O(n_nodes) + // property walk + warmup churn every round - the dominant per-round CPU cost for a stable + // speculative-verify graph that re-splits because the higher-level graph-reuse check missed. + // Opt out with GGML_SCHED_STABLE_UID=0. Hint only: a reused uid merely skips a walk that would + // have found no change; any mismatch falls back to the full walk, so correctness is unaffected. + static const bool stable_uid = [] { + const char * e = getenv("GGML_SCHED_STABLE_UID"); + return e == nullptr || atoi(e) != 0; // on by default + }(); for (int i = 0; i < sched->n_splits; ++i) { - sched->splits[i].graph.uid = ggml_graph_next_uid(); + struct ggml_backend_sched_split * split = &sched->splits[i]; + + uint64_t sig = 0; + if (stable_uid && split->graph.n_nodes > 0) { + // cheap topology signature: backend + node count + a strided sample of node pointers + // (head, tail, and up to ~16 interior nodes). Nodes are placement-allocated at stable + // offsets in the reused compute buffer, so an unchanged topology re-materializes the + // exact same pointers; ANY topology change shifts the count and/or these offsets. The + // interior sampling makes a same-count/same-endpoints-but-different-middle collision + // (which would let the backend reuse a stale captured graph) effectively impossible. + const int n = split->graph.n_nodes; + const int step = n > 16 ? n / 16 : 1; + sig = (uint64_t) (uint32_t) split->backend_id; + sig = sig * 1099511628211ull + (uint64_t) n; + for (int k = 0; k < n; k += step) { + sig = sig * 1099511628211ull + (uint64_t) (uintptr_t) split->graph.nodes[k]; + } + sig = sig * 1099511628211ull + (uint64_t) (uintptr_t) split->graph.nodes[n - 1]; + } + + if (stable_uid && sig != 0 && sig == split->prev_sig && split->prev_uid != 0) { + // identical to last time - keep the previous uid so the backend graph fast-path fires + split->graph.uid = split->prev_uid; + } else { + split->graph.uid = ggml_graph_next_uid(); + } + + split->prev_uid = split->graph.uid; + split->prev_sig = sig; } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 22727fd91dbf..d1ca702a808c 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -278,6 +278,16 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); + // [TAG_GDN_CHUNKED] DFlash verify fast path: chunk-parallel GDN (see DESIGN.md + gated_delta_net_chunked.cu). + // When implemented & validated, route the single-sequence verify block here to cut sequential depth N -> ceil(N/C): + // if (n_seqs == 1 && n_tokens >= GDN_CHUNK_MIN && n_tokens <= GDN_CHUNK_MAX + // && S_v <= GDN_CHUNK_DMAX && trace_d == nullptr) { + // if (kda) launch_gated_delta_net_chunked(q_d,k_d,v_d,g_d,b_d,s_d,dst_d,trace_d, S_v,H,n_tokens, sq1,sq2,sv1,sv2,sb1,sb2, scale,stream); + // else launch_gated_delta_net_chunked(q_d,k_d,v_d,g_d,b_d,s_d,dst_d,trace_d, S_v,H,n_tokens, sq1,sq2,sv1,sv2,sb1,sb2, scale,stream); + // return; + // } + // (left disabled until the skeleton is made compilable + bitwise-validated against this sequential kernel.) + if (kda) { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, diff --git a/ggml/src/ggml-cuda/gated_delta_net.cuh b/ggml/src/ggml-cuda/gated_delta_net.cuh index 7375e81c0c36..f71dd68d5908 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cuh +++ b/ggml/src/ggml-cuda/gated_delta_net.cuh @@ -2,3 +2,17 @@ #include "ggml.h" void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +// Chunk-parallel GDN forward for the DFlash speculative VERIFY path (single sequence). +// Cuts the verify's sequential depth from N to ceil(N/C). See DESIGN.md (repo root) +// and gated_delta_net_chunked.cu. SKELETON — not yet wired into the dispatcher. +template +void launch_gated_delta_net_chunked( + const float * q_d, const float * k_d, const float * v_d, + const float * g_d, const float * b_d, const float * s_d, + float * dst_d, float * trace_d, + int64_t S_v, int64_t H, int64_t n_tokens, + int64_t sq1, int64_t sq2, + int64_t sv1, int64_t sv2, + int64_t sb1, int64_t sb2, + float scale, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/gated_delta_net_chunked.cu b/ggml/src/ggml-cuda/gated_delta_net_chunked.cu new file mode 100644 index 000000000000..74fbc15fe266 --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net_chunked.cu @@ -0,0 +1,325 @@ +#include "gated_delta_net.cuh" + +// NOTE: SKELETON ONLY - not yet compilable (pseudo-helpers/TODOs). Guarded out of the +// build so the branch compiles; see DESIGN.md. Remove this #if 0 once the kernel is +// brought up per the validation plan. +#if 0 + +// ============================================================================= +// Chunk-parallel Gated-DeltaNet forward for the DFlash speculative VERIFY path. +// +// SCOPE: single sequence (n_seqs == 1), a block of N tokens (N = draft_max + 1, +// up to ~32). Prefill and single-token decode keep using the existing +// SEQUENTIAL kernel in gated_delta_net.cu. This file is the CHUNK-PARALLEL +// variant that cuts the verify's sequential depth from N to ceil(N/C). +// +// STATUS: reviewable SKELETON. The structure, tiling, and math->thread mapping +// are concrete; the cooperative matmul/scan/solve bodies are sketched +// with pseudo-helpers (cg_* / smem tiles) and are NOT yet a compilable, +// numerically-verified kernel. See DESIGN.md (repo root), sections 2-6, +// for the derivation and the bring-up/validation plan. +// +// MATH (mathematical S[i][j] convention; i = key dim, j = value dim, D = head dim): +// per token t, with decay a_t[i] = exp(g_t[i]) (KDA) or exp(g_t) (scalar): +// decay: S[i][j] <- a_t[i] * S[i][j] +// kv: u_t[j] = sum_i S[i][j] * k_t[i] +// delta: d_t[j] = (v_t[j] - u_t[j]) * beta_t +// update: S[i][j] += k_t[i] * d_t[j] +// out: o_t[j] = scale * sum_i S[i][j] * q_t[i] scale = 1/sqrt(D) +// +// chunked (chunk size C, S_in = state entering the chunk; see DESIGN.md 2.2): +// A_r[i] = prod_{s<=r} a_s[i] (inclusive cumprod, log space) +// Kbar_r = A_r (.) k_r Qbar_r = A_r (.) q_r (carry-read keys/queries) +// k~_r = k_r / A_r q~_r = q_r * A_r (deflated/inflated) +// Kw_r = (A_{C-1}/A_r) (.) k_r (carry-forward write key) +// U_carry = Kbar @ S_in [C x D] +// T = strict_tril( beta (.) (K~ K~^T) ) [C x C] +// Dmat = (I + T)^{-1} @ ( beta (.) (V - U_carry) ) [C x D] (fwd-subst solve) +// O = scale*(Qbar @ S_in) + scale*tril(Q~ K~^T) @ Dmat [C x D] +// S_out = diag(A_{C-1}) S_in + Kw^T @ Dmat [D x D] +// +// STATE STORAGE: same transposed layout as the sequential kernel and CPU ref: +// M[j*D + i] = S[i][j] (row j of M is column j of S, contiguous). +// +// IMPORTANT (see top-level CLAUDE.md fp16 pooling lesson): ALL accumulation here +// is f32. Gate cumprods are done in LOG space (sums of g) and ratios formed as +// exp(L_r - L_s) so we never divide by a tiny exp() denominator. +// ============================================================================= + +// ---- tunables (mirror DESIGN.md 3.1) --------------------------------------- +#define GDN_CHUNK_C 16 // chunk size; a <=16-token verify block is ONE chunk +#define GDN_CHUNK_MIN 2 // below this, sequential decode wins +#define GDN_CHUNK_MAX 32 // largest verify block we accept on this path +#define GDN_CHUNK_DMAX 64 // start with D<=64; D=128 needs the register-shard variant + +// Pseudo-helpers used in the sketch (to be replaced with real cooperative impls): +// smem_matmul_AB(out, A, B, M, K, N) : out[MxN] = A[MxK] @ B[KxN], f32 accum, block-cooperative +// smem_matmul_ABt(out, A, B, M, K, N): out[MxN] = A[MxK] @ B[NxK]^T +// tri_mask_strict(M, C) : zero the upper triangle incl. diagonal +// tri_mask_incl(M, C) : zero the strict upper triangle (keep diagonal) +// warp_fwd_subst(W_or_inplace, T, C) : solve (I + strict_tril(T)) X = RHS by forward substitution + +// ----------------------------------------------------------------------------- +// One CUDA block per (head, sequence). For verify n_seqs==1 -> grid (H,1,1). +// Template on S_v (=D) and KDA exactly like the sequential kernel so the same +// dispatch switch can pick the instantiation. +// ----------------------------------------------------------------------------- +template +__global__ void gated_delta_net_chunked_cuda( + const float * q, // [D, H, T] (key dim, head, token) strides sq* + const float * k, // [D, H, T] + const float * v, // [D, H, T] (value dim, head, token) + const float * g, // [1|D, H, T] gate (scalar or KDA vector over key dim) + const float * beta, // [1, H, T] + const float * curr_state, // [D, D, H] incoming state S_in (transposed M[j][i]) + float * dst, // [attn_scores | new_states] (same layout as sequential kernel) + float * trace, // optional per-token state trace (n_seqs==1), may be nullptr + int64_t H, + int64_t n_tokens, // N (the verify block length) + int64_t sq1, int64_t sq2, // q/k strides (floats): sq1 over head, sq2 over token + int64_t sv1, int64_t sv2, // v strides + int64_t sb1, int64_t sb2, // beta/g base strides + float scale) { + + const int h_idx = blockIdx.x; // head this block owns + const int lane = threadIdx.x; // 0..D-1 (value/key column) + const int warp = threadIdx.y; // 0..num_warps-1 + + // ---- shared-memory tiles (DESIGN.md 3, "Tiling") ------------------------ + // For D<=64, C=16 these fit in <=48KB. For D=128 use the register-shard + // variant for S (like the sequential kernel) instead of smem S. + __shared__ float s_S [D][D]; // incoming/outgoing state (M[j][i] = S[i][j]) + __shared__ float s_K [GDN_CHUNK_C][D]; // raw chunk tiles + __shared__ float s_Q [GDN_CHUNK_C][D]; + __shared__ float s_V [GDN_CHUNK_C][D]; + __shared__ float s_L [GDN_CHUNK_C][D]; // cumulative LOG gate L_r[i] = sum_{s<=r} g_s[i] (KDA) + // (scalar gate: column 0 used, broadcast over i) + __shared__ float s_beta [GDN_CHUNK_C]; + __shared__ float s_Ucar [GDN_CHUNK_C][D]; // U_carry + __shared__ float s_T [GDN_CHUNK_C][GDN_CHUNK_C]; // intra-chunk coupling / solve workspace + __shared__ float s_Dmat [GDN_CHUNK_C][D]; // resolved per-token deltas d_r + __shared__ float s_O [GDN_CHUNK_C][D]; // outputs + + const int C = (int) (n_tokens < GDN_CHUNK_C ? n_tokens : GDN_CHUNK_C); + + // base pointers for this (head) — n_seqs==1 so sequence offset is 0 + const float * q_h = q + h_idx * sq1; + const float * k_h = k + h_idx * sq1; + const float * v_h = v + h_idx * sv1; + const float * gb_base = (const float *) nullptr; // gate/beta offset computed per token below + const int64_t gb_h = h_idx * sb1; + + float * attn_data = dst + h_idx * D; // [.. + token*D*H], value rows + const int64_t attn_score_elems = (int64_t) D * H * n_tokens; // n_seqs==1 + float * state_out = dst + attn_score_elems + (int64_t) h_idx * D * D; + + // ========================================================================= + // Outer loop over chunks. Sequential DEPTH = ceil(N/C). For N<=C this runs once. + // S_in for chunk 0 is curr_state; for later chunks it's the previous S_out. + // ========================================================================= + // load S_in (transposed) into s_S + for (int j = warp; j < D; j += blockDim.y) { + s_S[j][lane] = curr_state[(int64_t) (h_idx * D + j) * D + lane]; + } + __syncthreads(); + + for (int chunk_base = 0; chunk_base < n_tokens; chunk_base += GDN_CHUNK_C) { + const int cc = (int) min((int64_t) GDN_CHUNK_C, n_tokens - chunk_base); + + // -- Phase 1: load chunk tiles + cumulative LOG gate (Hillis-Steele scan) -- + // Load k_r, q_r, v_r, beta_r, g_r for r=0..cc-1 into smem. Then prefix-sum + // g over r (log space) -> s_L[r][i] = sum_{s<=r} g_s[i]. scalar gate: + // s_L[r][0] = sum_{s<=r} g_s, broadcast at use sites. + for (int r = warp; r < cc; r += blockDim.y) { + const int t = chunk_base + r; + s_K[r][lane] = k_h[t * sq2 + lane]; + s_Q[r][lane] = q_h[t * sq2 + lane]; + s_V[r][lane] = v_h[t * sv2 + lane]; + const int64_t gb = gb_h + (int64_t) t * sb2; + if (lane == 0) s_beta[r] = beta[gb]; + // KDA gate is a length-D vector over key dim; scalar gate is length 1 + s_L[r][lane] = KDA ? g[gb * D + lane] : (lane == 0 ? g[gb] : 0.0f); + } + __syncthreads(); + // inclusive prefix sum of s_L along r (one warp marches r; cheap, C<=32) + // gdn_prefix_sum_logspace(s_L, cc, D, KDA); // <- TODO real scan + __syncthreads(); + + // Convenience: A_r[i] = exp(s_L[r][i]) + // A_last[i] = exp(s_L[cc-1][i]) + // ratio(r,i) = exp(s_L[cc-1][i] - s_L[r][i]) (= A_last/A_r, carry-forward) + // Build the derived keys/queries on the fly inside the matmuls below to + // avoid extra smem; shown here named for clarity: + // Kbar_r[i] = exp(s_L[r][i]) * s_K[r][i] + // Qbar_r[i] = exp(s_L[r][i]) * s_Q[r][i] + // k~_r[i] = exp(-s_L[r][i]) * s_K[r][i] + // q~_r[i] = exp(+s_L[r][i]) * s_Q[r][i] (== Qbar; same for query) + // Kw_r[i] = ratio(r,i) * s_K[r][i] + + // -- Phase 2: U_carry = Kbar @ S_in ; O_carry = scale*(Qbar @ S_in) -- + // s_Ucar[r][j] = sum_i Kbar_r[i] * S[i][j] + // = sum_i (exp(L[r][i]) * s_K[r][i]) * s_S[j][i] (M transposed!) + for (int r = warp; r < cc; r += blockDim.y) { + float acc = 0.0f; + for (int i = 0; i < D; ++i) { + const float Kbar = expf(s_L[r][i]) * s_K[r][i]; + acc += Kbar * s_S[lane][i]; // s_S[j][i], here j=lane (value dim column) + } + s_Ucar[r][lane] = acc; + // O_carry accumulates into s_O below (Qbar @ S_in == same with q) + float oacc = 0.0f; + for (int i = 0; i < D; ++i) { + const float Qbar = expf(s_L[r][i]) * s_Q[r][i]; + oacc += Qbar * s_S[lane][i]; + } + s_O[r][lane] = scale * oacc; // O_carry; intra-chunk term added in Phase 5 + } + __syncthreads(); + + // -- Phase 3: T[r][s] = beta_r * (k~_r . k~_s), strict lower triangle -- + // k~_r[i] = exp(-L[r][i]) * s_K[r][i]. Build [C x C], mask s>=r to 0. + for (int r = warp; r < cc; r += blockDim.y) { + // each lane handles a subset of columns s + for (int s = lane; s < cc; s += blockDim.x) { + if (s < r) { + float dot = 0.0f; + for (int i = 0; i < D; ++i) { + const float kr = expf(-s_L[r][i]) * s_K[r][i]; + const float ks = expf(-s_L[s][i]) * s_K[s][i]; + dot += kr * ks; + } + s_T[r][s] = s_beta[r] * dot; + } else { + s_T[r][s] = 0.0f; // strict lower only + } + } + } + __syncthreads(); + + // -- Phase 4: solve Dmat = (I + strict_tril(T))^{-1} @ RHS, RHS = beta(.)(V - U_carry) -- + // Unit-lower-triangular -> forward substitution, C sequential micro-steps + // on the C x C system (one warp). RHS lives in s_Dmat initially. + for (int r = warp; r < cc; r += blockDim.y) { + s_Dmat[r][lane] = s_beta[r] * (s_V[r][lane] - s_Ucar[r][lane]); // RHS row r + } + __syncthreads(); + // forward substitution: for r = 0..cc-1: Dmat[r] -= sum_{sr)^T @ Dmat[0..r] per r + // and write each into trace[(t*H + h)*D*D ...]. Sketched as a prefix pass; + // omitted in this skeleton (start by routing trace!=nullptr to sequential). + (void) trace; + } + + // -- final state writeback (transposed layout, same as sequential kernel) -- + for (int j = warp; j < D; j += blockDim.y) { + state_out[(int64_t) j * D + lane] = s_S[j][lane]; + } +} + +// ----------------------------------------------------------------------------- +// Host launcher. Mirrors launch_gated_delta_net in gated_delta_net.cu but +// uses a (H,1,1) grid with one block per head. Called from the guarded fast path +// in ggml_cuda_op_gated_delta_net (see DESIGN.md 3.1): +// +// if (n_seqs == 1 && n_tokens >= GDN_CHUNK_MIN && n_tokens <= GDN_CHUNK_MAX +// && S_v <= GDN_CHUNK_DMAX && trace_d == nullptr) +// launch_gated_delta_net_chunked(...); +// else +// launch_gated_delta_net(...); // existing sequential kernel +// ----------------------------------------------------------------------------- +template +void launch_gated_delta_net_chunked( + const float * q_d, const float * k_d, const float * v_d, + const float * g_d, const float * b_d, const float * s_d, + float * dst_d, float * trace_d, + int64_t S_v, int64_t H, int64_t n_tokens, + int64_t sq1, int64_t sq2, + int64_t sv1, int64_t sv2, + int64_t sb1, int64_t sb2, + float scale, cudaStream_t stream) { + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + const int num_warps = 4; + dim3 grid_dims((unsigned) H, 1, 1); + dim3 block_dims((unsigned) (warp_size <= S_v ? warp_size : S_v), num_warps, 1); + + switch (S_v) { + case 16: + gated_delta_net_chunked_cuda<16, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, + sq1, sq2, sv1, sv2, sb1, sb2, scale); + break; + case 32: + gated_delta_net_chunked_cuda<32, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, + sq1, sq2, sv1, sv2, sb1, sb2, scale); + break; + case 64: + gated_delta_net_chunked_cuda<64, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, + sq1, sq2, sv1, sv2, sb1, sb2, scale); + break; + // case 128: needs the register-shard S variant (smem S is 64KB); see DESIGN.md 3. + default: + GGML_ABORT("gated_delta_net_chunked: unsupported S_v (use sequential path)"); + break; + } +} + +// explicit instantiations so the dispatcher in gated_delta_net.cu can link them +template void launch_gated_delta_net_chunked(const float*,const float*,const float*,const float*,const float*,const float*,float*,float*,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,float,cudaStream_t); +template void launch_gated_delta_net_chunked(const float*,const float*,const float*,const float*,const float*,const float*,float*,float*,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,float,cudaStream_t); + +#endif // skeleton guard diff --git a/ggml/src/ggml-cuda/gdn_chunked_oracle.py b/ggml/src/ggml-cuda/gdn_chunked_oracle.py new file mode 100644 index 000000000000..e80481ab7645 --- /dev/null +++ b/ggml/src/ggml-cuda/gdn_chunked_oracle.py @@ -0,0 +1,67 @@ +import numpy as np +np.random.seed(0) + +# Gated DeltaNet reference (matches ggml CPU ops.cpp gated_delta_net_one_chunk). +# State S is D x D, S[i,j] (i=key dim, j=value dim). Per token (vector gate / KDA): +# S <- diag(a_t) @ S (a_t[i] = exp(g_t[i])) [decay rows by a] +# u_t[j] = sum_i S[i,j] k_t[i] [readout on DECAYED state] +# delta_t[j] = beta_t (v_t[j] - u_t[j]) +# S[i,j] += k_t[i] delta_t[j] [rank-1 update] +# o_t[j] = scale * sum_i S[i,j] q_t[i] [POST-update readout] +def sequential(q, k, v, g, beta, S0, scale): + N, D = q.shape + S = S0.astype(np.float64).copy() + O = np.zeros((N, D)) + for t in range(N): + a = np.exp(g[t]) # (D,) decay per key-dim i + S = (a[:, None]) * S # decay rows + u = S.T @ k[t] # (D,) over j + delta = beta[t] * (v[t] - u) # (D,) + S = S + np.outer(k[t], delta) # rank-1 + O[t] = scale * (S.T @ q[t]) # post-update + return O, S + +# Chunked (single chunk = whole block), per agent B's design. Inclusive cumulative +# decay A_r[i] = prod_{s<=r} a_s[i]. Deflate by A to factor the decay out. +def chunked(q, k, v, g, beta, S0, scale): + N, D = q.shape + S0 = S0.astype(np.float64) + a = np.exp(g.astype(np.float64)) # (N,D) + A = np.cumprod(a, axis=0) # (N,D) inclusive cumulative decay + Kbar = A * k # (N,D) "later token" (carries A_r) + Qbar = A * q + Ktil = k / A # (N,D) "earlier token" (carries 1/A_s) + # carry from incoming state + U_carry = Kbar @ S0 # (N,D) over j + O_carry = scale * (Qbar @ S0) + # pairwise decay s->r is A_r/A_s => Kbar_r . Ktil_s (bounded for s forward substitution) + Dmat = np.linalg.solve(np.eye(N) + T, rhs) + # intra-chunk output: lower-tri (incl diagonal) of (Qbar Ktil^T) + QK = np.tril(Qbar @ Ktil.T, k=0) # (N,N) + O = O_carry + scale * (QK @ Dmat) + # carry-out state: S_out = diag(A_{N-1}) S0 + Kw^T @ Dmat, Kw_r = (A_{N-1}/A_r) k_r + Aend = A[-1] # (D,) + Kw = (Aend[None, :] / A) * k # (N,D) + S_out = Aend[:, None] * S0 + Kw.T @ Dmat + return O, S_out + +for trial in range(5): + N = np.random.randint(2, 17) # block up to 16 + D = 64 + q = np.random.randn(N, D)*0.5 + k = np.random.randn(N, D)*0.5 + v = np.random.randn(N, D)*0.5 + g = -np.abs(np.random.randn(N, D))*0.1 # gates: log-decay <=0 (a<=1) + beta = np.random.rand(N) + S0 = np.random.randn(D, D)*0.3 + scale = 1.0/np.sqrt(D) + Os, Ss = sequential(q,k,v,g,beta,S0,scale) + Oc, Sc = chunked(q,k,v,g,beta,S0,scale) + eO = np.abs(Os-Oc).max() + eS = np.abs(Ss-Sc).max() + print(f"trial {trial}: N={N} D={D} | max|dO|={eO:.2e} max|dS|={eS:.2e} | {'OK' if max(eO,eS)<1e-9 else 'MISMATCH'}") diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index cb78b4067e61..dcbe0149b743 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -397,6 +397,20 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + // chunk-parallel GDN for a multi-token block (DFlash verify): portable (pure ggml ops) verify of a + // small block. Opt-in via LLAMA_GDN_CHUNKED=1. Only when no per-token trace is requested (the + // rewind needs per-token state, see GDN_CHUNKED_BRINGUP.md), n_seqs==1, vector/scalar gate, N>1. + // Capped to small N: the tiling unrolls ceil(N/16) subgraphs per layer, so a long prefill (the + // 512-token ubatch reserve) would explode the static graph -> fall through to the fused op there. + static const bool gdn_chunked = getenv("LLAMA_GDN_CHUNKED") && + std::string(getenv("LLAMA_GDN_CHUNKED")) != "0"; + static const int64_t gdn_chunked_maxn = + getenv("LLAMA_GDN_CHUNKED_MAXN") ? atoll(getenv("LLAMA_GDN_CHUNKED_MAXN")) : 128; + if (gdn_chunked && gdn_trace == nullptr && n_seqs == 1 && n_tokens > 1 && n_tokens <= gdn_chunked_maxn + && (g->ne[0] == S_v || g->ne[0] == 1)) { + return build_delta_net_chunked(q, k, v, g, b, s, il); + } + ggml_tensor * result; if (gdn_trace != nullptr) { // per-token state trace requested (DFlash speculative rewind on recurrent targets) @@ -427,6 +441,100 @@ std::pair llm_build_delta_net_base::build_delta_ne return {output, new_state}; } +// One chunk of the chunk-parallel GDN. Returns attn O [token,S_v,H] and carry-out state S_out +// [i,j,H] (both 3D, pre-reshape) so the tiling wrapper can concat tokens and thread state. Math +// validated bitwise vs ggml_gated_delta_net in tests/test-gdn-chunked.cpp (vector+scalar gate, GQA). +std::pair llm_build_delta_net_base::build_delta_net_one_chunk( + ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * b, ggml_tensor * s) { + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t N = v->ne[2]; + const int64_t Hk = q->ne[1]; + const float scale = 1.0f/sqrtf((float)S_v); + + auto toDNH = [&](ggml_tensor * x){ return ggml_cont(ctx0, ggml_permute(ctx0, x, 0,2,1,3)); }; // [S,Hx,N]->[S,N,Hx] + ggml_tensor * qp = toDNH(q), * kp = toDNH(k), * vp = toDNH(v), * gp = toDNH(g); + if (Hk != H) { // GQA: broadcast q/k heads Hk->H (interleaved) + ggml_tensor * tgt = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, S_v, N, H); + qp = ggml_repeat(ctx0, qp, tgt); + kp = ggml_repeat(ctx0, kp, tgt); + } + ggml_tensor * gN = ggml_cont(ctx0, ggml_permute(ctx0, gp, 1,0,2,3)); // [N,S_v,H] + ggml_tensor * Lp = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cumsum(ctx0, gN), 1,0,2,3)); // [S_v,N,H] + ggml_tensor * A = ggml_exp(ctx0, Lp); + ggml_tensor * Ainv = ggml_exp(ctx0, ggml_neg(ctx0, Lp)); + // full-size tensor first so a scalar gate's A=[1,N,H] broadcasts into kp/qp=[S_v,N,H] + ggml_tensor * Kbar = ggml_mul(ctx0, kp, A); + ggml_tensor * Qbar = ggml_mul(ctx0, qp, A); + ggml_tensor * Ktil = ggml_mul(ctx0, kp, Ainv); + ggml_tensor * betaNH = ggml_cont(ctx0, ggml_permute(ctx0, b, 0,2,1,3)); // [1,N,H] + ggml_tensor * Ucar = ggml_mul_mat(ctx0, s, Kbar); // [j,r,H] + ggml_tensor * Ocar = ggml_scale(ctx0, ggml_mul_mat(ctx0, s, Qbar), scale); + ggml_tensor * rhs = ggml_mul(ctx0, ggml_sub(ctx0, vp, Ucar), betaNH); // [S_v(j),N(r),H] + ggml_tensor * KK = ggml_mul_mat(ctx0, Ktil, Kbar); // [s,r,H] + ggml_tensor * Tlo = ggml_tri(ctx0, ggml_mul(ctx0, KK, betaNH), GGML_TRI_TYPE_LOWER); + // identity [N,N] (solve_tri needs an explicit unit diagonal): ones via exp(0*beta), then diag + ggml_tensor * b0 = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_view_2d(ctx0, betaNH, 1, N, betaNH->nb[1], 0))); + ggml_tensor * ones = ggml_exp(ctx0, ggml_scale(ctx0, b0, 0.0f)); // [N,1] all-ones + ggml_tensor * Imat = ggml_diag(ctx0, ones); // [N,N] + ggml_tensor * ITm = ggml_add(ctx0, Tlo, Imat); // I+T, bcast over H + ggml_tensor * Dmat = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_solve_tri(ctx0, ITm, rhs, true, true, false))); // [N(token),S_v(j),H] + ggml_tensor * QKlo = ggml_tri(ctx0, ggml_mul_mat(ctx0, Ktil, Qbar), GGML_TRI_TYPE_LOWER_DIAG); + ggml_tensor * intra= ggml_scale(ctx0, ggml_mul_mat(ctx0, QKlo, Dmat), scale); // [t,j,H] + ggml_tensor * O = ggml_add(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, Ocar)), intra); // [t,j,H] + ggml_tensor * AendB= ggml_cont(ctx0, ggml_view_3d(ctx0, A, A->ne[0], 1, H, A->nb[1], A->nb[2], (N-1)*A->nb[1])); + ggml_tensor * S0dec= ggml_mul(ctx0, s, AendB); + ggml_tensor * Kw = ggml_mul(ctx0, kp, ggml_mul(ctx0, Ainv, AendB)); + ggml_tensor * upd = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, Kw)), Dmat); + ggml_tensor * S_out= ggml_add(ctx0, S0dec, upd); // [i,j,H] + return {O, S_out}; +} + +std::pair llm_build_delta_net_base::build_delta_net_chunked( + ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * b, ggml_tensor * s, int il) { + // Chunk-parallel GDN. Pure ggml ops -> portable (CUDA/Metal/Vulkan/WebGPU) for backends without a + // fused GDN kernel. The N tokens are tiled into blocks of C, carrying the recurrent state forward, + // so Ainv=exp(-cumsum(g)) stays bounded per block: single-chunk over a long prefill overflows fp32. + // For a verify block (N<=C) the loop runs once -> identical to the original single-chunk path. + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t N = v->ne[2]; + const int64_t Hk = q->ne[1]; + // Gate may be per-channel (KDA vector, g->ne[0]==S_v) or per-head scalar (Gated DeltaNet, ==1). + GGML_ASSERT((g->ne[0] == S_v || g->ne[0] == 1) && "chunked GDN gate must be vector(S_v) or scalar(1)"); + GGML_ASSERT(v->ne[3] == 1 && "chunked GDN path is n_seqs==1 only"); + + // Block size. The deflation A=exp(+/-cumsum(g)) has a wide dynamic range; strong-decay heads + // overflow fp32 precision when a chunk is too long (empirically garbles at C>=16 on Qwen3.5, + // clean at C<=8). 8 is the safe default; override with LLAMA_GDN_CHUNK_SIZE. + int64_t C = 8; + if (const char * e = getenv("LLAMA_GDN_CHUNK_SIZE")) { int64_t c = atoll(e); if (c >= 1) C = c; } + + if (getenv("LLAMA_GDN_CHUNKED_VERBOSE")) { + static bool once = false; + if (!once) { once = true; fprintf(stderr, "[GDN-CHUNKED] active: N=%lld C=%lld chunks=%lld S_v=%lld H=%lld Hk=%lld gate=%s\n", + (long long)N, (long long)C, (long long)((N+C-1)/C), (long long)S_v, (long long)H, (long long)Hk, g->ne[0]==1?"scalar":"vector"); } + } + + ggml_tensor * S = s; // carried recurrent state [S_v,S_v,H,1] + ggml_tensor * O_full = nullptr; // attn output, concatenated over tokens [token,S_v,H] + for (int64_t start = 0; start < N; start += C) { + const int64_t cn = std::min(C, N - start); + auto slc = [&](ggml_tensor * x){ + return ggml_view_4d(ctx0, x, x->ne[0], x->ne[1], cn, 1, x->nb[1], x->nb[2], x->nb[3], start*x->nb[2]); + }; + auto oc = build_delta_net_one_chunk(slc(q), slc(k), slc(v), slc(g), slc(b), S); + O_full = O_full ? ggml_concat(ctx0, O_full, oc.first, 0) : oc.first; // concat tokens on ne0 + S = ggml_reshape_4d(ctx0, oc.second, S_v, S_v, H, 1); + } + ggml_tensor * output = ggml_reshape_4d(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, O_full, 2,0,1,3)), S_v, H, N, 1); + ggml_tensor * new_state = S; + cb(output, LLAMA_TENSOR_NAME_FGDN_CH, il); + return {output, new_state}; +} + std::pair llm_build_delta_net_base::build_delta_net( ggml_tensor * q, ggml_tensor * k, diff --git a/src/models/models.h b/src/models/models.h index e7efcc4823af..097fc6063940 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -55,6 +55,27 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); + // chunk-parallel GDN for a multi-token block (DFlash verify): chunked delta-rule built from + // ggml ops (cumsum/tri/solve_tri/mul_mat) - cheap verify of an N-token block + portable. + // Returns {output [S_v,H_v,N,1], new_state [S_v,S_v,H_v,1]}. n_seqs==1, KDA gate only. + std::pair build_delta_net_chunked( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // one block of the tiled chunk-parallel GDN; returns attn [token,S_v,H] and state [i,j,H] + std::pair build_delta_net_one_chunk( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s); + // choose one of two implementations above based on the number of tokens std::pair build_delta_net( ggml_tensor * q, diff --git a/tests/test-gdn-chunked.cpp b/tests/test-gdn-chunked.cpp new file mode 100644 index 000000000000..e64080b0b6eb --- /dev/null +++ b/tests/test-gdn-chunked.cpp @@ -0,0 +1,235 @@ +// Standalone CPU validation: chunked Gated-DeltaNet (ggml-op decomposition) vs the reference +// sequential ggml_gated_delta_net. Build: see the compile cmd at the bottom of this file. +// Goal: bitwise-ish match (max|diff| < 1e-3) so the chunked path can replace the sequential +// GDN kernel on the DFlash verify (block of N tokens), affording larger blocks -> ~3x. +#include "ggml.h" +#include "ggml-cpu.h" +#include "ggml-backend.h" +#include +#include +#include +#include +#include +#include + +// Leaf inputs are created in a no_alloc context (so the graph can run on CUDA); their host data is +// staged here and uploaded with ggml_backend_tensor_set after the backend allocates the graph. +struct Pending { ggml_tensor * t; std::vector data; }; +static std::vector g_pending; + +static ggml_tensor * rnd(ggml_context * c, int64_t a,int64_t b,int64_t d,int64_t e, std::mt19937 & g, float sc, float bias=0.f){ + ggml_tensor * t = ggml_new_tensor_4d(c, GGML_TYPE_F32, a,b,d,e); + std::vector h(ggml_nelements(t)); + std::normal_distribution N(0,1); + for (auto & x : h) x = N(g)*sc + bias; + g_pending.push_back({t, std::move(h)}); + return t; +} + +// L2-normalize each ne0 vector to unit norm (delta-net normalizes q/k). Without this, random keys +// have ||k||^2 ~ S_v*sc^2 >> 2, so beta*||k||^2 violates the delta-rule stability bound and the TRUE +// recurrence diverges -> ref and chunked blow up in the unstable directions for long sequences. +static void l2norm_rows(ggml_tensor * t){ + for (auto & pd : g_pending) if (pd.t == t) { + const int64_t S = t->ne[0]; const int64_t rows = (int64_t)pd.data.size()/S; + for (int64_t r=0;r0 ? (float)(1.0/sqrt(n)) : 0.f; for(int64_t i=0;ine[0], H = v->ne[1], N = v->ne[2]; + const int64_t Hk = q->ne[1]; // GQA: q/k have Hk heads, broadcast (interleaved iv%Hk) to H v-heads + const float scale = 1.0f/sqrtf((float)S_v); + // reorg [S_v,Hx,N] -> [S_v,N,Hx] + auto toDNH = [&](ggml_tensor * x){ return ggml_cont(c, ggml_permute(c, x, 0,2,1,3)); }; + ggml_tensor * qp = toDNH(q), * kp = toDNH(k), * vp = toDNH(v), * gp = toDNH(g); // q/k:[S_v,N,Hk] v/g:[S_v,N,H] + if (Hk != H) { // expand q/k heads Hk->H, interleaved (ggml_repeat tiles ne2: h -> h%Hk) + ggml_tensor * tgt = ggml_new_tensor_3d(c, GGML_TYPE_F32, S_v, N, H); + qp = ggml_repeat(c, qp, tgt); + kp = ggml_repeat(c, kp, tgt); + } + // A_r[i] = prod_{s<=r} exp(g): cumsum over tokens. cumsum is ne0-only -> put tokens on ne0. + ggml_tensor * gN = ggml_cont(c, ggml_permute(c, gp, 1,0,2,3)); // [N,S_v,H] + ggml_tensor * L = ggml_cumsum(c, gN); // [N,S_v,H] cumulative log-decay + ggml_tensor * Lp = ggml_cont(c, ggml_permute(c, L, 1,0,2,3)); // [S_v,N,H] + ggml_tensor * A = ggml_exp(c, Lp); // [S_v|1, N, H] + ggml_tensor * Ainv = ggml_exp(c, ggml_neg(c, Lp)); // 1/A + // full-size tensor first so a scalar gate's A=[1,N,H] broadcasts into kp/qp=[S_v,N,H] + ggml_tensor * Kbar = ggml_mul(c, kp, A); + ggml_tensor * Qbar = ggml_mul(c, qp, A); + ggml_tensor * Ktil = ggml_mul(c, kp, Ainv); + // beta -> [1,N,H] (broadcast over dim) + ggml_tensor * betaNH = ggml_cont(c, ggml_permute(c, beta, 0,2,1,3)); // [1,N,H] + // U_carry[j,r] = sum_i Kbar[i,r] S0[i,j] ; O_carry = scale Qbar . S0 + ggml_tensor * Ucar = ggml_mul_mat(c, S0, Kbar); // [j, r, H] (contract i=ne0) + ggml_tensor * Ocar = ggml_scale(c, ggml_mul_mat(c, S0, Qbar), scale); // [j,r,H] + // rhs[j,r] = beta_r (v[j,r] - Ucar[j,r]) (vp is [S_v(j),N(r),H]) + ggml_tensor * rhs = ggml_mul(c, ggml_sub(c, vp, Ucar), betaNH); // [S_v(j),N(r),H] + // KK[s,r] = sum_i Ktil[i,s] Kbar[i,r] = Kbar_r . Ktil_s. solve_tri wants A[ne0=s, ne1=r] + // = (I+T)[r,s], so KK is already in the right orientation (no transpose). beta_r over ne1=r. + ggml_tensor * KK = ggml_mul_mat(c, Ktil, Kbar); // [s,r,H] + ggml_tensor * Tfull= ggml_mul(c, KK, betaNH); // [s,r,H] * beta_r(ne1) + ggml_tensor * Tlo = ggml_tri(c, Tfull, GGML_TRI_TYPE_LOWER); // keep s works with no_alloc + // / CUDA): all-ones [N,1] via exp(0*beta) then ggml_diag -> [N,N], broadcast over H. + ggml_tensor * b0 = ggml_cont(c, ggml_transpose(c, ggml_view_2d(c, betaNH, 1, N, betaNH->nb[1], 0))); + ggml_tensor * ones = ggml_exp(c, ggml_scale(c, b0, 0.0f)); // [N,1] all-ones + ggml_tensor * Imat = ggml_diag(c, ones); // [N,N] + ggml_tensor * IT = ggml_add(c, Tlo, Imat); // [s,r,H] = (I+T) in solve orientation + // Dmat = (I+T)^-1 rhs. b = rhs [S_v(j), N(r), H] (ne1=N matches A->ne1). result [S_v(j),N(r),H]. + ggml_tensor * Dmat = ggml_solve_tri(c, IT, rhs, true, true, false); // [S_v(j), N(token), H] + Dmat = ggml_cont(c, ggml_transpose(c, Dmat)); // -> [N(token), S_v(j), H] + // O = O_carry + scale * tril(Qbar.Ktil^T incl-diag) @ Dmat. QK[s,r]=Qbar_r.Ktil_s. + ggml_tensor * QK = ggml_mul_mat(c, Ktil, Qbar); // [s,r,H] + ggml_tensor * QKlo = ggml_tri(c, QK, GGML_TRI_TYPE_LOWER_DIAG); // keep s<=r (incl diagonal) + // intra[r,j] = sum_s QKlo[s,r] Dmat[s,j]. mul_mat(QKlo[s,r,H], Dmat[token=s,j,H]) -> [r,j,H] + ggml_tensor * intra = ggml_mul_mat(c, QKlo, Dmat); // contract s=ne0 -> [r,j,H] + intra = ggml_scale(c, intra, scale); // [r,j,H] + // O_carry is [j,r,H]; intra is [r,j,H] -> transpose O_carry + ggml_tensor * OcarT = ggml_cont(c, ggml_transpose(c, Ocar)); // [r,j,H] + ggml_tensor * O = ggml_add(c, OcarT, intra); // [r(N),j(S_v),H] attn per token + *out_attn = O; + // S_out[i,j] = A_end[i] S0[i,j] + sum_r Kw[i,r] Dmat[r,j], Kw_r=(A_end/A_r) k_r + ggml_tensor * Aend = ggml_view_3d(c, A, A->ne[0], 1, H, A->nb[1], A->nb[2], (N-1)*A->nb[1]); // [S_v|1,1,H] + ggml_tensor * AendB = ggml_cont(c, Aend); + ggml_tensor * S0dec = ggml_mul(c, S0, AendB); // [i,j,H] * A_end[i](ne0,bcast over j) + ggml_tensor * Kw = ggml_mul(c, kp, ggml_mul(c, Ainv, AendB)); // (A_end/A_r) k_r [S_v(i),N(r),H] + // Kw^T @ Dmat: sum_r Kw[i,r] Dmat[r,j]. mul_mat contracts ne0 -> need Kw[r,i] and Dmat[r,j] + ggml_tensor * KwT = ggml_cont(c, ggml_transpose(c, Kw)); // [N(r),S_v(i),H] + ggml_tensor * upd = ggml_mul_mat(c, KwT, Dmat); // contract r -> [i,j,H] + *out_state = ggml_add(c, S0dec, upd); // [i,j,H] +} + +// Multi-chunk tiling: split the N tokens into blocks of C, run build_chunked per block carrying the +// recurrent state forward. Bounds Ainv=exp(-cumsum(g)) to <=C tokens -> numerically stable for long +// prefill (single-chunk overflows). ceil(N/C) chunks, unrolled at graph-build time (N is static). +static void build_chunked_tiled(ggml_context * c, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * beta, ggml_tensor * S0, int64_t C, + ggml_tensor ** out_attn, ggml_tensor ** out_state) { + const int64_t N = v->ne[2]; + ggml_tensor * S = S0; + ggml_tensor * attn_full = nullptr; + for (int64_t start = 0; start < N; start += C) { + const int64_t cn = std::min(C, N - start); + auto slice = [&](ggml_tensor * x){ + return ggml_view_4d(c, x, x->ne[0], x->ne[1], cn, 1, x->nb[1], x->nb[2], x->nb[3], start*x->nb[2]); + }; + ggml_tensor *ac=nullptr,*sc=nullptr; + build_chunked(c, slice(q), slice(k), slice(v), slice(g), slice(beta), S, &ac, &sc); + attn_full = attn_full ? ggml_concat(c, attn_full, ac, 0) : ac; // O is [token, S_v, H]; concat tokens on ne0 + S = ggml_reshape_4d(c, sc, S0->ne[0], S0->ne[0], S0->ne[2], 1); // match model: thread 4D state + } + *out_attn = attn_full; + *out_state = S; +} + +static int run_case(int64_t S_v, int64_t H, int64_t N, std::mt19937 & rng, int64_t Hk=-1, bool scalar_gate=false, int64_t chunk=0){ + if (Hk < 0) Hk = H; + g_pending.clear(); + size_t mem = 64ull*1024*1024; // metadata only; tensor data lives in the backend buffer (no_alloc) + ggml_init_params ip{mem, nullptr, true}; + ggml_context * c = ggml_init(ip); + ggml_tensor * q = rnd(c,S_v,Hk,N,1,rng,0.5f); + ggml_tensor * k = rnd(c,S_v,Hk,N,1,rng,0.5f); + l2norm_rows(q); l2norm_rows(k); // delta-net normalizes q,k -> stable recurrence + ggml_tensor * v = rnd(c,S_v,H,N,1,rng,0.5f); + // gate: vector (KDA, [S_v,H,N]) or per-head scalar (Gated DeltaNet, [1,H,N]) + ggml_tensor * g = rnd(c, scalar_gate ? 1 : S_v, H, N, 1, rng, 0.1f, -0.2f); // log-decay <0 + g = ggml_neg(c, ggml_abs(c, g)); // ensure <=0 -> a<=1 + ggml_tensor * beta = rnd(c,1,H,N,1,rng,0.0f,0.5f); + ggml_tensor * S0 = rnd(c,S_v,S_v,H,1,rng,0.3f); + + ggml_tensor * ref = ggml_gated_delta_net(c, q,k,v,g,beta,S0); // [S_v*H, N+S_v] + ggml_tensor *ca=nullptr,*cs=nullptr; + if (chunk > 0) build_chunked_tiled(c,q,k,v,g,beta,S0,chunk,&ca,&cs); + else build_chunked(c,q,k,v,g,beta,S0,&ca,&cs); + // replicate the model's EXACT output op: permute concatenated O [t,j,H] -> [S_v,H,N,1] + cont + ca = ggml_reshape_4d(c, ggml_cont(c, ggml_permute(c, ca, 2,0,1,3)), S_v, H, N, 1); + cs = ggml_reshape_4d(c, cs, S_v, S_v, H, 1); + + ggml_cgraph * gf = ggml_new_graph_custom(c, 8192, false); + ggml_build_forward_expand(gf, ref); + ggml_build_forward_expand(gf, ca); + ggml_build_forward_expand(gf, cs); + ggml_backend_t be = make_backend(); + ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(be)); + ggml_gallocr_alloc_graph(galloc, gf); + for (auto & pd : g_pending) ggml_backend_tensor_set(pd.t, pd.data.data(), 0, pd.data.size()*sizeof(float)); + ggml_backend_graph_compute(be, gf); + + // read outputs back to host (works for CPU and CUDA) + std::vector hr(ggml_nelements(ref)), hca(ggml_nelements(ca)), hcs(ggml_nelements(cs)); + ggml_backend_tensor_get(ref, hr.data(), 0, hr.size()*sizeof(float)); + ggml_backend_tensor_get(ca, hca.data(),0, hca.size()*sizeof(float)); + ggml_backend_tensor_get(cs, hcs.data(),0, hcs.size()*sizeof(float)); + const int64_t ca_n0=ca->ne[0], ca_n1=ca->ne[1], cs_n0=cs->ne[0], cs_n1=cs->ne[1]; + // reference attn: cols 0..N-1 of [S_v*H, N+S_v]; per (head h, token t): ref[ h*S_v + j , t ] + auto refAttn = [&](int h,int t,int j){ return hr[(int64_t)t*(S_v*H) + h*S_v + j]; }; + auto refState= [&](int h,int i,int j){ return hr[(int64_t)S_v*H*N + (int64_t)h*S_v*S_v + (int64_t)j*S_v + i]; }; + // chunked attn O[r,j,H]; state [i,j,H] + // ca is now model layout [S_v,H,N,1]: element[j,h,t] = j + h*S_v + t*S_v*H + (void)ca_n0;(void)ca_n1; + auto caV=[&](int h,int t,int j){ return hca[ (int64_t)t*S_v*H + (int64_t)h*S_v + j ]; }; + auto csV=[&](int h,int i,int j){ return hcs[ (int64_t)h*cs_n0*cs_n1 + (int64_t)j*cs_n0 + i ]; }; + double mA=0, mS=0; + for(int h=0;h0) snprintf(ch,sizeof ch,"C=%lld",(long long)chunk); else snprintf(ch,sizeof ch,"single"); + printf("S_v=%lld H=%lld N=%lld Hk=%lld gate=%-6s %-7s: max|dAttn|=%.2e max|dState|=%.2e -> %s\n", + (long long)S_v,(long long)H,(long long)N,(long long)Hk, scalar_gate?"scalar":"vector", ch, mA, mS, ok?"PASS":"FAIL"); + ggml_gallocr_free(galloc); + ggml_backend_free(be); + ggml_free(c); + return ok; +} + +int main(){ + std::mt19937 rng(0); + int all = 1; + { ggml_backend_t b = make_backend(); printf("backend: %s\n", ggml_backend_name(b)); ggml_backend_free(b); } + printf("== vector (KDA) gate ==\n"); + for (int64_t S_v : {64, 128}) for (int64_t N : {1, 2, 5, 8, 12, 16}) all &= run_case(S_v, 4, N, rng); + // GQA: H_v=4, H_k=2 and H_k=1 (q/k broadcast interleaved) + for (int64_t N : {1, 5, 16}) { all &= run_case(64, 4, N, rng, 2); all &= run_case(128, 4, N, rng, 1); } + printf("== scalar (Gated DeltaNet) gate -- Qwen3.5 ==\n"); + for (int64_t S_v : {64, 128}) for (int64_t N : {1, 2, 5, 8, 12, 16}) all &= run_case(S_v, 4, N, rng, -1, true); + // GQA + scalar gate (Qwen3.5 is GQA: H_v=32, H_k=16 -> ratio 2) + for (int64_t N : {1, 5, 16}) { all &= run_case(64, 4, N, rng, 2, true); all &= run_case(128, 4, N, rng, 1, true); } + printf("== multi-chunk tiling (long sequences; single-chunk overflows) ==\n"); + // N far beyond a verify block; C=8 chunks carry state. C must stay small: the deflation + // A=exp(+/-cumsum(g)) has wide dynamic range, so strong-decay heads lose fp32 precision when a + // chunk is too long (the model garbles at C>=16 on Qwen3.5; the deployed default is C=8). + for (int64_t N : {32, 64, 128, 200}) { + all &= run_case(128, 4, N, rng, -1, false, 8); // vector + all &= run_case(128, 4, N, rng, -1, true, 8); // scalar (Qwen3.5) + } + all &= run_case(128, 4, 128, rng, 1, true, 8); // GQA + scalar + tiled + all &= run_case(64, 4, 96, rng, 2, false, 8); // GQA + vector + tiled + // sanity: tiling with C>=N must equal the single-chunk path + all &= run_case(128, 4, 12, rng, -1, true, 64); + printf("%s\n", all ? "ALL PASS" : "SOME FAIL"); + return all ? 0 : 1; +} From e654a71ab66ee4bbed3669edea455df5d7040191 Mon Sep 17 00:00:00 2001 From: Aleksandr Nikolich Date: Fri, 12 Jun 2026 18:11:16 +0200 Subject: [PATCH 4/4] server: fix DFlash speculative decoding and add GPU greedy verify The server DFlash path was wired but crashed on every request, because the server processes a prompt in several ubatches while speculative-simple does it in one decode: - index the target features by absolute position and accumulate across ubatches (a chunked prompt previously left the first draft reading stale features), and read the [n_total-n_new, n_total) slice in the drafter - reset dflash_n_past per request in begin() (it carried over between requests) - set the view buffers in dflash_promote_state so the trace/promote copy also runs on the CPU backend (was a CUDA-only path, asserted on a null buffer) Also add GPU greedy verify: for a pure-greedy request the target emits an on-device argmax of the verify block and the host skips the per-block logits download + CPU sampler. Enabled only after the first token is sampled from logits, reset per request; non-greedy requests fall back to the host sampler. Lossless (byte-identical to the host-verify path). ~2.0x -> 2.4x on reasoning. --- common/speculative.cpp | 6 +++ src/llama-context.cpp | 31 +++++++++++-- tools/server/server-context.cpp | 78 ++++++++++++++++++++++++++++++++- 3 files changed, 109 insertions(+), 6 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 378776179b34..1bc4afb7d2fb 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -763,6 +763,12 @@ struct common_speculative_state_dflash : public common_speculative_state { void begin(const llama_tokens & prompt) override { GGML_UNUSED(prompt); + // New sequence (server: new request on the slot): the target features are re-extracted from + // position 0 and the DFlash device cross cache is rewritten from there, so reset the running + // count. Without this, dflash_n_past carries over from the previous request and the first + // draft computes n_new = n - dflash_n_past < 1 -> GGML_ASSERT(n_new >= 1) abort. + dflash_n_past = 0; + accumulated_ctx.clear(); } void draft( diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ebf6deda4057..afc97401411c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1326,9 +1326,12 @@ void llama_context::dflash_append_features(const float * feat, int32_t n_new, in const auto & hparams = model.hparams; const size_t n_feat = hparams.dflash_target_layer_ids.size() * hparams.n_embd; - dflash.feat_staging.assign(feat, feat + n_feat * n_new); + // `feat` is the position-indexed target-feature buffer (see extract_dflash_features); the new + // tokens to encode are the [n_total - n_new, n_total) slice, not the first n_new entries. + const int32_t feat_pos0 = n_total - n_new; + dflash.feat_staging.assign(feat + (size_t) feat_pos0 * n_feat, feat + (size_t) n_total * n_feat); dflash.feat_n = n_new; - dflash.feat_pos0 = n_total - n_new; + dflash.feat_pos0 = feat_pos0; dflash.feat_bucket = n_new <= 8 ? 8 : 256; // graph rebuilds when the bucket changes (prompt round) // bucketed mask/position sizing, same scheme as the legacy host-mediated path @@ -1536,16 +1539,23 @@ bool llama_context::dflash_promote_state(int32_t idx, llama_pos pos_last, llama_ } } + // views created in a no_alloc context don't carry a buffer; set it to the parent's so the + // backend copy can resolve buffer_is_host (the CPU backend asserts on a null buffer; the CUDA + // path happened to skip the check). Required for the trace/promote path to run on CPU. ggml_tensor * src_s = ggml_view_1d(cg.get(), dflash.trace_s[il], hparams.n_embd_s(), (size_t) idx * dflash.trace_s[il]->nb[1]); ggml_tensor * dst_s = ggml_view_1d(cg.get(), s_l, hparams.n_embd_s(), (size_t) cell * s_l->nb[1]); + src_s->buffer = dflash.trace_s[il]->buffer; + dst_s->buffer = s_l->buffer; ggml_backend_tensor_copy_async(be, be, src_s, dst_s); ggml_tensor * src_r = ggml_view_1d(cg.get(), dflash.trace_r[il], hparams.n_embd_r(), (size_t) idx * dflash.trace_r[il]->nb[1]); ggml_tensor * dst_r = ggml_view_1d(cg.get(), r_l, hparams.n_embd_r(), (size_t) cell * r_l->nb[1]); + src_r->buffer = dflash.trace_r[il]->buffer; + dst_r->buffer = r_l->buffer; ggml_backend_tensor_copy_async(be, be, src_r, dst_r); } @@ -2997,7 +3007,19 @@ void llama_context::extract_dflash_features(const llama_ubatch & ubatch) { const size_t n_layers = dflash.extract_tensors.size(); const int64_t n_embd_concat = n_embd * n_layers; - dflash.target_features.resize(n_embd_concat * n_tokens); + // Index the per-token features by their ABSOLUTE position, accumulating across ubatches. The + // draft (dflash_append_features) reads the [n_total - n_new, n_total) slice, so a prompt processed + // in multiple ubatches (the server chunks it) and a partial-accept verify block both land at the + // right positions. The old resize(n_tokens) overwrote the buffer with only the last ubatch, so a + // chunked prompt left the first draft reading garbage -> argmax -1 -> "invalid token -1" decode fail. + llama_pos pos_max = -1; + for (int64_t i = 0; i < n_tokens; ++i) { + pos_max = std::max(pos_max, ubatch.pos[i]); + } + const size_t need = (size_t)(pos_max + 1) * n_embd_concat; + if (dflash.target_features.size() < need) { + dflash.target_features.resize(need); + } static thread_local std::vector temp_layer_features; temp_layer_features.resize(n_embd * n_tokens); @@ -3020,8 +3042,9 @@ void llama_context::extract_dflash_features(const llama_ubatch & ubatch) { ggml_backend_sched_synchronize(sched.get()); for (int64_t token_idx = 0; token_idx < n_tokens; ++token_idx) { + const llama_pos pos = ubatch.pos[token_idx]; const float * src = temp_layer_features.data() + token_idx * n_embd; - float * dest = dflash.target_features.data() + token_idx * n_embd_concat + layer_idx * n_embd; + float * dest = dflash.target_features.data() + (size_t) pos * n_embd_concat + layer_idx * n_embd; std::memcpy(dest, src, n_embd * sizeof(float)); } } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9d216109fe5c..07df6475c1e8 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -96,6 +96,11 @@ struct server_slot { // the verify decode so a partial acceptance promotes the state at the accepted position on-device // instead of the ~50 MiB host checkpoint round-trip + re-decode (see llama_dflash_promote_state) bool spec_state_trace = false; + // DFlash GPU greedy verify: when the request samples a raw argmax (pure greedy, no penalties/ + // grammar/logit-bias/n_probs), the target decode emits an on-device argmax of the verify block + // and the host skips the ~n_vocab x block logits download + CPU sampler. Lossless for greedy. + bool spec_gpu_verify = false; + bool spec_argmax_active = false; // out_argmax currently enabled on slot.ctx (after the first sample) llama_pos spec_pos0 = 0; // base position of the current verify batch (rewind target = pos0 + accepted) // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state @@ -1366,6 +1371,30 @@ struct server_context_impl { llama_set_sampler(ctx, slot.id, nullptr); } + // DFlash GPU greedy verify: when this request samples a raw argmax (pure greedy: temp<=0, + // no penalties / DRY / grammar / logit-bias / n_probs), turn on the target's on-device + // argmax so the verify reads block_size+1 ints instead of downloading block_size+1 x n_vocab + // logits and running the host sampler. Lossless for greedy; falls back to the host path + // otherwise. Toggled per request (out_argmax change triggers a one-off graph reserve). + { + const auto & sp = task.params.sampling; + const bool pure_greedy = + sp.temp <= 0.0f && sp.penalty_repeat == 1.0f && sp.penalty_freq == 0.0f && + sp.penalty_present == 0.0f && sp.dry_multiplier == 0.0f && sp.grammar.empty() && + sp.logit_bias.empty() && sp.n_probs == 0; + slot.spec_gpu_verify = params_base.speculative.dflash && params_base.n_parallel == 1 && + pure_greedy && + !(getenv("LLAMA_SPEC_NO_GPU_VERIFY") && std::string(getenv("LLAMA_SPEC_NO_GPU_VERIFY")) != "0"); + // out_argmax is turned ON only after the first token is sampled from logits (like + // speculative-simple): the prompt's first sample needs raw logits, the verify loop + // afterwards reads the on-device argmax. Reset here for a fresh request on the slot. + llama_set_out_argmax(slot.ctx, false); + slot.spec_argmax_active = false; + if (slot.spec_gpu_verify) { + SLT_INF(slot, "%s", "DFlash GPU greedy verify enabled (on-device argmax)\n"); + } + } + SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); } else { slot.smpl.reset(); @@ -2956,7 +2985,22 @@ struct server_context_impl { const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx); + llama_token id; + if (slot.spec_gpu_verify && slot.spec_argmax_active) { + // out_argmax already on (a rare empty-draft round after the first token): read the + // target's on-device argmax for this output row instead of the (unavailable) logits + int32_t n_am = 0; + const int32_t * am = llama_get_dflash_argmax(slot.ctx, &n_am); + GGML_ASSERT(am != nullptr && tok_idx < n_am && "DFlash target argmax missing"); + id = (llama_token) am[tok_idx]; + } else { + id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx); + // first sample done from logits -> enable on-device argmax for the verify loop + if (slot.spec_gpu_verify && !slot.spec_argmax_active) { + llama_set_out_argmax(slot.ctx, true); + slot.spec_argmax_active = true; + } + } slot.i_batch = -1; @@ -3018,7 +3062,37 @@ struct server_context_impl { } GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); - auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); + llama_tokens accepted; + if (slot.spec_gpu_verify) { + // greedy accept from the target's on-device argmax (DFlash sets output_all, so + // the argmax row == the verify token's batch index in spec_i_batch). Same + // semantics as common_sampler_sample_and_accept_n with a greedy sampler: take + // the target token at each position up to and including the first mismatch, plus + // a bonus token if every draft matched. Skips the per-block logits download. + int32_t n_am = 0; + const int32_t * am = llama_get_dflash_argmax(slot.ctx, &n_am); + GGML_ASSERT(am != nullptr && "DFlash target argmax missing"); + size_t k = 0; + for (; k < slot.spec_draft.size(); ++k) { + const int32_t row = slot.spec_i_batch[k]; + GGML_ASSERT(row < n_am); + const llama_token t = (llama_token) am[row]; + accepted.push_back(t); + common_sampler_accept(slot.smpl.get(), t, true); + if (slot.spec_draft[k] != t) { + break; + } + } + if (k == slot.spec_draft.size()) { // all drafts matched -> bonus token + const int32_t row = slot.spec_i_batch[k]; + GGML_ASSERT(row < n_am); + const llama_token t = (llama_token) am[row]; + accepted.push_back(t); + common_sampler_accept(slot.smpl.get(), t, true); + } + } else { + accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); + } slot.spec_i_batch.clear(); SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size());