diff --git a/autoresearch.ideas.md b/autoresearch.ideas.md new file mode 100644 index 000000000000..525d6ca1acf3 --- /dev/null +++ b/autoresearch.ideas.md @@ -0,0 +1,9 @@ +- Revisit sizing the DFlash draft context from target feature window + block instead of fixed 2048/4096; one discarded run reduced target_feat set_inputs from ~4 ms to ~0.7 ms but did not improve decode TPS, and it still did not fit fast-rollback persist. +- Persist compression remains the route to fast rollback at 64k/full-draft: q4->tq3 KV, draft n_ctx reduction, and ngld4 did not free enough; ngld3 fit but was slower. Conv persist is still F32 and `ggml_ssm_conv_tree_persist` asserts F32, so F16 conv persist would need kernel + rollback support. +- GPU-side draft logits top-k/argmax is likely the real top-k optimization; CPU heap/fixed-array tweaks and budget K=1 did not move primary TPS enough. +- 2026-04-30 autoresearch Phase 2 findings: + - Multi-prompt batched/exact validation: 7/8 prompts perfect match; 56k prompt had 1 diff at min_margin=0.0045 (step 2). A margin threshold >=0.01 might be safe for exact-skip on tested prompts, but broader validation needed. + - Draft compute buffer reduction (capping draft n_batch to 64) saves ~1.1 GiB GPU memory and is now committed. Does not affect exact-path TPS but enables smaller-n_batch fast-rollback fits. + - Fast-batched + fast-rollback path consistently slower than exact-validation for budgets 8-16 (7.3-7.8 TPS vs ~10 TPS exact). Target_tree decode dominates cost (~70-120ms depending on budget). + - GGML_CUDA_ENABLE_UNIFIED_MEMORY=1 allows budget 22 persist to allocate (cudaMallocManaged), but target_tree decode slows to ~1344ms/step due to page-fault thrashing — unusable. + - To make high-budget fast rollback viable without performance collapse, need either: (a) persist representation smaller than F16 (custom 8-bit kernel support), (b) extract accepted-node state from tree-kernel result tensor instead of separate persist buffer, or (c) free ~500MB-1GB additional contiguous GPU memory. diff --git a/autoresearch.jsonl b/autoresearch.jsonl new file mode 100644 index 000000000000..e05c1589a332 --- /dev/null +++ b/autoresearch.jsonl @@ -0,0 +1,190 @@ +{"type":"config","name":"Optimize DFlash DDTree TPS on Castle","metricName":"tps","metricUnit":"","bestDirection":"higher"} +{"run":1,"commit":"b63fd0d","metric":0,"metrics":{"tps":0},"status":"crash","description":"baseline harness failed before benchmark during sync/ssh setup","timestamp":1777494462566,"segment":0,"confidence":null,"asi":{"hypothesis":"establish Castle e2e baseline using local-to-remote sync and CUDA harness","rollback_reason":"script exited with ssh/rsync code 255 before producing output","next_action_hint":"debug autoresearch.sh sync command; likely rsync invoking ssh differently than direct ssh"}} +{"run":2,"commit":"b63fd0d","metric":0,"metrics":{"tps":0},"status":"crash","description":"baseline harness failed due shell quoting in remote kill command","timestamp":1777494502610,"segment":0,"confidence":null,"asi":{"hypothesis":"fix pkill self-termination in Castle benchmark harness","rollback_reason":"bash expanded $1 locally under set -u in the awk kill snippet","next_action_hint":"escape awk field as \\$1 or avoid awk entirely with pgrep pattern"}} +{"run":3,"commit":"0374171","metric":0.665751,"metrics":{"tps":0.665751,"spec_sec":24.033,"gen_tokens":16,"steps":9,"committed":0,"step_ms":282.13,"pack_ms":10.25,"draft_ms":203.26,"topk_ms":20.85,"exact_ms":47.73,"exact_decode_ms":0,"acceptance":1.889},"status":"keep","description":"baseline Castle e2e chain-only DDTree benchmark with 1024 target feature window","timestamp":1777494603864,"segment":0,"confidence":null,"asi":{"hypothesis":"establish baseline using current stable chain-only exact validation, full target/draft GPU offload, q4 KV, 64k ctx, real rendered prompt","benchmark":"./autoresearch.sh on Castle; gen=16; prompt=/tmp/real_rendered_prompt.txt; LLAMA_DDTREE_TARGET_FEAT_CTX=1024","note":"committed parser returned 0 likely due regex choosing absent/overwritten field; acceptance parsed from e2e output is the useful commit-rate signal"}} +{"type":"config","name":"Optimize DFlash DDTree decode TPS on Castle","metricName":"tps","metricUnit":"","bestDirection":"higher"} +{"run":4,"commit":"6510b31","metric":6.419361,"metrics":{"tps":6.419361,"e2e_tps":1.197874,"spec_sec":26.714,"gen_tokens":32,"steps":18,"committed":0,"step_ms":276.94,"pack_ms":8.88,"draft_ms":201.75,"topk_ms":20.64,"exact_ms":45.63,"exact_decode_ms":0,"acceptance":1.833},"status":"keep","description":"decode-TPS baseline with gen32 after fixing primary metric to exclude prompt/model overhead","timestamp":1777494737663,"segment":1,"confidence":null,"asi":{"hypothesis":"use decode-only TPS as the primary metric because current optimization target is per-step decode cost, not prompt prefill/model load","benchmark":"Castle e2e gen=32, real_rendered_prompt, 64k ctx, q4 KV, target_feat_ctx=1024","note":"committed parser still reports 0 due escaped whitespace regex; output shows committed=33, fix parser next"}} +{"run":5,"commit":"6510b31","metric":6.32999,"metrics":{"tps":6.32999,"e2e_tps":1.207183,"spec_sec":26.508,"gen_tokens":32,"steps":18,"committed":33,"step_ms":280.85,"pack_ms":8.91,"draft_ms":204.58,"topk_ms":21.5,"exact_ms":45.83,"exact_decode_ms":0,"acceptance":1.833},"status":"discard","description":"fix committed metric parser; benchmark noise slightly lower than baseline","timestamp":1777494820583,"segment":1,"confidence":null,"asi":{"hypothesis":"correct benchmark committed-token parsing without changing runtime code","rollback_reason":"primary decode TPS was slightly worse than baseline, likely noise; no runtime optimization tested","next_action_hint":"parser fix is in autoresearch.sh and should be preserved as an autoresearch file; proceed to source-level optimization"}} +{"run":6,"commit":"73f2fb0","metric":7.73072,"metrics":{"tps":7.73072,"e2e_tps":1.251564,"spec_sec":25.568,"gen_tokens":32,"steps":17,"committed":33,"step_ms":243.49,"pack_ms":4.49,"draft_ms":171.03,"topk_ms":19.51,"exact_ms":48.43,"exact_decode_ms":0,"acceptance":1.941},"status":"keep","description":"benchmark target feature window 512 for lower draft/pack cost","timestamp":1777494901661,"segment":1,"confidence":14.673204954627211,"asi":{"hypothesis":"smaller DFlash target feature context may reduce pack/upload and draft compute enough to improve decode TPS on the real rendered prompt","result":"LLAMA_DDTREE_TARGET_FEAT_CTX=512 improved decode TPS from 6.42 to 7.73; pack 8.88->4.49 ms, draft 201.75->171.03 ms, acceptance 1.83->1.94","next_action_hint":"confirm 512 on repeated run and test 256/768; if stable, consider default/window recommendation for server workload"}} +{"run":7,"commit":"2a8eba2","metric":8.388756,"metrics":{"tps":8.388756,"e2e_tps":1.262128,"spec_sec":25.354,"gen_tokens":32,"steps":17,"committed":33,"step_ms":224.39,"pack_ms":2.24,"draft_ms":154.39,"topk_ms":19.77,"exact_ms":47.96,"exact_decode_ms":0,"acceptance":1.941},"status":"keep","description":"benchmark target feature window 256 for further draft/pack reduction","timestamp":1777494974670,"segment":1,"confidence":2.8119551947912877,"asi":{"hypothesis":"target feature context can be reduced below 512 on this prompt without hurting proposal quality, further lowering draft compute and pack cost","result":"256 improved decode TPS to 8.39; draft dropped to 154 ms and pack to 2.24 ms with same acceptance as 512","next_action_hint":"test 128 and 384; watch acceptance/regressions because too short a feature window may harm draft quality on broader prompts"}} +{"run":8,"commit":"3f125a4","metric":8.751467,"metrics":{"tps":8.751467,"e2e_tps":1.277649,"spec_sec":25.046,"gen_tokens":32,"steps":17,"committed":33,"step_ms":215.09,"pack_ms":1.09,"draft_ms":144.76,"topk_ms":20.79,"exact_ms":48.41,"exact_decode_ms":0,"acceptance":1.941},"status":"keep","description":"benchmark target feature window 128","timestamp":1777495047214,"segment":1,"confidence":2.2847052207843856,"asi":{"hypothesis":"a 128-token target feature window may retain enough recent signal while reducing draft attention and upload cost","result":"128 improved decode TPS to 8.75; draft 144.76 ms, pack 1.09 ms, acceptance unchanged vs 256/512","next_action_hint":"test 64 and longer generation/prompt coverage before changing default; risk is hidden acceptance drop on other prompts"}} +{"run":9,"commit":"3f125a4","metric":8.481742,"metrics":{"tps":8.481742,"e2e_tps":1.274139,"spec_sec":25.115,"gen_tokens":32,"steps":17,"committed":33,"step_ms":221.93,"pack_ms":0.47,"draft_ms":153.01,"topk_ms":20.45,"exact_ms":47.96,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"benchmark target feature window 64; pack lower but draft slower than 128","timestamp":1777495120133,"segment":1,"confidence":4.187908592095226,"asi":{"hypothesis":"shrinking target feature window below 128 may further reduce pack and draft context work","rollback_reason":"64 lowered pack to 0.47 ms but draft compute rose to 153 ms and decode TPS fell below the 128-window best","next_action_hint":"treat 128 as the current sweet spot on this prompt; test 96/160 only if tuning window further matters"}} +{"run":10,"commit":"3f125a4","metric":7.958279,"metrics":{"tps":7.958279,"e2e_tps":1.256676,"spec_sec":25.464,"gen_tokens":32,"steps":19,"committed":33,"step_ms":211.63,"pack_ms":0.79,"draft_ms":147.96,"topk_ms":19.88,"exact_ms":42.96,"exact_decode_ms":0,"acceptance":1.737},"status":"discard","description":"benchmark target feature window 96; lower per-step cost lost acceptance","timestamp":1777495203736,"segment":1,"confidence":4.455149647635071,"asi":{"hypothesis":"96-token target feature context might sit between 64 and 128 with lower draft cost but enough proposal quality","rollback_reason":"acceptance dropped from 1.94 to 1.74 and required 19 steps; decode TPS below 128-window best","next_action_hint":"128 remains better than smaller windows for this prompt; test 160/192 or move to non-window optimizations"}} +{"run":11,"commit":"3f125a4","metric":8.742524,"metrics":{"tps":8.742524,"e2e_tps":1.276935,"spec_sec":25.06,"gen_tokens":32,"steps":17,"committed":33,"step_ms":215.31,"pack_ms":1.34,"draft_ms":147.17,"topk_ms":18.84,"exact_ms":47.92,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"benchmark target feature window 160; tied 128 within noise but no improvement","timestamp":1777495276127,"segment":1,"confidence":4.609797945056553,"asi":{"hypothesis":"160-token window may reduce noise versus 128 while keeping acceptance and similar draft cost","rollback_reason":"decode TPS was effectively tied but below the 128-window best; no primary improvement","next_action_hint":"128/160 are close; choose 128 for lower memory/pack unless longer prompt validation favors 160"}} +{"run":12,"commit":"3f125a4","metric":8.341175,"metrics":{"tps":8.341175,"e2e_tps":1.263873,"spec_sec":25.319,"gen_tokens":32,"steps":17,"committed":33,"step_ms":225.67,"pack_ms":1.61,"draft_ms":155.69,"topk_ms":20.27,"exact_ms":48.07,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"benchmark target feature window 192; slower draft than 128/160","timestamp":1777495349316,"segment":1,"confidence":5.810668520414904,"asi":{"hypothesis":"192-token window may retain proposal quality while staying much cheaper than 256/512","rollback_reason":"acceptance matched 128 but draft/step time regressed, so primary TPS fell","next_action_hint":"do not tune larger windows for this prompt unless broader validation shows 128 harms acceptance"}} +{"run":13,"commit":"3f125a4","metric":8.477158,"metrics":{"tps":8.477158,"e2e_tps":1.282257,"spec_sec":24.956,"gen_tokens":32,"steps":17,"committed":33,"step_ms":222.05,"pack_ms":1.08,"draft_ms":153.39,"topk_ms":19.58,"exact_ms":47.97,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"default target feature window 128 plus smaller draft context","timestamp":1777495532959,"segment":1,"confidence":6.104510117006526,"asi":{"hypothesis":"make 128 the code default and size draft n_ctx to target feature window + block so server avoids oversized draft context","rollback_reason":"primary decode TPS was below the prior 128-window env best despite lower set_inputs; draft compute/noise dominated and source change did not improve best","next_action_hint":"revisit draft n_ctx sizing separately with repeated runs; the log showed draft n_ctx=256 and set_inputs ~0.7 ms vs ~4 ms, but step time was not better"}} +{"run":14,"commit":"3f125a4","metric":8.572515,"metrics":{"tps":8.572515,"e2e_tps":1.275053,"spec_sec":25.097,"gen_tokens":32,"steps":17,"committed":33,"step_ms":219.58,"pack_ms":1.08,"draft_ms":149.11,"topk_ms":20.98,"exact_ms":48.38,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"replace small-K heap top-k with fixed array scan","timestamp":1777495681069,"segment":1,"confidence":6.592190362045204,"asi":{"hypothesis":"avoid per-row heap allocation/maintenance in extract_top_k_logprobs for K=8 using fixed stack arrays","rollback_reason":"topk time increased to 20.98 ms and primary TPS remained below the 128-window best; std heap path is not the bottleneck or is better optimized","next_action_hint":"future top-k gains likely need GPU-side logits processing or avoiding logsumexp/scanning, not a small CPU data structure tweak"}} +{"run":15,"commit":"3f125a4","metric":8.708952,"metrics":{"tps":8.708952,"e2e_tps":1.275764,"spec_sec":25.083,"gen_tokens":32,"steps":17,"committed":33,"step_ms":216.14,"pack_ms":1.09,"draft_ms":151.15,"topk_ms":16.14,"exact_ms":47.73,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"benchmark budget 14 with 128-window; lower top-k cost but no TPS win","timestamp":1777495763704,"segment":1,"confidence":7.965359774029058,"asi":{"hypothesis":"budget 14 triggers cheaper K=1/argmax proposal extraction while preserving the same accepted chain on this prompt","rollback_reason":"topk dropped 20.8->16.1 ms but total step time did not beat the 128-window budget-22 best; likely draft/noise dominates","next_action_hint":"test budget 8/1 only if looking for lower overhead mode; budget 14 is close but not a primary improvement"}} +{"run":16,"commit":"3f125a4","metric":8.433481,"metrics":{"tps":8.433481,"e2e_tps":1.267578,"spec_sec":25.245,"gen_tokens":32,"steps":17,"committed":33,"step_ms":223.2,"pack_ms":1.08,"draft_ms":156.77,"topk_ms":17.67,"exact_ms":47.64,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"benchmark budget 8 with 128-window; slower than budget 14/22","timestamp":1777495844676,"segment":1,"confidence":8.4658857012172,"asi":{"hypothesis":"smaller budget may reduce proposal overhead without reducing accepted chain length in exact-validation mode","rollback_reason":"budget 8 had same step count but slower draft/topk aggregate than budget 14 and below best TPS","next_action_hint":"budget tuning is not the next major path; keep budget 22 or 14 and focus draft/exact decode"}} +{"run":17,"commit":"3f125a4","metric":8.123665,"metrics":{"tps":8.123665,"e2e_tps":1.582756,"spec_sec":26.536,"gen_tokens":42,"steps":24,"committed":42,"step_ms":215.42,"pack_ms":1.02,"draft_ms":150.09,"topk_ms":20.84,"exact_ms":43.43,"exact_decode_ms":0,"acceptance":1.75},"status":"discard","description":"confirm 128-window on longer gen request; generation ended at 42 tokens and TPS lower","timestamp":1777495921126,"segment":1,"confidence":7.969102337827434,"asi":{"hypothesis":"longer generation should reduce noise and confirm whether 128-window remains best beyond a 32-token sample","rollback_reason":"request stopped after 42 generated tokens and acceptance fell to 1.75; decode TPS below current best","next_action_hint":"use a non-EOS prompt or different fixture for long-generation confirmation; current rendered prompt is not ideal for 64-token stability"}} +{"run":18,"commit":"c2c95e3","metric":9.112863,"metrics":{"tps":9.112863,"e2e_tps":1.2838,"spec_sec":24.926,"gen_tokens":32,"steps":17,"committed":33,"step_ms":206.56,"pack_ms":1.06,"draft_ms":137.65,"topk_ms":19.8,"exact_ms":48.01,"exact_decode_ms":0,"acceptance":1.941},"status":"keep","description":"change DDTree default target feature window to 128","timestamp":1777496003711,"segment":1,"confidence":8.71562209789579,"asi":{"hypothesis":"make the observed 128-token target feature window the default so server/e2e use the faster draft context without requiring an env override","result":"default 128 produced the best run so far: 9.11 decode TPS, draft 137.65 ms, pack 1.06 ms, acceptance unchanged at 1.94","next_action_hint":"validate on additional prompts before treating 128 as globally safe; env knob still allows raising to 512/1024 if acceptance drops"}} +{"run":19,"commit":"c2c95e3","metric":8.611341,"metrics":{"tps":8.611341,"e2e_tps":1.273784,"spec_sec":25.122,"gen_tokens":32,"steps":17,"committed":33,"step_ms":218.59,"pack_ms":1.07,"draft_ms":149.86,"topk_ms":19.62,"exact_ms":48.01,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"repeat default 128-window run for noise check","timestamp":1777496086262,"segment":1,"confidence":9.96049456675487,"asi":{"hypothesis":"re-run the kept 128 default to estimate noise and verify improvement is not a one-off","rollback_reason":"repeat was below the 9.11 best but still far above the original 1024-window baseline; no new code changes to keep","next_action_hint":"treat 128 default as real but noisy; future comparisons should beat about 8.6-9.1 decode TPS"}} +{"run":20,"commit":"c2c95e3","metric":8.819946,"metrics":{"tps":8.819946,"e2e_tps":1.273885,"spec_sec":25.12,"gen_tokens":32,"steps":17,"committed":33,"step_ms":213.42,"pack_ms":1.09,"draft_ms":144.99,"topk_ms":19.36,"exact_ms":47.94,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"disable DDTree micro-profile logging during benchmark","timestamp":1777496178429,"segment":1,"confidence":10.150139806908188,"asi":{"hypothesis":"LLAMA_DDTREE_PROFILE logging may add per-step overhead, so production benchmarking should disable it","rollback_reason":"no-profile run did not beat the kept 9.11 TPS best, though it remained above baseline","next_action_hint":"keep profiling disabled for user-facing server unless collecting diagnostics; script knob is preserved in autoresearch.sh"}} +{"run":21,"commit":"c2c95e3","metric":8.471256,"metrics":{"tps":8.471256,"e2e_tps":1.271052,"spec_sec":25.176,"gen_tokens":32,"steps":18,"committed":33,"step_ms":209.86,"pack_ms":1.08,"draft_ms":151.2,"topk_ms":12.38,"exact_ms":45.18,"exact_decode_ms":0,"acceptance":1.833},"status":"discard","description":"diagnostic dynamic DDTree block size 8","timestamp":1777496331268,"segment":1,"confidence":10.70842957726881,"asi":{"hypothesis":"smaller draft block size may reduce draft/top-k compute enough to offset lower speculative horizon","rollback_reason":"block size 8 reduced topk but acceptance dropped and an extra step was needed; primary TPS below current 128-window default","next_action_hint":"keep block size 16; reducing block size does not help this prompt"}} +{"run":22,"commit":"c2c95e3","metric":8.249464,"metrics":{"tps":8.249464,"e2e_tps":1.263823,"spec_sec":25.32,"gen_tokens":32,"steps":19,"committed":33,"step_ms":204.16,"pack_ms":0.23,"draft_ms":141.61,"topk_ms":19.45,"exact_ms":42.84,"exact_decode_ms":0,"acceptance":1.737},"status":"discard","description":"benchmark target feature window 32; too little context hurts acceptance","timestamp":1777496404861,"segment":1,"confidence":11.331709410339274,"asi":{"hypothesis":"very small target feature window may lower draft compute enough to compensate for weaker proposals","rollback_reason":"acceptance dropped to 1.74 and 19 steps were needed, so TPS stayed below the 128 default","next_action_hint":"do not go below 128 for this workload unless a broader prompt shows different behavior"}} +{"run":23,"commit":"c2c95e3","metric":8.67164,"metrics":{"tps":8.67164,"e2e_tps":1.279284,"spec_sec":25.014,"gen_tokens":32,"steps":17,"committed":33,"step_ms":217.07,"pack_ms":1.07,"draft_ms":151.99,"topk_ms":16.08,"exact_ms":47.9,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"use K=1 proposal extraction for default exact-validation path","timestamp":1777496508196,"segment":1,"confidence":11.723927501915169,"asi":{"hypothesis":"when batched tree verify is disabled, default exact validation can use a top-1 draft chain instead of K=8 tree proposals to reduce top-k overhead while remaining correct","rollback_reason":"topk improved by ~3.7 ms but total step time stayed below the 128-window best because draft compute/noise dominated","next_action_hint":"K=1 is a possible conservative mode but not a net win in this benchmark; do not keep as default"}} +{"run":24,"commit":"c2c95e3","metric":8.675637,"metrics":{"tps":8.675637,"e2e_tps":1.271102,"spec_sec":25.175,"gen_tokens":32,"steps":17,"committed":33,"step_ms":216.97,"pack_ms":1.06,"draft_ms":148.41,"topk_ms":19.69,"exact_ms":47.77,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"cache DFlash mask token embedding in DDTree driver","timestamp":1777496601681,"segment":1,"confidence":11.829481672771372,"asi":{"hypothesis":"avoid per-step mask embedding lookup/allocation when building draft noise embeddings","rollback_reason":"change was correct but primary TPS did not beat the 128-window best; mask embedding work is not a measurable bottleneck","next_action_hint":"focus on draft graph compute or exact target decode rather than small CPU setup around noise embeddings"}} +{"run":25,"commit":"c2c95e3","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":21.707,"gen_tokens":0,"steps":1,"committed":0,"step_ms":0,"pack_ms":3.58,"draft_ms":152.52,"topk_ms":18.22,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"fast-batched fast-rollback at 128-window still OOMs persist on 64k full-draft","timestamp":1777496697475,"segment":1,"confidence":11.829481672771372,"asi":{"hypothesis":"smaller target feature window might free enough memory for fast rollback persist with -ngl65 -ngld6 -c65536","rollback_reason":"persist allocation still failed needing 1707.75 MiB; with SNAPSHOT_FALLBACK=0 driver returned empty output","next_action_hint":"fast rollback remains nonviable for full-draft 64k on this 24GB GPU unless memory is freed elsewhere; test q4/tq3 or reduced draft offload only if willing to trade draft speed"}} +{"run":26,"commit":"c2c95e3","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":0,"gen_tokens":0,"steps":0,"committed":0,"step_ms":0,"pack_ms":0,"draft_ms":0,"topk_ms":0,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"try tq3 KV to fit fast rollback persist; e2e harness does not accept tq3_0","timestamp":1777496743075,"segment":1,"confidence":11.829481672771372,"asi":{"hypothesis":"TQ3 KV might free enough VRAM for fast rollback persist with full target/draft offload","rollback_reason":"test-speculative-tree-e2e only accepts its hardcoded KV type set and rejected --kv-type tq3_0","next_action_hint":"add tq3_0 support to e2e parser before testing this; do not infer runtime viability from this crash"}} +{"run":27,"commit":"c2c95e3","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":26.717,"gen_tokens":0,"steps":1,"committed":0,"step_ms":0,"pack_ms":3.47,"draft_ms":153,"topk_ms":22.52,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"add tq3_0 e2e parser support and retry fast rollback; persist still OOMs","timestamp":1777496853671,"segment":1,"confidence":11.829481672771372,"asi":{"hypothesis":"with e2e tq3_0 support, smaller KV cache may fit the 1.7 GiB fast rollback persist buffer","rollback_reason":"tq3_0 run still failed CUDA persist allocation for 1707.75 MiB and returned empty output with snapshot fallback disabled","next_action_hint":"fast rollback requires a larger memory reduction than q4->tq3 KV provides in this harness; avoid more persist-fit tests unless freeing draft layers or compressing persist"}} +{"run":28,"commit":"c2c95e3","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":21.704,"gen_tokens":0,"steps":1,"committed":0,"step_ms":0,"pack_ms":3.52,"draft_ms":279.6,"topk_ms":16.26,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"fast rollback with draft GPU layers 4 still OOMs persist","timestamp":1777496939181,"segment":1,"confidence":11.829481672771372,"asi":{"hypothesis":"offloading fewer draft layers may free VRAM for fast rollback persist while retaining enough draft speed","rollback_reason":"ngld4 still failed 1707.75 MiB persist allocation and draft time already rose to 280 ms","next_action_hint":"if testing persist-fit, drop to ngld3 or lower, but expected draft slowdown likely outweighs rollback savings"}} +{"run":29,"commit":"c2c95e3","metric":4.278956,"metrics":{"tps":4.278956,"e2e_tps":1.114323,"spec_sec":28.717,"gen_tokens":32,"steps":18,"committed":33,"step_ms":415.47,"pack_ms":1.09,"draft_ms":294.24,"topk_ms":17.17,"exact_ms":0,"exact_decode_ms":0,"acceptance":1.833},"status":"discard","description":"fast rollback with draft GPU layers 3 fits but is slower than exact path","timestamp":1777497015792,"segment":1,"confidence":11.723927501915169,"asi":{"hypothesis":"ngld3 may free enough VRAM for fast rollback and remove exact decode cost","rollback_reason":"persist fit and exact cost dropped to zero, but draft slowed to 294 ms and target_tree/rollback path made step 415 ms; TPS far below exact default","next_action_hint":"do not trade draft GPU layers for fast rollback on 64k; full draft GPU exact path is better"}} +{"run":30,"commit":"c2c95e3","metric":8.740901,"metrics":{"tps":8.740901,"e2e_tps":1.289127,"spec_sec":24.823,"gen_tokens":32,"steps":17,"committed":33,"step_ms":215.35,"pack_ms":1.07,"draft_ms":147.03,"topk_ms":19.45,"exact_ms":47.77,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"benchmark 32k context for 20k prompt with 128 default","timestamp":1777497087757,"segment":1,"confidence":11.620240385859825,"asi":{"hypothesis":"if the prompt fits under 32k, lower target context may reduce memory/compute overhead versus 64k","rollback_reason":"32k context was above baseline but below the 64k 128-window best; no primary improvement in this e2e workload","next_action_hint":"context reduction may still help server prefill/memory, but decode step optimization should focus elsewhere"}} +{"run":31,"commit":"c2c95e3","metric":8.538297,"metrics":{"tps":8.538297,"e2e_tps":1.284779,"spec_sec":24.907,"gen_tokens":32,"steps":17,"committed":33,"step_ms":220.46,"pack_ms":1.09,"draft_ms":151.06,"topk_ms":20.14,"exact_ms":48.13,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"benchmark 24k context for 20k prompt","timestamp":1777497161415,"segment":1,"confidence":11.723927501915169,"asi":{"hypothesis":"context just above prompt length may reduce decode overhead further than 32k/64k","rollback_reason":"decode TPS below current best; smaller n_ctx did not improve step time on this harness","next_action_hint":"do not spend more iterations on n_ctx for decode TPS; use n_ctx tuning only for memory/server fit"}} +{"run":32,"commit":"c2c95e3","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":21.44,"gen_tokens":0,"steps":1,"committed":0,"step_ms":0,"pack_ms":3.54,"draft_ms":146.12,"topk_ms":20.51,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"smaller draft context plus full-draft fast rollback still OOMs persist","timestamp":1777497248234,"segment":1,"confidence":11.723927501915169,"asi":{"hypothesis":"sizing draft n_ctx to the 128 target-feature window may free enough compute-buffer VRAM for full-draft fast rollback","rollback_reason":"fast rollback still failed persist allocation and returned empty output with snapshot fallback disabled","next_action_hint":"draft context sizing alone does not free the required 1.7 GiB; persist compression is needed for this route"}} +{"run":33,"commit":"c2c95e3","metric":8.497057,"metrics":{"tps":8.497057,"e2e_tps":1.266474,"spec_sec":25.267,"gen_tokens":32,"steps":17,"committed":33,"step_ms":221.53,"pack_ms":1.07,"draft_ms":153.17,"topk_ms":19.25,"exact_ms":48.01,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"skip exact validation seq_rm diagnostic","timestamp":1777497395516,"segment":1,"confidence":11.854680691870985,"asi":{"hypothesis":"default exact chain validation may be clearing an already-empty future range every step; skipping it could save overhead","rollback_reason":"correctness passed but primary TPS dropped; seq_rm is not a meaningful bottleneck","next_action_hint":"keep the safety clear in exact validation"}} +{"run":34,"commit":"c2c95e3","metric":2.679353,"metrics":{"tps":2.679353,"e2e_tps":0.960384,"spec_sec":33.32,"gen_tokens":32,"steps":17,"committed":33,"step_ms":702.54,"pack_ms":1.11,"draft_ms":160.23,"topk_ms":19.65,"exact_ms":0,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"fast-batched with snapshot/replay fallback under 128 default","timestamp":1777497492163,"segment":1,"confidence":11.723927501915169,"asi":{"hypothesis":"trusting batched posterior with snapshot fallback might avoid exact decode while preserving correctness","rollback_reason":"snapshot/tree/replay path made step time 702 ms, far slower than exact chain default; fast-batched without persist rollback is not viable","next_action_hint":"only revisit fast-batched with a no-OOM fast rollback or conditional exact rule; snapshot fallback is too expensive"}} +{"run":35,"commit":"c2c95e3","metric":8.402611,"metrics":{"tps":8.402611,"e2e_tps":1.269942,"spec_sec":25.198,"gen_tokens":32,"steps":17,"committed":33,"step_ms":224.02,"pack_ms":1.09,"draft_ms":154.23,"topk_ms":20.73,"exact_ms":47.95,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"benchmark n_batch 2048 with 512 ubatch","timestamp":1777497577065,"segment":1,"confidence":11.829481672771372,"asi":{"hypothesis":"larger prompt batch may improve prefill or reserve behavior without changing decode acceptance","rollback_reason":"decode step time worsened; n_batch tuning is not a decode TPS improvement","next_action_hint":"keep n_batch=512 for this e2e benchmark unless optimizing prompt prefill separately"}} +{"run":36,"commit":"c2c95e3","metric":8.610159,"metrics":{"tps":8.610159,"e2e_tps":1.233046,"spec_sec":25.952,"gen_tokens":32,"steps":17,"committed":33,"step_ms":218.62,"pack_ms":1.06,"draft_ms":149.81,"topk_ms":19.88,"exact_ms":47.83,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"benchmark n_ubatch 256","timestamp":1777497653606,"segment":1,"confidence":12.65478788505223,"asi":{"hypothesis":"smaller ubatch may reduce memory pressure or improve graph reserve for decode","rollback_reason":"primary TPS below best and e2e time worsened due slower prompt processing; no decode win","next_action_hint":"do not tune ubatch for current decode target"}} +{"run":37,"commit":"c2c95e3","metric":8.817054,"metrics":{"tps":8.817054,"e2e_tps":1.209327,"spec_sec":26.461,"gen_tokens":32,"steps":17,"committed":33,"step_ms":213.49,"pack_ms":1.09,"draft_ms":145.34,"topk_ms":18.88,"exact_ms":48.15,"exact_decode_ms":47.76,"acceptance":1.941},"status":"discard","description":"print exact validation timing split in e2e harness","timestamp":1777497756241,"segment":1,"confidence":11.854680691870985,"asi":{"hypothesis":"surface exact_decode/exact_sample timing in e2e to guide next optimization decisions","rollback_reason":"instrumentation did not improve primary TPS and was reverted by experiment rules, though it confirmed exact_decode is essentially all exact cost","next_action_hint":"if more diagnostics are needed, re-add exact split intentionally; current useful fact: exact_decode 47.76 ms of exact 48.15 ms"}} +{"run":38,"commit":"c2c95e3","metric":7.97639,"metrics":{"tps":7.97639,"e2e_tps":1.251418,"spec_sec":25.571,"gen_tokens":32,"steps":18,"committed":33,"step_ms":222.88,"pack_ms":0.94,"draft_ms":156.66,"topk_ms":19.89,"exact_ms":45.36,"exact_decode_ms":0,"acceptance":1.833},"status":"discard","description":"benchmark target feature window 112","timestamp":1777497962865,"segment":1,"confidence":11.723927501915169,"asi":{"hypothesis":"112-token window might reduce draft context below 128 without dropping acceptance","rollback_reason":"acceptance dropped to 1.83 and an extra step was needed; 128 remains better","next_action_hint":"avoid target feature windows below 128 on this prompt"}} +{"run":39,"commit":"c2c95e3","metric":8.508964,"metrics":{"tps":8.508964,"e2e_tps":1.272163,"spec_sec":25.154,"gen_tokens":32,"steps":17,"committed":33,"step_ms":221.22,"pack_ms":1.21,"draft_ms":151.9,"topk_ms":20.01,"exact_ms":48.07,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"benchmark target feature window 144","timestamp":1777498045067,"segment":1,"confidence":11.854680691870985,"asi":{"hypothesis":"144-token window may be a safer slightly larger default than 128 with similar acceptance","rollback_reason":"acceptance matched 128 but draft/step time was worse and primary TPS below best","next_action_hint":"128 is still the best target feature default among tested windows"}} +{"run":40,"commit":"c2c95e3","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":33.226,"gen_tokens":32,"steps":32,"committed":32,"step_ms":212.78,"pack_ms":1.02,"draft_ms":160.08,"topk_ms":21.43,"exact_ms":30.22,"exact_decode_ms":0,"acceptance":1},"status":"crash","description":"tq3_0 KV exact default diverges from chain in e2e","timestamp":1777498174420,"segment":1,"confidence":11.854680691870985,"asi":{"hypothesis":"TQ3 KV may reduce memory or improve speed in exact default mode after adding e2e parser support","rollback_reason":"correctness failed at first generated token: chain[0]=8635, spec[0]=248069; TQ3 is not a safe drop-in for this correctness gate","next_action_hint":"do not use tq3_0 for DDTree correctness until target-only chain/spec equivalence is debugged"}} +{"run":41,"commit":"c2c95e3","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":0,"gen_tokens":0,"steps":0,"committed":0,"step_ms":0,"pack_ms":0,"draft_ms":0,"topk_ms":0,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"q8 KV exact default OOMs with full draft at 64k","timestamp":1777498231508,"segment":1,"confidence":11.854680691870985,"asi":{"hypothesis":"q8 KV might speed exact decode enough to offset larger memory footprint","rollback_reason":"q8_0 KV caused draft context compute buffer allocation OOM with -ngl65 -ngld6 -c65536","next_action_hint":"q4_0 remains the viable full-draft 64k KV type on this GPU"}} +{"run":42,"commit":"c2c95e3","metric":9.049339,"metrics":{"tps":9.049339,"e2e_tps":1.28726,"spec_sec":24.859,"gen_tokens":32,"steps":17,"committed":33,"step_ms":208.01,"pack_ms":1.07,"draft_ms":140.46,"topk_ms":19.16,"exact_ms":47.28,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"q8 KV with 32k context fits and nearly matches best","timestamp":1777498306731,"segment":1,"confidence":11.723927501915169,"asi":{"hypothesis":"for prompts fitting 32k, q8 KV may fit and speed decode compared with q4 at 64k","rollback_reason":"q8_0 + 32k was close but still below the 64k q4 best run, so no primary improvement","next_action_hint":"q8_0 + 32k is a viable alternative for <=32k prompts; retest with server/prompt-prefill metric if memory quality matters"}} +{"run":43,"commit":"c2c95e3","metric":8.663259,"metrics":{"tps":8.663259,"e2e_tps":1.277853,"spec_sec":25.042,"gen_tokens":32,"steps":17,"committed":33,"step_ms":217.28,"pack_ms":1.07,"draft_ms":148.82,"topk_ms":19.59,"exact_ms":47.76,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"repeat q8 KV with 32k context","timestamp":1777498380449,"segment":1,"confidence":12.711493900280791,"asi":{"hypothesis":"confirm whether q8_0 + 32k close-to-best run was stable","rollback_reason":"repeat fell to 8.66 TPS, confirming the 9.05 run was noise/near-best but not better than 64k q4 default","next_action_hint":"do not switch primary config to q8_0 + 32k for decode TPS"}} +{"run":44,"commit":"c2c95e3","metric":8.561987,"metrics":{"tps":8.561987,"e2e_tps":1.273987,"spec_sec":25.118,"gen_tokens":32,"steps":17,"committed":33,"step_ms":219.85,"pack_ms":1.04,"draft_ms":150.19,"topk_ms":20.4,"exact_ms":48.2,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"benchmark target feature window 124","timestamp":1777498457519,"segment":1,"confidence":14.22995076181824,"asi":{"hypothesis":"124 may be just below 128 while preserving acceptance","rollback_reason":"acceptance preserved but step time was worse than 128 default; no primary improvement","next_action_hint":"stop fine-tuning target feature window around 128"}} +{"run":45,"commit":"c2c95e3","metric":5.277915,"metrics":{"tps":5.277915,"e2e_tps":1.165883,"spec_sec":27.447,"gen_tokens":32,"steps":30,"committed":32,"step_ms":202.1,"pack_ms":1,"draft_ms":153.89,"topk_ms":20.42,"exact_ms":26.76,"exact_decode_ms":0,"acceptance":1.067},"status":"discard","description":"do not request draft logits for root slot","timestamp":1777498552553,"segment":1,"confidence":12.711493900280791,"asi":{"hypothesis":"skip unused root-position draft logits to reduce lm_head/output work by one row per draft step","rollback_reason":"correctness remained bit-equal only because exact validation fell back, but draft proposal quality collapsed: acceptance 1.94->1.07 and steps 17->30; logits row indexing/output semantics do not support this simple change","next_action_hint":"do not skip root logits without first understanding llama_get_logits_ith output indexing for sparse logits flags"}} +{"run":46,"commit":"c2c95e3","metric":8.607403,"metrics":{"tps":8.607403,"e2e_tps":1.272113,"spec_sec":25.155,"gen_tokens":32,"steps":17,"committed":33,"step_ms":218.69,"pack_ms":1.07,"draft_ms":148.88,"topk_ms":20.44,"exact_ms":48.27,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"suppress root logits flag but keep original logits row indexing","timestamp":1777498637976,"segment":1,"confidence":14.22995076181824,"asi":{"hypothesis":"root logits can be disabled while keeping llama_get_logits_ith(1) indexing, reducing output work without changing proposals","rollback_reason":"proposal quality stayed intact but draft/topk/step time did not improve enough; primary below best","next_action_hint":"root logits flag is not a useful optimization in current llama.cpp output path"}} +{"run":47,"commit":"c2c95e3","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":0,"gen_tokens":0,"steps":0,"committed":0,"step_ms":0,"pack_ms":0,"draft_ms":0,"topk_ms":0,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"disable Flash Attention diagnostic with q4 KV","timestamp":1777498814939,"segment":1,"confidence":14.22995076181824,"asi":{"hypothesis":"small one-token exact decode and draft attention might be faster without Flash Attention overhead","rollback_reason":"q4_0 V cache quantization requires flash_attn, so target context creation failed before benchmark","next_action_hint":"only test no-flash with f16 KV if memory allows; q4/q8 quantized V requires Flash Attention"}} +{"run":48,"commit":"c2c95e3","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":0,"gen_tokens":0,"steps":0,"committed":0,"step_ms":0,"pack_ms":0,"draft_ms":0,"topk_ms":0,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"no-flash with f16 KV at 32k OOMs target compute buffer","timestamp":1777498859775,"segment":1,"confidence":14.22995076181824,"asi":{"hypothesis":"use f16 KV at 32k to allow no-flash attention and test whether one-token decode is faster without Flash Attention","rollback_reason":"context creation failed allocating 1739 MiB CUDA compute buffer; no benchmark ran","next_action_hint":"no-flash is not viable in current full-offload/full-draft memory envelope; abandon this path"}} +{"run":49,"commit":"c2c95e3","metric":8.674837,"metrics":{"tps":8.674837,"e2e_tps":1.274088,"spec_sec":25.116,"gen_tokens":32,"steps":17,"committed":33,"step_ms":216.99,"pack_ms":1.08,"draft_ms":146.47,"topk_ms":20.99,"exact_ms":48.42,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"force chain recurrent kernel diagnostic","timestamp":1777498962837,"segment":1,"confidence":16.05291169266154,"asi":{"hypothesis":"forcing chain recurrent kernels might reduce exact one-token decode overhead or avoid tree-kernel dispatch cost in exact validation","rollback_reason":"correctness passed but step time stayed below best; exact remained about 48 ms and topk worsened","next_action_hint":"leave chain-kernel force unset for normal runs"}} +{"run":50,"commit":"c2c95e3","metric":8.605435,"metrics":{"tps":8.605435,"e2e_tps":1.275154,"spec_sec":25.095,"gen_tokens":32,"steps":17,"committed":33,"step_ms":218.74,"pack_ms":0.73,"draft_ms":150.02,"topk_ms":19.64,"exact_ms":48.32,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"bulk-copy contiguous spans when packing target feature ring","timestamp":1777499134443,"segment":1,"confidence":16.106475474044874,"asi":{"hypothesis":"replace per-column target_feat ring packing with one or two bulk memcpy calls to reduce the remaining pack/upload overhead without changing model behavior","rollback_reason":"pack time improved from about 1.06 ms to 0.73 ms, but total step time and primary TPS stayed below the 128-window best; pack is now too small to move throughput alone","next_action_hint":"do not spend more iterations on CPU ring packing until larger draft/exact decode bottlenecks are addressed"}} +{"run":51,"commit":"c2c95e3","metric":8.567053,"metrics":{"tps":8.567053,"e2e_tps":1.269287,"spec_sec":25.211,"gen_tokens":32,"steps":17,"committed":33,"step_ms":219.72,"pack_ms":1.09,"draft_ms":151.19,"topk_ms":19.48,"exact_ms":47.93,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"allow DFlash target_feat graph input to reuse draft graph","timestamp":1777499394932,"segment":1,"confidence":18.011796096054017,"asi":{"hypothesis":"DFlash draft graph was not reused because target_feat input lacked can_reuse; allowing reuse should stabilize CUDA graph/build and reduce draft overhead without changing math","rollback_reason":"graph reuse worked after the first step (reused=1, build_alloc=0) but primary TPS and draft_ms were worse than the 128-window best; graph rebuild cost was not a bottleneck","next_action_hint":"do not prioritize graph reuse; focus on actual draft compute (fc/attention/lm_head) or exact target decode"}} +{"run":52,"commit":"c2c95e3","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":21.679,"gen_tokens":0,"steps":1,"committed":0,"step_ms":0,"pack_ms":3.62,"draft_ms":154.37,"topk_ms":18.85,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"fast rollback with budget 14 to reduce persist footprint","timestamp":1777499530631,"segment":1,"confidence":18.011796096054017,"asi":{"hypothesis":"lowering DDTree budget from 22 to 14 reduces persist allocation from about 1.7 GiB to about 1.1 GiB and may allow full-draft fast rollback at 64k","rollback_reason":"persist allocation still OOMed at 1086.75 MiB; driver produced empty output with snapshot fallback disabled","next_action_hint":"full-draft fast rollback needs more than budget reduction; try lower target/draft compute buffers only if expected target_tree cost can still beat exact chain"}} +{"run":53,"commit":"c2c95e3","metric":7.570506,"metrics":{"tps":7.570506,"e2e_tps":1.246834,"spec_sec":25.665,"gen_tokens":32,"steps":19,"committed":34,"step_ms":222.47,"pack_ms":1.05,"draft_ms":154.78,"topk_ms":22.6,"exact_ms":43.99,"exact_decode_ms":0,"acceptance":1.789},"status":"discard","description":"diagnostic larger draft block size 24","timestamp":1777499663146,"segment":1,"confidence":16.106475474044874,"asi":{"hypothesis":"increasing DDTree draft block size beyond the trained/default 16 may expose a longer proposal horizon and reduce exact validation steps without changing correctness because exact validation remains authoritative","rollback_reason":"block size 24 passed bit-equal correctness but acceptance dropped to 1.79, steps rose to 19, and topk/draft costs increased; primary TPS below best","next_action_hint":"do not increase DFlash block size above 16 for this model; the draft quality beyond the default horizon is worse"}} +{"run":54,"commit":"abe969d","metric":9.820771,"metrics":{"tps":9.820771,"e2e_tps":1.295075,"spec_sec":24.709,"gen_tokens":32,"steps":16,"committed":33,"step_ms":203.65,"pack_ms":1.08,"draft_ms":134.43,"topk_ms":17.84,"exact_ms":50.27,"exact_decode_ms":0,"acceptance":2.062},"status":"keep","description":"increase DDTree budget to 32 for exact-validation path","timestamp":1777499764379,"segment":1,"confidence":19.93149922357974,"asi":{"hypothesis":"larger DDTree budget may include more alternative draft branches so exact chain validation can accept farther before falling off the proposal tree; top-k cost should stay similar because K remains 8","result":"budget 32 improved decode TPS to 9.82, reduced steps 17->16, and raised acceptance to 2.06 despite exact_ms rising per step","next_action_hint":"confirm with repeat and test budget 40/48; watch for prompt overfit because larger budgets may help only when alternate branches match the exact path"}} +{"run":55,"commit":"abe969d","metric":9.003331,"metrics":{"tps":9.003331,"e2e_tps":1.282822,"spec_sec":24.945,"gen_tokens":32,"steps":16,"committed":33,"step_ms":222.14,"pack_ms":1.08,"draft_ms":150.07,"topk_ms":19.8,"exact_ms":51.16,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test DDTree budget 40 after budget 32 improvement","timestamp":1777499863976,"segment":1,"confidence":19.318235385991656,"asi":{"hypothesis":"budget 40 may further improve proposal coverage beyond budget 32 while keeping the same number of exact-validation steps","rollback_reason":"acceptance matched budget 32 but step time regressed to 222 ms, especially draft/topk, so primary TPS fell below the budget-32 best","next_action_hint":"budget 32 appears better than 40 on this prompt; test budget 28/36 or repeat 32 for noise before changing server defaults further"}} +{"run":56,"commit":"abe969d","metric":9.144111,"metrics":{"tps":9.144111,"e2e_tps":1.287467,"spec_sec":24.855,"gen_tokens":32,"steps":16,"committed":33,"step_ms":218.72,"pack_ms":1.09,"draft_ms":145.96,"topk_ms":20.45,"exact_ms":51.19,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test DDTree budget 28 below budget-32 best","timestamp":1777499961123,"segment":1,"confidence":19.01142448327141,"asi":{"hypothesis":"budget 28 may retain the step-count win from budget 32 while reducing tree/proposal overhead","rollback_reason":"steps and acceptance matched budget 32, but step time was slower and primary TPS remained below 9.82","next_action_hint":"budget 32 remains the best tested budget; test 30/34 or repeat 32 to confirm noise"}} +{"run":57,"commit":"abe969d","metric":8.99928,"metrics":{"tps":8.99928,"e2e_tps":1.257022,"spec_sec":25.457,"gen_tokens":32,"steps":16,"committed":33,"step_ms":222.24,"pack_ms":1.08,"draft_ms":149.56,"topk_ms":20.23,"exact_ms":51.33,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"repeat DDTree budget 32 to confirm previous best against noise","timestamp":1777500057910,"segment":1,"confidence":19.196128504310348,"asi":{"hypothesis":"repeat budget 32 because the previous 9.82 TPS improvement may include noise; confirm whether the step-count/acceptance gain is stable","rollback_reason":"repeat preserved the 16-step and 2.06 acceptance behavior but step time regressed to 222 ms, so primary TPS was below the kept best","next_action_hint":"budget 32 still looks structurally useful via fewer steps, but compare future ideas against a noisy 9.0-9.8 TPS range rather than a single best run"}} +{"run":58,"commit":"abe969d","metric":9.0151,"metrics":{"tps":9.0151,"e2e_tps":1.280307,"spec_sec":24.994,"gen_tokens":32,"steps":16,"committed":33,"step_ms":221.85,"pack_ms":1.07,"draft_ms":148.88,"topk_ms":20.24,"exact_ms":51.62,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test DDTree budget 30 between budget 28 and 32","timestamp":1777500149659,"segment":1,"confidence":19.077213862263655,"asi":{"hypothesis":"budget 30 may keep the 16-step acceptance gain from budget 32 while trimming proposal overhead versus larger budgets","rollback_reason":"budget 30 matched the 16-step/2.06 acceptance pattern but step time stayed around 222 ms and primary TPS remained below the kept budget-32 best","next_action_hint":"budget 30/28/40 do not beat budget 32; try budget 34/36 or switch away from budget tuning if repeats remain noisy"}} +{"run":59,"commit":"abe969d","metric":9.208951,"metrics":{"tps":9.208951,"e2e_tps":1.286587,"spec_sec":24.872,"gen_tokens":32,"steps":16,"committed":33,"step_ms":217.18,"pack_ms":1.1,"draft_ms":144.66,"topk_ms":19.97,"exact_ms":51.42,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test DDTree budget 36 above budget-32 best","timestamp":1777500249227,"segment":1,"confidence":18.755483015403495,"asi":{"hypothesis":"budget 36 may preserve the 16-step acceptance gain from budget 32 while giving slightly better branch coverage than 30/32 without the overhead seen at 40","rollback_reason":"budget 36 passed and was better than 28/30/40, but still below the kept budget-32 best of 9.82 TPS","next_action_hint":"budget 34 is the last nearby budget worth testing; otherwise budget tuning appears saturated and noisy around 9.0-9.8 TPS"}} +{"run":60,"commit":"abe969d","metric":9.048545,"metrics":{"tps":9.048545,"e2e_tps":1.250684,"spec_sec":25.586,"gen_tokens":32,"steps":16,"committed":33,"step_ms":221.03,"pack_ms":1.08,"draft_ms":147.2,"topk_ms":21.15,"exact_ms":51.57,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test DDTree budget 34 near budget-32 best","timestamp":1777500350130,"segment":1,"confidence":18.510168209448402,"asi":{"hypothesis":"budget 34 may retain the 16-step acceptance improvement from budget 32 while finding a lower-overhead point than 36/40","rollback_reason":"budget 34 matched the 16-step and 2.06 acceptance pattern, but topk/step time were worse and primary TPS remained below the kept budget-32 best","next_action_hint":"stop local budget sweep around 32; nearby budgets 28/30/34/36/40 all failed to beat 32, so switch to a different bottleneck such as exact target decode or draft compute"}} +{"run":61,"commit":"abe969d","metric":8.570174,"metrics":{"tps":8.570174,"e2e_tps":1.266524,"spec_sec":25.266,"gen_tokens":32,"steps":17,"committed":33,"step_ms":219.64,"pack_ms":1.35,"draft_ms":150.5,"topk_ms":19.51,"exact_ms":48.24,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"test budget 32 with target feature window 160","timestamp":1777500462376,"segment":1,"confidence":18.755483015403495,"asi":{"hypothesis":"with the larger budget-32 proposal tree, a slightly longer target-feature window may improve draft proposal quality enough to offset extra pack/draft cost","rollback_reason":"target_feat_ctx=160 lost the budget-32 structural gain: steps rose from 16 to 17 and acceptance fell to 1.94, so TPS dropped","next_action_hint":"keep target feature window at 128 for budget-32; do not combine larger windows with larger budget on this workload"}} +{"run":62,"commit":"abe969d","metric":9.649133,"metrics":{"tps":9.649133,"e2e_tps":1.291572,"spec_sec":24.776,"gen_tokens":32,"steps":17,"committed":33,"step_ms":195.08,"pack_ms":0.95,"draft_ms":128.19,"topk_ms":18.6,"exact_ms":47.32,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"test budget 32 with smaller target feature window 112","timestamp":1777500568011,"segment":1,"confidence":18.510168209448402,"asi":{"hypothesis":"with budget 32, a smaller 112-token target-feature window may reduce draft compute enough to offset losing some proposal quality","rollback_reason":"draft/step time improved significantly, but steps rose to 17 and acceptance fell to 1.94; primary TPS stayed below the kept budget-32 best of 9.82","next_action_hint":"try an intermediate window such as 120 with budget 32, or repeat 112 if prioritizing lower draft latency over best TPS"}} +{"run":63,"commit":"abe969d","metric":8.489392,"metrics":{"tps":8.489392,"e2e_tps":1.270648,"spec_sec":25.184,"gen_tokens":32,"steps":17,"committed":33,"step_ms":221.73,"pack_ms":1.02,"draft_ms":152.51,"topk_ms":20.3,"exact_ms":47.87,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"test budget 32 with target feature window 120","timestamp":1777500665459,"segment":1,"confidence":18.755483015403495,"asi":{"hypothesis":"target_feat_ctx=120 may sit between 112 and 128, retaining some draft-cost savings while avoiding the acceptance loss seen at 112","rollback_reason":"120 did not recover budget-32 acceptance; steps stayed 17 and draft/step time was worse than 112 and 128, so primary TPS fell","next_action_hint":"do not continue fine-tuning between 112 and 128; use 128 for structural 16-step behavior or 112 only as lower-latency non-best variant"}} +{"run":64,"commit":"abe969d","metric":9.104151,"metrics":{"tps":9.104151,"e2e_tps":1.286174,"spec_sec":24.88,"gen_tokens":32,"steps":16,"committed":33,"step_ms":219.68,"pack_ms":1.07,"draft_ms":147.78,"topk_ms":19.46,"exact_ms":51.33,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test budget 32 with DDTree micro-profile logging disabled","timestamp":1777500762894,"segment":1,"confidence":18.510168209448402,"asi":{"hypothesis":"per-step DDTree micro-profile logging may add overhead in the budget-32 configuration, so disabling it could improve production-like decode TPS","rollback_reason":"profile-disabled run preserved the 16-step/2.06 acceptance behavior but did not beat the kept budget-32 best; primary TPS stayed in the noisy ~9.0 band","next_action_hint":"profiling can remain disabled in production, but it is not a major optimization lever; move to draft compute or exact decode changes"}} +{"run":65,"commit":"abe969d","metric":8.427062,"metrics":{"tps":8.427062,"e2e_tps":1.266675,"spec_sec":25.263,"gen_tokens":32,"steps":17,"committed":33,"step_ms":223.37,"pack_ms":1.06,"draft_ms":153.86,"topk_ms":20.47,"exact_ms":47.95,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"test budget 32 with 32k context and q8 KV","timestamp":1777500859781,"segment":1,"confidence":18.755483015403495,"asi":{"hypothesis":"combine the budget-32 branch-coverage gain with the previously viable 32k/q8_0 context to see if q8 exact decode or lower context improves throughput","rollback_reason":"q8_0 + 32k lost the budget-32 structural gain: steps rose to 17 and acceptance fell to 1.94, with slower draft/step time","next_action_hint":"keep q4_0 64k for budget-32; q8_0/32k does not combine well with this proposal tree"}} +{"run":66,"commit":"abe969d","metric":8.867312,"metrics":{"tps":8.867312,"e2e_tps":1.283903,"spec_sec":24.924,"gen_tokens":32,"steps":17,"committed":33,"step_ms":212.28,"pack_ms":0.96,"draft_ms":143.28,"topk_ms":20.12,"exact_ms":47.89,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"test budget 32 with target feature window 116","timestamp":1777500957905,"segment":1,"confidence":18.510168209448402,"asi":{"hypothesis":"target_feat_ctx=116 may keep most of the draft-cost reduction seen at 112 while recovering some proposal quality toward the 128-window budget-32 path","rollback_reason":"116 kept lower draft cost but still lost the budget-32 structural win: steps stayed 17 and acceptance stayed 1.94, so primary TPS remained below best","next_action_hint":"stop testing sub-128 windows with budget 32; they lower per-step cost but consistently lose the 16-step acceptance behavior"}} +{"run":67,"commit":"abe969d","metric":8.537135,"metrics":{"tps":8.537135,"e2e_tps":1.284986,"spec_sec":24.903,"gen_tokens":32,"steps":17,"committed":33,"step_ms":220.49,"pack_ms":1.07,"draft_ms":151.17,"topk_ms":20.16,"exact_ms":48.05,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"test budget 32 with 32k context and q4 KV","timestamp":1777501067545,"segment":1,"confidence":18.755483015403495,"asi":{"hypothesis":"for prompts fitting under 32k, lowering n_ctx while keeping q4 KV may reduce decode memory/compute overhead and combine with the budget-32 proposal tree","rollback_reason":"32k context lost the budget-32 structural gain: steps rose to 17 and acceptance fell to 1.94, with TPS well below the 64k budget-32 best","next_action_hint":"keep 64k q4 for budget-32 on this workload; context reduction changes proposal/target behavior enough to lose acceptance"}} +{"run":68,"commit":"abe969d","metric":9.041183,"metrics":{"tps":9.041183,"e2e_tps":1.284161,"spec_sec":24.919,"gen_tokens":32,"steps":16,"committed":33,"step_ms":221.21,"pack_ms":1.14,"draft_ms":148.76,"topk_ms":20.16,"exact_ms":51.12,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test budget 32 with target feature window 136","timestamp":1777501168604,"segment":1,"confidence":18.510168209448402,"asi":{"hypothesis":"target_feat_ctx=136 may preserve the 16-step budget-32 acceptance behavior while being close enough to 128 to avoid the overhead/regression seen at 160","rollback_reason":"136 preserved the 16-step and 2.06 acceptance behavior, but step time was around 221 ms and primary TPS stayed below the budget-32 best","next_action_hint":"128 remains the best window for budget-32; larger windows preserve acceptance but do not improve throughput"}} +{"run":69,"commit":"abe969d","metric":9.219989,"metrics":{"tps":9.219989,"e2e_tps":1.28236,"spec_sec":24.954,"gen_tokens":32,"steps":16,"committed":33,"step_ms":216.92,"pack_ms":1.09,"draft_ms":144.92,"topk_ms":19.97,"exact_ms":50.9,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test budget 32 with draft top-k capped at 4","timestamp":1777502096247,"segment":1,"confidence":15.883380263274644,"asi":{"hypothesis":"budget 32 may not need K=8 alternatives per row; capping draft top-k to 4 could preserve the 16-step acceptance behavior while reducing top-k/tree overhead","rollback_reason":"K=4 preserved 16 steps and 2.06 acceptance but primary TPS stayed below the kept budget-32 best; topk did not materially improve","next_action_hint":"top-k cardinality is not the main cost; do not add a TOP_K knob unless needed for diagnostics"}} +{"run":70,"commit":"abe969d","metric":9.147875,"metrics":{"tps":9.147875,"e2e_tps":1.279335,"spec_sec":25.013,"gen_tokens":32,"steps":16,"committed":34,"step_ms":218.63,"pack_ms":1.08,"draft_ms":147.84,"topk_ms":17.36,"exact_ms":52.32,"exact_decode_ms":0,"acceptance":2.125},"status":"discard","description":"cap DDTree proposal rows to 8 for budget 32","timestamp":1777502354709,"segment":1,"confidence":15.856576119639561,"asi":{"hypothesis":"exact-validation path rarely needs all 15 draft rows; capping DDTree proposal rows to 8 may preserve acceptance while reducing CPU top-k scan cost","rollback_reason":"topk improved to 17.36 ms and acceptance stayed strong, but exact cost rose and total step time remained below the kept best; also broad prompts could need deeper rows","next_action_hint":"tree-row capping can reduce topk but is not enough alone; avoid keeping because it risks hurting prompts with long accepted runs"}} +{"run":71,"commit":"abe969d","metric":8.451759,"metrics":{"tps":8.451759,"e2e_tps":1.584547,"spec_sec":26.506,"gen_tokens":42,"steps":23,"committed":42,"step_ms":216.06,"pack_ms":1.03,"draft_ms":149.5,"topk_ms":20.1,"exact_ms":45.4,"exact_decode_ms":0,"acceptance":1.826},"status":"discard","description":"validate budget 32 on longer generation request","timestamp":1777502458355,"segment":1,"confidence":15.883380263274644,"asi":{"hypothesis":"longer generation should check whether the budget-32 16-step gain generalizes beyond the short 32-token sample and reduce overfitting risk","rollback_reason":"the prompt ended at 42 generated tokens and acceptance fell to 1.83; decode TPS below budget-32 short-run best","next_action_hint":"use an additional non-EOS prompt fixture for robust validation; this prompt is useful but not sufficient for long-run budget tuning"}} +{"run":72,"commit":"abe969d","metric":9.227646,"metrics":{"tps":9.227646,"e2e_tps":1.283852,"spec_sec":24.925,"gen_tokens":32,"steps":16,"committed":33,"step_ms":216.74,"pack_ms":1.08,"draft_ms":144.28,"topk_ms":20.3,"exact_ms":51.03,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test pure best-first DDTree without greedy chain seed at budget 32","timestamp":1777502608147,"segment":1,"confidence":15.856576119639561,"asi":{"hypothesis":"chain_seed spends many budget-32 nodes on a deep greedy spine even though exact acceptance averages only about 2 tokens; pure best-first may allocate more shallow alternatives and improve proposal coverage","rollback_reason":"chain_seed=0 preserved the 16-step/2.06 acceptance behavior but did not beat the kept budget-32 best; no primary improvement and it may be less robust on prompts with long greedy runs","next_action_hint":"keep chain_seed enabled; if revisiting tree construction, use a hybrid depth cap rather than disabling the greedy spine entirely"}} +{"run":73,"commit":"abe969d","metric":9.040774,"metrics":{"tps":9.040774,"e2e_tps":1.281897,"spec_sec":24.963,"gen_tokens":32,"steps":16,"committed":33,"step_ms":221.22,"pack_ms":1.09,"draft_ms":148.96,"topk_ms":20.16,"exact_ms":50.98,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"hybrid DDTree chain seed capped to depth 4 at budget 32","timestamp":1777502734449,"segment":1,"confidence":15.776850112479453,"asi":{"hypothesis":"cap the pre-inserted greedy spine to depth 4 so budget-32 keeps robust shallow branch coverage while still preserving enough greedy-chain depth","rollback_reason":"chain-depth cap preserved the 16-step/2.06 acceptance pattern but step time stayed around 221 ms and did not beat the kept budget-32 best","next_action_hint":"hybrid chain-depth capping does not improve this workload; avoid further tree construction tweaks unless using a broader prompt corpus"}} +{"run":74,"commit":"abe969d","metric":9.173891,"metrics":{"tps":9.173891,"e2e_tps":1.284058,"spec_sec":24.921,"gen_tokens":32,"steps":16,"committed":33,"step_ms":218.01,"pack_ms":1.07,"draft_ms":146.71,"topk_ms":19.2,"exact_ms":51,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test budget 32 with target feature window 124","timestamp":1777502840662,"segment":1,"confidence":15.556627806464379,"asi":{"hypothesis":"target_feat_ctx=124 may be the smallest window that preserves the budget-32 16-step acceptance behavior while slightly reducing draft compute versus 128","rollback_reason":"124 preserved 16 steps and 2.06 acceptance, but primary TPS stayed below the kept budget-32 best and within the noisy 9.0-9.2 band","next_action_hint":"128 remains the best kept setting; 124 is viable but not a primary improvement, so move away from window tuning"}} +{"run":75,"commit":"abe969d","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":21.522,"gen_tokens":0,"steps":1,"committed":0,"step_ms":0,"pack_ms":3.54,"draft_ms":201.54,"topk_ms":15.37,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"fast rollback budget 14 with draft GPU layers 5","timestamp":1777502969042,"segment":1,"confidence":15.556627806464379,"asi":{"hypothesis":"budget 14 reduces persist to about 1.09 GiB; moving one draft layer off GPU may free enough VRAM for fast rollback while keeping draft faster than the known ngld3 path","rollback_reason":"persist allocation still OOMed at 1086.75 MiB, and draft time already rose to 201 ms with ngld5; no tokens generated","next_action_hint":"do not pursue small draft-offload reductions for persist fit; memory fragmentation or compute buffers still prevent budget14 persist allocation, and draft slowdown erodes expected gains"}} +{"run":76,"commit":"abe969d","metric":9.074822,"metrics":{"tps":9.074822,"e2e_tps":1.279232,"spec_sec":25.015,"gen_tokens":32,"steps":16,"committed":33,"step_ms":220.39,"pack_ms":1.09,"draft_ms":148.35,"topk_ms":19.94,"exact_ms":50.97,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"single memcpy fast path for one-token exact hidden-capture ingest","timestamp":1777503129647,"segment":1,"confidence":14.215306045687464,"asi":{"hypothesis":"exact validation ingests one hidden-capture token at a time; for n_tokens=1 the five layer slices are contiguous and can be copied into the target_feat ring with one memcpy instead of five","rollback_reason":"correctness passed but ingest stayed about 9.19 ms and primary TPS remained below the budget-32 best; hidden-capture retrieval/device transfer dominates, not the five small memcpy calls","next_action_hint":"do not optimize CPU memcpy in ingest further; a GPU-resident target_feat ring or capture path would be needed for meaningful ingest savings"}} +{"run":77,"commit":"abe969d","metric":8.99847,"metrics":{"tps":8.99847,"e2e_tps":1.275307,"spec_sec":25.092,"gen_tokens":32,"steps":16,"committed":34,"step_ms":222.26,"pack_ms":1.08,"draft_ms":150.91,"topk_ms":17.85,"exact_ms":52.38,"exact_decode_ms":0,"acceptance":2.125},"status":"discard","description":"cap draft logits outputs to tree rows for budget 32 row-8 proposal","timestamp":1777503251223,"segment":1,"confidence":13.227183816638313,"asi":{"hypothesis":"when capping proposal rows to 8, also avoid requesting draft logits for rows beyond that cap to reduce lm_head/logits work while keeping exact validation authoritative","rollback_reason":"correctness passed and topk remained lower, but draft time increased to 150.91 ms and total TPS fell below the budget-32 best; sparse logits flags do not reduce the dominant draft compute here","next_action_hint":"do not pursue output-row pruning in the draft path; focus on draft internal compute or exact target decode instead"}} +{"run":78,"commit":"abe969d","metric":9.495774,"metrics":{"tps":9.495774,"e2e_tps":1.288245,"spec_sec":24.84,"gen_tokens":32,"steps":16,"committed":33,"step_ms":210.62,"pack_ms":1.09,"draft_ms":138.52,"topk_ms":19.88,"exact_ms":51.09,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test DDTree budget 26 below 28/32 threshold","timestamp":1777503357658,"segment":1,"confidence":12.929750236155215,"asi":{"hypothesis":"budget 26 may be the smallest budget that preserves the 16-step acceptance gain while reducing tree/proposal overhead versus budget 28/32","rollback_reason":"budget 26 preserved 16 steps and improved step time to 210.62 ms, but primary TPS remained below the kept budget-32 best of 9.82","next_action_hint":"budget 26 is promising and less noisy than nearby budgets; test budget 24 or repeat 26 before considering changing the configured default"}} +{"run":79,"commit":"abe969d","metric":9.011039,"metrics":{"tps":9.011039,"e2e_tps":1.284883,"spec_sec":24.905,"gen_tokens":32,"steps":16,"committed":33,"step_ms":221.95,"pack_ms":1.08,"draft_ms":149.2,"topk_ms":19.99,"exact_ms":51.65,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test DDTree budget 24 after promising budget 26","timestamp":1777503455626,"segment":1,"confidence":12.590074250645914,"asi":{"hypothesis":"budget 24 may be the smallest budget that retains the 16-step acceptance gain while lowering overhead more than budget 26","rollback_reason":"budget 24 preserved 16 steps but step time regressed to 221.95 ms and TPS fell well below budget 26 and the kept best","next_action_hint":"budget 26 looks more promising than 24; repeat 26 or test 25/27 if continuing budget tuning"}} +{"run":80,"commit":"abe969d","metric":8.965795,"metrics":{"tps":8.965795,"e2e_tps":1.280974,"spec_sec":24.981,"gen_tokens":32,"steps":16,"committed":33,"step_ms":223.07,"pack_ms":1.07,"draft_ms":149.46,"topk_ms":20.71,"exact_ms":51.79,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"repeat DDTree budget 26 to check noise","timestamp":1777503555898,"segment":1,"confidence":10.891133809356086,"asi":{"hypothesis":"repeat budget 26 because the previous 9.50 TPS result may indicate a smaller-budget sweet spot or may be noise","rollback_reason":"repeat preserved 16-step acceptance but regressed to 8.97 TPS, confirming the earlier 9.50 budget-26 run was mostly noise","next_action_hint":"do not keep budget 26; budget 32 remains the configured best despite noisy single-run highs"}} +{"run":81,"commit":"a4884f4","metric":10.147906,"metrics":{"tps":10.147906,"e2e_tps":1.29749,"spec_sec":24.663,"gen_tokens":32,"steps":14,"committed":33,"step_ms":225.24,"pack_ms":1.12,"draft_ms":145.35,"topk_ms":20.04,"exact_ms":58.69,"exact_decode_ms":0,"acceptance":2.357},"status":"keep","description":"add proposal-temperature override and test budget 32 at temp 1.0","timestamp":1777503715337,"segment":1,"confidence":11.576312390556494,"asi":{"hypothesis":"greedy sampling uses --temp 0, but DDTree proposal scoring should not necessarily use near-zero temperature; a separate proposal temperature can allocate budget to useful alternatives while exact validation preserves bit-equal output","result":"LLAMA_DDTREE_PROPOSAL_TEMP=1.0 with budget 32 reduced steps from 16 to 14 and improved decode TPS to 10.15 despite higher exact cost per step","next_action_hint":"confirm on repeat and sweep proposal_temp around 0.5/0.75/1.25; if robust, consider making proposal temperature distinct from sampler temperature by default"}} +{"run":82,"commit":"f282526","metric":10.443251,"metrics":{"tps":10.443251,"e2e_tps":1.302932,"spec_sec":24.56,"gen_tokens":32,"steps":14,"committed":33,"step_ms":218.87,"pack_ms":1.11,"draft_ms":140.22,"topk_ms":19.49,"exact_ms":58.02,"exact_decode_ms":0,"acceptance":2.357},"status":"keep","description":"tune DDTree proposal temperature to 0.75","timestamp":1777503819076,"segment":1,"confidence":12.243787036262265,"asi":{"hypothesis":"proposal_temp=0.75 may keep the useful branch diversity from temp 1.0 while weighting top draft candidates more strongly, reducing per-step overhead or improving acceptance stability","result":"proposal_temp=0.75 matched the 14-step/2.36 acceptance gain and improved decode TPS to 10.44 via lower step time than temp 1.0","next_action_hint":"sweep 0.5 and 0.9/1.25, then repeat the best to check noise; validate on a second prompt before hardcoding any default"}} +{"run":83,"commit":"f282526","metric":9.354263,"metrics":{"tps":9.354263,"e2e_tps":1.28123,"spec_sec":24.976,"gen_tokens":32,"steps":15,"committed":34,"step_ms":228.06,"pack_ms":1.09,"draft_ms":149.88,"topk_ms":20.81,"exact_ms":56.24,"exact_decode_ms":0,"acceptance":2.267},"status":"discard","description":"sweep DDTree proposal temperature 0.5","timestamp":1777503920054,"segment":1,"confidence":12.281436942986181,"asi":{"hypothesis":"proposal_temp=0.5 may focus tree budget on higher-confidence draft branches and reduce wasted alternatives versus 0.75/1.0","rollback_reason":"0.5 reduced acceptance relative to 0.75 and needed 15 steps; step time also worsened, so primary TPS dropped below current best","next_action_hint":"proposal_temp around 0.75 remains best; test 0.9 or 0.625 next rather than going lower"}} +{"run":84,"commit":"f282526","metric":10.070113,"metrics":{"tps":10.070113,"e2e_tps":1.299229,"spec_sec":24.63,"gen_tokens":32,"steps":14,"committed":33,"step_ms":226.98,"pack_ms":1.12,"draft_ms":146.35,"topk_ms":20.69,"exact_ms":58.8,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"sweep DDTree proposal temperature 0.9","timestamp":1777504018992,"segment":1,"confidence":12.265451874159854,"asi":{"hypothesis":"proposal_temp=0.9 may keep the 14-step branch-diversity gain of 0.75/1.0 while being less noisy than 0.75","rollback_reason":"0.9 preserved 14 steps and 2.36 acceptance but step time was worse than 0.75, so primary TPS fell below current best","next_action_hint":"0.75 remains the best tested proposal temperature; try 0.625/0.8 or repeat 0.75 for stability"}} +{"run":85,"commit":"f282526","metric":9.341975,"metrics":{"tps":9.341975,"e2e_tps":1.288712,"spec_sec":24.831,"gen_tokens":32,"steps":15,"committed":34,"step_ms":228.36,"pack_ms":1.11,"draft_ms":149.89,"topk_ms":21.18,"exact_ms":56.13,"exact_decode_ms":0,"acceptance":2.267},"status":"discard","description":"sweep DDTree proposal temperature 0.625","timestamp":1777504113549,"segment":1,"confidence":12.249508362405432,"asi":{"hypothesis":"proposal_temp=0.625 may sit between 0.5 and 0.75, preserving more branch diversity than 0.5 while focusing proposals more than 0.75","rollback_reason":"0.625 behaved like 0.5: 15 steps, 2.27 acceptance, and slower step time; primary TPS well below 0.75","next_action_hint":"lower side of temp sweep is done; 0.75 remains best, test 0.8/0.7 or repeat 0.75"}} +{"run":86,"commit":"f282526","metric":9.949568,"metrics":{"tps":9.949568,"e2e_tps":1.292303,"spec_sec":24.762,"gen_tokens":32,"steps":14,"committed":33,"step_ms":229.73,"pack_ms":1.1,"draft_ms":148.29,"topk_ms":21.43,"exact_ms":58.88,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"sweep DDTree proposal temperature 0.8","timestamp":1777504213478,"segment":1,"confidence":12.153903309794075,"asi":{"hypothesis":"proposal_temp=0.8 may be close to the observed 0.75 optimum while possibly more robust than 0.75","rollback_reason":"0.8 kept the 14-step/2.36 acceptance pattern but step time and topk were worse, so primary TPS stayed below 0.75","next_action_hint":"0.75 remains the best tested proposal temperature; test 0.7 or repeat 0.75 for stability"}} +{"run":87,"commit":"c499e6d","metric":10.507099,"metrics":{"tps":10.507099,"e2e_tps":1.308686,"spec_sec":24.452,"gen_tokens":32,"steps":14,"committed":33,"step_ms":217.54,"pack_ms":1.11,"draft_ms":138.33,"topk_ms":19.95,"exact_ms":58.11,"exact_decode_ms":0,"acceptance":2.357},"status":"keep","description":"sweep DDTree proposal temperature 0.7","timestamp":1777504313200,"segment":1,"confidence":12.22183088063815,"asi":{"hypothesis":"proposal_temp=0.7 may be slightly better than 0.75 by keeping the same 14-step branch diversity while weighting top candidates a little more strongly","result":"0.7 preserved the 14-step/2.36 acceptance behavior and improved decode TPS to 10.51 with lower draft/step time than 0.75","next_action_hint":"repeat 0.7 and test 0.68/0.72; validate on additional prompt before making proposal_temp=0.7 a hard default"}} +{"run":88,"commit":"c499e6d","metric":10.013205,"metrics":{"tps":10.013205,"e2e_tps":1.300654,"spec_sec":24.603,"gen_tokens":32,"steps":14,"committed":32,"step_ms":228.27,"pack_ms":1.11,"draft_ms":149.75,"topk_ms":20.67,"exact_ms":56.7,"exact_decode_ms":0,"acceptance":2.286},"status":"discard","description":"sweep DDTree proposal temperature 0.68","timestamp":1777504416257,"segment":1,"confidence":12.743318525446139,"asi":{"hypothesis":"proposal_temp=0.68 may refine the 0.7 optimum by slightly increasing confidence weighting without falling into the 0.625/0.5 lower-acceptance regime","rollback_reason":"0.68 kept 14 steps but committed/acceptance fell and draft/step time worsened, so primary TPS dropped below 0.7","next_action_hint":"0.7 remains better than the lower side; test 0.72 or repeat 0.7 for noise"}} +{"run":89,"commit":"c499e6d","metric":10.111096,"metrics":{"tps":10.111096,"e2e_tps":1.280974,"spec_sec":24.981,"gen_tokens":32,"steps":14,"committed":33,"step_ms":226.06,"pack_ms":1.11,"draft_ms":144.8,"topk_ms":21.34,"exact_ms":58.76,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"sweep DDTree proposal temperature 0.72","timestamp":1777504517855,"segment":1,"confidence":12.766361853364844,"asi":{"hypothesis":"proposal_temp=0.72 may sit close to the 0.7 optimum while preserving 0.75-like acceptance stability","rollback_reason":"0.72 preserved 14 steps and 2.36 acceptance but topk/step time were worse than 0.7; primary TPS below best","next_action_hint":"proposal_temp 0.7 remains best; repeat 0.7 for confirmation or test nearby 0.69/0.71 only if needed"}} +{"run":90,"commit":"c499e6d","metric":9.937888,"metrics":{"tps":9.937888,"e2e_tps":1.292146,"spec_sec":24.765,"gen_tokens":32,"steps":14,"committed":33,"step_ms":230,"pack_ms":1.11,"draft_ms":149.15,"topk_ms":21.08,"exact_ms":58.63,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"repeat DDTree proposal temperature 0.7","timestamp":1777504613299,"segment":1,"confidence":12.802747380091832,"asi":{"hypothesis":"repeat proposal_temp=0.7 to check whether the 10.51 TPS run was stable or mostly noise","rollback_reason":"repeat preserved 14-step/2.36 acceptance but step time regressed to 230 ms, so primary TPS was below the kept best; structural gain is stable but single-run TPS is noisy","next_action_hint":"treat proposal_temp=0.7 as useful for fewer steps but compare future changes against a 9.9-10.5 TPS noise range"}} +{"run":91,"commit":"c499e6d","metric":3.149191,"metrics":{"tps":3.149191,"e2e_tps":1.013588,"spec_sec":31.571,"gen_tokens":32,"steps":14,"committed":33,"step_ms":725.81,"pack_ms":1.16,"draft_ms":149.93,"topk_ms":20.87,"exact_ms":59.92,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"diagnostic batched-vs-exact trace with proposal temp 0.7","timestamp":1777505377073,"segment":1,"confidence":12.319068657292135,"asi":{"hypothesis":"enable batched tree verify trace to measure batched posterior vs exact validation agreement and margin for possible high-confidence exact-skip rule","rollback_reason":"diagnostic trace adds target-tree/snapshot overhead, so primary TPS is expectedly worse and not a runtime optimization","result":"for visible traced tail, batched_exact_same=14 and diff=0 in stats; parsed 3 trace lines all had batched_commit_n equal exact commit_n with min margin 2.12, avg margin 7.62","next_action_hint":"improve trace capture to emit compact aggregate metrics for all steps instead of relying on tailed logs; this run suggests current prompt/temp0.7 has no batched/exact divergence, but broader prompts are needed before trusting batched"}} +{"run":92,"commit":"c499e6d","metric":3.209461,"metrics":{"tps":3.209461,"e2e_tps":1.019303,"spec_sec":31.394,"gen_tokens":32,"steps":14,"committed":33,"step_ms":712.18,"pack_ms":1.16,"draft_ms":152.9,"topk_ms":20.95,"exact_ms":0,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"fast-batched correctness probe at proposal temp 0.7","timestamp":1777505519305,"segment":1,"confidence":11.723331153709331,"asi":{"hypothesis":"if proposal_temp=0.7 eliminates batched/exact divergence on this prompt, fast-batched can skip exact validation and remain bit-equal, clarifying whether exact skip is correctness-safe before optimizing rollback","rollback_reason":"correctness passed, but snapshot+replay fast-batched path is far slower: step=712 ms due snapshot 177.6 ms, target_tree 118.6 ms, replay 53.2 ms; primary TPS much worse","result":"batched path produced bit-equal output with budget32/temp0.7 on this prompt; stats show committed=33, batched_committed=33, max_commit=4, no divergence in final token sequence","next_action_hint":"exact-skip is only useful if paired with no-snapshot fast rollback or another cheap state-commit path; next root-cause work should quantify/compress persist buffers or add aggregate batched/exact margin diagnostics across prompts"}} +{"run":93,"commit":"c499e6d","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":21.672,"gen_tokens":0,"steps":1,"committed":0,"step_ms":0,"pack_ms":3.69,"draft_ms":154.36,"topk_ms":14.63,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"diagnose fast rollback persist memory split at budget 14","timestamp":1777505657370,"segment":1,"confidence":11.723331153709331,"asi":{"hypothesis":"quantify DFlash fast-rollback persist allocation into SSM vs conv components to see whether conv F32 is the main compressible memory target","rollback_reason":"diagnostic logging changed code and fast rollback still OOMed before generating tokens, so not a throughput improvement","result":"budget14 planned persist split: ssm=1008.00 MiB, conv=78.75 MiB, total=1086.75 MiB; CUDA allocation failed for the full 1086.75 MiB","next_action_hint":"conv F32 is only about 7.2% of budget14 persist; compressing conv alone cannot make fast rollback fit. The dominant target is SSM persist, already F16 on CUDA, or freeing/defragmenting ~1.1 GiB elsewhere"}} +{"run":94,"commit":"c499e6d","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":21.723,"gen_tokens":0,"steps":1,"committed":0,"step_ms":0,"pack_ms":3.53,"draft_ms":137.37,"topk_ms":13.7,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"fast rollback fit probe with budget 8 persist","timestamp":1777505801493,"segment":1,"confidence":11.723331153709331,"asi":{"hypothesis":"using a much smaller budget 8 should reduce DFlash fast-rollback persist to about 621 MiB; if it fits, fast-batched plus rollback can reveal whether cheap state commit can offset lower acceptance","rollback_reason":"budget8 persist still OOMed before any token was generated, so no throughput result; snapshot fallback was disabled intentionally to test true fast rollback fit","result":"CUDA allocation failed for 621.00 MiB persist at n_tokens=8; even the small-budget fast rollback path cannot fit in the current 64k/full-target/full-draft memory envelope","next_action_hint":"fast rollback needs either freeing hundreds of MiB elsewhere before persist allocation or changing allocation strategy/lifetime; smaller budget alone is not enough, so investigate target/draft compute buffers or lower n_ctx/offload as controlled diagnostics"}} +{"run":95,"commit":"c499e6d","metric":8.005844,"metrics":{"tps":8.005844,"e2e_tps":1.269589,"spec_sec":25.205,"gen_tokens":32,"steps":18,"committed":33,"step_ms":222.06,"pack_ms":1.08,"draft_ms":132.93,"topk_ms":15.31,"exact_ms":0,"exact_decode_ms":0,"acceptance":1.833},"status":"discard","description":"fast rollback fit probe at 32k context with budget 8","timestamp":1777505918594,"segment":1,"confidence":11.195080202773184,"asi":{"hypothesis":"lowering n_ctx from 64k to 32k may free enough GPU memory for even budget8 fast rollback, revealing whether fast-batched plus cheap state commit can be viable when persist fits","rollback_reason":"32k/budget8 fast rollback fit and remained bit-equal, but acceptance fell to 1.83 over 18 steps and TPS stayed below the current exact-validation best","result":"persist allocated successfully: 621.00 MiB, fast_rollback=18, snapshot_replays=0, exact=0, target_tree=70.61 ms, rollback=0.65 ms, step=222.06 ms, correctness PASS","next_action_hint":"fast rollback itself is cheap when it fits; the remaining costs are target_tree and reduced acceptance at budget8/32k. Test 32k with larger budget if memory allows, or focus on freeing 64k memory for budget32 persist rather than optimizing rollback kernels."}} +{"run":96,"commit":"c499e6d","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":0,"gen_tokens":0,"steps":0,"committed":0,"step_ms":0,"pack_ms":0,"draft_ms":0,"topk_ms":0,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"fast rollback fit probe at 32k context with budget 14","timestamp":1777506020282,"segment":1,"confidence":11.195080202773184,"asi":{"hypothesis":"32k context may free enough GPU memory for budget14 fast rollback persist, improving acceptance versus the budget8 fast-batched rollback run while keeping cheap state commit","rollback_reason":"budget14 persist allocated at 32k, but subsequent CUDA graph compute hit out-of-memory before benchmark metrics were emitted, so no throughput result","result":"persist allocation succeeded for 64 layers, 14 tokens, 1086.75 MiB, then CUDA OOM occurred in ggml_backend_cuda_graph_compute during target tree decode","next_action_hint":"32k frees enough memory for budget14 persist allocation but not enough working-memory headroom for compute; test an intermediate smaller budget such as 10 or 12, or reduce graph compute buffer pressure before revisiting budget14"}} +{"run":97,"commit":"c499e6d","metric":6.498729,"metrics":{"tps":6.498729,"e2e_tps":1.223055,"spec_sec":26.164,"gen_tokens":32,"steps":19,"committed":33,"step_ms":259.16,"pack_ms":1.06,"draft_ms":156.53,"topk_ms":16.75,"exact_ms":0,"exact_decode_ms":0,"acceptance":1.737},"status":"discard","description":"fast rollback fit probe at 32k context with budget 12","timestamp":1777506119588,"segment":1,"confidence":11.176576899258043,"asi":{"hypothesis":"32k context with budget12 may be the largest fast-rollback configuration that leaves enough compute headroom, improving acceptance versus budget8 without the budget14 CUDA OOM","rollback_reason":"budget12 fast rollback fit and stayed bit-equal, but acceptance fell to 1.74 over 19 steps and step time rose to 259 ms; primary TPS was far below the budget32 exact-validation best and below the 32k budget8 rollback probe","result":"persist allocated successfully at 931.50 MiB, fast_rollback=19, snapshot_replays=0, exact=0, target_tree=81.96 ms, rollback=0.66 ms, correctness PASS","next_action_hint":"budget12 worsens both acceptance and step time versus budget8 on this prompt; do not pursue 32k rollback budget sweep as an optimization. The useful finding is that rollback is cheap when persist fits; focus on freeing 64k memory or reducing target_tree cost."}} +{"run":98,"commit":"c499e6d","metric":7.723089,"metrics":{"tps":7.723089,"e2e_tps":1.206091,"spec_sec":26.532,"gen_tokens":32,"steps":18,"committed":33,"step_ms":230.19,"pack_ms":1.11,"draft_ms":137.84,"topk_ms":15.24,"exact_ms":0,"exact_decode_ms":0,"acceptance":1.833},"status":"discard","description":"64k fast rollback fit probe with smaller batch buffers and budget 8","timestamp":1777506225210,"segment":1,"confidence":11.158134659584107,"asi":{"hypothesis":"reducing n_batch/n_ubatch from 512 to 256 may free enough 64k GPU working memory for budget8 fast rollback persist while preserving the 64k prompt/proposal behavior better than lowering n_ctx to 32k","rollback_reason":"fast rollback fit and correctness passed, but budget8 acceptance stayed low at 1.83 over 18 steps and TPS remained below the budget32 exact-validation best","result":"64k with n_batch=256,n_ubatch=256 allocated 621.00 MiB persist, fast_rollback=18, snapshot_replays=0, exact=0, target_tree=73.86 ms, rollback=0.62 ms, correctness PASS","next_action_hint":"smaller batch buffers solve the 64k budget8 persist fit, but budget8 acceptance is too low; test whether n_batch=256 allows budget14 or higher persist, otherwise freeing memory alone at tiny budget is not enough"}} +{"run":99,"commit":"c499e6d","metric":7.011148,"metrics":{"tps":7.011148,"e2e_tps":1.201021,"spec_sec":26.644,"gen_tokens":32,"steps":17,"committed":33,"step_ms":268.48,"pack_ms":1.09,"draft_ms":160.71,"topk_ms":16.82,"exact_ms":0,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"64k fast rollback fit probe with smaller batch buffers and budget 14","timestamp":1777506326499,"segment":1,"confidence":10.999616549456363,"asi":{"hypothesis":"with n_batch/n_ubatch reduced to 256, 64k may have enough memory for budget14 fast rollback persist, improving acceptance over budget8 while keeping cheap rollback state commits","rollback_reason":"budget14 fast rollback fit and stayed bit-equal, but step time rose to 268 ms and acceptance only improved to 1.94, so primary TPS remained below both budget8 rollback and the budget32 exact-validation best","result":"64k n_batch=256,n_ubatch=256 allocated 1086.75 MiB persist, fast_rollback=17, snapshot_replays=0, exact=0, target_tree=86.66 ms, rollback=0.64 ms, correctness PASS","next_action_hint":"n_batch reduction frees enough memory for budget14, but target_tree+draft cost dominate and acceptance remains below the budget32 exact path; test whether budget22/32 fit with n_batch=256 only as a memory map, otherwise stop fast-rollback budget sweep and investigate target_tree cost or GPU target_feat ring."}} +{"run":100,"commit":"c499e6d","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":22.457,"gen_tokens":0,"steps":1,"committed":0,"step_ms":0,"pack_ms":3.57,"draft_ms":157.23,"topk_ms":20.94,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"64k fast rollback fit probe with smaller batch buffers and budget 22","timestamp":1777506423953,"segment":1,"confidence":10.999616549456363,"asi":{"hypothesis":"with n_batch/n_ubatch reduced to 256, 64k may have enough memory for budget22 fast rollback persist, restoring more acceptance than budget8/14 while avoiding exact validation","rollback_reason":"budget22 persist still OOMed before generating tokens, so no throughput result; snapshot fallback was disabled intentionally to test true fast rollback fit","result":"CUDA allocation failed for 1707.75 MiB persist at n_tokens=22 even with n_batch=256,n_ubatch=256; budget14 is the largest tested 64k rollback configuration that fits, but it is too slow","next_action_hint":"do not pursue higher budget fast rollback under current 64k/full-draft memory envelope. Fast rollback needs either a smaller persist representation or major memory freeing; shift to target_tree cost reduction or GPU-resident target_feat ring diagnostics."}} +{"run":101,"commit":"c499e6d","metric":0,"metrics":{"tps":0,"e2e_tps":0,"spec_sec":26.823,"gen_tokens":0,"steps":1,"committed":0,"step_ms":0,"pack_ms":1.57,"draft_ms":153.13,"topk_ms":21.01,"exact_ms":0,"exact_decode_ms":0,"acceptance":0},"status":"crash","description":"64k fast rollback fit probe with n_batch 128 and budget 22","timestamp":1777506525928,"segment":1,"confidence":10.999616549456363,"asi":{"hypothesis":"reducing n_batch/n_ubatch further from 256 to 128 may free enough 64k GPU memory for budget22 fast rollback persist, restoring acceptance while avoiding exact validation","rollback_reason":"budget22 persist still OOMed before generating tokens even with n_batch=128; snapshot fallback was disabled intentionally to test true fast rollback fit","result":"CUDA allocation failed for 1707.75 MiB persist at n_tokens=22; reducing n_batch below 256 did not free enough additional memory, while prompt_ingests doubled to 160","next_action_hint":"stop trying to fit budget22/32 fast rollback via n_batch reduction. Need different persist representation, moving buffers off GPU, or target_tree/GPU target_feat root-cause work."}} +{"run":102,"commit":"c499e6d","metric":3.404755,"metrics":{"tps":3.404755,"e2e_tps":1.040312,"spec_sec":30.76,"gen_tokens":32,"steps":14,"committed":33,"step_ms":671.33,"pack_ms":1.15,"draft_ms":144.48,"topk_ms":20.62,"exact_ms":59.5,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"aggregate batched-vs-exact margin diagnostics at proposal temp 0.7","timestamp":1777506711000,"segment":1,"confidence":10.845539326886065,"asi":{"hypothesis":"add compact aggregate margin metrics for all diagnostic batched tree steps to evaluate whether a high-confidence exact-skip rule could be safe without relying on truncated per-node trace logs","rollback_reason":"diagnostic target-tree verification and logging are expectedly much slower than the runtime exact-validation path, so primary TPS is below best and the code is not a production optimization","result":"on the current prompt with budget32/temp0.7, batched_exact_same=14 and diff=0; aggregate margin steps=14, batched_min_avg=9.087, batched_min_min=0.895, ge1 same=13/13, ge2 same=12/12, ge5 same=9/9","next_action_hint":"current prompt supports a confidence-gated exact skip with thresholds >=1 or >=2, but this is not enough to ship; validate aggregate margins on additional prompts before trusting batched posterior, and pair any exact skip with cheap state commit because snapshot+replay is slower."}} +{"run":103,"commit":"c499e6d","metric":6.832497,"metrics":{"tps":6.832497,"e2e_tps":1.205636,"spec_sec":26.542,"gen_tokens":32,"steps":17,"committed":33,"step_ms":275.5,"pack_ms":1.1,"draft_ms":157.54,"topk_ms":20.23,"exact_ms":0,"exact_decode_ms":0,"acceptance":1.941},"status":"discard","description":"64k fast rollback fit probe with smaller batch buffers and budget 16","timestamp":1777506853395,"segment":1,"confidence":10.715120218615038,"asi":{"hypothesis":"budget16 may be a 64k fast-rollback middle point that still fits with n_batch/n_ubatch=256 while improving acceptance over budget8/14 enough to offset target_tree cost","rollback_reason":"budget16 fast rollback fit and stayed bit-equal, but acceptance remained 1.94 like budget14 while step time worsened to 275.5 ms, so primary TPS stayed far below the budget32 exact-validation best","result":"64k n_batch=256,n_ubatch=256 allocated 1242.00 MiB persist, fast_rollback=17, snapshot_replays=0, exact=0, target_tree=93.03 ms, rollback=0.68 ms, correctness PASS","next_action_hint":"fast rollback parameter sweep is exhausted: budget8/14/16 fit but are slow, budget22 OOMs. Move to target_tree cost reduction, broader batched/exact confidence validation, or GPU-resident target_feat ring instead of more budget tuning."}} +{"run":104,"commit":"c499e6d","metric":3.414521,"metrics":{"tps":3.414521,"e2e_tps":1.03795,"spec_sec":30.83,"gen_tokens":32,"steps":14,"committed":33,"step_ms":669.41,"pack_ms":1.16,"draft_ms":146.87,"topk_ms":20.93,"exact_ms":60.27,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"target-only chain decode timing diagnostic attempt","timestamp":1777507043307,"segment":1,"confidence":10.587800455864077,"asi":{"hypothesis":"instrument the chain reference path to quantify target-only one-token decode time and compare it with DDTree exact validation, clarifying whether the target-only reference mismatch comes from target decode overhead","rollback_reason":"diagnostic output was not surfaced because autoresearch.sh only tails the last 120 lines, and a preserved TRACE env pass-through accidentally enabled expensive target-tree tracing for every run; primary TPS is therefore not comparable","result":"the run still passed correctness but showed trace overhead: target_tree=118.29 ms and snapshot=150.34 ms in the exact path; chain timing detail needs a rerun after fixing TRACE pass-through/output capture","next_action_hint":"fix autoresearch.sh so empty LLAMA_DDTREE_TRACE does not enable tracing, then rerun the chain timing diagnostic with enough output to capture chain timing detail"}} +{"run":105,"commit":"9fbce93","metric":11.924636,"metrics":{"tps":11.924636,"e2e_tps":1.328573,"spec_sec":24.086,"gen_tokens":32,"steps":14,"committed":33,"step_ms":191.68,"pack_ms":1.1,"draft_ms":114.78,"topk_ms":17.91,"exact_ms":57.86,"exact_decode_ms":0,"acceptance":2.357},"status":"keep","description":"fix empty trace pass-through and add chain timing diagnostic hook","timestamp":1777507150730,"segment":1,"confidence":13.950437243906348,"asi":{"hypothesis":"the preserved autoresearch trace env pass-through was accidentally enabling target-tree diagnostics even when LLAMA_DDTREE_TRACE was empty; fixing it should restore normal exact-validation timing and allow a target-only chain timing diagnostic hook","result":"normal exact-validation path restored: target_tree=0, snapshot=0, step=191.68 ms, draft=114.78 ms, exact=57.86 ms, acceptance=2.36, tps=11.92; this is a measurement/harness fix rather than a product-code speedup","next_action_hint":"rerun once to confirm the 11.9 TPS result is not noise, and increase autoresearch output capture further if chain timing detail is still needed"}} +{"run":106,"commit":"9fbce93","metric":10.063906,"metrics":{"tps":10.063906,"e2e_tps":1.307083,"spec_sec":24.482,"gen_tokens":32,"steps":14,"committed":33,"step_ms":227.12,"pack_ms":1.11,"draft_ms":146.94,"topk_ms":20.58,"exact_ms":58.45,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"repeat normal exact-validation path after trace pass-through fix","timestamp":1777507233994,"segment":1,"confidence":13.654565431988555,"asi":{"hypothesis":"repeat the restored no-trace exact-validation path to determine whether the 11.92 TPS run was a real speedup or timing noise","rollback_reason":"repeat preserved the 14-step/2.36 acceptance behavior and no target_tree overhead, but step time returned to 227 ms and primary TPS fell below the kept 11.92 run","result":"target_tree=0 and snapshot=0 confirm the trace pass-through fix works; the 11.92 TPS run was mostly draft/step-time noise, while normal performance remains around 10.0 TPS","next_action_hint":"continue comparing future changes against a noisy 10.0-11.9 TPS range; if chain timing detail is still needed, make autoresearch.sh print the chain timing lines explicitly instead of relying on tail length."}} +{"run":107,"commit":"9fbce93","metric":9.951734,"metrics":{"tps":9.951734,"e2e_tps":1.299756,"spec_sec":24.62,"gen_tokens":32,"steps":14,"committed":33,"step_ms":229.68,"pack_ms":1.12,"draft_ms":149.57,"topk_ms":20.36,"exact_ms":58.6,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"surface target-only chain decode timing diagnostic","timestamp":1777507363548,"segment":1,"confidence":13.340332316400888,"asi":{"hypothesis":"print the chain reference timing lines from the full benchmark log so we can compare target-only one-token decode against DDTree exact validation and identify whether target decode itself is anomalously slow","rollback_reason":"this was a diagnostic harness output change and primary TPS stayed below the kept best; no production code speedup","result":"target-only chain after prompt has decode_avg=24.79 ms over 32 tokens, while DDTree exact validation is 58.60 ms per speculative step plus draft/topk; the target-only server reference gap is therefore not explained by base one-token target decode alone","next_action_hint":"investigate why DDTree exact validation costs ~58 ms/step versus target-only 1-token decode ~25 ms, likely hidden capture, seq_rm/cache state, or validation loop overhead; profile exact validation internals before more proposal tuning"}} +{"run":108,"commit":"9fbce93","metric":10.245246,"metrics":{"tps":10.245246,"e2e_tps":1.309436,"spec_sec":24.438,"gen_tokens":32,"steps":14,"committed":33,"step_ms":223.1,"pack_ms":1.11,"draft_ms":143.19,"topk_ms":20.11,"exact_ms":58.66,"exact_decode_ms":58.19,"acceptance":2.357},"status":"discard","description":"profile exact validation decode/sample timing","timestamp":1777507489027,"segment":1,"confidence":12.942229525616327,"asi":{"hypothesis":"instrument exact validation internals to determine whether the ~58 ms exact cost comes from llama_decode, logits sampling, advance callbacks, or driver overhead","rollback_reason":"diagnostic timing output only; primary TPS stayed below the kept best and there is no product-code optimization in this change","result":"exact validation is almost entirely target llama_decode: exact=58.66 ms/step, exact_decode=58.19 ms/step, exact_sample=0.41 ms/step, exact_nodes=33 over 14 steps; target-only chain decode remains 24.69 ms/token","next_action_hint":"focus on why exact validation llama_decode is ~2.35x slower per token than target-only chain decode: hidden capture enabled during spec, seq_rm/KV state mutations, or context state after draft/tree operations. Run a controlled chain-with-hidden-capture timing diagnostic next."}} +{"run":109,"commit":"9fbce93","metric":10.373578,"metrics":{"tps":10.373578,"e2e_tps":1.315735,"spec_sec":24.321,"gen_tokens":32,"steps":14,"committed":33,"step_ms":220.34,"pack_ms":1.12,"draft_ms":139.62,"topk_ms":20.94,"exact_ms":58.62,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"chain decode timing with hidden capture enabled","timestamp":1777507637568,"segment":1,"confidence":12.882552326173203,"asi":{"hypothesis":"if hidden capture is the reason DDTree exact validation llama_decode is ~2.35x slower than target-only chain decode, enabling capture_hidden on the chain reference should raise chain decode_avg toward ~58 ms/token","rollback_reason":"diagnostic-only harness/env pass-through change; primary TPS did not improve over the kept best","result":"with LLAMA_DDTREE_CHAIN_CAPTURE=1, chain decode_avg remained 24.74 ms/token while DDTree exact stayed 58.62 ms/step; hidden capture alone does not explain the exact-validation decode slowdown","next_action_hint":"next isolate seq_rm/KV state mutation or repeated position rollback effects: run target-only chain with seq_rm before each decode, or instrument exact validation to measure llama_memory_seq_rm and per-node decode depth costs."}} +{"run":110,"commit":"9fbce93","metric":10.039594,"metrics":{"tps":10.039594,"e2e_tps":1.291312,"spec_sec":24.781,"gen_tokens":32,"steps":14,"committed":33,"step_ms":227.67,"pack_ms":1.11,"draft_ms":147.57,"topk_ms":20.34,"exact_ms":58.61,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"chain decode timing with hidden capture and seq_rm controls","timestamp":1777507811343,"segment":1,"confidence":12.847363213337257,"asi":{"hypothesis":"if DDTree exact validation is slower because of hidden capture or the seq_rm call before validation, then target-only chain decode with capture_hidden plus seq_rm before every decode should approach the ~58 ms exact-validation cost","rollback_reason":"diagnostic-only harness changes; primary TPS stayed below the kept best and no production optimization was introduced","result":"chain with capture+seq_rm still decoded at 24.81 ms/token; seq_rm itself averaged only 0.026 ms. DDTree exact remained 58.61 ms/step, so neither hidden capture nor a no-op seq_rm explains the slowdown","next_action_hint":"isolate context-state effects specific to DDTree exact validation: repeated decode of root/current positions after previous speculative commits, recurrent/SSM state mutation, or graph shape/reuse differences. Next useful diagnostic is per-depth exact decode timing or chain replay pattern that decodes root+accepted path chunks like validation."}} +{"run":111,"commit":"9fbce93","metric":10.079439,"metrics":{"tps":10.079439,"e2e_tps":1.297911,"spec_sec":24.655,"gen_tokens":32,"steps":14,"committed":33,"step_ms":226.77,"pack_ms":1.11,"draft_ms":146.48,"topk_ms":20.54,"exact_ms":58.6,"exact_decode_ms":58.14,"acceptance":2.357},"status":"discard","description":"normalize exact validation decode cost per node","timestamp":1777507941061,"segment":1,"confidence":12.752562835954159,"asi":{"hypothesis":"the apparent ~58 ms exact-validation cost may be per speculative step, not per target token; logging exact_nodes and exact_decode per node should clarify whether exact llama_decode is actually slower than target-only chain decode","rollback_reason":"diagnostic-only harness logging; primary TPS stayed below the kept best and no runtime optimization was introduced","result":"exact_nodes=33 over 14 steps; exact_decode=58.14 ms/step but exact_node=24.66 ms/token, matching target-only chain decode_avg=24.76 ms/token. Exact validation is not slower per token; it costs about accepted_nodes_per_step target decodes.","next_action_hint":"stop treating exact validation as a per-token slowdown. The remaining throughput gap is structural: each DDTree step pays draft/topk plus ~2.36 target decodes. Next focus on either trusting batched posterior with cheap state commit, reducing target decodes per step, or lowering draft/topk cost."}} +{"run":112,"commit":"9fbce93","metric":10.473398,"metrics":{"tps":10.473398,"e2e_tps":1.308312,"spec_sec":24.459,"gen_tokens":32,"steps":14,"committed":33,"step_ms":218.24,"pack_ms":0.95,"draft_ms":139.62,"topk_ms":19.74,"exact_ms":57.9,"exact_decode_ms":0,"acceptance":2.357},"status":"discard","description":"test smaller target feature window 112 with proposal temp 0.7","timestamp":1777508043255,"segment":1,"confidence":12.651180373150947,"asi":{"hypothesis":"with proposal_temp=0.7 preserving the 14-step acceptance pattern, reducing target_feat_ctx from 128 to 112 may lower draft/pack cost without losing proposal quality","rollback_reason":"target_feat_ctx=112 preserved 14 steps and 2.36 acceptance and reduced pack/draft versus typical 128 runs, but primary TPS did not beat the kept best and remains within the noisy 10.0-11.9 range","result":"112-window produced tps=10.47, step=218.24 ms, draft=139.62 ms, pack=0.95 ms, acceptance=2.36; proposal_temp=0.7 appears to recover the acceptance loss previously seen at 112 without a clear primary improvement","next_action_hint":"test 96 or 104 only if looking for more draft-cost reduction, but compare against noise; avoid declaring smaller windows best until repeated and validated on broader prompts"}} +{"run":113,"commit":"9fbce93","metric":9.938876,"metrics":{"tps":9.938876,"e2e_tps":1.298174,"spec_sec":24.65,"gen_tokens":32,"steps":16,"committed":33,"step_ms":201.23,"pack_ms":0.88,"draft_ms":131.18,"topk_ms":18.74,"exact_ms":50.38,"exact_decode_ms":0,"acceptance":2.062},"status":"discard","description":"test smaller target feature window 104 with proposal temp 0.7","timestamp":1777508138445,"segment":1,"confidence":13.074597480445423,"asi":{"hypothesis":"with proposal_temp=0.7, target_feat_ctx=104 may further reduce draft/pack cost while still preserving enough proposal quality to beat the 128-window exact-validation path","rollback_reason":"104 reduced draft/pack and per-step time, but lost the 14-step acceptance pattern: steps rose to 16 and acceptance fell to 2.06, so primary TPS stayed below the current best","result":"target_feat_ctx=104 produced step=201.23 ms, draft=131.18 ms, pack=0.88 ms, but required 16 steps; smaller windows below 112 start trading away the structural acceptance gain","next_action_hint":"avoid shrinking target_feat_ctx below 112 for this prompt/temp unless validating a latency-focused variant; next try broader-prompt batched/exact confidence validation or target_tree/state-commit work rather than more small window tuning"}} +{"type":"config","name":"Optimize DFlash DDTree decode TPS on Castle - Phase 2: Batched posterior + fast commit","metricName":"tps","metricUnit":"","bestDirection":"higher"} +{"run":114,"commit":"166b674","metric":9.411626,"metrics":{},"status":"keep","description":"Phase 2 baseline: proposal_temp=0.7, budget=32, target_feat_ctx=128, exact-validation path on real_rendered_prompt","timestamp":1777537489249,"segment":2,"confidence":null,"asi":{"hypothesis":"establish new segment baseline with current best config before pursuing batched posterior validation and fast commit optimizations","benchmark":"./autoresearch.sh on Castle; gen=32; prompt=/tmp/real_rendered_prompt.txt; LLAMA_DDTREE_PROPOSAL_TEMP=0.7; LLAMA_DDTREE_TARGET_FEAT_CTX=128","note":"steps=15 vs previous segment best 14-step runs; within noise band 9.4-10.5 TPS"}} +{"run":115,"commit":"9fbce93","metric":7.28014,"metrics":{},"status":"discard","description":"fast-batched + fast-rollback with budget 10 at 64k context","timestamp":1777541038238,"segment":2,"confidence":null,"asi":{"hypothesis":"test whether fast-batched path with small budget 10 can achieve acceptable TPS without exact validation","note":"fast rollback fits at budget10 with n_batch=256; acceptance=1.94, step=258ms, TPS=7.28 vs baseline 9.41","next_action_hint":"budget10 fast-batched is worse than exact path; try smaller budgets or combine with other optimizations"}} +{"run":116,"commit":"9fbce93","metric":7.136072,"metrics":{},"status":"discard","description":"fast-batched + fast-rollback with budget 14 at 64k context","timestamp":1777541046222,"segment":2,"confidence":null,"asi":{"hypothesis":"test whether budget 14 fast-batched improves over budget 10 with better acceptance","note":"budget 14 step time increased to 264ms but acceptance stayed 1.94; TPS=7.14 worse than budget 10","next_action_hint":"larger budgets increase target_tree cost without acceptance gain in fast-batched mode; smaller budgets may be better"}} +{"run":117,"commit":"9fbce93","metric":7.851335,"metrics":{},"status":"discard","description":"fast-batched + fast-rollback with budget 8 at 64k context","timestamp":1777541054565,"segment":2,"confidence":null,"asi":{"hypothesis":"test whether budget 8 fast-batched achieves best TPS among small-budget fast-rollback fits","next_action_hint":"budget 8 is the best fast-batched config so far (7.85 TPS) but still below exact path; target_tree decode dominates fast-batched cost","note":"budget 8 step=226ms, acceptance=1.83, draft=135ms; fast-batched path consistently slower than exact validation for this implementation"}} +{"run":118,"commit":"ade211a","metric":10.195434,"metrics":{"e2e_tps":1.306336,"spec_sec":24.496,"gen_tokens":32,"steps":14,"committed":33,"step_ms":224.19,"pack_ms":1.16,"draft_ms":143.91,"topk_ms":20.75,"exact_ms":58.32,"acceptance":2.357},"status":"keep","description":"test DDTree top_k=4 with proposal temp 0.7 exact-validation path","timestamp":1777543025604,"segment":2,"confidence":1.0958318828179303,"asi":{"benchmark":"./autoresearch.sh; prompt=/tmp/real_rendered_prompt.txt; LLAMA_DDTREE_PROPOSAL_TEMP=0.7; AUTORESEARCH_BUDGET=32; LLAMA_DDTREE_TOP_K=4","hypothesis":"reducing DDTree top-k cardinality from default 8 to 4 may lower CPU top-k scan cost without losing acceptance on this prompt","next_action_hint":"repeat K=4 and test K=2 to see if smaller K continues to help or hurts acceptance; if stable, consider defaulting to K=4","note":"first K=4 run gave 10.20 TPS, repeat gave 10.03; both above 9.41 baseline but within historical 9.4-10.5 noise band; draft/step time slightly lower than baseline"}} +{"run":119,"commit":"e385d50","metric":10.147005,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":145.55,"e2e_tps":1.302932,"exact_ms":58.27,"gen_tokens":32,"pack_ms":0.94,"spec_sec":24.56,"step_ms":225.26,"steps":14,"topk_ms":20.47},"status":"keep","description":"test target_feat_ctx=112 with top_k=4 and proposal_temp=0.7 exact path","timestamp":1777545728081,"segment":2,"confidence":0.550668850678441,"asi":{"hypothesis":"smaller target feature window 112 may reduce draft compute without losing acceptance versus 128","next_action_hint":"test ctx=96 or 104 to find the lower bound; if none beat 128, keep 128 as default","note":"ctx=112 produced 10.15 TPS, step=225ms, draft=145ms, acceptance=2.36; within noise of ctx=128 best"}} +{"run":120,"commit":"ff511ad","metric":9.892726,"metrics":{"acceptance":2.429,"committed":34,"draft_ms":147.88,"e2e_tps":0.951446,"exact_ms":61.21,"gen_tokens":32,"pack_ms":1.11,"spec_sec":33.633,"step_ms":231.05,"steps":14,"topk_ms":20.82},"status":"keep","description":"exact path with n_batch=64 to validate per-layer persist allocation does not regress baseline","timestamp":1777546971003,"segment":2,"confidence":1,"asi":{"hypothesis":"validate that per-layer persist allocation + n_batch=64 does not break exact path correctness or TPS","next_action_hint":"per-layer allocation is committed and safe; continue exploring fast-batched optimization or server integration","note":"n_batch=64 exact path achieved 9.89 TPS, slightly below n_batch=512 best (10.15) but within noise; per-layer allocation code is working"}} +{"run":121,"commit":"ff511ad","metric":10.188617,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":143.93,"e2e_tps":1.307029,"exact_ms":58.34,"gen_tokens":32,"pack_ms":1.11,"spec_sec":24.483,"step_ms":224.34,"steps":14,"topk_ms":20.92},"status":"discard","description":"exact path budget 40 with top_k=4 and proposal_temp=0.7 to test acceptance saturation beyond budget 32","timestamp":1777547541311,"segment":2,"confidence":1.451900946467489,"asi":{"hypothesis":"budget 40 provides more tree nodes than budget 32, potentially increasing acceptance per step enough to reduce total steps and improve overall TPS","rollback_reason":"budget 40 produced identical acceptance (2.357), steps (14), and committed (33) as budget 32, with nearly identical step time (~224ms); extra budget nodes do not improve proposal quality on this prompt","next_action_hint":"acceptance is saturated at ~2.36 for this prompt with budget >=32; try budget 28 or 30 to see if lower budget achieves same acceptance with less overhead, or switch to cross-prompt validation or server integration testing"}} +{"run":122,"commit":"974eb18","metric":11.808554,"metrics":{"acceptance":2.267,"committed":34,"draft_ms":106.3,"e2e_tps":1.32714,"exact_ms":55.5,"gen_tokens":32,"pack_ms":1.06,"spec_sec":24.112,"step_ms":180.66,"steps":15,"topk_ms":17.77},"status":"keep","description":"exact path budget 28 with top_k=4 and proposal_temp=0.7 — large draft speedup vs budget 32","timestamp":1777547640327,"segment":2,"confidence":4.982182498441078,"asi":{"hypothesis":"budget 28 may reduce draft tree-compute overhead more than it hurts acceptance, yielding better TPS than budget 32 despite slightly lower per-step acceptance","next_action_hint":"repeat budget 28 immediately to validate whether the 106ms draft time and 11.81 TPS are stable or noise; if repeatable, this is a major finding","note":"budget 28 step time dropped to 180.66ms (from ~224ms at budget 32), driven by draft_ms falling to 106.3ms (from ~144ms); acceptance slightly lower at 2.267 vs 2.357 but more than offset by per-step speedup"}} +{"run":123,"commit":"974eb18","metric":10.144721,"metrics":{"acceptance":2.267,"committed":34,"draft_ms":133.06,"e2e_tps":1.307831,"exact_ms":56.01,"gen_tokens":32,"pack_ms":1.07,"spec_sec":24.468,"step_ms":210.29,"steps":15,"topk_ms":20.12},"status":"discard","description":"repeat budget 28 to validate the 11.81 TPS result","timestamp":1777547746753,"segment":2,"confidence":6.1161100677716975,"asi":{"hypothesis":"repeat budget 28 to check whether the 11.81 TPS run was stable or noise","next_action_hint":"budget 28 repeat at 10.14 TPS is comparable to budget 32 (10.20 TPS); try budget 30 as the intermediate sweet spot, or investigate why draft time varies so widely (106-144ms)","note":"repeat budget 28 draft_ms=133.06ms, step_ms=210.29ms, TPS=10.14 — much closer to budget 32 than the first run's 106ms/180ms/11.81; first run was likely an outlier","rollback_reason":"10.14 TPS is slightly below the kept budget 32 best of 10.20 TPS; budget 28 is not clearly better"}} +{"run":124,"commit":"ff511ad","metric":10.068782,"metrics":{"acceptance":2.286,"committed":32,"draft_ms":148.42,"e2e_tps":1.304472,"exact_ms":56.42,"gen_tokens":32,"pack_ms":1.09,"spec_sec":24.531,"step_ms":227.01,"steps":14,"topk_ms":21.04},"status":"discard","description":"exact path budget 30 with top_k=4 and proposal_temp=0.7 — intermediate sweep between 28 and 32","timestamp":1777547834290,"segment":2,"confidence":13.614577179988123,"asi":{"hypothesis":"budget 30 may be the sweet spot between budget 28 and 32, preserving acceptance while reducing overhead","rollback_reason":"budget 30 produced lower acceptance (2.286 vs 2.357), higher draft time (148ms vs 144ms), and lower TPS (10.07 vs 10.20) than budget 32; budget 32 remains the best exact-path configuration","next_action_hint":"budget 32 is the exact-path sweet spot; validate on a different prompt to check for overfitting, or try top_k=6, or explore reducing exact validation cost per step"}} +{"run":125,"commit":"ff511ad","metric":14.684288,"metrics":{"acceptance":2.4,"committed":12,"draft_ms":84.62,"e2e_tps":11.472275,"exact_ms":57.61,"gen_tokens":12,"pack_ms":0.37,"spec_sec":1.046,"step_ms":163.44,"steps":5,"topk_ms":20.8},"status":"discard","description":"cross-prompt validation on qwen_rendered_prompt.txt with best config — only 12 tokens generated, not comparable to 32-token benchmark","timestamp":1777547880051,"segment":2,"confidence":5.272893061768601,"asi":{"hypothesis":"validate best config (budget 32, top_k=4, proposal_temp=0.7) on a different prompt to check for overfitting","next_action_hint":"this prompt only generated 12 tokens (EOS early), making it incomparable to the 32-token benchmark; try a different prompt that generates closer to 32 tokens, such as task761 or task0","note":"config works correctly on different prompt (bit-equal PASS) with acceptance=2.40, but short generation makes TPS incomparable"}} +{"run":126,"commit":"ff511ad","metric":7.524101,"metrics":{"acceptance":1.65,"committed":33,"draft_ms":149.57,"e2e_tps":1.181422,"exact_ms":41.2,"gen_tokens":32,"pack_ms":1.03,"spec_sec":27.086,"step_ms":212.65,"steps":20,"topk_ms":20.82},"status":"discard","description":"cross-prompt validation on task761 (21k tokens) with best config — acceptance drops to 1.65, TPS falls to 7.52","timestamp":1777547973644,"segment":2,"confidence":3.647426181911143,"asi":{"hypothesis":"validate best config on a different prompt to check for overfitting","next_action_hint":"acceptance varies significantly by prompt (1.65 vs 2.36); test whether higher budget or top_k=6 improves harder prompts without hurting the benchmark prompt","note":"task761 generated 32 tokens in 20 steps (vs 14 on real_rendered), acceptance 1.65; draft model effectiveness varies by prompt; config is correct (bit-equal PASS) but slower on this prompt","rollback_reason":"7.52 TPS is below baseline on this prompt, but this is a cross-validation run rather than a benchmark regression; the benchmark prompt remains the optimization target"}} +{"run":127,"commit":"ff511ad","metric":9.903871,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":149.32,"e2e_tps":1.293766,"exact_ms":58.77,"gen_tokens":32,"pack_ms":1.11,"spec_sec":24.734,"step_ms":230.79,"steps":14,"topk_ms":21.56},"status":"discard","description":"exact path top_k=2 with budget 32 and proposal_temp=0.7 — acceptance same as top_k=4 but slightly slower","timestamp":1777548247599,"segment":2,"confidence":6.1161100677716975,"asi":{"hypothesis":"top_k=2 may reduce topk overhead while maintaining acceptance if the extra K=4 candidates are unused on this prompt","next_action_hint":"top_k=2 acceptance is identical to top_k=4 on this prompt (2.357), but draft/topk noise makes it slightly slower; try n_ctx=32k to see if reduced KV cache pressure helps target decode, or test top_k=4 on task761 to check cross-prompt generalization","rollback_reason":"9.90 TPS is below the kept top_k=4 best of 10.20 TPS; top_k=2 is not better on the benchmark prompt","note":"acceptance unchanged at 2.357, confirming that extra K=4 candidates do not improve acceptance on this prompt; draft time 149ms vs 144ms is within noise"}} +{"run":128,"commit":"ff511ad","metric":10.212288,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":143.97,"e2e_tps":1.301236,"exact_ms":58.07,"gen_tokens":32,"pack_ms":1.09,"spec_sec":24.592,"step_ms":223.82,"steps":14,"topk_ms":20.66},"status":"discard","description":"exact path with n_ctx=32768 to test if reduced KV cache allocation improves TPS","timestamp":1777548345195,"segment":2,"confidence":13.614577179988123,"asi":{"hypothesis":"n_ctx=32768 reduces KV cache allocation overhead compared to 65536, potentially improving target decode speed","next_action_hint":"n_ctx has no meaningful effect on exact-path TPS; try draft GPU layers 5 or explore exact-validation batching/parallelization","note":"n_ctx=32k produced 10.21 TPS, essentially identical to 64k (10.20); KV cache pressure is not the bottleneck for this prompt length (20k tokens)","rollback_reason":"10.21 TPS is indistinguishable from 10.20 best; no improvement from halving n_ctx"}} +{"run":129,"commit":"ff511ad","metric":8.558163,"metrics":{"acceptance":2.429,"committed":34,"draft_ms":187.64,"e2e_tps":1.283028,"exact_ms":59.93,"gen_tokens":32,"pack_ms":1.1,"spec_sec":24.941,"step_ms":267.08,"steps":14,"topk_ms":18.36},"status":"discard","description":"exact path with draft GPU layers 5 — one layer off GPU hurts draft time significantly","timestamp":1777548448213,"segment":2,"confidence":5.987365455085925,"asi":{"hypothesis":"draft GPU layers 5 may reduce GPU memory contention or improve CPU-GPU balance","next_action_hint":"ngld6 is optimal for exact path; do not reduce draft offload further. Next explore exact-validation batching/parallelization or target decode call overhead reduction.","note":"draft_ms jumped to 187.64ms from ~144ms, step_ms to 267ms; moving draft layers off GPU hurts more than any acceptance gain","rollback_reason":"8.56 TPS is far below the 10.20 best; draft GPU offload must stay at 6 for exact path performance"}} +{"run":130,"commit":"974eb18","metric":9.884597,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":150.8,"e2e_tps":1.298069,"exact_ms":58.39,"gen_tokens":32,"pack_ms":1.11,"spec_sec":24.652,"step_ms":231.24,"steps":14,"topk_ms":20.91},"status":"discard","description":"exact path with chain_seed=0 (pure best-first) and optimized params — chain seed still helps","timestamp":1777548855529,"segment":2,"confidence":7.771711676074962,"asi":{"hypothesis":"with top_k=4 and proposal_temp=0.7, pure best-first might produce a better tree than chain seed","next_action_hint":"chain seed remains beneficial; do not disable it. Next try gen=64 diagnostic or explore exact-validation batching code optimization.","note":"chain_seed=0 produced 9.88 TPS vs 10.20 with default chain seed; draft_ms=150.8 vs ~144ms; chain seed improves tree quality enough to offset any overhead","rollback_reason":"9.88 TPS below best; chain seed default is correct"}} +{"run":131,"commit":"974eb18","metric":9.720872,"metrics":{"acceptance":2.1,"committed":42,"draft_ms":143.17,"e2e_tps":1.631702,"exact_ms":51.79,"gen_tokens":42,"pack_ms":1.03,"spec_sec":25.74,"step_ms":216.03,"steps":20,"topk_ms":20},"status":"discard","description":"diagnostic run with gen=64 to check TPS over longer generation","timestamp":1777548960687,"segment":2,"confidence":7.844313356514611,"asi":{"hypothesis":"longer generation may reveal acceptance degradation or step time changes not visible in the 32-token benchmark","next_action_hint":"acceptance degrades from 2.36 to 2.10 over longer generation; draft/step time stays stable. Focus on per-step optimizations rather than prompt-length tuning.","note":"gen_tokens=42 (EOS at 42), steps=20, acceptance=2.10, step_ms=216ms; TPS=9.72 vs 10.2 for gen=32; draft model accuracy degrades with longer context, as expected"}} +{"run":132,"commit":"974eb18","metric":9.979978,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":148.91,"e2e_tps":1.293191,"exact_ms":58.56,"gen_tokens":32,"pack_ms":1.1,"spec_sec":24.745,"step_ms":229.03,"steps":14,"topk_ms":20.43},"status":"discard","description":"exact path with K=1 to test whether topk_ms is CPU-bound or sync-bound","timestamp":1777549060523,"segment":2,"confidence":8.220960821503416,"asi":{"hypothesis":"K=1 argmax should be much faster than K=4 logsumexp+heap if CPU computation dominates topk_ms","next_action_hint":"topk_ms is dominated by GPU sync/wait, not CPU computation (K=1=20.4ms vs K=4=20.9ms). Do not optimize CPU top-k further. Focus on draft decode speed or exact-validation batching.","note":"K=1 topk_ms=20.43ms, K=4 topk_ms=20.92ms — difference is only 0.5ms. The ~20ms is mostly GPU synchronization tail from draft decode. Total draft+topk is ~169ms regardless of K.","rollback_reason":"K=1 does not improve topk_ms meaningfully; TPS=9.98 below best of 10.20"}} +{"run":133,"commit":"974eb18","metric":9.720387,"metrics":{"acceptance":2.267,"committed":34,"draft_ms":149.93,"e2e_tps":1.29728,"exact_ms":55.65,"gen_tokens":32,"pack_ms":1.08,"spec_sec":24.667,"step_ms":219.47,"steps":15,"topk_ms":12.79},"status":"discard","description":"exact path with block_size=8 override — draft time unchanged, acceptance drops","timestamp":1777549316541,"segment":2,"confidence":8.160393835091734,"asi":{"hypothesis":"reducing block_size from 16 to 8 should halve draft compute if draft time scales with tokens","next_action_hint":"draft time is NOT proportional to block_size; the SSM tree kernel likely has fixed overhead per step. Revert block_size override and try other optimizations.","note":"draft_ms=149.93 with block_size=8 vs ~144ms with block_size=16 — essentially the same. Acceptance dropped to 2.267 (from 2.357) because shallower tree has fewer candidate paths.","rollback_reason":"block_size=8 hurts acceptance without improving draft speed; not a useful optimization"}} +{"run":134,"commit":"974eb18","metric":10.074996,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":148.2,"e2e_tps":1.297859,"exact_ms":58.21,"gen_tokens":32,"pack_ms":1.09,"spec_sec":24.656,"step_ms":226.87,"steps":14,"topk_ms":19.34},"status":"discard","description":"exact path with KV type q8_0 — similar performance to q4_0","timestamp":1777549692319,"segment":2,"confidence":8.417775842329654,"asi":{"hypothesis":"q8_0 KV cache may improve cache locality or reduce dequantization overhead vs q4_0","next_action_hint":"KV cache type has no meaningful effect on TPS; try other optimizations like target GPU layers or draft model quantization","note":"q8_0 produced 10.07 TPS, essentially identical to q4_0 (10.20); correctness passes; KV type is not a bottleneck","rollback_reason":"no improvement over q4_0 baseline"}} +{"run":135,"commit":"b4033f4","metric":10.364171,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":142.29,"e2e_tps":1.309329,"exact_ms":57.78,"gen_tokens":32,"pack_ms":1.1,"spec_sec":24.44,"step_ms":220.54,"steps":14,"topk_ms":19.33},"status":"keep","description":"exact path with target GPU layers 80 — potential improvement over 65","timestamp":1777549798463,"segment":2,"confidence":9.583840128587989,"asi":{"hypothesis":"increasing target GPU layers from 65 to 80 may offload remaining CPU layers and improve target decode speed","next_action_hint":"repeat with n_gpu_layers=80 to confirm stability, and test n_gpu_layers=99 to check for further gains","note":"TPS=10.36, draft_ms=142.29, step_ms=220.54 — slightly better than n_gpu_layers=65 best of 10.20; could be noise or real improvement from better GPU utilization"}} +{"run":136,"commit":"b4033f4","metric":10.053725,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":147.13,"e2e_tps":1.302189,"exact_ms":58.7,"gen_tokens":32,"pack_ms":1.09,"spec_sec":24.574,"step_ms":227.35,"steps":14,"topk_ms":20.39},"status":"discard","description":"exact path with target GPU layers 99 — worse than 80, possibly noise or memory contention","timestamp":1777549879409,"segment":2,"confidence":10.317799492058029,"asi":{"hypothesis":"n_gpu_layers=99 may fully offload all layers, potentially improving over 80","next_action_hint":"n_gpu_layers=80 appears best among tested values; repeat n_gpu_layers=80 for stability","rollback_reason":"10.05 TPS below n_gpu_layers=80 best of 10.36; 99 may cause memory fragmentation or is within noise","note":"step_ms=227.35, draft_ms=147.13 at n_gpu_layers=99 vs step_ms=220.54, draft_ms=142.29 at n_gpu_layers=80"}} +{"run":137,"commit":"b4033f4","metric":9.495718,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":159.25,"e2e_tps":1.288297,"exact_ms":58.13,"gen_tokens":32,"pack_ms":1.11,"spec_sec":24.839,"step_ms":240.71,"steps":14,"topk_ms":22.19},"status":"discard","description":"repeat n_gpu_layers=80 to check stability — regressed to 9.50 TPS, confirming first run was noise","timestamp":1777549973455,"segment":2,"confidence":9.583840128587989,"asi":{"hypothesis":"repeat n_gpu_layers=80 to confirm the 10.36 TPS improvement was stable","next_action_hint":"n_gpu_layers parameter does not meaningfully affect TPS; the 10.36 run was noise. Return to exploring exact-validation batching or other code optimizations.","note":"repeat draft_ms=159.25, step_ms=240.71, TPS=9.50 — much worse than first run; GPU thermal/scheduling noise dominates small parameter changes","rollback_reason":"9.50 TPS well below best; n_gpu_layers=80 is not a real improvement"}} +{"run":138,"commit":"974eb18","metric":0,"metrics":{"acceptance":0,"committed":0,"draft_ms":0,"e2e_tps":0,"exact_ms":0,"gen_tokens":0,"pack_ms":0,"spec_sec":0,"step_ms":0,"steps":0,"topk_ms":0},"status":"crash","description":"exact path with no flash attention — fails because q4_0 KV cache requires flash_attn","timestamp":1777550079676,"segment":2,"confidence":9.583840128587989,"asi":{"hypothesis":"disabling flash attention might reduce overhead for small batch decodes","next_action_hint":"flash attention is required for this config; do not test without it","rollback_reason":"V cache quantization requires flash_attn; cannot disable flash attention with q4_0 KV cache"}} +{"run":139,"commit":"974eb18","metric":9.686461,"metrics":{"acceptance":2.429,"committed":34,"draft_ms":154.37,"e2e_tps":1.301607,"exact_ms":59.7,"gen_tokens":32,"pack_ms":1.14,"spec_sec":24.585,"step_ms":235.97,"steps":14,"topk_ms":20.73},"status":"discard","description":"exact path with draft_n_ctx capped to 512 — works correctly but slightly slower","timestamp":1777550273300,"segment":2,"confidence":8.417775842329654,"asi":{"hypothesis":"reducing draft_n_ctx from 4096 to 512 may reduce memory pressure and improve draft speed","next_action_hint":"draft_n_ctx can be as low as 512 without breaking correctness, but does not improve TPS. Revert and focus on other optimizations.","note":"draft_n_ctx=512 passed correctness but draft_ms=154.37 vs ~144ms baseline; no speedup","rollback_reason":"9.69 TPS below 10.20 best; draft_n_ctx reduction does not improve performance"}} +{"run":140,"commit":"974eb18","metric":9.452913,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":160.47,"e2e_tps":1.286225,"exact_ms":58.18,"gen_tokens":32,"pack_ms":1.09,"spec_sec":24.879,"step_ms":241.8,"steps":14,"topk_ms":22.02},"status":"discard","description":"exact path with block_size=4 — draft time increases, not helpful","timestamp":1777550372252,"segment":2,"confidence":8.160393835091734,"asi":{"hypothesis":"block_size=4 should reduce draft compute if proportional to tokens","next_action_hint":"draft time has fixed floor around 140-160ms regardless of block_size; do not change block_size","note":"block_size=4 draft_ms=160.47 vs ~144ms at 16; acceptance unchanged at 2.357","rollback_reason":"worse TPS and higher draft time"}} +{"run":141,"commit":"974eb18","metric":9.433795,"metrics":{"acceptance":2.429,"committed":34,"draft_ms":158.73,"e2e_tps":1.286691,"exact_ms":60.06,"gen_tokens":32,"pack_ms":1.38,"spec_sec":24.87,"step_ms":242.29,"steps":14,"topk_ms":22.09},"status":"discard","description":"exact path with target_feat_ctx=160 — slightly higher acceptance but much higher draft time","timestamp":1777550473024,"segment":2,"confidence":7.918284287167812,"asi":{"hypothesis":"larger target feature window 160 may improve draft model accuracy","next_action_hint":"target_feat_ctx=128 remains the sweet spot; 160 increases draft time more than acceptance gain","note":"acceptance=2.429 vs 2.357 but draft_ms=158.73 vs ~144ms; net TPS worse","rollback_reason":"9.43 TPS below 10.20 best"}} +{"run":142,"commit":"b4033f4","metric":10.018032,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":148.43,"e2e_tps":1.257664,"exact_ms":57.94,"gen_tokens":32,"pack_ms":1.1,"spec_sec":25.444,"step_ms":228.16,"steps":14,"topk_ms":20.66},"status":"discard","description":"exact path with actual top_k=4 — same as K=8 on this prompt","timestamp":1777550692217,"segment":2,"confidence":8.160393835091734,"asi":{"hypothesis":"actual top_k=4 may reduce topk overhead while maintaining acceptance","next_action_hint":"K=4 acceptance identical to K=8; test K=2 to find the lower bound","note":"actual K=4 produces same acceptance (2.357) and similar timing to K=8; extra candidates 5-8 are unused on this prompt"}} +{"run":143,"commit":"b4033f4","metric":9.697851,"metrics":{"acceptance":2.2,"committed":33,"draft_ms":145.8,"e2e_tps":1.295442,"exact_ms":53.86,"gen_tokens":32,"pack_ms":1.07,"spec_sec":24.702,"step_ms":219.98,"steps":15,"topk_ms":19.22},"status":"discard","description":"exact path with actual top_k=2 — acceptance drops to 2.20, requiring 15 steps","timestamp":1777550958043,"segment":2,"confidence":8.10071276246994,"asi":{"hypothesis":"actual top_k=2 may reduce overhead without losing too much acceptance","next_action_hint":"K=2 hurts acceptance (2.20 vs 2.36); K=4 matches K=8. Test K=4 again for confirmation.","note":"actual K=2: acceptance=2.20, steps=15, TPS=9.70; K=4 and K=8 both give 2.36/14 steps"}} +{"run":144,"commit":"74b01e6","metric":10.254438,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":144.06,"e2e_tps":1.306389,"exact_ms":57.84,"gen_tokens":32,"pack_ms":1.11,"spec_sec":24.495,"step_ms":222.9,"steps":14,"topk_ms":19.86},"status":"keep","description":"exact path with actual top_k=4 — matches K=8 acceptance with less CPU work, new best TPS","timestamp":1777551092551,"segment":2,"confidence":8.160393835091734,"asi":{"hypothesis":"actual top_k=4 matches K=8 acceptance while using less CPU and memory for top-k extraction","next_action_hint":"K=4 is validated as equivalent to K=8 on this prompt; continue exploring other optimizations like exact-validation batching or server integration","note":"code changes for --ddtree-top-k flag and top_k parameter are now committed; K=4 acceptance=2.357, steps=14, TPS=10.25"}} +{"run":145,"commit":"74b01e6","metric":9.812038,"metrics":{"acceptance":2.357,"committed":33,"draft_ms":152.96,"e2e_tps":1.287208,"exact_ms":58.2,"gen_tokens":32,"pack_ms":1.1,"spec_sec":24.86,"step_ms":232.95,"steps":14,"topk_ms":20.65},"status":"discard","description":"repeat best config to check noise — draft time 153ms, TPS drops to 9.81","timestamp":1777551308662,"segment":2,"confidence":8.10071276246994,"asi":{"hypothesis":"repeat best config to establish noise floor","next_action_hint":"run one more repeat to confirm noise band, then try a different approach if no stable improvement is found","note":"draft_ms varies from 144ms to 153ms between runs; TPS range 9.8-10.3; noise floor is about 0.5 TPS","rollback_reason":"9.81 TPS below best of 10.25; within noise"}} +{"run":146,"commit":"6a23e9f","metric":0,"metrics":{"e2e_tps":0,"spec_sec":0,"gen_tokens":0,"steps":0,"committed":0,"step_ms":0,"pack_ms":0,"draft_ms":0,"topk_ms":0,"exact_ms":0,"acceptance":0},"status":"crash","description":"attempted repeat best config with inline env but harness rejected custom command form","timestamp":1777556393600,"segment":2,"confidence":8.10071276246994,"asi":{"hypothesis":"repeat pure best-first K=4 best config to confirm stability","rollback_reason":"run_experiment enforces direct ./autoresearch.sh invocation when autoresearch.sh exists and rejected inline environment assignment","next_action_hint":"temporarily edit autoresearch.sh defaults to the target config, then run exactly ./autoresearch.sh"}} +{"run":147,"commit":"6a23e9f","metric":0,"metrics":{"e2e_tps":0,"spec_sec":0,"gen_tokens":0,"steps":0,"committed":0,"step_ms":0,"pack_ms":0,"draft_ms":0,"topk_ms":0,"exact_ms":0,"acceptance":0},"status":"crash","description":"attempted direct script run with cd prefix but harness still rejected command form","timestamp":1777556433889,"segment":2,"confidence":8.10071276246994,"asi":{"hypothesis":"run ./autoresearch.sh after encoding best defaults in the script for repeat validation","rollback_reason":"run_experiment requires the command string to be exactly ./autoresearch.sh or bash autoresearch.sh; cd prefix is rejected","next_action_hint":"invoke run_experiment with command exactly ./autoresearch.sh and rely on harness working directory, or add a root-level wrapper only if exact command cannot find the script"}} +{"run":148,"commit":"de24439","metric":10.589855,"metrics":{"e2e_tps":1.303038,"spec_sec":24.558,"gen_tokens":32,"steps":14,"committed":33,"step_ms":215.84,"pack_ms":1.14,"draft_ms":137.23,"topk_ms":19.59,"exact_ms":57.84,"acceptance":2.357},"status":"keep","description":"repeat pure best-first K=4 defaults in autoresearch.sh — stable new best","timestamp":1777556713769,"segment":2,"confidence":8.160393835091734,"asi":{"hypothesis":"encoding the current best exact-path configuration as harness defaults and repeating pure best-first K=4 should confirm whether the 10.33 TPS run was stable","result":"repeat improved to 10.59 TPS with same 14-step acceptance pattern; draft_ms dropped to 137ms and exact_ms remained 57.8ms","next_action_hint":"use these defaults as the new baseline; next test whether budget can be reduced under pure best-first K=4, or validate on task761 to avoid overfitting"}} +{"run":149,"commit":"de24439","metric":9.866251,"metrics":{"e2e_tps":1.293086,"spec_sec":24.747,"gen_tokens":32,"steps":14,"committed":32,"step_ms":231.67,"pack_ms":1.13,"draft_ms":153.82,"topk_ms":20.37,"exact_ms":56.32,"acceptance":2.286},"status":"discard","description":"pure best-first K=4 with budget 28 — acceptance drops and draft time rises","timestamp":1777556801426,"segment":2,"confidence":8.10071276246994,"asi":{"hypothesis":"with pure best-first K=4, budget 28 might preserve the 14-step path while reducing draft/tree overhead versus budget 32","rollback_reason":"budget 28 lost one committed token, lowered acceptance to 2.286, and draft_ms increased to 153.8ms; primary TPS fell to 9.87 below the 10.59 best","next_action_hint":"do not reduce budget below 32 for this pure best-first config; test budget 36/40 only if looking for acceptance gains, or validate best defaults on a harder prompt"}} +{"run":150,"commit":"de24439","metric":10.244327,"metrics":{"e2e_tps":1.298807,"spec_sec":24.638,"gen_tokens":32,"steps":14,"committed":33,"step_ms":223.12,"pack_ms":1.12,"draft_ms":144.18,"topk_ms":20.05,"exact_ms":57.73,"acceptance":2.357},"status":"discard","description":"pure best-first K=4 with budget 36 — no acceptance gain over budget 32","timestamp":1777556896855,"segment":2,"confidence":8.160393835091734,"asi":{"hypothesis":"budget 36 may add useful branches under pure best-first K=4 and improve acceptance enough to offset extra overhead","rollback_reason":"budget 36 produced the same 14 steps, 33 committed tokens, and 2.357 acceptance as budget 32, while step_ms rose to 223ms and TPS fell to 10.24","next_action_hint":"budget 32 remains the sweet spot on this prompt; reset the harness default to 32 before further experiments"}} +{"run":151,"commit":"de24439","metric":9.803621,"metrics":{"e2e_tps":1.275815,"spec_sec":25.082,"gen_tokens":32,"steps":14,"committed":33,"step_ms":233.15,"pack_ms":0.98,"draft_ms":153.76,"topk_ms":20.33,"exact_ms":58.05,"acceptance":2.357},"status":"discard","description":"pure best-first K=4 with target_feat_ctx 112 — pack lower but draft slower","timestamp":1777556989888,"segment":2,"confidence":8.10071276246994,"asi":{"hypothesis":"target_feat_ctx 112 might lower feature packing and draft compute while pure best-first K=4 preserves the 14-step acceptance pattern","rollback_reason":"ctx112 kept acceptance but draft_ms rose to 153.8ms and step_ms to 233ms, dropping TPS to 9.80; the smaller pack cost did not translate to faster draft decode","next_action_hint":"keep target_feat_ctx 128; avoid further small-window tuning unless measuring multiple repeats or a different prompt"}} +{"run":152,"commit":"de24439","metric":9.751044,"metrics":{"e2e_tps":1.293086,"spec_sec":24.747,"gen_tokens":32,"steps":15,"committed":34,"step_ms":218.78,"pack_ms":1.09,"draft_ms":141.95,"topk_ms":20.11,"exact_ms":55.6,"acceptance":2.267},"status":"discard","description":"pure best-first K=4 with proposal_temp 0.5 — fewer accepted per step","timestamp":1777557081428,"segment":2,"confidence":8.586538754323378,"asi":{"hypothesis":"lower proposal_temp 0.5 may sharpen best-first scores and favor higher-confidence paths, possibly reducing exact validation work","rollback_reason":"temp 0.5 required 15 steps with acceptance 2.267 versus 14 steps at temp 0.7; lower per-step exact cost did not offset the extra step","next_action_hint":"keep proposal_temp 0.7; if sweeping temp, try slightly higher 0.85/1.0 rather than lower"}} +{"run":153,"commit":"de24439","metric":10.500824,"metrics":{"e2e_tps":1.308472,"spec_sec":24.456,"gen_tokens":32,"steps":14,"committed":33,"step_ms":217.67,"pack_ms":1.12,"draft_ms":138.2,"topk_ms":20.33,"exact_ms":57.98,"acceptance":2.357},"status":"discard","description":"pure best-first K=4 with proposal_temp 0.85 — close but below 0.7 best","timestamp":1777557170324,"segment":2,"confidence":8.10071276246994,"asi":{"hypothesis":"slightly higher proposal_temp 0.85 may diversify pure best-first branches without changing the 14-step acceptance pattern","rollback_reason":"temp 0.85 preserved acceptance and was fast, but TPS 10.50 remained below the kept 10.59 run at temp 0.7","next_action_hint":"0.7 and 0.85 are within noise; keep 0.7 as current best and test a structurally different knob such as tree row count or exact validation batching"}} +{"run":154,"commit":"de24439","metric":10.038712,"metrics":{"e2e_tps":1.290895,"spec_sec":24.789,"gen_tokens":32,"steps":14,"committed":33,"step_ms":227.69,"pack_ms":1.12,"draft_ms":147.82,"topk_ms":20.69,"exact_ms":58.03,"acceptance":2.357},"status":"discard","description":"reuse draft llama_batch across steps — allocation removal does not improve TPS","timestamp":1777557336806,"segment":2,"confidence":8.8935689439265,"asi":{"hypothesis":"allocating and freeing the draft embedding llama_batch every speculative step may add overhead or disturb allocator/cache behavior; reusing one batch should reduce draft-step overhead","rollback_reason":"reusing draft_batch passed correctness but TPS fell to 10.04 and draft_ms rose to 147.8ms; batch allocation is not a meaningful bottleneck or the result is noise below best","next_action_hint":"avoid micro-optimizing host batch allocation; focus on reducing draft llama_decode time, exact validation calls, or proposal acceptance"}} +{"run":155,"commit":"de24439","metric":10.016276,"metrics":{"e2e_tps":1.293191,"spec_sec":24.745,"gen_tokens":32,"steps":14,"committed":33,"step_ms":228.2,"pack_ms":1.12,"draft_ms":147.99,"topk_ms":21.17,"exact_ms":57.89,"acceptance":2.357},"status":"discard","description":"skip root draft logits output — fewer logits rows but no speedup","timestamp":1777557465323,"segment":2,"confidence":9.858464879449215,"asi":{"hypothesis":"the draft root position logits are unused; disabling logits for batch position 0 should reduce one vocab row of output materialization and top-k sync work","rollback_reason":"correctness passed, but TPS fell to 10.02 and topk_ms increased to 21.17ms; outputting one fewer row does not reduce the dominant GPU synchronization cost","next_action_hint":"do not pursue small logits-row reductions; a real top-k improvement likely needs GPU-side argmax/top-k or eliminating the sync"}} +{"run":156,"commit":"de24439","metric":10.259041,"metrics":{"e2e_tps":1.295861,"spec_sec":24.694,"gen_tokens":32,"steps":14,"committed":33,"step_ms":222.8,"pack_ms":1.12,"draft_ms":142.96,"topk_ms":20.67,"exact_ms":58.01,"acceptance":2.357},"status":"discard","description":"capped chain seed depth 4 with K=4 — no gain over pure best-first","timestamp":1777557573722,"segment":2,"confidence":9.583840128587989,"asi":{"hypothesis":"a short greedy-chain seed may keep early top-1 continuity while leaving most of the budget for best-first branches, combining benefits of chain_seed and pure best-first","rollback_reason":"chain_seed cap 4 produced the same acceptance as pure best-first but slower step_ms and 10.26 TPS, below the 10.59 best","next_action_hint":"pure best-first remains better; if revisiting chain caps, test only if cross-prompt acceptance requires it"}} +{"run":157,"commit":"de24439","metric":9.168135,"metrics":{"e2e_tps":1.262178,"spec_sec":25.353,"gen_tokens":32,"steps":15,"committed":34,"step_ms":232.69,"pack_ms":1.13,"draft_ms":154.03,"topk_ms":21.51,"exact_ms":55.98,"acceptance":2.267},"status":"discard","description":"n_batch/n_ubatch 256 exact path — acceptance pattern worsens and TPS drops","timestamp":1777557665480,"segment":2,"confidence":8.417775842329654,"asi":{"hypothesis":"reducing n_batch/n_ubatch from 512 to 256 may reduce buffer pressure and improve decode latency without affecting outputs","rollback_reason":"n_batch 256 changed the speculative path to 15 steps with lower acceptance and slower draft/topk timing; TPS fell to 9.17","next_action_hint":"keep n_batch/n_ubatch at 512 for the benchmark; do not reduce batch sizing unless targeting memory footprint rather than TPS"}} +{"run":158,"commit":"de24439","metric":0,"metrics":{"e2e_tps":0,"spec_sec":0,"gen_tokens":0,"steps":0,"committed":0,"step_ms":0,"pack_ms":0,"draft_ms":0,"topk_ms":0,"exact_ms":0,"acceptance":0},"status":"crash","description":"n_batch/n_ubatch 1024 exact path — target context OOM","timestamp":1777557702217,"segment":2,"confidence":8.417775842329654,"asi":{"hypothesis":"increasing n_batch/n_ubatch to 1024 may improve graph efficiency or decode scheduling compared with 512","rollback_reason":"context initialization failed allocating a 3.6 GiB CUDA compute buffer for 1024-token graph reserve; no benchmark metrics produced","next_action_hint":"512 is the practical upper bound on the 24GB Castle GPU for this configuration; reset n_batch/n_ubatch to 512"}} +{"run":159,"commit":"de24439","metric":9.87179,"metrics":{"e2e_tps":1.293923,"spec_sec":24.731,"gen_tokens":32,"steps":14,"committed":33,"step_ms":231.54,"pack_ms":1.11,"draft_ms":151.29,"topk_ms":20.99,"exact_ms":58.12,"acceptance":2.357},"status":"discard","description":"reuse 1-token target validation batch — no host allocation speedup","timestamp":1777557879589,"segment":2,"confidence":8.8935689439265,"asi":{"hypothesis":"exact chain validation allocates a 1-token llama_batch for every validated node; reusing one batch may reduce per-step host overhead","rollback_reason":"reusable target batch passed correctness but TPS fell to 9.87 and exact_ms stayed ~58ms; target llama_decode dominates, not batch allocation","next_action_hint":"stop pursuing host allocation micro-optimizations; focus on reducing number of target validation decodes or draft GPU compute"}} +{"run":160,"commit":"de24439","metric":9.727272,"metrics":{"e2e_tps":1.286174,"spec_sec":24.88,"gen_tokens":32,"steps":14,"committed":33,"step_ms":234.98,"pack_ms":1.12,"draft_ms":155.15,"topk_ms":20.47,"exact_ms":58.22,"acceptance":2.357},"status":"discard","description":"pure best-first actual top_k=3 — same acceptance but slower timing","timestamp":1777557966866,"segment":2,"confidence":9.426370246854823,"asi":{"hypothesis":"K=3 may preserve the K=4 acceptance path while slightly reducing top-k work and tree branching","rollback_reason":"K=3 kept the 14-step acceptance pattern but draft/step timing was worse and TPS fell to 9.73; K=4 remains the best tested K above the K=2 acceptance cliff","next_action_hint":"keep top_k=4; K changes mostly affect acceptance, not the ~20ms sync-bound top-k cost"}} +{"run":161,"commit":"de24439","metric":9.447052,"metrics":{"e2e_tps":1.282977,"spec_sec":24.942,"gen_tokens":32,"steps":14,"committed":33,"step_ms":241.95,"pack_ms":1.13,"draft_ms":159.54,"topk_ms":22.89,"exact_ms":58.35,"acceptance":2.357},"status":"discard","description":"pure best-first K=4 with proposal_temp 1.0 — same acceptance but much slower","timestamp":1777558057033,"segment":2,"confidence":8.586538754323378,"asi":{"hypothesis":"proposal_temp 1.0 may diversify branch scores more than 0.7 and improve path quality","rollback_reason":"temp 1.0 did not improve acceptance and produced much slower draft/topk timing, dropping TPS to 9.45","next_action_hint":"keep proposal_temp 0.7; higher temperatures above 0.85 are not useful on this prompt"}} +{"run":162,"commit":"de24439","metric":10.433717,"metrics":{"e2e_tps":1.297806,"spec_sec":24.657,"gen_tokens":32,"steps":14,"committed":33,"step_ms":219.07,"pack_ms":1.13,"draft_ms":140.33,"topk_ms":19.74,"exact_ms":57.83,"acceptance":2.357},"status":"discard","description":"pure best-first K=4 with block_size 12 — close but below block_size 16 best","timestamp":1777558168159,"segment":2,"confidence":8.10071276246994,"asi":{"hypothesis":"block_size 12 may reduce draft work while retaining the same accepted path under pure best-first K=4","rollback_reason":"block_size 12 preserved acceptance and was faster than many noisy runs, but TPS 10.43 stayed below the kept 10.59 block_size 16 baseline","next_action_hint":"block_size 12 is close enough to revisit with repeats only if the 10.59 baseline proves noisy; for now keep default block_size 16"}} +{"run":163,"commit":"de24439","metric":9.508358,"metrics":{"e2e_tps":1.284934,"spec_sec":24.904,"gen_tokens":32,"steps":14,"committed":33,"step_ms":240.39,"pack_ms":1.12,"draft_ms":159.02,"topk_ms":22.22,"exact_ms":58,"acceptance":2.357},"status":"discard","description":"repeat current best defaults after sweep — noisy low baseline repeat","timestamp":1777558253018,"segment":2,"confidence":7.901161640801124,"asi":{"hypothesis":"repeat the current best defaults after restoring block_size 16 to measure noise and ensure the harness is back on the kept config","rollback_reason":"repeat preserved correctness and acceptance but landed in a slow-noise band at 9.51 TPS, below the 10.59 kept best","next_action_hint":"the metric is noisy with draft_ms spanning ~137-159ms; future keeps should be repeated, and block_size 12/0.85 temp close calls should not be kept without repeated wins"}} +{"run":164,"commit":"de24439","metric":10.074996,"metrics":{"e2e_tps":1.299123,"spec_sec":24.632,"gen_tokens":32,"steps":14,"committed":33,"step_ms":226.87,"pack_ms":1.11,"draft_ms":147.9,"topk_ms":20.15,"exact_ms":57.67,"acceptance":2.357},"status":"discard","description":"repeat current best defaults while refocusing on draft decode","timestamp":1777559045624,"segment":2,"confidence":8.10071276246994,"asi":{"hypothesis":"repeat current best pure best-first K=4 baseline to measure the current draft_ms noise band before changing draft decode","rollback_reason":"baseline repeat did not beat kept 10.59 TPS; no code/config change to keep","next_action_hint":"use draft_ms around 148ms as current reference; inspect full log for dflash_draft_ubatch_timing compute vs set_inputs breakdown"}} +{"run":165,"commit":"de24439","metric":9.79312,"metrics":{"e2e_tps":1.298122,"spec_sec":24.651,"gen_tokens":32,"steps":14,"committed":33,"step_ms":233.4,"pack_ms":1.13,"draft_ms":154,"topk_ms":20.15,"exact_ms":58.09,"acceptance":2.357},"status":"discard","description":"repeat block_size 12 draft decode optimization candidate","timestamp":1777559136372,"segment":2,"confidence":8.586538754323378,"asi":{"hypothesis":"block_size 12 may reduce draft graph work while preserving the same 14-step acceptance path, improving draft decode","rollback_reason":"repeat was slower: draft_ms 154ms and TPS 9.79, below kept block_size 16 best","next_action_hint":"do not keep block_size 12; draft compute does not scale down reliably with fewer block tokens"}} +{"run":166,"commit":"de24439","metric":8.835483,"metrics":{"e2e_tps":1.263923,"spec_sec":25.318,"gen_tokens":32,"steps":16,"committed":33,"step_ms":226.36,"pack_ms":0.82,"draft_ms":152.98,"topk_ms":21.5,"exact_ms":51.03,"acceptance":2.062},"status":"discard","description":"pure best-first K=4 with target_feat_ctx 96 to reduce draft FC/attention context","timestamp":1777559283158,"segment":2,"confidence":7.884112887310015,"asi":{"hypothesis":"shrinking target_feat_ctx from 128 to 96 should reduce draft target_feat FC/attention work enough to offset any small acceptance loss","rollback_reason":"ctx96 reduced pack only; draft_ms rose to 153ms and acceptance fell to 2.06, requiring 16 steps and lowering TPS to 8.84","next_action_hint":"target feature context below 128 damages proposal quality and does not reduce draft compute in practice; keep ctx128"}} +{"run":167,"commit":"de24439","metric":8.609927,"metrics":{"e2e_tps":1.273125,"spec_sec":25.135,"gen_tokens":32,"steps":16,"committed":33,"step_ms":232.29,"pack_ms":0.83,"draft_ms":158.47,"topk_ms":22.05,"exact_ms":50.9,"acceptance":2.062},"status":"discard","description":"F16 target_feat_raw draft input under accidental ctx96 — slower and lower acceptance","timestamp":1777559419904,"segment":2,"confidence":7.637615029697407,"asi":{"hypothesis":"using F16 target_feat_raw may reduce draft input bandwidth and FC compute cost","rollback_reason":"run used ctx96 left in autoresearch.sh from the previous config experiment, and F16 input was slower with lower acceptance; not a valid improvement","next_action_hint":"restore autoresearch.sh target_feat_ctx default to 128 before retrying any target_feat format experiment; be aware log_experiment preserves autoresearch.sh changes"}} +{"run":168,"commit":"de24439","metric":10.134408,"metrics":{"e2e_tps":1.297175,"spec_sec":24.669,"gen_tokens":32,"steps":14,"committed":33,"step_ms":225.54,"pack_ms":1.13,"draft_ms":146.19,"topk_ms":20.52,"exact_ms":57.67,"acceptance":2.357},"status":"discard","description":"F16 target_feat_raw draft input at ctx128","timestamp":1777559542537,"segment":2,"confidence":7.884112887310015,"asi":{"hypothesis":"using F16 target_feat_raw at the draft graph input should cut target_feat upload bandwidth and may speed the FC projection without changing exact correctness","rollback_reason":"correctness passed but TPS 10.13 stayed below kept 10.59 and draft_ms 146ms is within normal baseline noise; CPU F32→F16 conversion plus F16 input did not produce a real speedup","next_action_hint":"do not change target_feat_raw to F16 via host conversion; if revisiting, capture/pack directly in F16 or move compression to GPU to avoid CPU conversion cost"}} +{"run":169,"commit":"de24439","metric":6.018997,"metrics":{"e2e_tps":1.197694,"spec_sec":26.718,"gen_tokens":32,"steps":25,"committed":32,"step_ms":212.66,"pack_ms":1.04,"draft_ms":158.08,"topk_ms":21.8,"exact_ms":31.7,"acceptance":1.28},"status":"discard","description":"compute draft lm_head only for non-root output rows","timestamp":1777559673944,"segment":2,"confidence":7.637615029697407,"asi":{"hypothesis":"applying out_ids before the draft lm_head and omitting the unused root logits row should reduce draft output projection work","rollback_reason":"correctness passed but proposal alignment broke: acceptance fell to 1.28 and steps rose to 25; draft_ms also increased to 158ms","next_action_hint":"the dflash draft logits rows are position-sensitive; do not compact rows before lm_head without verifying row/order semantics against standalone DFlash"}} +{"run":170,"commit":"de24439","metric":10.329045,"metrics":{"e2e_tps":1.302613,"spec_sec":24.566,"gen_tokens":32,"steps":14,"committed":33,"step_ms":221.29,"pack_ms":1.11,"draft_ms":142,"topk_ms":20.07,"exact_ms":58.08,"acceptance":2.357},"status":"discard","description":"allow graph reuse for single-sequence DFlash draft ubatch view","timestamp":1777559819914,"segment":2,"confidence":7.711205551462638,"asi":{"hypothesis":"DFlash draft graph reuse may be blocked by non-owning single-sequence ubatch views; allowing reuse could enable CUDA graph replay and reduce draft decode overhead","rollback_reason":"profile still showed reused=0 on every draft ubatch, so this was not the reuse blocker; TPS 10.33 was below kept 10.59 and within noise","next_action_hint":"if pursuing graph reuse, instrument llm_graph_params::allow_reuse to print which field differs rather than guessing"}} +{"run":171,"commit":"de24439","metric":9.832298,"metrics":{"e2e_tps":1.296229,"spec_sec":24.687,"gen_tokens":32,"steps":14,"committed":33,"step_ms":232.47,"pack_ms":1.13,"draft_ms":152.38,"topk_ms":20.99,"exact_ms":57.94,"acceptance":2.357},"status":"discard","description":"force DFlash draft graph params reuse to test CUDA graph replay hypothesis","timestamp":1777559990277,"segment":2,"confidence":7.637615029697407,"asi":{"hypothesis":"if llm_graph_params comparison is the blocker, forcing DFlash draft allow_reuse true should make reused=1 and reduce draft compute overhead","rollback_reason":"profile still showed reused=0 for every draft ubatch, so the graph is reset before can_reuse matters, likely by the DFlash recurrent memory update path; TPS was worse","next_action_hint":"graph reuse optimization would need to avoid gf_res_prev reset in the memory update path or use a separate persistent draft graph; not a simple params comparison fix"}} +{"run":172,"commit":"de24439","metric":9.756752,"metrics":{"e2e_tps":1.294656,"spec_sec":24.717,"gen_tokens":32,"steps":14,"committed":33,"step_ms":234.27,"pack_ms":1.13,"draft_ms":154.23,"topk_ms":20.78,"exact_ms":58.09,"acceptance":2.357},"status":"discard","description":"skip gf_res_prev reset for DFlash draft memory update","timestamp":1777560098862,"segment":2,"confidence":7.565415826302686,"asi":{"hypothesis":"the DFlash recurrent memory update resets gf_res_prev before every draft decode; skipping that reset may permit graph/CUDA graph reuse and reduce draft decode overhead","rollback_reason":"profile still showed reused=0 for every draft step and TPS fell to 9.76; the reset was not the only blocker or reuse is disabled elsewhere","next_action_hint":"avoid more graph-reuse guesses; instrument process_ubatch/can_reuse if this path is revisited"}} +{"run":173,"commit":"de24439","metric":9.997438,"metrics":{"e2e_tps":1.300232,"spec_sec":24.611,"gen_tokens":32,"steps":14,"committed":33,"step_ms":228.63,"pack_ms":1.12,"draft_ms":149.42,"topk_ms":20.16,"exact_ms":57.9,"acceptance":2.357},"status":"discard","description":"draft context n_batch/n_ubatch 16 exactly matching DFlash block size","timestamp":1777560191687,"segment":2,"confidence":8.276184968104477,"asi":{"hypothesis":"setting draft n_batch/n_ubatch to the exact 16-token block size may reduce draft graph reserve/buffer overhead and improve draft decode","rollback_reason":"correctness passed but draft_ms 149ms and TPS 10.00 were below the kept best; smaller draft batch did not improve compute","next_action_hint":"keep draft n_batch cap at 64; batch-size memory savings are not a decode-speed path here"}} +{"run":174,"commit":"de24439","metric":10.155571,"metrics":{"e2e_tps":1.315249,"spec_sec":24.33,"gen_tokens":32,"steps":14,"committed":34,"step_ms":225.07,"pack_ms":1.12,"draft_ms":145.05,"topk_ms":19.36,"exact_ms":59.5,"acceptance":2.429},"status":"discard","description":"draft context n_ctx 256 to reduce draft buffer footprint","timestamp":1777560274649,"segment":2,"confidence":8.845601423014735,"asi":{"hypothesis":"shrinking draft n_ctx to 256 may reduce scheduler/compute-buffer overhead for the DFlash draft graph while retaining correctness","rollback_reason":"correctness passed and acceptance rose to 2.43, but TPS 10.16 stayed below kept 10.59; draft_ms 145ms is only within noise of baseline","next_action_hint":"draft n_ctx can be very small without breaking correctness, but it is not a clear speedup; if memory footprint matters, retest n_ctx=256 separately for server/fast-rollback fit"}} +{"run":175,"commit":"de24439","metric":10.232403,"metrics":{"e2e_tps":1.301766,"spec_sec":24.582,"gen_tokens":32,"steps":14,"committed":33,"step_ms":223.38,"pack_ms":1.16,"draft_ms":144.6,"topk_ms":19.79,"exact_ms":57.8,"acceptance":2.357},"status":"discard","description":"block_size 20 to test larger DFlash draft block without acceptance loss","timestamp":1777560437231,"segment":2,"confidence":8.45678600725038,"asi":{"hypothesis":"a larger draft block may expose more candidate positions for the same budget and improve acceptance without substantially increasing draft decode time","rollback_reason":"block_size 20 preserved the same 14-step acceptance pattern and did not beat the kept 10.59 TPS; no improvement despite draft_ms being in a good noise band","next_action_hint":"block size changes mostly do not affect draft_ms or acceptance on this prompt; reset block_size default and focus elsewhere"}} +{"run":176,"commit":"de24439","metric":10.295083,"metrics":{"e2e_tps":1.303144,"spec_sec":24.556,"gen_tokens":32,"steps":14,"committed":33,"step_ms":222.02,"pack_ms":1.11,"draft_ms":142.8,"topk_ms":20.29,"exact_ms":57.78,"acceptance":2.357},"status":"discard","description":"block_size 32 with budget 32 under pure best-first K=4","timestamp":1777560542708,"segment":2,"confidence":8.10071276246994,"asi":{"hypothesis":"larger DFlash block_size 32 may let the same budget explore deeper draft positions and improve acceptance while draft compute remains mostly fixed-overhead","rollback_reason":"acceptance stayed identical at 2.357 and TPS 10.30 remained below kept 10.59; larger block did not change the accepted path","next_action_hint":"block_size can vary widely without affecting this prompt's accepted path; no need for more block_size sweeps unless cross-prompt behavior differs"}} +{"run":177,"commit":"de24439","metric":9.860292,"metrics":{"e2e_tps":1.307831,"spec_sec":24.468,"gen_tokens":32,"steps":14,"committed":34,"step_ms":231.81,"pack_ms":1.09,"draft_ms":150.43,"topk_ms":20.62,"exact_ms":59.64,"acceptance":2.429},"status":"discard","description":"draft n_ctx 256 combined with block_size 32","timestamp":1777560663130,"segment":2,"confidence":8.45678600725038,"asi":{"hypothesis":"combining small draft n_ctx with a larger block may keep the higher 2.43 acceptance while reducing or amortizing draft decode overhead","rollback_reason":"acceptance stayed at 2.43 but draft_ms rose to 150ms and TPS fell to 9.86; the combination is worse than either isolated noisy candidate","next_action_hint":"do not combine n_ctx=256 and larger block_size for speed; reset block_size default before continuing"}} +{"run":178,"commit":"de24439","metric":6.159437,"metrics":{"e2e_tps":1.203188,"spec_sec":26.596,"gen_tokens":32,"steps":24,"committed":33,"step_ms":216.47,"pack_ms":1.03,"draft_ms":160.45,"topk_ms":21.04,"exact_ms":33.93,"acceptance":1.375},"status":"discard","description":"skip target K/V attention for last two DFlash draft layers","timestamp":1777560789901,"segment":2,"confidence":7.884112887310015,"asi":{"hypothesis":"dropping target-feature K/V matmuls in later draft layers may reduce draft compute while leaving enough proposal quality for exact validation","rollback_reason":"proposal quality collapsed: acceptance 1.38, 24 steps, and draft_ms increased to 160ms; the altered graph likely lost CUDA efficiency and model accuracy","next_action_hint":"do not prune target K/V attention inside trained DFlash layers; accuracy loss overwhelms any theoretical compute saving"}} +{"run":179,"commit":"de24439","metric":9.67981,"metrics":{"e2e_tps":1.291156,"spec_sec":24.784,"gen_tokens":32,"steps":15,"committed":32,"step_ms":220.39,"pack_ms":1.1,"draft_ms":146.41,"topk_ms":20.39,"exact_ms":52.45,"acceptance":2.133},"status":"discard","description":"use only last 32 target_feat positions for last two DFlash draft layers","timestamp":1777560915280,"segment":2,"confidence":8.155576462101964,"asi":{"hypothesis":"later DFlash draft layers may only need a short recent target-feature window; limiting them to 32 positions could reduce K/V matmul and attention compute with smaller accuracy loss than dropping target K/V entirely","rollback_reason":"draft_ms did not improve meaningfully and acceptance fell to 2.13, adding an extra step; TPS 9.68 is below baseline","next_action_hint":"per-layer target window truncation hurts proposal quality before it produces measurable draft speedup; avoid this approximation"}} +{"run":180,"commit":"de24439","metric":8.819071,"metrics":{"e2e_tps":1.275714,"spec_sec":25.084,"gen_tokens":32,"steps":15,"committed":33,"step_ms":241.9,"pack_ms":1.12,"draft_ms":162.84,"topk_ms":23.27,"exact_ms":54.64,"acceptance":2.2},"status":"discard","description":"use last 64 target_feat positions only for final DFlash draft layer","timestamp":1777561023846,"segment":2,"confidence":7.565415826302686,"asi":{"hypothesis":"truncating only the final draft layer to a 64-position target window may preserve most quality while reducing one layer's K/V and attention cost","rollback_reason":"acceptance dropped to 2.20 and draft_ms worsened to 163ms; the graph variant is both less accurate and slower","next_action_hint":"stop per-layer target window truncation experiments; the trained graph wants the full 128-window layout"}} +{"run":181,"commit":"de24439","metric":9.513107,"metrics":{"e2e_tps":1.266624,"spec_sec":25.264,"gen_tokens":32,"steps":14,"committed":33,"step_ms":240.27,"pack_ms":1.11,"draft_ms":158,"topk_ms":22.79,"exact_ms":58.34,"acceptance":2.357},"status":"discard","description":"disable flash_attn flag only for DFlash draft context","timestamp":1777561140520,"segment":2,"confidence":7.420714849615349,"asi":{"hypothesis":"for the small DFlash draft attention shape, disabling the draft context flash-attention flag might choose a lower-overhead kernel","rollback_reason":"draft_ms worsened to 158ms and TPS fell to 9.51; flash attention should stay enabled for the draft context","next_action_hint":"do not disable draft flash attention; attention kernel changes need a targeted custom kernel, not the generic non-flash path"}} +{"run":182,"commit":"de24439","metric":10.198618,"metrics":{"e2e_tps":1.300654,"spec_sec":24.603,"gen_tokens":32,"steps":14,"committed":33,"step_ms":224.12,"pack_ms":1.12,"draft_ms":145.1,"topk_ms":19.86,"exact_ms":58,"acceptance":2.357},"status":"discard","description":"disable detailed DFlash draft ubatch profiling logs in benchmark run","timestamp":1777561240315,"segment":2,"confidence":7.406063452435399,"asi":{"hypothesis":"LLAMA_DDTREE_PROFILE logging inside every draft llama_decode may add measurable draft overhead; disabling it should improve draft_ms without losing driver metrics","rollback_reason":"TPS 10.20 and draft_ms 145ms did not beat the kept 10.59 run; any logging overhead is smaller than run-to-run noise","next_action_hint":"profile logs are useful for draft diagnostics; keep the benchmark default profiling unless doing production latency measurements"}} +{"run":183,"commit":"de24439","metric":10.448024,"metrics":{"e2e_tps":1.300126,"spec_sec":24.613,"gen_tokens":32,"steps":14,"committed":33,"step_ms":218.77,"pack_ms":1.11,"draft_ms":139.63,"topk_ms":20.23,"exact_ms":57.77,"acceptance":2.357},"status":"discard","description":"truly omit LLAMA_DDTREE_PROFILE env to remove draft ubatch logging","timestamp":1777561340487,"segment":2,"confidence":7.5178283267860095,"asi":{"hypothesis":"the script was still setting an empty LLAMA_DDTREE_PROFILE env var, so profiling remained enabled; conditionally omitting it should reduce draft logging overhead","rollback_reason":"draft_ms improved to 139.6ms but TPS 10.45 still did not beat kept best 10.59; as a benchmark-control change it should not be kept unless it clearly improves primary","next_action_hint":"profile omission is a useful production-latency check; reset the benchmark to profiling-on for diagnostics or repeat no-profile if the target metric should exclude profiler overhead"}} +{"run":184,"commit":"de24439","metric":8.95165,"metrics":{"e2e_tps":1.280307,"spec_sec":24.994,"gen_tokens":32,"steps":17,"committed":33,"step_ms":210.28,"pack_ms":1.09,"draft_ms":146.42,"topk_ms":15.12,"exact_ms":47.62,"acceptance":1.941},"status":"discard","description":"skip logsumexp in top-k extraction using raw scaled logits","timestamp":1777561476120,"segment":2,"confidence":7.333912639063997,"asi":{"hypothesis":"removing per-vocab exp/logsumexp from draft top-k should cut topk_ms while exact validation preserves output correctness","rollback_reason":"topk_ms improved from about 20ms to 15ms, but proposal scoring changed enough to drop acceptance to 1.94 and require 17 steps; TPS fell to 8.95","next_action_hint":"try a cheaper row-relative score or approximate length penalty if pursuing top-k CPU cost, but raw logits alone distort tree ranking"}} +{"run":185,"commit":"de24439","metric":9.468352,"metrics":{"e2e_tps":1.284058,"spec_sec":24.921,"gen_tokens":32,"steps":16,"committed":33,"step_ms":211.23,"pack_ms":1.09,"draft_ms":144.33,"topk_ms":15.24,"exact_ms":50.53,"acceptance":2.062},"status":"discard","description":"top-k extraction with row-relative scores instead of logsumexp","timestamp":1777561568397,"segment":2,"confidence":7.124543556285478,"asi":{"hypothesis":"row-relative top-k scores may preserve within-row candidate ordering while removing expensive logsumexp, recovering acceptance compared with raw logits","rollback_reason":"topk_ms stayed low at 15.2ms but acceptance only recovered to 2.06, still requiring 16 steps; TPS 9.47 below best","next_action_hint":"the original logsumexp/row entropy term is important for tree depth ranking; any cheaper approximation needs an explicit depth/entropy penalty"}} +{"run":186,"commit":"de24439","metric":9.291117,"metrics":{"e2e_tps":1.260587,"spec_sec":25.385,"gen_tokens":32,"steps":15,"committed":34,"step_ms":229.61,"pack_ms":1.11,"draft_ms":156.59,"topk_ms":16.14,"exact_ms":55.73,"acceptance":2.267},"status":"discard","description":"top-k-only approximate logsumexp for proposal scores","timestamp":1777561666340,"segment":2,"confidence":6.926796845424066,"asi":{"hypothesis":"approximating logsumexp from only retained top-K logits may keep the row uncertainty penalty while avoiding full-vocab exp cost","rollback_reason":"topk_ms improved to 16ms, but acceptance dropped to 2.27 and draft_ms noise was high; TPS 9.29 is below best","next_action_hint":"top-K-only normalization is not close enough; the full-vocab tail mass affects DDTree depth ranking on this prompt"}} +{"run":187,"commit":"de24439","metric":10.548801,"metrics":{"e2e_tps":1.301766,"spec_sec":24.582,"gen_tokens":32,"steps":14,"committed":33,"step_ms":216.68,"pack_ms":1.12,"draft_ms":142.9,"topk_ms":14.74,"exact_ms":57.88,"acceptance":2.357},"status":"discard","description":"top-K-only normalization with larger budget 40 to recover acceptance","timestamp":1777561796176,"segment":2,"confidence":6.8563876878805425,"asi":{"hypothesis":"the cheap top-K-only logsumexp approximation saves about 5ms topk; increasing budget to 40 may recover the 14-step accepted path and beat the baseline","rollback_reason":"promising but TPS 10.55 was still just below kept 10.59, so not a primary improvement by the rules","next_action_hint":"repeat this exact combo (top-K-only normalization + budget 40) because it recovered acceptance and cut topk_ms to 14.7; a repeat may exceed the noisy best"}} diff --git a/autoresearch.md b/autoresearch.md new file mode 100644 index 000000000000..932b1a2489a3 --- /dev/null +++ b/autoresearch.md @@ -0,0 +1,62 @@ +# Autoresearch: DDTree DFlash throughput on Castle + +## Objective +Improve DFlash + DDTree decode throughput for the Qwen3.5-27B llama.cpp server port on Castle without sacrificing greedy bit-equal correctness in the e2e harness. The current user-visible TPS is far below the standalone DFlash baseline and below the target-only llama.cpp server reference. + +## Metrics +- **Primary**: `tps` (tok/s, higher is better) — decode-only speculative throughput, computed as `generated_tokens / (steps * step_ms)` from the harness timing breakdown. +- **Secondary**: `e2e_tps`, `step_ms`, `draft_ms`, `exact_decode_ms`, `pack_ms`, `topk_ms`, `acceptance`, `gen_tokens`, `spec_sec`. + +## How to Run +`./autoresearch.sh` + +The script syncs the local working tree source files to Castle, builds `test-speculative-tree-e2e` and `llama-server`, runs the Castle CUDA e2e benchmark, and prints `METRIC` lines. + +## Files in Scope +- `common/speculative-tree-driver.cpp` / `.h` — DDTree proposal, target feature ring, exact validation, timing counters. +- `src/llama-context.cpp` / `.h` — decode path, DFlash persist/rollback, draft profiling. +- `src/llama-graph.cpp` — graph inputs and draft positional behavior. +- `src/models/dflash-draft.cpp` — DFlash draft model graph/input integration. +- `tools/server/server-context.cpp` — DDTree server integration and prompt cache rebuild policy. +- `tests/test-speculative-tree-e2e.cpp` — benchmark/correctness harness and metrics output. +- `ggml/src/ggml-cuda/*` and relevant `ggml/src/ggml-cpu/*` files — only for narrowly scoped kernels needed by DDTree/DFlash. + +## Off Limits +- Do not change model files, benchmark prompts, or expected outputs. +- Do not relax correctness gates or remove bit-equal validation. +- Do not overfit to a single prompt by hard-coding token IDs, prompt lengths, or outputs. +- Do not touch unrelated llama.cpp features. + +## Constraints +- Benchmark must pass e2e bit-equal greedy correctness. +- Keep `--require-full-prompt-ingest`; use `--require-ddtree` only when the tested mode should run batched tree verify. +- No new external dependencies. +- Castle is the source of performance truth; local CPU/Metal builds are not enough for keeps. +- If a Castle server is occupying the GPU, benchmark runs may stop it to free VRAM for e2e testing. + +## Current Baseline Context +Status from `DDTREE_STATUS_2026-04-29.md`: +- Stable server config uses `LLAMA_DDTREE_TARGET_FEAT_CTX=1024`, `-ngl65 -ngld6 -c65536`, q4 KV, chain-only exact validation. +- Real task sample after 1024-window server: API TPS 6.38 tok/s, wall TPS 5.92 tok/s. +- Server logs imply raw DDTree decode around 7.7 tok/s with `exact_avg_commit ~= 3.5`, `step ~= 456 ms`. +- Main remaining costs: exact target 1-token decode and draft decode. Recent timing split showed `exact_decode` dominates exact validation. +- Fast rollback helps only when persist fits with full target offload; at 64k/full-draft on 24GB it currently OOMs, and reducing target offload loses the gain. + +## What's Been Tried +- Shared draft lm_head with target output weight: correct and frees about 1 GiB duplicate GPU allocation. +- Full draft GPU offload (`-ngld6`) after avoiding unused persist allocation: correct and substantially faster draft compute. +- Server prompt cache/checkpoint restored with DDTree rebuild window: repeated long prompts now reuse cache and rebuild only last target-feature window. +- `LLAMA_DDTREE_TARGET_FEAT_CTX=1024`: kept previously; reduced pack/draft costs and improved repeated-request latency. +- 2026-04-30 autoresearch found a stronger default: target feature window 128. On the 20,401-token rendered prompt, decode TPS improved from 6.42 to a best run of 9.11 tok/s. 64/96/192 were worse; 160 tied but did not beat 128. This is promising but still needs broader prompt validation. +- Removing default batched tree verify from exact correctness path: kept; avoids diagnostic target-tree/snapshot overhead when exact chain validation is final authority. +- Exact batched spine: discarded; snapshot/restore/replay cost made it slower. +- Skipping logsumexp in top-k scores: discarded; top-k dominated by vocab scan/heap, not score normalization. +- Fast rollback at 64k/full-draft: not viable on current 24GB GPU because persist allocation needs about 1.7 GiB extra VRAM. Retested after 128-window and with tq3/less draft offload: full/`ngld4` still OOM; `ngld3` fits but is slower because draft compute jumps to ~294 ms. +- Budget/K tuning: budget 14 lowered top-k cost but did not beat the best 128-window run; budget 8 was worse. +- Small CPU optimizations discarded: fixed-array top-k and cached mask embedding did not improve primary TPS. + +## Next Experiment Directions +- Lower target 1-token exact decode cost without changing outputs. +- Reduce draft decode graph compute or mixed backend overhead. +- Explore safe conditional batched acceptance only if correctness evidence supports it. +- Move or compress target feature data to reduce pack/upload cost after larger bottlenecks are addressed. diff --git a/autoresearch.sh b/autoresearch.sh new file mode 100755 index 000000000000..db0485e11df0 --- /dev/null +++ b/autoresearch.sh @@ -0,0 +1,201 @@ +#!/usr/bin/env bash +set -euo pipefail + +REMOTE=castle.local +REMOTE_DIR=/home/leechael/workshop/lucebox-hub/dflash/deps/llama.cpp +TARGET_MODEL=/home/leechael/workshop/lucebox-hub/dflash/models/Qwen3.5-27B-Q4_K_M.gguf +DRAFT_MODEL=/home/leechael/workshop/lucebox-hub/dflash/models/draft/model.gguf +PROMPT_TEXT=${AUTORESEARCH_PROMPT:-/tmp/real_rendered_prompt.txt} +GEN=${AUTORESEARCH_GEN:-32} +CTX=${AUTORESEARCH_CTX:-65536} +KV_TYPE=${AUTORESEARCH_KV_TYPE:-q4_0} +DRAFT_GPU_LAYERS=${AUTORESEARCH_DRAFT_GPU_LAYERS:-6} +N_BATCH=${AUTORESEARCH_N_BATCH:-512} +N_UBATCH=${AUTORESEARCH_N_UBATCH:-512} +BUDGET=${AUTORESEARCH_BUDGET:-40} +PROFILE=${LLAMA_DDTREE_PROFILE:-1} +BLOCK_SIZE=${LLAMA_DDTREE_BLOCK_SIZE:-} +TARGET_FEAT_CTX=${LLAMA_DDTREE_TARGET_FEAT_CTX:-128} +FAST_BATCHED=${LLAMA_DDTREE_FAST_BATCHED:-} +FAST_ROLLBACK=${LLAMA_DDTREE_FAST_ROLLBACK:-} +SNAPSHOT_FALLBACK=${LLAMA_DDTREE_SNAPSHOT_FALLBACK:-} +FORCE_CHAIN=${LLAMA_DDTREE_FORCE_CHAIN_KERNEL:-} +SKIP_EXACT_SEQ_RM=${LLAMA_DDTREE_SKIP_EXACT_SEQ_RM:-} +TOP_K=${LLAMA_DDTREE_TOP_K:-4} +TREE_ROWS=${LLAMA_DDTREE_TREE_ROWS:-} +CHAIN_SEED=${LLAMA_DDTREE_CHAIN_SEED:-0} +CHAIN_DEPTH_CAP=${LLAMA_DDTREE_CHAIN_DEPTH_CAP:-} +PROPOSAL_TEMP=${LLAMA_DDTREE_PROPOSAL_TEMP:-0.7} +TRACE=${LLAMA_DDTREE_TRACE:-} +CHAIN_CAPTURE=${LLAMA_DDTREE_CHAIN_CAPTURE:-} +CHAIN_SEQ_RM=${LLAMA_DDTREE_CHAIN_SEQ_RM:-} +NO_FLASH_ARG=${AUTORESEARCH_NO_FLASH_ATTN:+--no-flash-attn} +AGENT_TPS_GATE=${AUTORESEARCH_AGENT_TPS_GATE:-required} +AGENT_TPS_LOG=${AUTORESEARCH_AGENT_TPS_LOG:-} +AGENT_TPS_MIN_API=${AUTORESEARCH_AGENT_TPS_MIN_API:-} +AGENT_TPS_MIN_WALL=${AUTORESEARCH_AGENT_TPS_MIN_WALL:-} + +# Sync only source/control files needed for the benchmark. Avoid .git and build dirs. +rsync -az --delete \ + --exclude build --exclude build-server --exclude build-cpu --exclude .git \ + common include src tests tools ggml CMakeLists.txt cmake \ + "$REMOTE:$REMOTE_DIR/" >/dev/null + +ssh "$REMOTE" "pgrep -f '[b]uild-server/bin/llama-server' | xargs -r kill" + +ssh "$REMOTE" "cd '$REMOTE_DIR' && cmake --build build-server -j 16 --target test-speculative-tree-e2e llama-server" >/tmp/autoresearch_build.log 2>&1 || { + tail -80 /tmp/autoresearch_build.log + exit 1 +} + +out_file=$(mktemp /tmp/autoresearch_ddtree.XXXXXX) +ssh "$REMOTE" "cd '$REMOTE_DIR' && \ + LLAMA_DDTREE_PROFILE='$PROFILE' \ + LLAMA_DDTREE_BLOCK_SIZE='$BLOCK_SIZE' \ + LLAMA_DDTREE_TARGET_FEAT_CTX='$TARGET_FEAT_CTX' \ + LLAMA_DDTREE_FAST_BATCHED='$FAST_BATCHED' \ + LLAMA_DDTREE_FAST_ROLLBACK='$FAST_ROLLBACK' \ + LLAMA_DDTREE_SNAPSHOT_FALLBACK='$SNAPSHOT_FALLBACK' \ + LLAMA_DDTREE_FORCE_CHAIN_KERNEL='$FORCE_CHAIN' \ + LLAMA_DDTREE_SKIP_EXACT_SEQ_RM='$SKIP_EXACT_SEQ_RM' \ + LLAMA_DDTREE_TOP_K='$TOP_K' \ + LLAMA_DDTREE_TREE_ROWS='$TREE_ROWS' \ + LLAMA_DDTREE_CHAIN_SEED='$CHAIN_SEED' \ + LLAMA_DDTREE_CHAIN_DEPTH_CAP='$CHAIN_DEPTH_CAP' \ + LLAMA_DDTREE_PROPOSAL_TEMP='$PROPOSAL_TEMP' \ + ${TRACE:+LLAMA_DDTREE_TRACE='$TRACE'} \ + ${CHAIN_CAPTURE:+LLAMA_DDTREE_CHAIN_CAPTURE='$CHAIN_CAPTURE'} \ + ${CHAIN_SEQ_RM:+LLAMA_DDTREE_CHAIN_SEQ_RM='$CHAIN_SEQ_RM'} \ + ./build-server/bin/test-speculative-tree-e2e \ + --target-model '$TARGET_MODEL' \ + --draft-model '$DRAFT_MODEL' \ + --prompt-text '$PROMPT_TEXT' \ + --gen '$GEN' \ + --out-spec /tmp/autoresearch_spec.bin \ + --out-chain /tmp/autoresearch_chain.bin \ + --ddtree-budget '$BUDGET' \ + ${TOP_K:+--ddtree-top-k '$TOP_K'} \ + --require-full-prompt-ingest \ + --temp 0 \ + --n-gpu-layers ${AUTORESEARCH_N_GPU_LAYERS:-65} \ + --draft-gpu-layers '$DRAFT_GPU_LAYERS' \ + --n-ctx '$CTX' \ + --n-batch '$N_BATCH' \ + --n-ubatch '$N_UBATCH' \ + --kv-type '$KV_TYPE' \ + $NO_FLASH_ARG" >"$out_file" 2>&1 || { + tail -120 "$out_file" + exit 1 +} + +cat "$out_file" | tail -220 +grep -E 'chain timing detail|chain timing:' "$out_file" || true + +python3 - "$out_file" <<'PY' +import re, sys +text = open(sys.argv[1], 'r', errors='replace').read() + +def last_float(pattern, default=0.0): + vals = re.findall(pattern, text) + return float(vals[-1]) if vals else default + +def last_int(pattern, default=0): + vals = re.findall(pattern, text) + return int(vals[-1]) if vals else default + +spec_sec = last_float(r"spec timing:\s*([0-9.]+)\s*sec") +gen_tokens = last_int(r"spec:\s*generated\s+(\d+)\s+tokens") +# Committed can be > requested generation because one speculative step may validate beyond the requested output. +steps = last_int(r"steps=(\d+)") +committed = last_int(r"(?:^|\s)committed=(\d+)") +step_ms = last_float(r"spec timing avg:.*?step=([0-9.]+)") +pack_ms = last_float(r"spec timing avg:.*?pack=([0-9.]+)") +draft_ms = last_float(r"spec timing avg:.*?draft=([0-9.]+)") +topk_ms = last_float(r"spec timing avg:.*?topk=([0-9.]+)") +exact_ms = last_float(r"spec timing avg:.*?exact=([0-9.]+)") +exact_decode_ms = last_float(r"spec timing avg:.*?exact_decode=([0-9.]+)") +acceptance = last_float(r"exact_avg_commit_per_step=([0-9.]+)") +if not spec_sec or not gen_tokens or not step_ms or not steps: + print("Failed to parse spec timing/generated tokens/decode step timing", file=sys.stderr) + sys.exit(2) +e2e_tps = gen_tokens / spec_sec +decode_tps = gen_tokens / (steps * step_ms / 1000.0) +print(f"METRIC tps={decode_tps:.6f}") +print(f"METRIC e2e_tps={e2e_tps:.6f}") +print(f"METRIC spec_sec={spec_sec:.6f}") +print(f"METRIC gen_tokens={gen_tokens}") +print(f"METRIC steps={steps}") +print(f"METRIC committed={committed}") +print(f"METRIC step_ms={step_ms:.6f}") +print(f"METRIC pack_ms={pack_ms:.6f}") +print(f"METRIC draft_ms={draft_ms:.6f}") +print(f"METRIC topk_ms={topk_ms:.6f}") +print(f"METRIC exact_ms={exact_ms:.6f}") +print(f"METRIC exact_decode_ms={exact_decode_ms:.6f}") +print(f"METRIC acceptance={acceptance:.6f}") +PY + +if [[ "$AGENT_TPS_GATE" != "off" ]]; then + if [[ -z "$AGENT_TPS_LOG" || ! -f "$AGENT_TPS_LOG" ]]; then + echo "agent TPS gate failed: set AUTORESEARCH_AGENT_TPS_LOG=/path/to/agent-run.log, or AUTORESEARCH_AGENT_TPS_GATE=off for compile-only runs" >&2 + rm -f "$out_file" + exit 3 + fi + python3 - "$AGENT_TPS_LOG" "$AGENT_TPS_MIN_API" "$AGENT_TPS_MIN_WALL" <<'PY' +import re, sys +path, min_api_s, min_wall_s = sys.argv[1:4] +text = open(path, 'r', errors='replace').read() + +def f(pattern, required=True, default=0.0): + vals = re.findall(pattern, text, re.I | re.S) + if vals: + v = vals[-1] + if isinstance(v, tuple): + v = v[0] + return float(str(v).replace(',', '')) + if required: + print(f"agent TPS gate failed: missing pattern {pattern!r}", file=sys.stderr) + sys.exit(4) + return default + +def i(pattern, required=False, default=0): + vals = re.findall(pattern, text, re.I | re.S) + if vals: + v = vals[-1] + if isinstance(v, tuple): + v = v[0] + return int(str(v).replace(',', '')) + if required: + print(f"agent TPS gate failed: missing pattern {pattern!r}", file=sys.stderr) + sys.exit(4) + return default + +m = re.findall(r"(?:TASK\s+)?TPS\s+([0-9.]+)\s*tok/s\s*API,\s*([0-9.]+)\s*tok/s\s*wall", text, re.I) +if not m: + m = re.findall(r"agent[_ ]TPS:\s*([0-9.]+)\s*tok/s\s*API,\s*([0-9.]+)\s*tok/s\s*wall", text, re.I) +if not m: + print("agent TPS gate failed: missing 'TPS tok/s API, tok/s wall'", file=sys.stderr) + sys.exit(4) +api_tps, wall_tps = map(float, m[-1]) +requests = i(r"(?:req|requests)\s*[:=]?\s*([0-9,]+)") +output_tokens = i(r"(?:out|output(?:_tokens)?)\s*[:=]?\s*([0-9,]+)") +input_tokens = i(r"(?:in|input(?:_tokens)?)\s*[:=]?\s*([0-9,]+)") +api_sec = f(r"api\s*[:=]?\s*([0-9.]+)s", required=False) +wall_sec = f(r"wall\s*[:=]?\s*([0-9.]+)s", required=False) +if min_api_s and api_tps < float(min_api_s): + print(f"agent TPS gate failed: api_tps {api_tps:.6f} < {float(min_api_s):.6f}", file=sys.stderr) + sys.exit(5) +if min_wall_s and wall_tps < float(min_wall_s): + print(f"agent TPS gate failed: wall_tps {wall_tps:.6f} < {float(min_wall_s):.6f}", file=sys.stderr) + sys.exit(5) +print(f"METRIC agent_tps_api={api_tps:.6f}") +print(f"METRIC agent_tps_wall={wall_tps:.6f}") +print(f"METRIC agent_requests={requests}") +print(f"METRIC agent_output_tokens={output_tokens}") +print(f"METRIC agent_input_tokens={input_tokens}") +print(f"METRIC agent_api_sec={api_sec:.6f}") +print(f"METRIC agent_wall_sec={wall_sec:.6f}") +PY +fi + +rm -f "$out_file" diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index b313a7320e56..b9878b34182c 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -90,6 +90,12 @@ add_library(${TARGET} STATIC sampling.h speculative.cpp speculative.h + speculative-draft-backend.cpp + speculative-draft-backend.h + speculative-tree.cpp + speculative-tree.h + speculative-tree-driver.cpp + speculative-tree-driver.h unicode.cpp unicode.h jinja/lexer.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 3d0183ed7026..c6acebc36f62 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -390,6 +390,7 @@ const std::vector kv_cache_types = { GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_TQ3_0, }; static ggml_type kv_cache_type_from_str(const std::string & s) { @@ -600,6 +601,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context throw std::invalid_argument("error: --model is required\n"); } + // DDTree mode requires a draft model + if (params.speculative.ddtree_mode && !params.speculative.has_dft()) { + throw std::invalid_argument("error: --speculative-mode ddtree requires -md/--model-draft\n"); + } + if (params.escape) { string_process_escapes(params.prompt); string_process_escapes(params.input_prefix); @@ -3553,6 +3559,44 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.ngram_min_hits = value; } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--speculative-mode"}, "[chain|ddtree]", + "speculative decoding mode: 'chain' = standard draft-model chain (default), " + "'ddtree' = DDTree dflash-draft speculative decoding (requires -md)", + [](common_params & params, const std::string & value) { + if (value == "chain") { + params.speculative.ddtree_mode = false; + } else if (value == "ddtree") { + params.speculative.ddtree_mode = true; + } else { + throw std::invalid_argument("--speculative-mode must be 'chain' or 'ddtree'"); + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SPECULATIVE_MODE")); + add_opt(common_arg( + {"--ddtree-budget"}, "N", + string_format("DDTree: total tree node budget per spec step (default: %d)", params.speculative.ddtree_budget), + [](common_params & params, int value) { + if (value < 1) { + throw std::invalid_argument("--ddtree-budget must be >= 1"); + } + params.speculative.ddtree_budget = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DDTREE_BUDGET")); + add_opt(common_arg( + {"--ddtree-temp"}, "F", + string_format("DDTree: temperature for draft log-prob extraction (default: %.1f)", (double)params.speculative.ddtree_temp), + [](common_params & params, const std::string & value) { + params.speculative.ddtree_temp = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DDTREE_TEMP")); + add_opt(common_arg( + {"--ddtree-no-chain-seed"}, + "DDTree: disable chain-seed greedy initialization (enabled by default)", + [](common_params & params) { + params.speculative.ddtree_chain_seed = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-ctkd", "--cache-type-k-draft"}, "TYPE", string_format( diff --git a/common/common.h b/common/common.h index 020b6a721ff9..d0287865ca35 100644 --- a/common/common.h +++ b/common/common.h @@ -355,6 +355,12 @@ struct common_params_speculative { bool has_dft() const { return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty(); } + + // DDTree speculative decoding parameters (Phase 5) + bool ddtree_mode = false; // true when --speculative-mode ddtree is set + int32_t ddtree_budget = 22; // tree node budget (matches dflash default) + float ddtree_temp = 1.0f; // temperature for draft log-prob extraction + bool ddtree_chain_seed = true; // seed the tree heap with greedy chain (recommended) }; struct common_params_vocoder { diff --git a/common/speculative-draft-backend.cpp b/common/speculative-draft-backend.cpp new file mode 100644 index 000000000000..0b0d2af95545 --- /dev/null +++ b/common/speculative-draft-backend.cpp @@ -0,0 +1,326 @@ +#include "speculative-draft-backend.h" + +#include "log.h" + +#include +#include +#include +#include +#include +#include + +using ddtree_draft_clock = std::chrono::steady_clock; + +static double draft_elapsed_ms(ddtree_draft_clock::time_point t0) { + return std::chrono::duration(ddtree_draft_clock::now() - t0).count(); +} + +int llama_speculative_draft_top_k_width(int block_size, const llama_ddtree_params & params) { + const int L = block_size - 1; + return (params.top_k > 0) ? params.top_k : ((params.budget > L) ? 8 : 1); +} + +bool llama_speculative_draft_pack_target_feat(const llama_speculative_draft_target_feat_view & view, + std::vector & out, + int64_t & ctx_len) { + ctx_len = 0; + if (view.ring == nullptr || view.n_committed <= 0 || view.cap <= 0 || view.n_embd_fc <= 0) { + return false; + } + + ctx_len = std::min(view.n_committed, view.cap); + const int64_t ring_start = view.n_committed - ctx_len; + + out.resize((size_t) view.n_embd_fc * ctx_len); + for (int64_t t = 0; t < ctx_len; ++t) { + const int64_t ring_col = (ring_start + t) % view.cap; + const float * src = view.ring + ring_col * view.n_embd_fc; + float * dst = out.data() + t * view.n_embd_fc; + memcpy(dst, src, (size_t) view.n_embd_fc * sizeof(float)); + } + return true; +} + +class llama_speculative_llama_draft_backend final : public llama_speculative_draft_backend { + public: + llama_speculative_llama_draft_backend(llama_context * draft_ctx, + const llama_model * target_model, + int64_t n_embd, + int64_t n_vocab, + int64_t block_size, + llama_token mask_token_id, + const llama_ddtree_params & params) : + draft_ctx(draft_ctx), + target_model(target_model), + n_embd(n_embd), + n_vocab(n_vocab), + block_size(block_size), + mask_token_id(mask_token_id), + params(params) { + mask_embd.resize((size_t) n_embd); + noise_embd.resize((size_t) block_size * n_embd); + pos.resize((size_t) block_size); + n_seq_id.assign((size_t) block_size, 1); + seq_id_values.assign((size_t) block_size, 0); + seq_id_ptrs.resize((size_t) block_size); + logits.assign((size_t) block_size, 1); + for (int64_t i = 0; i < block_size; ++i) { + seq_id_ptrs[(size_t) i] = &seq_id_values[(size_t) i]; + } + } + + bool init() { + if (draft_ctx == nullptr || target_model == nullptr || n_embd <= 0 || n_vocab <= 0 || block_size <= 1) { + return false; + } + if (llama_model_token_embd_lookup(target_model, mask_token_id, mask_embd.data(), n_embd) != 0) { + LOG_ERR("%s: token_embd_lookup failed for mask_token=%d\n", __func__, (int) mask_token_id); + return false; + } + for (int64_t i = 1; i < block_size; ++i) { + memcpy(noise_embd.data() + i * n_embd, mask_embd.data(), (size_t) n_embd * sizeof(float)); + } + llama_set_dflash_draft_top_k(draft_ctx, + std::min(llama_speculative_draft_top_k_width((int) block_size, params), + n_vocab)); + return true; + } + + const char * name() const override { return "dflash-topk"; } + + bool ingest_target_capture(llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap, + double & elapsed_ms) override { + const auto t0 = ddtree_draft_clock::now(); + elapsed_ms = 0.0; + if (target_ctx == nullptr || n_dfs <= 0 || cap <= 0) { + return false; + } + const int ret = llama_dflash_draft_update_fused_cache_from_capture(draft_ctx, target_ctx, dfs_indices, + n_dfs, first_pos, cap); + elapsed_ms = draft_elapsed_ms(t0); + if (ret != 0) { + return false; + } + fused_target_feat_cap = cap; + fused_target_feat_n_embd = n_embd; + fused_target_feat_n_committed = first_pos + n_dfs; + return true; + } + + bool decode_topk(llama_token root_token, + llama_pos committed_pos, + const llama_speculative_draft_target_feat_view & target_feat, + std::vector & top_log_probs, + std::vector & top_token_ids, + llama_speculative_draft_decode_info & info) override { + info = {}; + info.L = (int) block_size - 1; + info.K = std::min(llama_speculative_draft_top_k_width((int) block_size, params), n_vocab); + + if (target_feat.ring == nullptr || target_feat.n_committed <= 0 || target_feat.cap <= 0 || + target_feat.n_embd_fc <= 0 || target_feat.n_embd_fc % 5 != 0) { + LOG_ERR( + "%s: target_feat ring is empty; call llama_speculative_tree_driver_ingest_prompt_capture first\n", + __func__); + return false; + } + + const int64_t fused_n_embd = target_feat.n_embd_fc / 5; + { + const auto t0 = ddtree_draft_clock::now(); + if (!ensure_fused_target_feat(target_feat, fused_n_embd)) { + info.t_draft_decode_ms += draft_elapsed_ms(t0); + return false; + } + info.t_draft_decode_ms += draft_elapsed_ms(t0); + } + + info.ctx_len = std::min(target_feat.n_committed, target_feat.cap); + const int64_t ring_start = target_feat.n_committed - info.ctx_len; + + { + const auto t0 = ddtree_draft_clock::now(); + + if (llama_model_token_embd_lookup(target_model, root_token, noise_embd.data(), n_embd) != 0) { + LOG_ERR("%s: token_embd_lookup failed for root_token=%d\n", __func__, (int) root_token); + info.t_draft_decode_ms += draft_elapsed_ms(t0); + return false; + } + for (int32_t i = 0; i < (int32_t) block_size; ++i) { + pos[(size_t) i] = committed_pos + i; + } + + llama_batch draft_batch{}; + draft_batch.n_tokens = (int32_t) block_size; + draft_batch.token = nullptr; + draft_batch.embd = noise_embd.data(); + draft_batch.pos = pos.data(); + draft_batch.n_seq_id = n_seq_id.data(); + draft_batch.seq_id = seq_id_ptrs.data(); + draft_batch.logits = logits.data(); + draft_batch.parent_id = nullptr; + + const int ret = llama_dflash_draft_encode_top_k_cached(draft_ctx, draft_batch, + fused_n_embd, info.ctx_len, + ring_start, target_feat.cap, + committed_pos, info.K); + if (ret != 0) { + LOG_ERR("%s: dflash draft encode-topK failed: %d\n", __func__, ret); + info.t_draft_decode_ms += draft_elapsed_ms(t0); + return false; + } + info.t_draft_decode_ms += draft_elapsed_ms(t0); + } + + top_log_probs.resize((size_t) info.L * info.K); + top_token_ids.resize((size_t) info.L * info.K); + + { + const auto t0 = ddtree_draft_clock::now(); + + float proposal_temp = params.temp; + if (const char * e = std::getenv("LLAMA_DDTREE_PROPOSAL_TEMP")) { + char * end = nullptr; + const float v = std::strtof(e, &end); + if (end != e && v > 0.0f) { + proposal_temp = v; + } + } + const float inv_t = 1.0f / std::max(1e-6f, proposal_temp); + + const float * draft_top_logits = nullptr; + const llama_token * draft_top_tokens = nullptr; + int32_t top_rows = 0; + int32_t top_k = 0; + if (!llama_get_dflash_draft_top_k(draft_ctx, &draft_top_logits, &draft_top_tokens, &top_rows, &top_k) || + draft_top_logits == nullptr || draft_top_tokens == nullptr || top_rows < (int32_t) block_size || + top_k < info.K) { + LOG_ERR("%s: dflash draft top-K unavailable\n", __func__); + info.t_topk_ms += draft_elapsed_ms(t0); + return false; + } + + struct Entry { + float logit; + llama_token token; + }; + std::vector row_top((size_t) info.K); + + for (int i = 0; i < info.L; ++i) { + const int row_idx = i + 1; + for (int k = 0; k < info.K; ++k) { + row_top[(size_t) k] = { + draft_top_logits[(size_t) row_idx * top_k + k], + draft_top_tokens[(size_t) row_idx * top_k + k], + }; + } + std::sort(row_top.begin(), row_top.end(), [](const Entry & a, const Entry & b) { + return a.logit > b.logit; + }); + + if (info.K == 1) { + top_log_probs[(size_t) i] = 0.0f; + top_token_ids[(size_t) i] = row_top[0].token; + continue; + } + + const float row_best = row_top[0].logit * inv_t; + float sum_exp_top = 0.0f; + for (int k = 0; k < info.K; ++k) { + sum_exp_top += std::exp(row_top[(size_t) k].logit * inv_t - row_best); + } + const float log_z_approx = row_best + std::log(sum_exp_top); + for (int k = 0; k < info.K; ++k) { + top_log_probs[(size_t) i * info.K + k] = row_top[(size_t) k].logit * inv_t - log_z_approx; + top_token_ids[(size_t) i * info.K + k] = row_top[(size_t) k].token; + } + } + info.t_topk_ms += draft_elapsed_ms(t0); + } + + return true; + } + + private: + bool ensure_fused_target_feat(const llama_speculative_draft_target_feat_view & target_feat, + int64_t fused_n_embd) { + const int64_t ctx_len = std::min(target_feat.n_committed, target_feat.cap); + const int64_t ring_start = target_feat.n_committed - ctx_len; + + if (fused_target_feat_cap != target_feat.cap || fused_target_feat_n_embd != fused_n_embd) { + fused_target_feat_cap = target_feat.cap; + fused_target_feat_n_embd = fused_n_embd; + fused_target_feat_n_committed = ring_start; + } + + if (fused_target_feat_n_committed < ring_start || fused_target_feat_n_committed > target_feat.n_committed) { + fused_target_feat_n_committed = ring_start; + } + + const int64_t missing = target_feat.n_committed - fused_target_feat_n_committed; + if (missing <= 0) { + return true; + } + + raw_fuse_buf.resize((size_t) target_feat.n_embd_fc * missing); + for (int64_t t = 0; t < missing; ++t) { + const int64_t logical_col = fused_target_feat_n_committed + t; + const int64_t ring_col = logical_col % target_feat.cap; + const float * src = target_feat.ring + ring_col * target_feat.n_embd_fc; + float * dst = raw_fuse_buf.data() + t * target_feat.n_embd_fc; + memcpy(dst, src, (size_t) target_feat.n_embd_fc * sizeof(float)); + } + + const int ret = llama_dflash_draft_update_fused_cache(draft_ctx, raw_fuse_buf.data(), target_feat.n_embd_fc, + missing, fused_target_feat_n_committed, + target_feat.cap); + if (ret != 0) { + LOG_ERR("%s: dflash target_feat cache update failed: %d\n", __func__, ret); + return false; + } + + fused_target_feat_n_committed = target_feat.n_committed; + return true; + } + + llama_context * draft_ctx = nullptr; + const llama_model * target_model = nullptr; + int64_t n_embd = 0; + int64_t n_vocab = 0; + int64_t block_size = 0; + llama_token mask_token_id = 0; + llama_ddtree_params params; + + std::vector target_feat_buf; + std::vector raw_fuse_buf; + int64_t fused_target_feat_n_committed = 0; + int64_t fused_target_feat_n_embd = 0; + int64_t fused_target_feat_cap = 0; + std::vector mask_embd; + std::vector noise_embd; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_values; + std::vector seq_id_ptrs; + std::vector logits; +}; + +std::unique_ptr llama_speculative_draft_backend_init_llama( + llama_context * draft_ctx, + const llama_model * target_model, + int64_t n_embd, + int64_t n_vocab, + int64_t block_size, + llama_token mask_token_id, + const llama_ddtree_params & params) { + auto backend = std::make_unique(draft_ctx, target_model, n_embd, n_vocab, + block_size, mask_token_id, params); + if (!backend->init()) { + return nullptr; + } + return backend; +} diff --git a/common/speculative-draft-backend.h b/common/speculative-draft-backend.h new file mode 100644 index 000000000000..dbc732d0d013 --- /dev/null +++ b/common/speculative-draft-backend.h @@ -0,0 +1,69 @@ +#pragma once + +#include "llama.h" +#include "speculative-tree.h" + +#include +#include +#include + +struct llama_speculative_draft_target_feat_view { + const float * ring = nullptr; + int64_t n_committed = 0; + int64_t cap = 0; + int64_t n_embd_fc = 0; +}; + +struct llama_speculative_draft_decode_info { + int L = 0; + int K = 0; + int64_t ctx_len = 0; + + double t_target_feat_pack_ms = 0.0; + double t_draft_decode_ms = 0.0; + double t_topk_ms = 0.0; +}; + +int llama_speculative_draft_top_k_width(int block_size, const llama_ddtree_params & params); + +bool llama_speculative_draft_pack_target_feat(const llama_speculative_draft_target_feat_view & view, + std::vector & out, + int64_t & ctx_len); + +class llama_speculative_draft_backend { + public: + virtual ~llama_speculative_draft_backend() = default; + + virtual const char * name() const = 0; + + virtual bool ingest_target_capture(llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap, + double & elapsed_ms) { + GGML_UNUSED(target_ctx); + GGML_UNUSED(dfs_indices); + GGML_UNUSED(n_dfs); + GGML_UNUSED(first_pos); + GGML_UNUSED(cap); + elapsed_ms = 0.0; + return false; + } + + virtual bool decode_topk(llama_token root_token, + llama_pos committed_pos, + const llama_speculative_draft_target_feat_view & target_feat, + std::vector & top_log_probs, + std::vector & top_token_ids, + llama_speculative_draft_decode_info & info) = 0; +}; + +std::unique_ptr llama_speculative_draft_backend_init_llama( + llama_context * draft_ctx, + const llama_model * target_model, + int64_t n_embd, + int64_t n_vocab, + int64_t block_size, + llama_token mask_token_id, + const llama_ddtree_params & params); diff --git a/common/speculative-tree-driver.cpp b/common/speculative-tree-driver.cpp new file mode 100644 index 000000000000..2df290a84f28 --- /dev/null +++ b/common/speculative-tree-driver.cpp @@ -0,0 +1,942 @@ +// speculative-tree-driver.cpp — Phase 4 DDTree spec-decode step coordinator. +// +// Mirrors the main loop in test_dflash.cpp:1070-1500 using the llama.cpp public API. +// +// Target_feat layout (from qwen35.cpp capture): +// The hidden capture tensor is [n_embd, 5*n_tokens] (column-major ggml / row-major C). +// Layer k's hidden for all decoded positions occupies columns [k*n_total .. (k+1)*n_total). +// For the draft window, we take the most recent ctx_len positions per layer and pack +// them into [5*n_embd, ctx_len]: +// row l*n_embd .. (l+1)*n_embd - 1 = layer l's hidden across ctx_len positions. +// This matches what dflash-draft.cpp's fc projection expects. +// +// SSM rollback strategy (Phase 2.4): +// If the accepted node is DFS-last, the live recurrent state already points at the accepted +// state and copying from persist buffers only adds risk. Non-DFS-last persist rollback is +// not yet proven correct for DFlash, so the driver falls back to +// snapshot+restore+chain-replay for that case. + +#include "speculative-tree-driver.h" +#include "speculative-draft-backend.h" +#include "speculative-tree.h" +#include "log.h" + +#include "llama.h" + +#include +#include +#include +#include +#include +#include +#include + +// Maximum target-context window that the draft can attend over. +// Matches test_dflash.cpp:1086 DRAFT_CTX_MAX. The server-port default is +// smaller; raise it with LLAMA_DDTREE_TARGET_FEAT_CTX when needed. +static constexpr int DRAFT_CTX_MAX = 2048; +static constexpr int DRAFT_CTX_DEFAULT = 128; + +// EOS token for Qwen3.5 family. +static constexpr llama_token QWEN35_EOS = 248045; + +using ddtree_clock = std::chrono::steady_clock; + +static double elapsed_ms(ddtree_clock::time_point t0) { + return std::chrono::duration(ddtree_clock::now() - t0).count(); +} + +struct llama_speculative_tree_driver { + llama_context * target_ctx = nullptr; + llama_context * draft_ctx = nullptr; + llama_ddtree_params params; + std::unique_ptr draft_backend; + + // n_embd from the target model (for hidden capture slicing). + int64_t n_embd = 0; + // vocabulary size (for logit indexing). + int64_t n_vocab = 0; + // draft block_size (number of noise tokens per step, typically 16). + int64_t block_size = 0; + // mask token id used to fill noise positions in the draft input. + llama_token mask_token_id = 0; + + // Scratch buffer for packed target_feat: [5*n_embd, ctx_len] + std::vector target_feat_buf; + + // Cumulative target_feat sliding ring buffer, [5*n_embd, target_feat_cap] + // Stored column-major: column t = position t, rows = [l*n_embd .. (l+1)*n_embd) for layer l. + // i.e. ring[(logical_col % cap) * target_feat_n_embd_fc + l*n_embd .. +n_embd] = + // layer l at committed pos logical_col. + std::vector target_feat_ring; // size = target_feat_n_embd_fc * target_feat_cap + int64_t target_feat_n_committed = 0; // total committed positions appended to the ring, not capped + int64_t target_feat_n_embd_fc = 0; // = 5 * n_embd + int64_t target_feat_cap = DRAFT_CTX_DEFAULT; // target feature context retained for draft + + // Scratch buffers + std::vector top_log_probs; // [block_size-1, K] + std::vector top_token_ids; // [block_size-1, K] + std::vector noise_embd_buf; // [block_size * n_embd] + std::vector posterior; // [N] argmax per tree node + + llama_speculative_tree_driver_stats stats; + bool fast_rollback_unavailable = false; +}; + +static bool ddtree_fast_batched_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_FAST_BATCHED"); + return e != nullptr && e[0] == '1'; +} + +static bool ddtree_fast_rollback_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_FAST_ROLLBACK"); + return e != nullptr && e[0] == '1'; +} + +static bool ddtree_snapshot_fallback_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_SNAPSHOT_FALLBACK"); + return e == nullptr || e[0] != '0'; +} + +static bool ddtree_capture_direct_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_CAPTURE_DIRECT"); + return e != nullptr && e[0] == '1'; +} + +static int64_t ddtree_target_feat_cap() { + const char * e = std::getenv("LLAMA_DDTREE_TARGET_FEAT_CTX"); + if (!e || e[0] == '\0') { + return DRAFT_CTX_DEFAULT; + } + + char * end = nullptr; + const long v = std::strtol(e, &end, 10); + if (end == e || v <= 0) { + return DRAFT_CTX_DEFAULT; + } + + return std::min(DRAFT_CTX_MAX, std::max(1, (int64_t)v)); +} + +llama_speculative_tree_driver * llama_speculative_tree_driver_init( + llama_context * target_ctx, + llama_context * draft_ctx, + const llama_ddtree_params & params) { + + const llama_model * target_model = llama_get_model(target_ctx); + + if (!target_model || !llama_get_model(draft_ctx)) { + LOG_ERR("%s: null model pointer\n", __func__); + return nullptr; + } + + auto * d = new llama_speculative_tree_driver; + d->target_ctx = target_ctx; + d->draft_ctx = draft_ctx; + d->params = params; + + d->n_embd = llama_model_n_embd(target_model); + // n_vocab: use target vocab; draft shares the same lm_head. + const llama_vocab * target_vocab = llama_model_get_vocab(target_model); + d->n_vocab = (target_vocab != nullptr) ? llama_vocab_n_tokens(target_vocab) : 0; + if (d->n_embd <= 0 || d->n_vocab <= 0) { + LOG_ERR("%s: invalid model dimensions n_embd=%lld n_vocab=%lld\n", + __func__, (long long)d->n_embd, (long long)d->n_vocab); + delete d; + return nullptr; + } + // block_size and mask_token_id are constants baked into the dflash-draft checkpoint. + // Qwen3.5-27B-DFlash always uses block_size=16 and mask_token_id=248070. + d->block_size = (int64_t)params.block_size; // from llama_ddtree_params (default 16) + d->mask_token_id = 248070; // dflash-draft MASK token + + // Initialize cumulative target_feat ring buffer. + d->target_feat_n_embd_fc = 5 * d->n_embd; + d->target_feat_n_committed = 0; + d->target_feat_cap = ddtree_target_feat_cap(); + d->target_feat_ring.assign((size_t)d->target_feat_n_embd_fc * d->target_feat_cap, 0.0f); + + d->draft_backend = llama_speculative_draft_backend_init_llama( + draft_ctx, target_model, d->n_embd, d->n_vocab, d->block_size, d->mask_token_id, d->params); + if (!d->draft_backend) { + LOG_ERR("%s: failed to initialize draft backend\n", __func__); + delete d; + return nullptr; + } + + return d; +} + +void llama_speculative_tree_driver_free(llama_speculative_tree_driver * d) { + delete d; +} + +llama_speculative_tree_driver_stats llama_speculative_tree_driver_get_stats( + const llama_speculative_tree_driver * d) { + return d ? d->stats : llama_speculative_tree_driver_stats{}; +} + +int32_t llama_speculative_tree_driver_context_window() { + return (int32_t)ddtree_target_feat_cap(); +} + +// Pack the hidden capture buffer into [5*n_embd, ctx_len] F32. +// capture : [n_embd, 5*n_total] — ggml column-major, row-major in C means ne[0]=n_embd. +// n_total : total positions in the capture tensor (= ne[1] / 5). +// ctx_len : number of most-recent positions to pack (ctx_len <= n_total). +// out : caller-allocated [5*n_embd * ctx_len] F32. +static void pack_target_feat( + const float * capture, int64_t n_embd, int64_t n_total, int64_t ctx_len, float * out) { + const int64_t start = n_total - ctx_len; // first column to include + for (int64_t l = 0; l < 5; ++l) { + for (int64_t t = 0; t < ctx_len; ++t) { + // Source: capture column (start+t) in layer l's block. + // In C memory (n_embd fastest): capture[(l*n_total + start + t) * n_embd .. +n_embd) + const float * src = capture + (l * n_total + start + t) * n_embd; + // Destination: row l*n_embd in the output, column t. + // Output is [5*n_embd, ctx_len] → in C: out[t * 5*n_embd + l*n_embd .. +n_embd) + float * dst = out + (size_t)t * 5 * n_embd + (size_t)l * n_embd; + memcpy(dst, src, (size_t)n_embd * sizeof(float)); + } + } +} + +// Append hidden capture data from target_ctx into the driver's ring buffer. +// dfs_indices: if non-NULL, selects which capture columns to ingest (the DFS accepted indices). +// if NULL, ingest the first n_dfs columns linearly (prompt prefill path). +// n_dfs: number of columns to ingest. +enum class ingest_source { + prompt, + tree, + replay, +}; + +static int32_t driver_ingest_capture(llama_speculative_tree_driver * d, + const int32_t * dfs_indices, + int32_t n_dfs, + ingest_source source) { + const auto t0 = ddtree_clock::now(); + ggml_tensor * t_capture = llama_get_hidden_capture(d->target_ctx); + int64_t ne0 = t_capture != nullptr ? t_capture->ne[0] : 0; + int64_t ne1 = t_capture != nullptr ? t_capture->ne[1] : 0; + if (t_capture == nullptr || ne0 == 0 || ne1 == 0) { + LOG_ERR("%s: no hidden capture data available\n", __func__); + d->stats.t_ingest_capture_ms += elapsed_ms(t0); + return 0; + } + // capture layout: [n_embd, 5*n_tokens] → ne0=n_embd, ne1=5*n_tokens + const int64_t n_embd = ne0; + const int64_t n_tokens = ne1 / 5; // number of decoded positions in this capture + + if (n_embd != d->n_embd) { + LOG_ERR("%s: capture n_embd=%lld != driver n_embd=%lld\n", + __func__, (long long)n_embd, (long long)d->n_embd); + d->stats.t_ingest_capture_ms += elapsed_ms(t0); + return 0; + } + + // Clamp n_dfs to what the capture actually contains. The server may call + // ingest_prompt_capture(slot.prompt_size) when only the new (uncached) tail + // of the prompt actually went through llama_decode — the capture only holds + // the most recent decode's columns. Out-of-range reads here would be UB. + int32_t n_to_ingest = n_dfs; + if (dfs_indices == nullptr && n_to_ingest > (int32_t)n_tokens) { + d->stats.n_capture_clamps++; + LOG_WRN("%s: requested n_dfs=%d but capture only has n_tokens=%lld; clamping (ring will be incomplete)\n", + __func__, n_dfs, (long long)n_tokens); + n_to_ingest = (int32_t)n_tokens; + } + + if (ddtree_capture_direct_enabled() && d->draft_backend) { + double direct_ms = 0.0; + if (d->draft_backend->ingest_target_capture(d->target_ctx, dfs_indices, n_to_ingest, + d->target_feat_n_committed, d->target_feat_cap, + direct_ms)) { + d->target_feat_n_committed += (int64_t)n_to_ingest; + switch (source) { + case ingest_source::prompt: + d->stats.n_prompt_ingest_calls++; + d->stats.n_prompt_ingested_tokens += n_to_ingest; + d->stats.t_prompt_ingest_ms += direct_ms; + break; + case ingest_source::tree: + d->stats.n_tree_ingested_tokens += n_to_ingest; + d->stats.t_tree_ingest_ms += direct_ms; + break; + case ingest_source::replay: + d->stats.n_replay_ingested_tokens += n_to_ingest; + d->stats.t_replay_ingest_ms += direct_ms; + break; + } + d->stats.t_ingest_capture_ms += direct_ms; + return n_to_ingest; + } + } + + const float * capture = llama_get_hidden_capture_data(d->target_ctx, &ne0, &ne1); + if (!capture || ne0 == 0 || ne1 == 0) { + LOG_ERR("%s: no hidden capture data available after direct-ingest fallback\n", __func__); + d->stats.t_ingest_capture_ms += elapsed_ms(t0); + return 0; + } + + for (int32_t i = 0; i < n_to_ingest; ++i) { + // Source column index in the capture buffer (within each layer's block). + const int64_t src_col = (dfs_indices != nullptr) ? (int64_t)dfs_indices[i] : (int64_t)i; + if (src_col < 0 || src_col >= n_tokens) { + LOG_ERR("%s: src_col=%lld out of capture range [0, %lld)\n", + __func__, (long long)src_col, (long long)n_tokens); + break; + } + + const int64_t logical_col = d->target_feat_n_committed + (int64_t)i; + const int64_t dst_col = logical_col % d->target_feat_cap; + + for (int64_t l = 0; l < 5; ++l) { + // Source: layer l's block starts at column l*n_tokens; pick column src_col within it. + const float * src = capture + (l * n_tokens + src_col) * n_embd; + // Destination: ring column dst_col, row l*n_embd. + float * dst = d->target_feat_ring.data() + dst_col * d->target_feat_n_embd_fc + l * n_embd; + memcpy(dst, src, (size_t)n_embd * sizeof(float)); + } + } + + d->target_feat_n_committed += (int64_t)n_to_ingest; + const double ingest_ms = elapsed_ms(t0); + switch (source) { + case ingest_source::prompt: + d->stats.n_prompt_ingest_calls++; + d->stats.n_prompt_ingested_tokens += n_to_ingest; + d->stats.t_prompt_ingest_ms += ingest_ms; + break; + case ingest_source::tree: + d->stats.n_tree_ingested_tokens += n_to_ingest; + d->stats.t_tree_ingest_ms += ingest_ms; + break; + case ingest_source::replay: + d->stats.n_replay_ingested_tokens += n_to_ingest; + d->stats.t_replay_ingest_ms += ingest_ms; + break; + } + d->stats.t_ingest_capture_ms += ingest_ms; + return n_to_ingest; +} + +static bool replay_committed_chain(llama_speculative_tree_driver * d, + const llama_ddtree & tree, + const int32_t * accepted_dfs, + int32_t commit_n, + llama_pos committed_pos) { + if (commit_n <= 0) { + return true; + } + + llama_memory_t mem = llama_get_memory(d->target_ctx); + if (!llama_memory_seq_rm(mem, /*seq_id=*/0, committed_pos, /*p1=*/-1)) { + LOG_ERR("%s: failed to remove tree KV/recurrent range at pos >= %d\n", + __func__, (int)committed_pos); + return false; + } + + llama_batch replay = llama_batch_init(commit_n, /*embd=*/0, /*n_seq_max=*/1); + replay.n_tokens = commit_n; + for (int32_t i = 0; i < commit_n; ++i) { + replay.token[i] = tree.nodes[accepted_dfs[i]].token_id; + replay.pos[i] = committed_pos + i; + replay.n_seq_id[i] = 1; + replay.seq_id[i][0] = 0; + replay.logits[i] = 0; + } + + const int ret = llama_decode(d->target_ctx, replay); + llama_batch_free(replay); + if (ret != 0) { + LOG_ERR("%s: chain replay llama_decode failed: %d\n", __func__, ret); + return false; + } + + driver_ingest_capture(d, nullptr, commit_n, ingest_source::replay); + return true; +} + +static int32_t find_child_token(const llama_ddtree & tree, int32_t parent, llama_token token) { + for (int32_t i = 1; i < (int32_t) tree.nodes.size(); ++i) { + if (tree.nodes[i].parent_idx == parent && tree.nodes[i].token_id == token) { + return i; + } + } + return -1; +} + +static llama_token pick_current_logits(llama_speculative_tree_driver * d, + const llama_speculative_tree_verify_cbs * verify_cbs) { + if (verify_cbs != nullptr && verify_cbs->sample_cb != nullptr) { + return (llama_token) verify_cbs->sample_cb(verify_cbs->user_data, /*logits_row_idx=*/0); + } + + const float * row = llama_get_logits_ith(d->target_ctx, 0); + if (!row) { + return LLAMA_TOKEN_NULL; + } + int32_t best = 0; + float best_val = row[0]; + for (int64_t v = 1; v < d->n_vocab; ++v) { + if (row[v] > best_val) { + best_val = row[v]; + best = (int32_t) v; + } + } + return (llama_token) best; +} + +static bool validate_tree_with_chain(llama_speculative_tree_driver * d, + const llama_ddtree & tree, + llama_pos committed_pos, + const llama_speculative_tree_verify_cbs * verify_cbs, + std::vector & accepted_dfs, + llama_token & next_token) { + accepted_dfs.clear(); + accepted_dfs.push_back(0); + next_token = LLAMA_TOKEN_NULL; + + llama_memory_t mem = llama_get_memory(d->target_ctx); + if (!llama_memory_seq_rm(mem, /*seq_id=*/0, committed_pos, /*p1=*/-1)) { + LOG_ERR("%s: failed to remove tree KV/recurrent range at pos >= %d\n", + __func__, (int)committed_pos); + return false; + } + + int32_t current = 0; + for (int32_t depth = 0; depth < (int32_t) tree.nodes.size(); ++depth) { + llama_batch b = llama_batch_init(1, /*embd=*/0, /*n_seq_max=*/1); + b.n_tokens = 1; + b.token[0] = tree.nodes[current].token_id; + b.pos[0] = committed_pos + depth; + b.n_seq_id[0] = 1; + b.seq_id[0][0] = 0; + b.logits[0] = 1; + + const auto t_decode0 = ddtree_clock::now(); + const int ret = llama_decode(d->target_ctx, b); + d->stats.t_exact_decode_ms += elapsed_ms(t_decode0); + d->stats.n_exact_validate_nodes++; + llama_batch_free(b); + if (ret != 0) { + LOG_ERR("%s: chain validation llama_decode failed at depth %d: %d\n", + __func__, (int) depth, ret); + return false; + } + + driver_ingest_capture(d, nullptr, 1, ingest_source::replay); + + const auto t_sample0 = ddtree_clock::now(); + const llama_token picked = pick_current_logits(d, verify_cbs); + d->stats.t_exact_sample_ms += elapsed_ms(t_sample0); + if (picked == LLAMA_TOKEN_NULL) { + LOG_ERR("%s: failed to pick from chain validation logits\n", __func__); + return false; + } + + const int32_t child = find_child_token(tree, current, picked); + if (child < 0) { + next_token = picked; + return true; + } + + if (verify_cbs != nullptr && verify_cbs->advance_cb != nullptr) { + const auto t_advance0 = ddtree_clock::now(); + verify_cbs->advance_cb(verify_cbs->user_data, picked); + d->stats.t_exact_advance_ms += elapsed_ms(t_advance0); + } + + accepted_dfs.push_back(child); + current = child; + } + + const auto t_sample0 = ddtree_clock::now(); + const llama_token picked = pick_current_logits(d, verify_cbs); + d->stats.t_exact_sample_ms += elapsed_ms(t_sample0); + if (picked == LLAMA_TOKEN_NULL) { + LOG_ERR("%s: failed to pick final chain validation token\n", __func__); + return false; + } + next_token = picked; + return true; +} + +static llama_token diagnose_chain_root_argmax(llama_speculative_tree_driver * d, + llama_token root_token, + llama_pos committed_pos) { + llama_mem_snapshot_id snap = llama_seq_snapshot(d->target_ctx, /*seq_id=*/0); + if (snap == LLAMA_MEM_SNAPSHOT_INVALID) { + return LLAMA_TOKEN_NULL; + } + + llama_batch b = llama_batch_init(1, /*embd=*/0, /*n_seq_max=*/1); + b.n_tokens = 1; + b.token[0] = root_token; + b.pos[0] = committed_pos; + b.n_seq_id[0] = 1; + b.seq_id[0][0] = 0; + b.logits[0] = 1; + + llama_token best = LLAMA_TOKEN_NULL; + if (llama_decode(d->target_ctx, b) == 0) { + const float * row = llama_get_logits_ith(d->target_ctx, 0); + if (row) { + best = 0; + float best_val = row[0]; + for (int64_t v = 1; v < d->n_vocab; ++v) { + if (row[v] > best_val) { best_val = row[v]; best = (llama_token)v; } + } + } + } + llama_batch_free(b); + + llama_memory_t mem = llama_get_memory(d->target_ctx); + llama_memory_seq_rm(mem, /*seq_id=*/0, committed_pos, /*p1=*/-1); + llama_seq_restore(d->target_ctx, snap); + llama_seq_release(d->target_ctx, snap); + return best; +} + +void llama_speculative_tree_driver_ingest_prompt_capture( + llama_speculative_tree_driver * d, + int32_t n_prompt_tokens) { + // Prompt prefill capture is laid out linearly; ingest columns 0..n_prompt_tokens-1. + driver_ingest_capture(d, nullptr, n_prompt_tokens, ingest_source::prompt); +} + +std::vector llama_speculative_tree_driver_step( + llama_speculative_tree_driver * d, + llama_token root_token, + llama_pos committed_pos, + const llama_speculative_tree_verify_cbs * verify_cbs) { + + if (!d) { + return {}; + } + const auto t_step0 = ddtree_clock::now(); + + const int64_t n_vocab = d->n_vocab; + + llama_speculative_draft_decode_info draft_info; + llama_speculative_draft_target_feat_view target_feat_view { + d->target_feat_ring.data(), + d->target_feat_n_committed, + d->target_feat_cap, + d->target_feat_n_embd_fc, + }; + if (!d->draft_backend->decode_topk( + root_token, committed_pos, target_feat_view, + d->top_log_probs, d->top_token_ids, draft_info)) { + return {}; + } + d->stats.t_target_feat_pack_ms += draft_info.t_target_feat_pack_ms; + d->stats.t_draft_decode_ms += draft_info.t_draft_decode_ms; + d->stats.t_topk_ms += draft_info.t_topk_ms; + + const int L = draft_info.L; + const int K = draft_info.K; + const int64_t ctx_len = draft_info.ctx_len; + + if (std::getenv("LLAMA_DDTREE_DUMP_DRAFT_TOP") != nullptr && d->stats.n_steps == 0) { + LOG_INF("draft_top port: step=%lld committed=%d ctx_len=%lld root=%d K=%d backend=%s\n", + (long long)d->stats.n_steps, + (int)committed_pos, + (long long)ctx_len, + (int)root_token, + K, + d->draft_backend->name()); + LOG_INF("draft_top port: top1:"); + for (int i = 0; i < L; ++i) { + LOG_INF(" %d", (int)d->top_token_ids[(size_t)i * K]); + } + LOG_INF("\n"); + if (K > 1) { + const int rows = std::min(4, L); + for (int r = 0; r < rows; ++r) { + LOG_INF("draft_top port: row%d:", r + 1); + for (int k = 0; k < K; ++k) { + LOG_INF(" %d", (int)d->top_token_ids[(size_t)r * K + k]); + } + LOG_INF("\n"); + } + } + } + + // ── Step 5: build DDTree ────────────────────────────────────────────────── + llama_ddtree tree; + { + const auto t0 = ddtree_clock::now(); + tree = build_ddtree( + d->top_log_probs.data(), d->top_token_ids.data(), + L, K, root_token, d->params); + d->stats.t_build_tree_ms += elapsed_ms(t0); + } + + const int N = (int)tree.nodes.size(); // includes root node at index 0 + d->stats.n_steps++; + d->stats.n_tree_verifies++; + d->stats.n_tree_nodes_total += N; + d->stats.max_tree_nodes = std::max(d->stats.max_tree_nodes, N); + + const bool fast_batched = ddtree_fast_batched_enabled(); + const bool trace_batched = std::getenv("LLAMA_DDTREE_TRACE") != nullptr || + std::getenv("LLAMA_DDTREE_TRACE_CHAIN_ROOT") != nullptr; + const bool need_batched_tree = fast_batched || trace_batched; + + if (!need_batched_tree) { + std::vector accepted_dfs; + llama_token next_token = LLAMA_TOKEN_NULL; + { + const auto t0 = ddtree_clock::now(); + if (!validate_tree_with_chain(d, tree, committed_pos, verify_cbs, accepted_dfs, next_token)) { + return {}; + } + d->stats.t_exact_validate_ms += elapsed_ms(t0); + } + + const int commit_n = (int)accepted_dfs.size(); + d->stats.n_committed_tokens += commit_n; + d->stats.max_committed_tokens_per_step = + std::max(d->stats.max_committed_tokens_per_step, commit_n); + + std::vector result; + result.reserve(commit_n + 1); + for (int i = 0; i < commit_n; ++i) { + result.push_back(tree.nodes[accepted_dfs[i]].token_id); + } + result.push_back(next_token); + + d->stats.t_step_ms += elapsed_ms(t_step0); + return result; + } + + const bool fast_rollback = fast_batched && ddtree_fast_rollback_enabled() && !d->fast_rollback_unavailable; + const bool keep_snapshot = !fast_batched || !fast_rollback || ddtree_snapshot_fallback_enabled(); + + // ── Step 6: snapshot before target verify ──────────────────────────────── + // By default fast-rollback mode still keeps a snapshot as a safety net. + // Set LLAMA_DDTREE_SNAPSHOT_FALLBACK=0 to remove this per-step host bounce + // after validating that persist allocation succeeds in the target runtime. + llama_mem_snapshot_id snap = LLAMA_MEM_SNAPSHOT_INVALID; + auto release_snap = [&]() { + if (snap != LLAMA_MEM_SNAPSHOT_INVALID) { + llama_seq_release(d->target_ctx, snap); + snap = LLAMA_MEM_SNAPSHOT_INVALID; + } + }; + if (N > 1) { + const auto t0 = ddtree_clock::now(); + llama_memory_t mem = llama_get_memory(d->target_ctx); + if (!llama_memory_seq_rm(mem, /*seq_id=*/0, committed_pos, /*p1=*/-1)) { + LOG_ERR("%s: failed to clear target future range before tree verify at pos %d\n", + __func__, (int)committed_pos); + return {}; + } + + if (keep_snapshot) { + snap = llama_seq_snapshot(d->target_ctx, /*seq_id=*/0); + if (snap == LLAMA_MEM_SNAPSHOT_INVALID) { + LOG_ERR("%s: llama_seq_snapshot failed before tree verify\n", __func__); + return {}; + } + } + d->stats.t_snapshot_ms += elapsed_ms(t0); + } + + llama_token diag_chain_root = LLAMA_TOKEN_NULL; + if (std::getenv("LLAMA_DDTREE_TRACE_CHAIN_ROOT") != nullptr) { + diag_chain_root = diagnose_chain_root_argmax(d, root_token, committed_pos); + } + + // ── Step 7: target verify (tree-mode forward) ───────────────────────────── + // Build a tree batch of N tokens and run target decode. + { + const auto t0 = ddtree_clock::now(); + llama_batch tree_batch = llama_batch_init_tree(N, 0, 1); + tree_batch.n_tokens = N; + for (int i = 0; i < N; ++i) { + tree_batch.token[i] = tree.nodes[i].token_id; + tree_batch.pos[i] = committed_pos + tree.nodes[i].depth; + tree_batch.n_seq_id[i] = 1; + tree_batch.seq_id[i][0] = 0; + tree_batch.logits[i] = 1; // output logits for all nodes + tree_batch.parent_id[i] = tree.nodes[i].parent_idx; // -1 for root + } + int ret = llama_decode(d->target_ctx, tree_batch); + llama_batch_free(tree_batch); + if (ret != 0) { + LOG_ERR("%s: target tree llama_decode failed: %d\n", __func__, ret); + release_snap(); + return {}; + } + d->stats.t_target_tree_decode_ms += elapsed_ms(t0); + } + + // ── Step 8: pick verify chain ───────────────────────────────────────────── + // Keep the batched tree posterior for diagnostics, but do not trust it for + // final acceptance. Quantized batched tree logits can drift from one-token + // AR logits enough to flip argmax on close rows. + d->posterior.resize(N); + std::vector posterior_margins; + posterior_margins.resize(N); + { + const auto t0 = ddtree_clock::now(); + for (int i = 0; i < N; ++i) { + const float * row = llama_get_logits_ith(d->target_ctx, i); + if (!row) { + LOG_ERR("%s: target logits[%d] unavailable\n", __func__, i); + release_snap(); + return {}; + } + int32_t best = 0; + float best_val = row[0]; + int32_t second = 0; + float second_val = row[0]; + if (n_vocab > 1) { + second = 1; + second_val = row[1]; + if (second_val > best_val) { + std::swap(best, second); + std::swap(best_val, second_val); + } + } + for (int64_t v = 2; v < n_vocab; ++v) { + const float val = row[v]; + if (val > best_val) { + second = best; + second_val = best_val; + best = (int32_t)v; + best_val = val; + } else if (val > second_val) { + second = (int32_t)v; + second_val = val; + } + } + d->posterior[i] = best; + posterior_margins[i] = best_val - second_val; + } + d->stats.t_posterior_scan_ms += elapsed_ms(t0); + } + + std::vector batched_accepted_dfs; + llama_token batched_next_token = LLAMA_TOKEN_NULL; + { + const auto t0 = ddtree_clock::now(); + follow_verified_tree(tree, d->posterior.data(), batched_accepted_dfs, batched_next_token); + d->stats.t_accept_path_ms += elapsed_ms(t0); + } + + const int batched_commit_n = (int)batched_accepted_dfs.size(); + d->stats.n_batched_posterior_committed_tokens += batched_commit_n; + d->stats.max_batched_posterior_committed_tokens_per_step = + std::max(d->stats.max_batched_posterior_committed_tokens_per_step, batched_commit_n); + + std::vector accepted_dfs; + llama_token next_token = LLAMA_TOKEN_NULL; + + if (fast_batched) { + if (verify_cbs != nullptr && verify_cbs->sample_cb != nullptr) { + const auto t0 = ddtree_clock::now(); + follow_verified_tree_cb( + tree, + verify_cbs->sample_cb, + verify_cbs->advance_cb, + verify_cbs->user_data, + accepted_dfs, + next_token); + d->stats.t_accept_path_ms += elapsed_ms(t0); + d->stats.n_fast_batched_callback_steps++; + } else { + accepted_dfs = batched_accepted_dfs; + next_token = batched_next_token; + } + + const int accept_depth = (int)accepted_dfs.size(); // includes root node (index 0) + const int commit_n = accept_depth; + + bool did_commit_state = false; + if (fast_rollback && N > 1) { + { + const auto t0 = ddtree_clock::now(); + llama_kv_cache_seq_compact_tree( + d->target_ctx, + /*seq_id=*/0, + accepted_dfs.data(), + (int32_t)accepted_dfs.size(), + commit_n, + (int32_t)committed_pos); + d->stats.t_kv_compact_ms += elapsed_ms(t0); + } + + const int32_t rollback_node = commit_n > 0 ? accepted_dfs[commit_n - 1] : 0; + bool rollback_ok = false; + { + const auto t0 = ddtree_clock::now(); + rollback_ok = llama_dflash_rollback_ssm_to_dfs(d->target_ctx, /*seq_id=*/0, rollback_node); + d->stats.t_ssm_rollback_ms += elapsed_ms(t0); + } + if (!rollback_ok) { + LOG_WRN("%s: fast rollback failed at DFS node %d; falling back to snapshot replay\n", + __func__, (int)rollback_node); + d->fast_rollback_unavailable = true; + } else { + const llama_pos recurrent_tail_pos = committed_pos + commit_n - 1; + bool tail_ok = false; + { + const auto t0 = ddtree_clock::now(); + tail_ok = llama_dflash_set_recurrent_tail_pos(d->target_ctx, /*seq_id=*/0, recurrent_tail_pos); + d->stats.t_ssm_rollback_ms += elapsed_ms(t0); + } + if (!tail_ok) { + LOG_WRN("%s: failed to set recurrent tail pos to %d after fast rollback; falling back to snapshot replay\n", + __func__, (int)recurrent_tail_pos); + d->fast_rollback_unavailable = true; + } else { + driver_ingest_capture(d, accepted_dfs.data(), commit_n, ingest_source::tree); + d->stats.n_fast_rollback_steps++; + did_commit_state = true; + } + } + } + + if (N > 1 && !did_commit_state) { + if (snap == LLAMA_MEM_SNAPSHOT_INVALID || !llama_seq_restore(d->target_ctx, snap)) { + LOG_ERR("%s: fast rollback failed and snapshot fallback is unavailable\n", __func__); + release_snap(); + return {}; + } + d->stats.n_snapshot_replays++; + d->stats.n_fast_batched_replays++; + { + const auto t0 = ddtree_clock::now(); + if (!replay_committed_chain(d, tree, accepted_dfs.data(), commit_n, committed_pos)) { + release_snap(); + return {}; + } + d->stats.t_replay_ms += elapsed_ms(t0); + } + } else { + // Root-only verify is already a normal one-token forward. Keep its + // live target state and ingest its hidden capture for the next draft. + driver_ingest_capture(d, nullptr, commit_n, ingest_source::replay); + } + + d->stats.n_committed_tokens += commit_n; + d->stats.max_committed_tokens_per_step = + std::max(d->stats.max_committed_tokens_per_step, commit_n); + + release_snap(); + + std::vector result; + result.reserve(commit_n + 1); + for (int i = 0; i < commit_n; ++i) { + result.push_back(tree.nodes[accepted_dfs[i]].token_id); + } + result.push_back(next_token); + + d->stats.t_step_ms += elapsed_ms(t_step0); + return result; + } + + if (snap == LLAMA_MEM_SNAPSHOT_INVALID || !llama_seq_restore(d->target_ctx, snap)) { + LOG_ERR("%s: llama_seq_restore failed before exact chain validation\n", __func__); + release_snap(); + return {}; + } + d->stats.n_snapshot_replays++; + { + const auto t0 = ddtree_clock::now(); + if (!validate_tree_with_chain(d, tree, committed_pos, verify_cbs, accepted_dfs, next_token)) { + release_snap(); + return {}; + } + d->stats.t_exact_validate_ms += elapsed_ms(t0); + } + + const int accept_depth = (int)accepted_dfs.size(); // includes root node (index 0) + const int commit_n = accept_depth; // root is always committed + if (batched_accepted_dfs == accepted_dfs && batched_next_token == next_token) { + d->stats.n_batched_exact_same++; + } else { + d->stats.n_batched_exact_diff++; + } + if (batched_commit_n > commit_n) { + d->stats.n_batched_exact_longer++; + } else if (batched_commit_n < commit_n) { + d->stats.n_batched_exact_shorter++; + } + d->stats.n_committed_tokens += commit_n; + d->stats.max_committed_tokens_per_step = + std::max(d->stats.max_committed_tokens_per_step, commit_n); + + if (std::getenv("LLAMA_DDTREE_TRACE") != nullptr) { + const int32_t rollback_node = (commit_n > 0) ? accepted_dfs[commit_n - 1] : 0; + const int32_t posterior0 = d->posterior.empty() ? -1 : d->posterior[0]; + float exact_min_margin = 0.0f; + for (int i = 0; i < (int)accepted_dfs.size(); ++i) { + const int32_t idx = accepted_dfs[i]; + const float margin = (idx >= 0 && idx < (int)posterior_margins.size()) ? posterior_margins[idx] : 0.0f; + exact_min_margin = (i == 0) ? margin : std::min(exact_min_margin, margin); + } + float batched_min_margin = 0.0f; + for (int i = 0; i < (int)batched_accepted_dfs.size(); ++i) { + const int32_t idx = batched_accepted_dfs[i]; + const float margin = (idx >= 0 && idx < (int)posterior_margins.size()) ? posterior_margins[idx] : 0.0f; + batched_min_margin = (i == 0) ? margin : std::min(batched_min_margin, margin); + } + LOG_INF("ddtree_trace: step=%lld pos=%d root=%d N=%d budget=%d posterior0=%d next=%d commit_n=%d batched_commit_n=%d batched_next=%d exact_min_margin=%.6g batched_min_margin=%.6g rollback_node=%d\n", + (long long)d->stats.n_steps, + (int)committed_pos, + (int)root_token, + N, + d->params.budget, + (int)posterior0, + (int)next_token, + commit_n, + batched_commit_n, + (int)batched_next_token, + (double)exact_min_margin, + (double)batched_min_margin, + (int)rollback_node); + if (diag_chain_root != LLAMA_TOKEN_NULL) { + LOG_INF("ddtree_trace: chain_pre_argmax=%d tree_root_argmax=%d\n", + (int)diag_chain_root, (int)posterior0); + } + LOG_INF("ddtree_trace: accepted="); + for (int i = 0; i < (int)accepted_dfs.size(); ++i) { + LOG_INF("%s%d", i == 0 ? "" : ",", (int)accepted_dfs[i]); + } + LOG_INF("\n"); + for (int i = 0; i < N; ++i) { + const int32_t post = (i < (int)d->posterior.size()) ? d->posterior[i] : -1; + LOG_INF("ddtree_trace: node=%d parent=%d depth=%d tok=%d posterior=%d\n", + i, + (int)tree.nodes[i].parent_idx, + (int)tree.nodes[i].depth, + (int)tree.nodes[i].token_id, + (int)post); + } + } + + // Step 9 is handled by validate_tree_with_chain(): after restoring the + // snapshot, it decodes the exact accepted path one token at a time and + // ingests the corresponding hidden captures into the draft feature ring. + release_snap(); + + // ── Step 10: assemble output ────────────────────────────────────────────── + // accepted[0] = root_token (always, the input token echoed back). + // accepted[1..accept_depth-1] = newly accepted draft tokens. + // accepted[accept_depth] = bonus token (next_token). + std::vector result; + result.reserve(commit_n + 1); + for (int i = 0; i < commit_n; ++i) { + result.push_back(tree.nodes[accepted_dfs[i]].token_id); + } + result.push_back(next_token); + + d->stats.t_step_ms += elapsed_ms(t_step0); + return result; +} diff --git a/common/speculative-tree-driver.h b/common/speculative-tree-driver.h new file mode 100644 index 000000000000..676d2a03b0e0 --- /dev/null +++ b/common/speculative-tree-driver.h @@ -0,0 +1,124 @@ +#pragma once + +// speculative-tree-driver.h — Phase 4 DDTree spec-decode coordinator. +// +// Binds a target context (Qwen3.5-27B with capture_hidden) and a draft context +// (LLM_ARCH_DFLASH_DRAFT) and implements one spec-decode step per call. +// +// Lifecycle: +// driver = llama_speculative_tree_driver_init(target_ctx, draft_ctx, params) +// while (not done): +// accepted = llama_speculative_tree_driver_step(driver, root_token, committed_pos) +// ... append accepted to output ... +// root_token = accepted.back(); committed_pos += accepted.size() - 1 +// llama_speculative_tree_driver_free(driver) + +#include "llama.h" +#include "speculative-tree.h" + +#include +#include + +struct llama_speculative_tree_driver; + +struct llama_speculative_tree_driver_stats { + int64_t n_steps = 0; + int64_t n_tree_verifies = 0; + int64_t n_tree_nodes_total = 0; + int32_t max_tree_nodes = 0; + int64_t n_dfs_last_commits = 0; + int64_t n_snapshot_replays = 0; + int64_t n_committed_tokens = 0; + int32_t max_committed_tokens_per_step = 0; + int64_t n_batched_posterior_committed_tokens = 0; + int32_t max_batched_posterior_committed_tokens_per_step = 0; + int64_t n_batched_exact_same = 0; + int64_t n_batched_exact_diff = 0; + int64_t n_batched_exact_longer = 0; + int64_t n_batched_exact_shorter = 0; + int64_t n_fast_batched_replays = 0; + int64_t n_fast_batched_callback_steps = 0; + int64_t n_fast_rollback_steps = 0; + int64_t n_prompt_ingest_calls = 0; + int64_t n_prompt_ingested_tokens = 0; + int64_t n_tree_ingested_tokens = 0; + int64_t n_replay_ingested_tokens = 0; + int64_t n_capture_clamps = 0; + int64_t n_exact_validate_nodes = 0; + + double t_step_ms = 0.0; + double t_target_feat_pack_ms = 0.0; + double t_draft_decode_ms = 0.0; + double t_topk_ms = 0.0; + double t_build_tree_ms = 0.0; + double t_snapshot_ms = 0.0; + double t_target_tree_decode_ms = 0.0; + double t_posterior_scan_ms = 0.0; + double t_accept_path_ms = 0.0; + double t_kv_compact_ms = 0.0; + double t_ssm_rollback_ms = 0.0; + double t_ingest_capture_ms = 0.0; + double t_prompt_ingest_ms = 0.0; + double t_tree_ingest_ms = 0.0; + double t_replay_ingest_ms = 0.0; + double t_replay_ms = 0.0; + double t_exact_validate_ms = 0.0; + double t_exact_decode_ms = 0.0; + double t_exact_sample_ms = 0.0; + double t_exact_advance_ms = 0.0; +}; + +// Allocate a driver. target_ctx must have capture_hidden enabled before any +// llama_decode() calls that prime the context. draft_ctx must use the +// LLM_ARCH_DFLASH_DRAFT architecture. +llama_speculative_tree_driver * llama_speculative_tree_driver_init( + llama_context * target_ctx, + llama_context * draft_ctx, + const llama_ddtree_params & params); + +void llama_speculative_tree_driver_free(llama_speculative_tree_driver * d); + +// Run one spec-decode step. +// +// root_token — the last committed token (bonus token from the previous step, +// or the last prompt token on the very first call). +// committed_pos — number of KV positions committed in the target context so far +// (i.e. seq_pos_max + 1 for the next token to be placed). +// +// Returns accepted tokens in chronological order (length >= 1): +// accepted[0] = root_token (the input, echoed for convenience) +// accepted[1..] = newly accepted draft tokens +// accepted.back() = bonus token from the target (the next root for the next step) +// +// The KV cache of target_ctx is compacted to hold only the accepted path after +// each step. The SSM/conv state is snapshot-before and restore-on-mismatch. +// +// Optional verify callbacks. If non-null, the driver picks the next token at +// each verify-chain step via sample_cb instead of internal argmax. advance_cb +// is invoked whenever the chain accepts a child, so callers can advance their +// sampler/grammar state to mirror the chain. +struct llama_speculative_tree_verify_cbs { + llama_speculative_pick_cb sample_cb = nullptr; + llama_speculative_advance_cb advance_cb = nullptr; + void * user_data = nullptr; +}; + +// Returns an empty vector on internal failure. +std::vector llama_speculative_tree_driver_step( + llama_speculative_tree_driver * d, + llama_token root_token, + llama_pos committed_pos, + const llama_speculative_tree_verify_cbs * verify_cbs = nullptr); + +// Ingest the most recent target_ctx capture as the initial ring contents. +// Call this AFTER the chain-mode prompt prefill that primed target capture, +// BEFORE the first spec step. +// n_prompt_tokens: number of tokens in the prompt that were decoded in the prefill batch. +void llama_speculative_tree_driver_ingest_prompt_capture( + llama_speculative_tree_driver * d, + int32_t n_prompt_tokens); + +int32_t llama_speculative_tree_driver_context_window(); + +llama_speculative_tree_driver_stats llama_speculative_tree_driver_get_stats( + const llama_speculative_tree_driver * d); diff --git a/common/speculative-tree.cpp b/common/speculative-tree.cpp new file mode 100644 index 000000000000..57786421ca36 --- /dev/null +++ b/common/speculative-tree.cpp @@ -0,0 +1,291 @@ +#include "speculative-tree.h" + +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// build_ddtree +// --------------------------------------------------------------------------- + +llama_ddtree build_ddtree( + const float * top_log_probs, + const int32_t * top_token_ids, + int L, + int K, + llama_token root_token, + const llama_ddtree_params & p) { + + llama_ddtree tree; + + // Node 0 is always the root (last committed token). + tree.nodes.push_back({root_token, /*parent_idx*/ -1, /*depth*/ 0}); + + // budget = total node cap including root. budget=1 means root-only. + const int budget = (p.budget < 1) ? 1 : p.budget; + + if (budget <= 1 || L <= 0) { + tree.visibility.assign(1, 1); + return tree; + } + + // child_maps[flat_index] maps token_id → child flat index. + std::vector> child_maps; + child_maps.emplace_back(); // root's children (index 0) + + // Heap entry: a candidate node waiting to be inserted. + struct HeapEntry { + float neg_logw; // stored as negative so max-heap by logw + int parent_index; // flat index of already-inserted parent + int depth; // absolute depth (1..L) + int rank; // rank within top_token_ids row (depth-1) + float logw; // cumulative path log-prob from root to candidate + }; + struct HeapCmp { + bool operator()(const HeapEntry & a, const HeapEntry & b) const { + return a.neg_logw > b.neg_logw; // pop smallest neg_logw = highest logw + } + }; + std::priority_queue, HeapCmp> heap; + + if (p.chain_seed) { + // Pre-insert the top-1 greedy chain to depth min(L, budget-1). + // Guarantees the tree always contains at least the greedy chain path. + const int chain_depth = std::min(L, budget - 1); + float cum_logw = 0.0f; + int prev_idx = 0; + + for (int d = 1; d <= chain_depth; d++) { + const int32_t tok_id = top_token_ids[(size_t)(d - 1) * K + 0]; + cum_logw += top_log_probs[(size_t)(d - 1) * K + 0]; + + const int cur_idx = (int)tree.nodes.size(); + tree.nodes.push_back({tok_id, prev_idx, d}); + child_maps.emplace_back(); + child_maps[prev_idx][tok_id] = cur_idx; + + // Queue rank-1 sibling so best-first can branch off the chain. + if (K > 1) { + const float sib_logw = cum_logw + - top_log_probs[(size_t)(d - 1) * K + 0] + + top_log_probs[(size_t)(d - 1) * K + 1]; + heap.push({-sib_logw, prev_idx, d, 1, sib_logw}); + } + + prev_idx = cur_idx; + } + } else { + // Pure best-first: seed with depth-1 top-1 candidate only. + const float logw0 = top_log_probs[0]; + heap.push({-logw0, 0, 1, 0, logw0}); + } + + // Expand candidates in log-prob order until budget is reached. + while (!heap.empty() && (int)tree.nodes.size() < budget) { + const HeapEntry top = heap.top(); + heap.pop(); + + const int dm1 = top.depth - 1; + const int rank = top.rank; + const int32_t tok_id = top_token_ids[(size_t)dm1 * K + rank]; + + // Skip duplicates (chain_seed may have already inserted this token + // under the same parent). + if (child_maps[top.parent_index].count(tok_id)) { + continue; + } + + const int cur_idx = (int)tree.nodes.size(); + tree.nodes.push_back({tok_id, top.parent_index, top.depth}); + child_maps.emplace_back(); + child_maps[top.parent_index][tok_id] = cur_idx; + + // Next sibling (same depth, rank+1). + if (rank + 1 < K) { + const float sib_logw = top.logw + - top_log_probs[(size_t)dm1 * K + rank] + + top_log_probs[(size_t)dm1 * K + rank + 1]; + heap.push({-sib_logw, top.parent_index, top.depth, rank + 1, sib_logw}); + } + + // First child (depth+1, top-1 under this node). + if (top.depth < L) { + const float child_logw = top.logw + + top_log_probs[(size_t)top.depth * K + 0]; + heap.push({-child_logw, cur_idx, top.depth + 1, 0, child_logw}); + } + } + + // Build ancestor-only visibility mask for attention masking. + const int N = (int)tree.nodes.size(); + tree.visibility.assign((size_t)N * N, 0); + build_tree_visibility(tree.nodes, tree.visibility.data()); + + return tree; +} + +// --------------------------------------------------------------------------- +// follow_verified_tree +// --------------------------------------------------------------------------- + +void follow_verified_tree( + const llama_ddtree & tree, + const int32_t * posterior, + std::vector & accepted, + llama_token & next_token) { + + const int N = (int)tree.nodes.size(); + + // Build per-node child maps from parent_idx links. + std::vector> child_maps(N); + for (int i = 1; i < N; i++) { + const int p = tree.nodes[i].parent_idx; + child_maps[p][tree.nodes[i].token_id] = i; + } + + accepted.clear(); + accepted.reserve(N); + accepted.push_back(0); // root is always accepted + + int current = 0; + while (true) { + // posterior[current] is the target model's argmax at this tree position. + const auto it = child_maps[current].find((llama_token)posterior[current]); + if (it == child_maps[current].end()) { + break; + } + current = it->second; + accepted.push_back(current); + } + + // Bonus token: the target's argmax at the deepest accepted node. + next_token = (llama_token)posterior[current]; +} + +// follow_verified_tree_cb: same chain-walk semantics as follow_verified_tree +// but the picked token at each step comes from sample_cb (caller-side +// grammar/sampler), and chain advances notify the caller via advance_cb. +void follow_verified_tree_cb( + const llama_ddtree & tree, + llama_speculative_pick_cb sample_cb, + llama_speculative_advance_cb advance_cb, + void * user_data, + std::vector & accepted, + llama_token & next_token) { + const int N = (int)tree.nodes.size(); + + // Build per-node child maps from parent_idx links. + std::vector> child_maps(N); + for (int i = 1; i < N; i++) { + const int p = tree.nodes[i].parent_idx; + child_maps[p][tree.nodes[i].token_id] = i; + } + + accepted.clear(); + accepted.reserve(N); + accepted.push_back(0); // root is always accepted + + int current = 0; + while (true) { + const int32_t picked = sample_cb(user_data, current); + const auto it = child_maps[current].find(picked); + if (it == child_maps[current].end()) { + next_token = (llama_token)picked; + break; + } + if (advance_cb != nullptr) { + advance_cb(user_data, (llama_token)tree.nodes[it->second].token_id); + } + current = it->second; + accepted.push_back(current); + } +} + +// --------------------------------------------------------------------------- +// build_tree_visibility +// --------------------------------------------------------------------------- + +void build_tree_visibility( + const std::vector & nodes, + uint8_t * dst) { + + const int N = (int)nodes.size(); + + // Root only sees itself. + dst[0 * N + 0] = 1; + + for (int i = 1; i < N; i++) { + const int p = nodes[i].parent_idx; + // DFS order guarantees p < i, so row p is already complete. + // Inherit the parent's visibility row, then mark self. + for (int j = 0; j < i; j++) { + dst[(size_t)i * N + j] = dst[(size_t)p * N + j]; + } + dst[(size_t)i * N + i] = 1; + } +} + +// --------------------------------------------------------------------------- +// extract_top_k_logprobs +// --------------------------------------------------------------------------- + +void extract_top_k_logprobs( + const float * logits, + int L, + int V, + int K, + float temp, + float * out_log_probs, + int32_t * out_token_ids) { + + const float inv_t = 1.0f / std::max(1e-6f, temp); + + struct Entry { + float logit; // temperature-scaled + int32_t id; + }; + // Min-heap: smallest scaled logit at top, evicted when a larger one arrives. + auto cmp_min = [](const Entry & a, const Entry & b) { + return a.logit > b.logit; + }; + + for (int i = 0; i < L; i++) { + const float * row = logits + (size_t)i * V; + + std::vector heap; + heap.reserve(K + 1); + + // Single pass: top-K min-heap. Approximate row normalization from + // the retained top-K only to avoid a full-vocab exp/logsumexp pass. + for (int j = 0; j < V; j++) { + const float l = row[j] * inv_t; + + // Maintain top-K min-heap. + if ((int)heap.size() < K) { + heap.push_back({l, (int32_t)j}); + std::push_heap(heap.begin(), heap.end(), cmp_min); + } else if (l > heap.front().logit) { + std::pop_heap(heap.begin(), heap.end(), cmp_min); + heap.back() = {l, (int32_t)j}; + std::push_heap(heap.begin(), heap.end(), cmp_min); + } + } + + // sort_heap with a greater-than comparator (cmp_min) produces descending + // order — same as std::sort with std::greater — so no reversal needed. + std::sort_heap(heap.begin(), heap.end(), cmp_min); + + const float row_best = heap.empty() ? 0.0f : heap[0].logit; + float sum_exp_top = 0.0f; + for (int k = 0; k < K; ++k) { + sum_exp_top += std::exp(heap[k].logit - row_best); + } + const float log_z_approx = row_best + std::log(sum_exp_top); + for (int k = 0; k < K; k++) { + out_log_probs[(size_t)i * K + k] = heap[k].logit - log_z_approx; + out_token_ids[(size_t)i * K + k] = heap[k].id; + } + } +} diff --git a/common/speculative-tree.h b/common/speculative-tree.h new file mode 100644 index 000000000000..66c8e9302fc0 --- /dev/null +++ b/common/speculative-tree.h @@ -0,0 +1,110 @@ +#pragma once + +#include "llama.h" + +#include +#include + +struct llama_ddtree_params { + int budget = 22; // total tree node count cap (including root) + float temp = 1.0f; // temperature for log-prob computation + bool chain_seed = true; // seed heap with greedy chain (recommended) + int block_size = 16; // matches dflash draft block_size + int top_k = 0; // 0 = auto (8 if budget > L else 1) +}; + +struct llama_ddtree_node { + llama_token token_id; + int32_t parent_idx; // -1 for root (index 0) + int32_t depth; // root = 0, root's children = 1, etc. +}; + +struct llama_ddtree { + // nodes[0] is always the root (last committed token). + // nodes[1..N-1] are DFS-ordered tree branches. + std::vector nodes; + + // visibility[i*N + j] = 1 iff node j is an ancestor of node i (inclusive). + // Row i of this matrix is the attention mask row for tree position i. + std::vector visibility; +}; + +// Build a DDTree from per-position top-K log-probabilities. +// +// top_log_probs [L, K] draft top-K log-probabilities, descending per row +// top_token_ids [L, K] matching token ids +// L number of draft positions (depth extent of the tree) +// K top-K width per position +// root_token the root token (last committed token) +// p build parameters +llama_ddtree build_ddtree( + const float * top_log_probs, + const int32_t * top_token_ids, + int L, + int K, + llama_token root_token, + const llama_ddtree_params & p); + +// Walk the tree greedily following the target's per-node argmax (posterior). +// +// Starting at the root (index 0), at each step the walk looks for a child +// whose token_id matches posterior[current_index]. The walk stops when no +// matching child exists. The root is always in the accepted list. +// +// posterior [N] target argmax token at each tree node position +// accepted output: flat node indices of the accepted path (starts with 0) +// next_token output: target argmax at the deepest accepted node (bonus token) +void follow_verified_tree( + const llama_ddtree & tree, + const int32_t * posterior, + std::vector & accepted, + llama_token & next_token); + +// Variant of follow_verified_tree that pulls the picked token at each chain +// step from caller-provided callbacks instead of a precomputed posterior[]. +// Lets callers (server) plug in grammar-aware sampling so the chain only +// accepts tokens the sampler+grammar would have produced. +// +// sample_cb (ud, logits_row_idx) -> picked token at this row (no state advance) +// advance_cb(ud, accepted_token) -> caller must advance its sampler/grammar +// +// advance_cb is invoked every time the chain accepts a child (= the picked +// token matched a child of `current`). It is NOT invoked for the bonus token. +typedef int32_t (*llama_speculative_pick_cb) (void * user_data, int32_t logits_row_idx); +typedef void (*llama_speculative_advance_cb)(void * user_data, llama_token accepted_token); + +void follow_verified_tree_cb( + const llama_ddtree & tree, + llama_speculative_pick_cb sample_cb, + llama_speculative_advance_cb advance_cb, + void * user_data, + std::vector & accepted, + llama_token & next_token); + +// Compute the [N, N] ancestor visibility mask from nodes[].parent_idx. +// dst must point to an N*N uint8 buffer (caller-allocated). +// Row i: dst[i*N + j] = 1 iff node j is an ancestor of i (inclusive). +void build_tree_visibility( + const std::vector & nodes, + uint8_t * dst); + +// Extract per-position top-K log-probabilities from a [L, V] logits matrix. +// +// Uses online logsumexp + a size-K min-heap for a single-pass O(L*V) scan. +// Output rows are sorted descending by log-probability (rank 0 = argmax). +// +// logits [L, V] row-major F32 +// L number of rows (draft positions) +// V vocabulary size +// K top-K width +// temp temperature: logits are divided by temp before softmax +// out_log_probs [L, K] caller-allocated output +// out_token_ids [L, K] caller-allocated output +void extract_top_k_logprobs( + const float * logits, + int L, + int V, + int K, + float temp, + float * out_log_probs, + int32_t * out_token_ids); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index a29dc707c3dc..0283ca3a8bec 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -32,6 +32,9 @@ else() add_subdirectory(simple-chat) add_subdirectory(speculative) add_subdirectory(speculative-simple) + # DDTree spec-decode driver; off by default to keep CI fast. + option(LLAMA_BUILD_EXAMPLES_SPECULATIVE_TREE "Build the llama-speculative-tree DDTree example" OFF) + add_subdirectory(speculative-tree) add_subdirectory(gen-docs) add_subdirectory(training) add_subdirectory(diffusion) diff --git a/examples/speculative-tree/CMakeLists.txt b/examples/speculative-tree/CMakeLists.txt new file mode 100644 index 000000000000..0599b9bf3aed --- /dev/null +++ b/examples/speculative-tree/CMakeLists.txt @@ -0,0 +1,7 @@ +if (LLAMA_BUILD_EXAMPLES_SPECULATIVE_TREE) + set(TARGET llama-speculative-tree) + add_executable(${TARGET} main.cpp) + install(TARGETS ${TARGET} RUNTIME) + target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) + target_compile_features(${TARGET} PRIVATE cxx_std_17) +endif() diff --git a/examples/speculative-tree/main.cpp b/examples/speculative-tree/main.cpp new file mode 100644 index 000000000000..a80d48493487 --- /dev/null +++ b/examples/speculative-tree/main.cpp @@ -0,0 +1,318 @@ +// examples/speculative-tree/main.cpp — DDTree spec-decode CLI driver. +// +// End-to-end command-line tool that loads a target Qwen3.5-27B model and a +// dflash-draft companion model, tokenizes a prompt, runs DDTree speculative +// decoding, and prints the generated text (and optionally timing statistics). +// +// Usage: +// llama-speculative-tree \ +// -m -md \ +// -p [--gen N] [--ddtree-budget N] [--temp F] \ +// [--n-gpu-layers N] [--n-ctx N] [--bench] [--out-tokens PATH] + +#include "speculative-tree-driver.h" +#include "llama.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Qwen3.5 EOS token id. +static constexpr llama_token QWEN35_EOS = 248045; + +struct cli_params { + std::string model_target; + std::string model_draft; + std::string prompt; + std::string prompt_tokens_path; + std::string out_tokens_path; + int gen = 64; + int ddtree_budget = 22; + bool ddtree_chain = true; + float temp = 1.0f; + int n_gpu_layers = 99; + int n_ctx = 4096; + bool bench = false; +}; + +static void print_usage(const char * prog) { + fprintf(stderr, + "Usage: %s -m PATH -md PATH [-p TEXT | --prompt-tokens PATH]\n" + " [--gen N] [--ddtree-budget N] [--ddtree-no-chain-seed]\n" + " [--temp F] [--n-gpu-layers N] [--n-ctx N]\n" + " [--out-tokens PATH] [--bench]\n", prog); +} + +static cli_params parse_args(int argc, char ** argv) { + cli_params p; + for (int i = 1; i < argc; ++i) { + std::string a = argv[i]; + auto next = [&]() -> std::string { + if (++i >= argc) { fprintf(stderr, "missing value for %s\n", a.c_str()); exit(1); } + return argv[i]; + }; + if (a == "-m") p.model_target = next(); + else if (a == "-md") p.model_draft = next(); + else if (a == "-p") p.prompt = next(); + else if (a == "--prompt-tokens") p.prompt_tokens_path = next(); + else if (a == "--gen") p.gen = std::stoi(next()); + else if (a == "--ddtree-budget") p.ddtree_budget = std::stoi(next()); + else if (a == "--ddtree-no-chain-seed") p.ddtree_chain = false; + else if (a == "--temp") p.temp = std::stof(next()); + else if (a == "--n-gpu-layers") p.n_gpu_layers = std::stoi(next()); + else if (a == "--n-ctx") p.n_ctx = std::stoi(next()); + else if (a == "--out-tokens") p.out_tokens_path = next(); + else if (a == "--bench") p.bench = true; + else { fprintf(stderr, "unknown option: %s\n", a.c_str()); print_usage(argv[0]); exit(1); } + } + if (p.model_target.empty() || p.model_draft.empty()) { + fprintf(stderr, "error: -m and -md are required\n"); + print_usage(argv[0]); + exit(1); + } + if (p.prompt.empty() && p.prompt_tokens_path.empty()) { + fprintf(stderr, "error: -p or --prompt-tokens is required\n"); + print_usage(argv[0]); + exit(1); + } + return p; +} + +// Write int32 LE binary. +static void write_tokens_bin(const std::string & path, const std::vector & toks) { + std::ofstream f(path, std::ios::binary); + for (llama_token t : toks) { + int32_t v = (int32_t)t; + f.write(reinterpret_cast(&v), 4); + } +} + +// Read int32 LE binary. +static std::vector read_tokens_bin(const std::string & path) { + std::ifstream f(path, std::ios::binary); + std::vector toks; + int32_t v; + while (f.read(reinterpret_cast(&v), 4)) { + toks.push_back((llama_token)v); + } + return toks; +} + +int main(int argc, char ** argv) { + cli_params cli = parse_args(argc, argv); + + llama_backend_init(); + + // Load target model. + llama_model_params mparams_tgt = llama_model_default_params(); + mparams_tgt.n_gpu_layers = cli.n_gpu_layers; + llama_model * model_tgt = llama_model_load_from_file(cli.model_target.c_str(), mparams_tgt); + if (!model_tgt) { + fprintf(stderr, "error: failed to load target model: %s\n", cli.model_target.c_str()); + return 1; + } + + // Load draft model. + llama_model_params mparams_dft = llama_model_default_params(); + mparams_dft.n_gpu_layers = cli.n_gpu_layers; + llama_model * model_dft = llama_model_load_from_file(cli.model_draft.c_str(), mparams_dft); + if (!model_dft) { + fprintf(stderr, "error: failed to load draft model: %s\n", cli.model_draft.c_str()); + llama_model_free(model_tgt); + return 1; + } + + // Create target context. + llama_context_params cparams_tgt = llama_context_default_params(); + cparams_tgt.n_ctx = (uint32_t)cli.n_ctx; + cparams_tgt.n_batch = 512; + llama_context * ctx_tgt = llama_init_from_model(model_tgt, cparams_tgt); + if (!ctx_tgt) { + fprintf(stderr, "error: failed to create target context\n"); + return 1; + } + + // Enable hidden capture on target so the draft can read its features. + llama_set_capture_hidden(ctx_tgt, true); + + // Create draft context. + // Draft uses small n_ctx (= DRAFT_CTX_MAX + block_size) since the dflash-draft + // model doesn't have a KV cache (it reuses target features directly). + llama_context_params cparams_dft = llama_context_default_params(); + cparams_dft.n_ctx = 2048 + 16; // DRAFT_CTX_MAX + block_size + cparams_dft.n_batch = 16; // one block per decode + llama_context * ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (!ctx_dft) { + fprintf(stderr, "error: failed to create draft context\n"); + return 1; + } + + // Tokenize prompt. + std::vector prompt_tokens; + if (!cli.prompt_tokens_path.empty()) { + prompt_tokens = read_tokens_bin(cli.prompt_tokens_path); + } else { + const llama_vocab * vocab = llama_model_get_vocab(model_tgt); + const int n_prompt = llama_tokenize(vocab, cli.prompt.c_str(), + (int32_t)cli.prompt.size(), + nullptr, 0, /*add_special=*/true, /*parse_special=*/false); + if (n_prompt < 0) { + fprintf(stderr, "error: tokenize failed\n"); + return 1; + } + prompt_tokens.resize(n_prompt); + llama_tokenize(vocab, cli.prompt.c_str(), (int32_t)cli.prompt.size(), + prompt_tokens.data(), n_prompt, true, false); + } + + if (prompt_tokens.empty()) { + fprintf(stderr, "error: empty prompt\n"); + return 1; + } + + // ── Prompt prefill (chain decode on target) ──────────────────────────────── + // Decode the prompt in one batch to fill the target KV cache and + // populate hidden_capture with the last token's layer features. + { + const int n_prompt = (int)prompt_tokens.size(); + llama_batch batch = llama_batch_init(n_prompt, 0, 1); + batch.n_tokens = n_prompt; + for (int i = 0; i < n_prompt; ++i) { + batch.token[i] = prompt_tokens[i]; + batch.pos[i] = (llama_pos)i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = (i == n_prompt - 1) ? 1 : 0; + } + int ret = llama_decode(ctx_tgt, batch); + llama_batch_free(batch); + if (ret != 0) { + fprintf(stderr, "error: prompt prefill decode failed: %d\n", ret); + return 1; + } + } + + // Greedy sample from last prompt token to get the first generated token. + // This becomes the root token for spec-decode step 0. + llama_token root_token; + { + const float * logits = llama_get_logits_ith(ctx_tgt, 0); + if (!logits) { fprintf(stderr, "error: no logits after prefill\n"); return 1; } + const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model_tgt)); + root_token = 0; + float best = logits[0]; + for (int v = 1; v < n_vocab; ++v) { + if (logits[v] > best) { best = logits[v]; root_token = (llama_token)v; } + } + } + + llama_pos committed_pos = (llama_pos)prompt_tokens.size(); + + // ── Build DDTree driver ─────────────────────────────────────────────────── + // Must be done before calling ingest_prompt_capture so the driver is initialized. + llama_ddtree_params ddparams; + ddparams.budget = cli.ddtree_budget; + ddparams.temp = cli.temp; + ddparams.chain_seed = cli.ddtree_chain; + ddparams.block_size = 16; // dflash-draft block size + + llama_speculative_tree_driver * driver = + llama_speculative_tree_driver_init(ctx_tgt, ctx_dft, ddparams); + if (!driver) { + fprintf(stderr, "error: failed to init speculative tree driver\n"); + return 1; + } + + // Ingest the prompt prefill capture into the driver's ring buffer. + llama_speculative_tree_driver_ingest_prompt_capture(driver, (int32_t)prompt_tokens.size()); + + // ── Generation loop ─────────────────────────────────────────────────────── + std::vector generated; + generated.reserve((size_t)cli.gen + 16); + + int64_t total_steps = 0; + int64_t total_accept = 0; // sum of commit_n per step (for accept rate) + + auto t_start = std::chrono::high_resolution_clock::now(); + + const llama_vocab * target_vocab = llama_model_get_vocab(model_tgt); + + while ((int)generated.size() < cli.gen && root_token != QWEN35_EOS) { + std::vector accepted = + llama_speculative_tree_driver_step(driver, root_token, committed_pos); + + if (accepted.empty()) { + fprintf(stderr, "error: driver step returned empty result\n"); + break; + } + + // accepted[0] = root_token (echoed) + // accepted[1..n-2] = newly accepted draft tokens + // accepted[n-1] = bonus token (next root) + + // Commit all tokens (root + draft accepted); hold bonus as new root. + // accepted: [root, draft_1, ..., draft_k, bonus] + // n_new = accept_depth = number of KV positions consumed this step. + const int n_new = (int)accepted.size() - 1; // excludes bonus + root_token = accepted.back(); // bonus = new root for next step + committed_pos += (llama_pos)n_new; // advance past all committed slots + + // The first accepted token == root_token == the one we greedy-sampled from prompt + // OR was returned as bonus from the prior step. Either way, we count it. + for (int i = 0; i < n_new && (int)generated.size() < cli.gen; ++i) { + generated.push_back(accepted[i]); + if (accepted[i] == QWEN35_EOS) { + root_token = QWEN35_EOS; + break; + } + } + + total_steps++; + total_accept += n_new; + + // Stream token text to stdout. + if (!cli.bench) { + for (int i = 0; i < n_new; ++i) { + char buf[256] = {0}; + int len = llama_token_to_piece(target_vocab, accepted[i], buf, sizeof(buf)-1, + /*lstrip=*/0, /*special=*/false); + if (len > 0) { buf[len] = '\0'; fputs(buf, stdout); fflush(stdout); } + } + } + } + + auto t_end = std::chrono::high_resolution_clock::now(); + double elapsed_s = std::chrono::duration(t_end - t_start).count(); + + printf("\n"); + + if (cli.bench) { + const double tps = (double)generated.size() / elapsed_s; + const double accept_rate = total_steps > 0 + ? (double)total_accept / (double)total_steps : 0.0; + printf("[bench] generated=%d tokens, elapsed=%.2fs, tokens/s=%.1f, " + "accept_rate=%.2f tokens/step, steps=%lld\n", + (int)generated.size(), elapsed_s, tps, accept_rate, + (long long)total_steps); + } + + if (!cli.out_tokens_path.empty()) { + write_tokens_bin(cli.out_tokens_path, generated); + } + + llama_speculative_tree_driver_free(driver); + llama_free(ctx_dft); + llama_free(ctx_tgt); + llama_model_free(model_dft); + llama_model_free(model_tgt); + llama_backend_free(); + + return 0; +} diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 11d3e8a81671..7b4c647d9753 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -429,7 +429,8 @@ extern "C" { GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) GGML_TYPE_Q1_0 = 41, - GGML_TYPE_COUNT = 42, + GGML_TYPE_TQ3_0 = 42, // TurboQuant 3.5 bpv (3-bit Lloyd-Max + FWHT rotation) + GGML_TYPE_COUNT = 43, }; // precision @@ -577,6 +578,8 @@ extern "C" { GGML_OP_GLU, + GGML_OP_TURBO_WHT, // FWHT rotation for TurboQuant KV cache + GGML_OP_COUNT, }; @@ -2351,6 +2354,14 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * sinks); + // TurboQuant FWHT rotation. direction: 0 = forward, 1 = inverse. + // Applies signs1 -> FWHT -> signs2 (forward) or signs2 -> FWHT -> signs1 (inverse). + // Used for KV cache rotation in TurboQuant quantization types (TQ3_0). + GGML_API struct ggml_tensor * ggml_turbo_wht( + struct ggml_context * ctx, + struct ggml_tensor * a, + int direction); + // TODO: needs to be adapted to ggml_flash_attn_ext GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, @@ -2365,6 +2376,31 @@ extern "C" { struct ggml_tensor * sx, struct ggml_tensor * c); + // dflash extension: tree-mode ssm_conv for DDTree-style spec + // decoding. parent_ids is an int32 tensor of shape [n_tokens, n_seqs] + // where entry [t, s] is the new-token index of t's tree parent, or -1 if + // t's parent is "before the block" (walks into the old conv state region). + // The CUDA kernel walks the parent chain K-1 times per token to build the + // correct conv window, so siblings get their own parent lineage instead + // of sharing the DFS-neighbour window. Mirrors SGLang's + // causal_conv1d_triton HAS_EAGLE_TREE_CUSTOM_ATTN_MASK path. + GGML_API struct ggml_tensor * ggml_ssm_conv_tree( + struct ggml_context * ctx, + struct ggml_tensor * sx, + struct ggml_tensor * c, + struct ggml_tensor * parent_ids); + + // dflash extension: tree-mode ssm_conv that also writes each token's + // (K-1)-element post-state to persist_inter so the driver can roll the + // live conv state back to the accepted DFS node. persist_inter must be + // contiguous F32 with shape [K-1, d_inner, n_tokens, n_seqs] (K-1 fastest). + GGML_API struct ggml_tensor * ggml_ssm_conv_tree_persist( + struct ggml_context * ctx, + struct ggml_tensor * sx, + struct ggml_tensor * c, + struct ggml_tensor * parent_ids, + struct ggml_tensor * persist_inter); + GGML_API struct ggml_tensor * ggml_ssm_scan( struct ggml_context * ctx, struct ggml_tensor * s, @@ -2485,6 +2521,43 @@ extern "C" { struct ggml_tensor * beta, struct ggml_tensor * state); + // dflash extension: tree-mode gated delta net for DDTree-style + // speculative decoding verify. `parent_ids` is an int32 tensor of shape + // [n_tokens, n_seqs] where entry [t, s] is the index within sequence s of + // the parent token in the DFS-flattened tree (or -1 for a root-level + // node). At each token step t > 0 in the recurrence, if parent_ids[t] is + // not (t - 1), the kernel reloads the recurrent state from the + // intermediate-states region at parent_ids[t] instead of continuing + // sequentially. This lets one verify forward pass correctly process + // multiple tree branches (siblings) without cross-contamination. + GGML_API struct ggml_tensor * ggml_gated_delta_net_tree( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state, + struct ggml_tensor * parent_ids); + + // dflash extension: tree-mode with direct intermediate-state writes to a + // persistent external buffer. Identical to ggml_gated_delta_net_tree but + // the per-token intermediate states are written to `persist_inter->data` + // (f32 or f16, [S_v, S_v, H, n_tokens, n_seqs], contiguous) instead of + // the default internal region of the result tensor. Eliminates a + // downstream ggml_cpy into the persistent cache buffer, saving ~5-10 ms + // per verify step on a 27B hybrid target. + GGML_API struct ggml_tensor * ggml_gated_delta_net_tree_persist( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state, + struct ggml_tensor * parent_ids, + struct ggml_tensor * persist_inter); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index f05683b44cd9..d47c976968f1 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -121,6 +121,9 @@ typedef sycl::half2 ggml_half2; #define QI8_1 (QK8_1 / (4 * QR8_1)) #define QR8_1 1 +#define QI_TQ3_0 (QK_TQ3_0 / (4 * QR_TQ3_0)) +#define QR_TQ3_0 2 + #define QI2_K (QK_K / (4*QR2_K)) #define QR2_K 4 @@ -277,6 +280,18 @@ typedef struct { } block_tq2_0; static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding"); +// TurboQuant 3.5 bpv — 3-bit Lloyd-Max codebook with FWHT rotation. +// 4 consecutive blocks (128 elements) share a group norm. +// Dequant: centroid[low2 | (hi1 << 2)] * norm +#define QK_TQ3_0 32 +#define QK_TQ3_0_GROUP 128 +typedef struct { + ggml_half norm; // 2 bytes + uint8_t qs[QK_TQ3_0 / 4]; // 8 bytes: 2-bit low indices, packed 4 per byte + uint8_t signs[QK_TQ3_0 / 8]; // 4 bytes: 1-bit high index per element +} block_tq3_0; // 14 bytes per 32 elements = 3.5 bpv +static_assert(sizeof(block_tq3_0) == 14, "wrong tq3_0 block size"); + // // Super-block quantization structures // diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2b3eb5b5ce65..fb5705f420fb 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2009,6 +2009,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_glu(params, tensor); } break; + case GGML_OP_TURBO_WHT: + { + // CUDA-only op; CPU fallback is identity (no rotation) + GGML_ASSERT(false && "TURBO_WHT is CUDA-only, use CUDA backend"); + } break; case GGML_OP_GET_REL_POS: { ggml_compute_forward_get_rel_pos(params, tensor); @@ -2273,9 +2278,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = n_threads; } break; default: - GGML_ABORT("fatal error"); + break; } break; + case GGML_OP_TURBO_WHT: + { + n_tasks = n_threads; + } break; case GGML_OP_SILU_BACK: case GGML_OP_MUL: case GGML_OP_DIV: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a9bc21da6f0f..51a71b1a5e7c 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -11,6 +11,7 @@ #include #include #include +#include // ggml_compute_forward_dup @@ -9254,6 +9255,8 @@ static void ggml_compute_forward_ssm_conv_f32( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; // conv_x const ggml_tensor * src1 = dst->src[1]; // conv1d.weight + const ggml_tensor * src2 = dst->src[2]; // parent_ids, optional tree mode + const ggml_tensor * src3 = dst->src[3]; // persist conv state, optional const int ith = params->ith; const int nth = params->nth; @@ -9269,6 +9272,17 @@ static void ggml_compute_forward_ssm_conv_f32( GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + if (src2 != nullptr) { + GGML_ASSERT(src2->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src2)); + GGML_ASSERT(ggml_nelements(src2) == n_t * n_s); + } + if (src3 != nullptr) { + GGML_ASSERT(src3->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src3)); + GGML_ASSERT(ggml_nelements(src3) >= (int64_t)(nc - 1) * nr * n_t * n_s); + } + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -9278,25 +9292,51 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir = ir1 - ir0; for (int i3 = 0; i3 < n_s; ++i3) { + const int32_t * parent_ids = src2 ? (const int32_t *) src2->data + (int64_t)i3 * n_t : nullptr; + float * persist_seq = src3 ? (float *) src3->data + (int64_t)i3 * n_t * nr * (nc - 1) : nullptr; + for (int i2 = 0; i2 < n_t; ++i2) { - // {d_conv - 1 + n_t, d_inner, n_seqs} - // sliding window - const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} - const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + int ancestors[GGML_MAX_DIMS] = {}; + if (parent_ids != nullptr) { + GGML_ASSERT(nc <= GGML_MAX_DIMS); + ancestors[nc - 1] = i2; + for (int k = nc - 2; k >= 0; --k) { + const int prev = ancestors[k + 1]; + ancestors[k] = prev >= 0 ? parent_ids[prev] : prev - 1; + } + } + // TODO: transpose the output for smaller strides for big batches? // d_inner for (int i1 = 0; i1 < ir; ++i1) { // rowwise dot product // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision float sumf = 0.0f; + float window[GGML_MAX_DIMS] = {}; // d_conv for (int i0 = 0; i0 < nc; ++i0) { - sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; + const int sx_slot = parent_ids ? (nc - 1 + ancestors[i0]) : (i2 + i0); + const float s = *(const float *) ((const char *) src0->data + + (int64_t)sx_slot * src0->nb[0] + + (int64_t)(ir0 + i1) * src0->nb[1] + + (int64_t)i3 * src0->nb[2]); + const float c = *(const float *) ((const char *) src1->data + + (int64_t)i0 * src1->nb[0] + + (int64_t)(ir0 + i1) * src1->nb[1]); + window[i0] = s; + sumf += s * c; } x[i1] = sumf; + + if (persist_seq != nullptr) { + float * persist_token = persist_seq + ((int64_t)i2 * nr + (ir0 + i1)) * (nc - 1); + for (int i0 = 0; i0 < nc - 1; ++i0) { + persist_token[i0] = window[i0 + 1]; + } + } } } } @@ -10439,6 +10479,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( ggml_tensor * src_g = dst->src[3]; ggml_tensor * src_beta = dst->src[4]; ggml_tensor * src_state = dst->src[5]; + ggml_tensor * src_parent = dst->src[6]; + ggml_tensor * src_persist = dst->src[7]; const int64_t S_v = src_v->ne[0]; const int64_t H = src_v->ne[1]; @@ -10454,6 +10496,16 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v); GGML_ASSERT(src_beta->ne[0] == 1); + if (src_parent != nullptr) { + GGML_ASSERT(src_parent->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src_parent)); + GGML_ASSERT(ggml_nelements(src_parent) == n_tokens * n_seqs); + } + if (src_persist != nullptr) { + GGML_ASSERT(src_persist->type == GGML_TYPE_F32 || src_persist->type == GGML_TYPE_F16); + GGML_ASSERT(ggml_is_contiguous(src_persist)); + GGML_ASSERT(ggml_nelements(src_persist) >= S_v * S_v * H * n_tokens * n_seqs); + } GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); @@ -10477,8 +10529,10 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( // attn_scores: S_v * H * n_tokens * n_seqs floats // new_states: S_v * S_v * H * n_seqs floats const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t state_elems = S_v * S_v * H * n_seqs; float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; + float * inter_out_base = state_out_base + state_elems; const float * state_in_base = (const float *)src_state->data; @@ -10505,10 +10559,45 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); + const int32_t * parent_ids = src_parent ? (const int32_t *) src_parent->data + iv3 * n_tokens : nullptr; + auto load_inter = [&](int64_t token, int64_t elem) -> float { + if (src_persist != nullptr) { + const int64_t off = ((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem; + if (src_persist->type == GGML_TYPE_F32) { + return ((const float *) src_persist->data)[off]; + } + return GGML_FP16_TO_FP32(((const ggml_fp16_t *) src_persist->data)[off]); + } + return inter_out_base[((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem]; + }; + auto store_inter = [&](int64_t token, int64_t elem, float value) { + if (src_persist != nullptr) { + const int64_t off = ((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem; + if (src_persist->type == GGML_TYPE_F32) { + ((float *) src_persist->data)[off] = value; + } else { + ((ggml_fp16_t *) src_persist->data)[off] = GGML_FP32_TO_FP16(value); + } + } else { + inter_out_base[((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem] = value; + } + }; + // attn output pointer for first token of this (head, seq) float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v; for (int64_t t = 0; t < n_tokens; t++) { + if (parent_ids != nullptr && t > 0) { + const int32_t parent_t = parent_ids[t]; + if (parent_t < 0) { + memcpy(s_out, s_in, S_v * S_v * sizeof(float)); + } else if (parent_t != t - 1) { + for (int64_t elem = 0; elem < S_v * S_v; ++elem) { + s_out[elem] = load_inter(parent_t, elem); + } + } + } + const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1); const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1); const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); @@ -10552,6 +10641,12 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } attn_data += S_v * H; // advance to next token + + if (parent_ids != nullptr || src_persist != nullptr) { + for (int64_t elem = 0; elem < S_v * S_v; ++elem) { + store_inter(t, elem, s_out[elem]); + } + } } } } diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index b54d4a6b1079..6ed2f61fedb4 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -120,7 +120,8 @@ if (CUDAToolkit_FOUND) template-instances/fattn-vec-instance-f16-f16.cu template-instances/fattn-vec-instance-q4_0-q4_0.cu template-instances/fattn-vec-instance-q8_0-q8_0.cu - template-instances/fattn-vec-instance-bf16-bf16.cu) + template-instances/fattn-vec-instance-bf16-bf16.cu + template-instances/fattn-vec-instance-tq3_0-tq3_0.cu) endif() ggml_add_backend_library(ggml-cuda diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh index 7697c292dd6f..00e5c3bc8979 100644 --- a/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -2,6 +2,7 @@ #include "ggml-common.h" #include "convert.cuh" +#include "tq3-quant.cuh" static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { if (x <= val[0]) return 0; @@ -211,6 +212,40 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti); } +static __device__ void quantize_f32_tq3_0_group(const float * __restrict__ src, block_tq3_0 * __restrict__ dst) { + float x[128]; + float norm_sq = 0.0f; + for (int j = 0; j < 128; j++) { + x[j] = src[j]; + norm_sq += x[j] * x[j]; + } + + float grp_norm = sqrtf(norm_sq); + float inv_norm = grp_norm > 1e-10f ? 1.0f / grp_norm : 0.0f; + for (int j = 0; j < 128; j++) x[j] *= inv_norm; + + tq3_rotate_forward(x); + + float recon_norm_sq = 0.0f; + for (int b = 0; b < 4; b++) { + const int off = b * QK_TQ3_0; + for (int j = 0; j < QK_TQ3_0 / 4; j++) dst[b].qs[j] = 0; + for (int j = 0; j < QK_TQ3_0 / 8; j++) dst[b].signs[j] = 0; + for (int j = 0; j < QK_TQ3_0; j++) { + uint8_t idx = tq3_find_nearest(x[off + j]); + dst[b].qs[j/4] |= (idx & 0x3) << ((j%4) * 2); + if (idx & 0x4) dst[b].signs[j/8] |= (1 << (j%8)); + float c = d_tq3_centroids[idx]; + recon_norm_sq += c * c; + } + } + + float recon_norm = sqrtf(recon_norm_sq); + float corrected_norm = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm; + half h_norm = __float2half(corrected_norm); + for (int b = 0; b < 4; b++) dst[b].norm = h_norm; +} + template static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) { *(dst_t *) cdsti = ggml_cuda_cast(*(const src_t *) cxi); diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index d208acf2d5f0..0fd719f1249b 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -372,6 +372,50 @@ static void ggml_cpy_f32_iq4_nl_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static __global__ void cpy_f32_tq3_0_kernel( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13) { + + const int64_t group = (int64_t)blockDim.x * blockIdx.x + threadIdx.x; + const int64_t i = group * QK_TQ3_0_GROUP; + + if (i >= ne) return; + + const int64_t i03 = i / (ne00 * ne01 * ne02); + const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00; + const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00; + const int64_t x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03; + + const int64_t i13 = i / (ne10 * ne11 * ne12); + const int64_t i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11); + const int64_t i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10; + const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; + const int64_t dst_offset = (i10 / QK_TQ3_0) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; + + const float * src = (const float *)(cx + x_offset); + block_tq3_0 * dst_blocks = (block_tq3_0 *)(cdst + dst_offset); + + quantize_f32_tq3_0_group(src, dst_blocks); +} + +static void ggml_cpy_f32_tq3_0_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { + + GGML_ASSERT(ne % QK_TQ3_0_GROUP == 0); + const int64_t num_groups = ne / QK_TQ3_0_GROUP; + GGML_ASSERT(num_groups < UINT_MAX); + cpy_f32_tq3_0_kernel<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -465,6 +509,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_TQ3_0) { + ggml_cpy_f32_tq3_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { ggml_cpy_f32_iq4_nl_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index e060fb29fdc0..1801bd0489bc 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -1,4 +1,5 @@ #include "common.cuh" +#include "tq3-quant.cuh" static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; diff --git a/ggml/src/ggml-cuda/fattn-chunked.cu b/ggml/src/ggml-cuda/fattn-chunked.cu new file mode 100644 index 000000000000..3240b5648cb0 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-chunked.cu @@ -0,0 +1,354 @@ +// Chunked-prefill FlashAttention driver. +// +// Ported from the non-TBQ path of ggml_cuda_tbq_chunked_prefill. Uses custom +// strided dequant kernels (see fattn-chunked.cuh) so K/V chunks are extracted +// per iteration at chunk_len × nh_kv × D granularity — not full-tensor up +// front. This is the critical perf fix: for a 64K prompt × 48 layers × 350 +// prefill steps, we were previously dequanting 64K × nh_kv × D each call. +// +// Structure: +// for each sequence: +// copy Q to contiguous fp32 +// init accumulators (O_acc, l_acc, m_acc) +// for each kv_chunk: +// strided dequant K chunk → k_tmp (chunk_len × nh_kv × D fp32) +// strided dequant V chunk → v_tmp (same shape) +// for each q_batch: +// if causal-skip: continue +// S = scale * Q @ K^T (cuBLAS SGEMM strided batched) +// online softmax update (kernel updates m/l/O) +// O_acc += P @ V (cuBLAS SGEMM strided batched) +// finalize (divide O by l, transpose to ggml layout) +// skip cudaStreamSynchronize during CUDA graph capture + +#include "fattn-chunked.cuh" + +#include +#include +#include + +// Q batch: default to the prefill's own nq (DFlash prefill uses ubatch=192). +// Override via env. 1024 is way too large — inflates S buffer by 5× with no +// perf upside, which in turn collapses the adaptive chunk size. +static int chunked_q_batch_env(int64_t nq) { + const char * e = getenv("DFLASH27B_CHUNKED_Q_BATCH"); + if (e) { + int v = atoi(e); + if (v >= 1) return v; + } + return (int)std::min(nq, (int64_t)256); +} + +// Chunk sizing: prefer a fixed 4096 by default. The original VRAM-adaptive +// formula runs AFTER we've already claimed persistent scratch, so cudaMemGetInfo +// under-reports and the formula collapses to CHUNKED_PF_MIN=256 → thousands +// of chunks per fattn call → pathological. Override via env if needed. +static int chunked_chunk_env(int fallback_from_vram) { + const char * e = getenv("DFLASH27B_CHUNKED_CHUNK"); + if (e) { + int v = atoi(e); + if (v >= 1) { + int p = 1; + while (p < v) p <<= 1; + return std::min(std::max(p, CHUNKED_PF_MIN), CHUNKED_PF_MAX); + } + } + return std::max(fallback_from_vram, 4096); +} + +struct chunked_scratch { + float * O_acc = nullptr; size_t O_bytes = 0; + float * l_acc = nullptr; size_t l_bytes = 0; + float * m_acc = nullptr; size_t m_bytes = 0; + float * S = nullptr; size_t S_bytes = 0; + float * k_tmp = nullptr; size_t k_bytes = 0; + float * v_tmp = nullptr; size_t v_bytes = 0; + float * Q_f32 = nullptr; size_t Q_bytes = 0; +}; +static chunked_scratch g_chunked_bufs[GGML_CUDA_MAX_DEVICES]; + +static bool try_ensure_buf(float ** p, size_t * cur_bytes, size_t need_bytes) { + if (need_bytes <= *cur_bytes && *p != nullptr) return true; + if (*p != nullptr) CUDA_CHECK(cudaFree(*p)); + *p = nullptr; + + const cudaError_t err = cudaMalloc(p, need_bytes); + if (err != cudaSuccess) { + // Clear the sticky CUDA error so the caller can retry with a smaller + // chunk instead of aborting the whole process. + (void) cudaGetLastError(); + *p = nullptr; + *cur_bytes = 0; + return false; + } + + *cur_bytes = need_bytes; + return true; +} + +static void free_buf(float ** p, size_t * cur_bytes) { + if (*p != nullptr) CUDA_CHECK(cudaFree(*p)); + *p = nullptr; + *cur_bytes = 0; +} + +void ggml_cuda_flash_attn_ext_chunked(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + cudaStream_t stream = ctx.stream(); + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + const int64_t D = Q->ne[0]; + const int64_t nq = Q->ne[1]; + const int64_t nh_q = Q->ne[2]; + const int64_t nh_kv = K->ne[2]; + const int64_t n_seq = Q->ne[3]; + const int64_t nkv = K->ne[1]; + const int64_t gqa = nh_q / nh_kv; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(nh_q % nh_kv == 0); + + float scale = 1.0f; + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + + const int q_batch_size = (int)std::min((int64_t)chunked_q_batch_env(nq), nq); + + size_t free_bytes = 0, total_bytes = 0; + CUDA_CHECK(cudaMemGetInfo(&free_bytes, &total_bytes)); + const int vram_chunk = chunked_pf_compute_chunk_size(free_bytes, nh_q, nh_kv, q_batch_size, D); + int tbq_chunk = chunked_chunk_env(vram_chunk); + + const int device = ctx.device; + GGML_ASSERT(device >= 0 && device < GGML_CUDA_MAX_DEVICES); + chunked_scratch & sc = g_chunked_bufs[device]; + + const size_t O_bytes = (size_t)nh_q * nq * D * sizeof(float); + const size_t l_bytes = (size_t)nh_q * nq * sizeof(float); + const size_t m_bytes = (size_t)nh_q * nq * sizeof(float); + const size_t Q_f32_bytes = (size_t)nh_q * nq * D * sizeof(float); + + float * O_acc = nullptr; + float * l_acc = nullptr; + float * m_acc = nullptr; + float * S = nullptr; + float * k_tmp = nullptr; + float * v_tmp = nullptr; + float * Q_f32 = nullptr; + + const int requested_tbq_chunk = tbq_chunk; + for (;;) { + const size_t S_bytes = (size_t)nh_q * q_batch_size * tbq_chunk * sizeof(float); + // Per-chunk K/V dequant: [nh_kv, tbq_chunk, D] fp32. The final chunk may + // be shorter; we still size the buffer for the max and only write chunk_len. + const size_t kv_bytes = (size_t)nh_kv * tbq_chunk * D * sizeof(float); + + const bool ok = + try_ensure_buf(&sc.O_acc, &sc.O_bytes, O_bytes) && + try_ensure_buf(&sc.l_acc, &sc.l_bytes, l_bytes) && + try_ensure_buf(&sc.m_acc, &sc.m_bytes, m_bytes) && + try_ensure_buf(&sc.S, &sc.S_bytes, S_bytes) && + try_ensure_buf(&sc.k_tmp, &sc.k_bytes, kv_bytes) && + try_ensure_buf(&sc.v_tmp, &sc.v_bytes, kv_bytes) && + try_ensure_buf(&sc.Q_f32, &sc.Q_bytes, Q_f32_bytes); + + if (ok) { + O_acc = sc.O_acc; + l_acc = sc.l_acc; + m_acc = sc.m_acc; + S = sc.S; + k_tmp = sc.k_tmp; + v_tmp = sc.v_tmp; + Q_f32 = sc.Q_f32; + break; + } + + if (tbq_chunk <= CHUNKED_PF_MIN) { + GGML_ABORT("chunked prefill: failed to allocate scratch buffers"); + } + + tbq_chunk >>= 1; + // Release chunk-dependent scratch allocated for the failed, larger + // chunk. Otherwise retry can keep the old large buffers alive and fail + // again despite the smaller chunk size. + free_buf(&sc.S, &sc.S_bytes); + free_buf(&sc.k_tmp, &sc.k_bytes); + free_buf(&sc.v_tmp, &sc.v_bytes); + GGML_LOG_WARN("chunked prefill: scratch allocation failed, retrying with chunk=%d (requested=%d)\n", + tbq_chunk, requested_tbq_chunk); + } + + cublasHandle_t cublas_handle = ctx.cublas_handle(); + CUBLAS_CHECK(cublasSetStream(cublas_handle, stream)); + + const int64_t mask_stride = mask ? (int64_t)(mask->nb[1] / sizeof(half)) : 0; + + for (int64_t seq = 0; seq < n_seq; seq++) { + const char * Q_data_seq = (const char *)Q->data + seq * Q->nb[3]; + const char * K_data_seq = (const char *)K->data + seq * K->nb[3]; + const char * V_data_seq = (const char *)V->data + seq * V->nb[3]; + const half * mask_seq = mask ? (const half *)((const char *)mask->data + seq * mask->nb[3]) : nullptr; + float * dst_seq = (float *)((char *)dst->data + seq * dst->nb[3]); + + if (Q->nb[0] == sizeof(float) && + Q->nb[1] == (size_t)D * sizeof(float) && + Q->nb[2] == (size_t)D * nq * sizeof(float)) { + CUDA_CHECK(cudaMemcpyAsync(Q_f32, Q_data_seq, + (size_t)nh_q * nq * D * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } else { + GGML_ASSERT(Q->nb[0] == sizeof(float)); + for (int64_t h = 0; h < nh_q; h++) { + for (int64_t q = 0; q < nq; q++) { + const char * src_ptr = Q_data_seq + h * Q->nb[2] + q * Q->nb[1]; + float * dst_ptr = Q_f32 + h * nq * D + q * D; + CUDA_CHECK(cudaMemcpyAsync(dst_ptr, src_ptr, D * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + } + } + + // Init accumulators. + { + const int64_t nq_heads = nh_q * nq; + const int threads_init = (int)std::min(D, (int64_t)1024); + k_chunked_attn_init<<<(int)nq_heads, threads_init, 0, stream>>>( + O_acc, l_acc, m_acc, nq_heads, D); + CUDA_CHECK(cudaGetLastError()); + } + + for (int64_t kv_start = 0; kv_start < nkv; kv_start += tbq_chunk) { + const int64_t chunk_len = (kv_start + tbq_chunk <= nkv) ? tbq_chunk : (nkv - kv_start); + + // Strided dequant K and V chunk to contiguous [nh_kv, chunk_len, D] fp32. + if (!chunked_dequant_launch(K->type, K_data_seq, k_tmp, + D, chunk_len, nh_kv, K->nb[1], K->nb[2], + kv_start, stream)) { + GGML_ABORT("chunked prefill: unsupported K type"); + } + if (!chunked_dequant_launch(V->type, V_data_seq, v_tmp, + D, chunk_len, nh_kv, V->nb[1], V->nb[2], + kv_start, stream)) { + GGML_ABORT("chunked prefill: unsupported V type"); + } + CUDA_CHECK(cudaGetLastError()); + + for (int64_t q_start = 0; q_start < nq; q_start += q_batch_size) { + const int64_t q_len = std::min((int64_t)q_batch_size, nq - q_start); + if ((nkv - nq) + q_start + q_len <= kv_start) continue; + + // S = scale * Q @ K^T. k_tmp layout is [nh_kv, chunk_len, D] + // with per-head stride chunk_len*D (compact — NOT nkv*D). + { + const float alpha_v = scale; + const float beta_v = 0.0f; + const long long stride_A = (long long)chunk_len * D; + const long long stride_B = (long long)nq * D; + const long long stride_C = (long long)q_len * tbq_chunk; + + if (gqa == 1) { + CUBLAS_CHECK(cublasSgemmStridedBatched( + cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + (int)chunk_len, (int)q_len, (int)D, + &alpha_v, + k_tmp, (int)D, stride_A, + Q_f32 + q_start * D, (int)D, stride_B, + &beta_v, + S, (int)tbq_chunk, stride_C, + (int)nh_q)); + } else { + for (int64_t kv_h = 0; kv_h < nh_kv; kv_h++) { + const float * k_head = k_tmp + kv_h * chunk_len * D; + const float * q_ptr = Q_f32 + kv_h * gqa * nq * D + q_start * D; + float * s_start = S + kv_h * gqa * (long long)q_len * tbq_chunk; + CUBLAS_CHECK(cublasSgemmStridedBatched( + cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + (int)chunk_len, (int)q_len, (int)D, + &alpha_v, + k_head, (int)D, 0LL, + q_ptr, (int)D, (long long)nq * D, + &beta_v, + s_start, (int)tbq_chunk, (long long)q_len * tbq_chunk, + (int)gqa)); + } + } + } + + // Online softmax. + { + const int64_t nq_batch_heads = nh_q * q_len; + const int chunk_len_int = (int)chunk_len; + const int chunk_pad = chunked_pf_next_pow2(chunk_len_int); + const int threads_sm = (chunk_pad < 1024) ? chunk_pad : 1024; + const size_t smem = ((size_t)chunk_pad + 2) * sizeof(float); + k_chunked_softmax_update<<<(int)nq_batch_heads, threads_sm, smem, stream>>>( + S, O_acc, l_acc, m_acc, + chunk_len_int, chunk_pad, + D, nq, q_len, q_start, + nh_q, + mask_seq, mask_stride, kv_start, tbq_chunk); + CUDA_CHECK(cudaGetLastError()); + } + + // O_acc += P @ V. v_tmp layout [nh_kv, chunk_len, D], per-head stride chunk_len*D. + { + const float alpha_v = 1.0f; + const float beta_v = 1.0f; + const long long stride_A = (long long)chunk_len * D; + const long long stride_B = (long long)q_len * tbq_chunk; + const long long stride_C = (long long)nq * D; + + if (gqa == 1) { + CUBLAS_CHECK(cublasSgemmStridedBatched( + cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + (int)D, (int)q_len, (int)chunk_len, + &alpha_v, + v_tmp, (int)D, stride_A, + S, (int)tbq_chunk, stride_B, + &beta_v, + O_acc + q_start * D, (int)D, stride_C, + (int)nh_q)); + } else { + for (int64_t kv_h = 0; kv_h < nh_kv; kv_h++) { + const float * v_head = v_tmp + kv_h * chunk_len * D; + const float * p_start = S + kv_h * gqa * (long long)q_len * tbq_chunk; + float * o_start = O_acc + kv_h * gqa * (long long)nq * D + q_start * D; + CUBLAS_CHECK(cublasSgemmStridedBatched( + cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + (int)D, (int)q_len, (int)chunk_len, + &alpha_v, + v_head, (int)D, 0LL, + p_start, (int)tbq_chunk, (long long)q_len * tbq_chunk, + &beta_v, + o_start, (int)D, (long long)nq * D, + (int)gqa)); + } + } + } + } + } + + // Finalize. + { + const int64_t nq_heads = nh_q * nq; + const int threads_fin = 128; + const dim3 grid_fin((int)nq_heads, (int)((D + threads_fin - 1) / threads_fin)); + k_chunked_attn_finalize<<>>( + O_acc, l_acc, dst_seq, nq, nh_q, D); + CUDA_CHECK(cudaGetLastError()); + } + } + + cudaStreamCaptureStatus status; + cudaError_t err = cudaStreamIsCapturing(stream, &status); + if (err == cudaSuccess && status == cudaStreamCaptureStatusNone) { + CUDA_CHECK(cudaStreamSynchronize(stream)); + } +} diff --git a/ggml/src/ggml-cuda/fattn-chunked.cuh b/ggml/src/ggml-cuda/fattn-chunked.cuh new file mode 100644 index 000000000000..6d220f15da1c --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-chunked.cuh @@ -0,0 +1,398 @@ +#pragma once + +// Chunked-prefill FlashAttention with online softmax and O(CHUNK) temp memory. +// +// Ported from Luce-Org-adjacent project dusterbloom-turboquant (fattn.cu:750-). +// Strips TBQ-specific dequant kernels and Walsh-Hadamard rotations; retains the +// KV-type-agnostic online softmax kernels and the generic chunked driver. +// +// Three kernels (identical to turboquant): +// k_chunked_attn_init — init O_acc/l_acc/m_acc +// k_chunked_softmax_update — online softmax over one KV chunk +// k_chunked_attn_finalize — normalize O by l, transpose to ggml layout +// +// The driver (ggml_cuda_flash_attn_ext_chunked) is provided in fattn-chunked.cu. +// +// Use when: Q->ne[1] > 1 (prefill) AND K->ne[1] > CHUNKED_PREFILL_THRESHOLD. +// KV must be dequantizable to f32 (f16, bf16, Q4_0, Q8_0). cuBLAS SGEMM is used +// for Q@K^T and P@V on FP32 temp buffers. + +#include "common.cuh" +#include "fattn-common.cuh" + +// Adaptive chunk-size configuration +static constexpr int CHUNKED_PF_MAX = 8192; +static constexpr int CHUNKED_PF_MIN = 256; + +// Compute largest power-of-2 chunk size that fits in available GPU memory. +// Accounts for S buffer (nh_q * nq * chunk * 4) and K/V temp buffers +// (2 * D * chunk * nh_kv * 4). Reserves 512 MB headroom. +static inline int chunked_pf_compute_chunk_size(size_t free_bytes, + int64_t nh_q, + int64_t nh_kv, + int64_t nq, + int64_t D) { + const size_t headroom = 512ULL * 1024 * 1024; + const size_t usable = (free_bytes > headroom) ? (free_bytes - headroom) : 0; + const size_t per_chunk_token = (size_t)(nh_q * nq + 2 * D * nh_kv) * sizeof(float); + if (per_chunk_token == 0) return CHUNKED_PF_MIN; + size_t max_chunk = usable / per_chunk_token; + int chunk = CHUNKED_PF_MAX; + while (chunk > CHUNKED_PF_MIN && (size_t)chunk > max_chunk) { + chunk >>= 1; + } + return chunk; +} + +// Next power of 2 >= n (n >= 1). +static inline int chunked_pf_next_pow2(int n) { + int p = 1; + while (p < n) p <<= 1; + return p; +} + +// Kernel 1: Initialize accumulators. +// O_acc = 0, l_acc = 0, m_acc = -inf +// Grid: (nq_heads, 1, 1), blockDim.x = min(D, 1024). +static __global__ void k_chunked_attn_init( + float * __restrict__ O_acc, + float * __restrict__ l_acc, + float * __restrict__ m_acc, + const int64_t nq_heads, + const int64_t D) { + const int64_t hq = (int64_t)blockIdx.x; + const int tid = (int)threadIdx.x; + const int bdx = (int)blockDim.x; + if (hq >= nq_heads) return; + if (tid == 0) { + l_acc[hq] = 0.0f; + m_acc[hq] = -INFINITY; + } + for (int64_t d = tid; d < D; d += bdx) { + O_acc[hq * D + d] = 0.0f; + } +} + +// Kernel 2: Online softmax update. +// Processes one (head, query) pair per thread block. +// Shared memory layout: sm[0..chunk_pad-1] scores/exp, sm[chunk_pad..chunk_pad+1] alpha/beta. +// blockDim.x = min(chunk_pad, 1024). Binary-tree max-then-sum reduction. +// Rescales O_acc by alpha, writes P = beta * exp(S-m_chunk) back into S for the P@V GEMM. +static __global__ void k_chunked_softmax_update( + float * __restrict__ S, // [nh_q, q_batch, s_stride] — scores → P after kernel + float * __restrict__ O_acc, // [nh_q, nq_total, D] + float * __restrict__ l_acc, // [nh_q, nq_total] + float * __restrict__ m_acc, // [nh_q, nq_total] + const int chunk_len, + const int chunk_pad, + const int64_t D, + const int64_t nq_total, + const int64_t q_batch, + const int64_t q_start, + const int64_t nh_q, + const half * __restrict__ mask, // [nq_total, nkv_total] f16 mask, or nullptr + const int64_t mask_stride, + const int64_t kv_start, + const int s_stride) { + const int64_t hq_idx = (int64_t)blockIdx.x; + const int64_t head = hq_idx / q_batch; + const int64_t q_local = hq_idx % q_batch; + const int64_t q_pos = q_start + q_local; + if (head >= nh_q) return; + + const int tid = (int)threadIdx.x; + const int bdx = (int)blockDim.x; + extern __shared__ float sm[]; // (chunk_pad + 2) floats + + const int64_t s_base = head * q_batch * s_stride + q_local * s_stride; + const int64_t acc_idx = head * nq_total + q_pos; + + // Load scores + mask into sm[], pad with -inf. + for (int c = tid; c < chunk_pad; c += bdx) { + if (c < chunk_len) { + float val = S[s_base + c]; + if (mask != nullptr) { + val += __half2float(mask[q_pos * mask_stride + kv_start + c]); + } + sm[c] = val; + } else { + sm[c] = -INFINITY; + } + } + __syncthreads(); + + // Binary tree max reduction. + for (int stride = chunk_pad >> 1; stride >= 1; stride >>= 1) { + for (int c = tid; c < stride; c += bdx) { + sm[c] = fmaxf(sm[c], sm[c + stride]); + } + __syncthreads(); + } + const float m_chunk = sm[0]; + __syncthreads(); + const bool chunk_empty = (m_chunk == -INFINITY); + + // Compute exp(score - m_chunk), pad with 0. + for (int c = tid; c < chunk_pad; c += bdx) { + if (c < chunk_len) { + float val = S[s_base + c]; + if (mask != nullptr) { + val += __half2float(mask[q_pos * mask_stride + kv_start + c]); + } + sm[c] = __expf(val - m_chunk); + } else { + sm[c] = 0.0f; + } + } + __syncthreads(); + + // Binary tree sum reduction. + for (int stride = chunk_pad >> 1; stride >= 1; stride >>= 1) { + for (int c = tid; c < stride; c += bdx) { + sm[c] += sm[c + stride]; + } + __syncthreads(); + } + const float l_chunk = sm[0]; + __syncthreads(); + + // Update m_acc, l_acc and broadcast alpha/beta via sm[chunk_pad..+1]. + if (tid == 0) { + if (!chunk_empty) { + const float m_old = m_acc[acc_idx]; + const float m_new = fmaxf(m_old, m_chunk); + const float alpha = (m_old > -INFINITY) ? __expf(m_old - m_new) : 0.0f; + const float beta = __expf(m_chunk - m_new); + sm[chunk_pad] = alpha; + sm[chunk_pad + 1] = beta; + l_acc[acc_idx] = alpha * l_acc[acc_idx] + beta * l_chunk; + m_acc[acc_idx] = m_new; + } else { + sm[chunk_pad] = 1.0f; + sm[chunk_pad + 1] = 0.0f; + } + } + __syncthreads(); + + const float alpha = sm[chunk_pad]; + const float beta = sm[chunk_pad + 1]; + + // Rescale O_acc. + for (int64_t d = tid; d < D; d += bdx) { + O_acc[acc_idx * D + d] *= alpha; + } + __syncthreads(); + + // Write P = beta * exp(S - m_chunk) back to S; zero padding. + for (int c = tid; c < chunk_len; c += bdx) { + if (chunk_empty) { + S[s_base + c] = 0.0f; + } else { + float val = S[s_base + c]; + if (mask != nullptr) { + val += __half2float(mask[q_pos * mask_stride + kv_start + c]); + } + S[s_base + c] = beta * __expf(val - m_chunk); + } + } + for (int c = tid + chunk_len; c < s_stride; c += bdx) { + S[s_base + c] = 0.0f; + } +} + +// Kernel 3: Finalize. Reads head-major O_acc, writes token-major dst / l. +static __global__ void k_chunked_attn_finalize( + const float * __restrict__ O_acc, // [nh_q, nq, D] head-major + const float * __restrict__ l_acc, // [nh_q * nq] + float * __restrict__ dst, // ggml layout [D, nh_q, nq] token-major + const int64_t nq, + const int64_t nh_q, + const int64_t D) { + const int64_t hq_idx = (int64_t)blockIdx.x; + const int64_t head = hq_idx / nq; + const int64_t q_pos = hq_idx % nq; + const int64_t d = (int64_t)blockIdx.y * blockDim.x + threadIdx.x; + if (d >= D) return; + + const float l = fmaxf(l_acc[hq_idx], 1e-30f); + dst[q_pos * nh_q * D + head * D + d] = O_acc[hq_idx * D + d] / l; +} + +// ─── Strided chunk dequant kernels ───────────────────────────────────────── +// +// Extract a [chunk_len, nh_kv] window from a ggml KV tensor (layout +// [D, kv_len, nh_kv] with byte strides nb1, nb2) into a contiguous +// [nh_kv, chunk_len, D] FP32 destination. Each thread block handles one +// (kv_local, head) cell; threadIdx.x cooperates across the head_dim. +// +// Grid: (chunk_len, nh_kv, 1). blockDim.x = up to D (capped to 256). +// +// This mirrors turboquant's k_tbq*_dequant_f32 pattern but for the native +// ggml block types (Q4_0, Q8_0) and plain half/bfloat16 layouts — nothing +// compressed-domain, no SRHT. + +static __global__ void k_chunked_dequant_f16_f32( + const char * __restrict__ src_base, + float * __restrict__ dst, + const int64_t D, + const int64_t chunk_len, + const size_t nb1, + const size_t nb2, + const int64_t kv_start) { + const int64_t kv_local = blockIdx.x; + const int64_t h = blockIdx.y; + if (kv_local >= chunk_len) return; + const half * src = (const half *)(src_base + h * nb2 + (kv_start + kv_local) * nb1); + float * out = dst + (h * chunk_len + kv_local) * D; + for (int64_t d = threadIdx.x; d < D; d += blockDim.x) { + out[d] = __half2float(src[d]); + } +} + +static __global__ void k_chunked_dequant_bf16_f32( + const char * __restrict__ src_base, + float * __restrict__ dst, + const int64_t D, + const int64_t chunk_len, + const size_t nb1, + const size_t nb2, + const int64_t kv_start) { + const int64_t kv_local = blockIdx.x; + const int64_t h = blockIdx.y; + if (kv_local >= chunk_len) return; + const nv_bfloat16 * src = (const nv_bfloat16 *)(src_base + h * nb2 + (kv_start + kv_local) * nb1); + float * out = dst + (h * chunk_len + kv_local) * D; + for (int64_t d = threadIdx.x; d < D; d += blockDim.x) { + out[d] = __bfloat162float(src[d]); + } +} + +// Q4_0: one token has D/QK4_0 blocks; each block has a half scale + 16 bytes +// of packed 4-bit quants (32 elements / block). Per-thread we dequant one +// whole block to amortize the scale load. +static __global__ void k_chunked_dequant_q4_0_f32( + const char * __restrict__ src_base, + float * __restrict__ dst, + const int64_t D, + const int64_t chunk_len, + const size_t nb1, + const size_t nb2, + const int64_t kv_start) { + const int64_t kv_local = blockIdx.x; + const int64_t h = blockIdx.y; + if (kv_local >= chunk_len) return; + const char * src_token = src_base + h * nb2 + (kv_start + kv_local) * nb1; + float * out = dst + (h * chunk_len + kv_local) * D; + const int n_blocks = D / QK4_0; + for (int b = threadIdx.x; b < n_blocks; b += blockDim.x) { + const block_q4_0 * blk = (const block_q4_0 *)(src_token + b * sizeof(block_q4_0)); + const float scale = __half2float(blk->d); + float * out_blk = out + b * QK4_0; + #pragma unroll + for (int i = 0; i < QK4_0 / 2; i++) { + const uint8_t packed = blk->qs[i]; + const int lo = (int)(packed & 0x0F) - 8; + const int hi = (int)(packed >> 4) - 8; + out_blk[i] = lo * scale; + out_blk[i + QK4_0 / 2] = hi * scale; + } + } +} + +// Q8_0: one token has D/QK8_0 blocks; each block has a half scale + 32 int8 +// quants (32 elements / block). +static __global__ void k_chunked_dequant_q8_0_f32( + const char * __restrict__ src_base, + float * __restrict__ dst, + const int64_t D, + const int64_t chunk_len, + const size_t nb1, + const size_t nb2, + const int64_t kv_start) { + const int64_t kv_local = blockIdx.x; + const int64_t h = blockIdx.y; + if (kv_local >= chunk_len) return; + const char * src_token = src_base + h * nb2 + (kv_start + kv_local) * nb1; + float * out = dst + (h * chunk_len + kv_local) * D; + const int n_blocks = D / QK8_0; + for (int b = threadIdx.x; b < n_blocks; b += blockDim.x) { + const block_q8_0 * blk = (const block_q8_0 *)(src_token + b * sizeof(block_q8_0)); + const float scale = __half2float(blk->d); + float * out_blk = out + b * QK8_0; + #pragma unroll + for (int i = 0; i < QK8_0; i++) { + out_blk[i] = ((int)blk->qs[i]) * scale; + } + } +} + +static __global__ void k_chunked_dequant_tq3_0_f32( + const char * __restrict__ src_base, + float * __restrict__ dst, + const int64_t D, + const int64_t chunk_len, + const size_t nb1, + const size_t nb2, + const int64_t kv_start) { + const int64_t kv_local = blockIdx.x; + const int64_t h = blockIdx.y; + if (kv_local >= chunk_len) return; + const char * src_token = src_base + h * nb2 + (kv_start + kv_local) * nb1; + float * out = dst + (h * chunk_len + kv_local) * D; + const int n_blocks = D / QK_TQ3_0; + for (int b = threadIdx.x; b < n_blocks; b += blockDim.x) { + const block_tq3_0 * blk = (const block_tq3_0 *)(src_token + b * sizeof(block_tq3_0)); + const float norm = __half2float(blk->norm); + float * out_blk = out + b * QK_TQ3_0; + #pragma unroll + for (int i = 0; i < QK_TQ3_0; i++) { + const uint8_t low2 = (blk->qs[i/4] >> ((i%4)*2)) & 0x3; + const uint8_t hi1 = (blk->signs[i/8] >> (i%8)) & 0x1; + out_blk[i] = d_tq3_centroids[low2 | (hi1 << 2)] * norm; + } + } +} + +// Launch helper: dispatch by ggml type. Returns false if unsupported. +static inline bool chunked_dequant_launch( + ggml_type type, + const char * src_base, + float * dst, + int64_t D, + int64_t chunk_len, + int64_t nh_kv, + size_t nb1, + size_t nb2, + int64_t kv_start, + cudaStream_t stream) { + const dim3 grid((int)chunk_len, (int)nh_kv, 1); + const int threads = (int)std::min(D, (int64_t)256); + switch (type) { + case GGML_TYPE_F16: + k_chunked_dequant_f16_f32<<>>( + src_base, dst, D, chunk_len, nb1, nb2, kv_start); + return true; + case GGML_TYPE_BF16: + k_chunked_dequant_bf16_f32<<>>( + src_base, dst, D, chunk_len, nb1, nb2, kv_start); + return true; + case GGML_TYPE_Q4_0: + k_chunked_dequant_q4_0_f32<<>>( + src_base, dst, D, chunk_len, nb1, nb2, kv_start); + return true; + case GGML_TYPE_Q8_0: + k_chunked_dequant_q8_0_f32<<>>( + src_base, dst, D, chunk_len, nb1, nb2, kv_start); + return true; + case GGML_TYPE_TQ3_0: + k_chunked_dequant_tq3_0_f32<<>>( + src_base, dst, D, chunk_len, nb1, nb2, kv_start); + return true; + default: + return false; + } +} + +// Entry point (implemented in fattn-chunked.cu). Enabled via +// BEST_FATTN_KERNEL_CHUNKED from the main dispatcher. +void ggml_cuda_flash_attn_ext_chunked(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index beeb52389464..b4bc5d0aa72f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -3,6 +3,7 @@ #include "common.cuh" #include "convert.cuh" #include "vecdotq.cuh" +#include "tq3-quant.cuh" #include @@ -577,6 +578,37 @@ static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict } } +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_tq3_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_tq3_0 * K_tq3 = (const block_tq3_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QK_TQ3_0; + const int j = k_KQ % QK_TQ3_0; + const float K_norm = __half2float(K_tq3[ib].norm); + const uint8_t low2 = (K_tq3[ib].qs[j/4] >> ((j%4)*2)) & 0x3; + const uint8_t hi1 = (K_tq3[ib].signs[j/8] >> (j%8)) & 0x1; + const float Kv = d_tq3_centroids[low2 | (hi1 << 2)] * K_norm; + + const int qi = k_KQ / 4; + const int shift = (k_KQ % 4) * 8; + const int8_t Qv = (int8_t)((Q_q8[qi] >> shift) & 0xFF); + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ / QK8_1]; + + sum += Kv * Qv * Q_ds.x; + } + + return sum; +} + template constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { if constexpr (type_K == GGML_TYPE_F16) { @@ -591,6 +623,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_q5_1; } else if constexpr (type_K == GGML_TYPE_Q8_0) { return vec_dot_fattn_vec_KQ_q8_0; + } else if constexpr (type_K == GGML_TYPE_TQ3_0) { + return vec_dot_fattn_vec_KQ_tq3_0; } else if constexpr (type_K == GGML_TYPE_BF16) { return vec_dot_fattn_vec_KQ_bf16; } else { @@ -599,6 +633,43 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { } } +template +static __device__ __forceinline__ void dequantize_V_tq3_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_tq3_0 * x = (const block_tq3_0 *) vx; + + const int64_t ib = i0 / QK_TQ3_0; + const int j = i0 % QK_TQ3_0; + const float norm = __half2float(x[ib].norm); + + static_assert(ne == 2 || ne == 4, "bad ne"); + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same::value) { + const half2 h_norm = __half2half2(__float2half(norm)); + +#pragma unroll + for (int l = 0; l < ne; l++) { + const int jl = j + l; + const uint8_t low2 = (x[ib].qs[jl/4] >> ((jl%4)*2)) & 0x3; + const uint8_t hi1 = (x[ib].signs[jl/8] >> (jl%8)) & 0x1; + const float cv = d_tq3_centroids[low2 | (hi1 << 2)]; + ((half *) dst)[l] = __float2half(cv) * __float2half(norm); + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same::value) { +#pragma unroll + for (int l = 0; l < ne; l++) { + const int jl = j + l; + const uint8_t low2 = (x[ib].qs[jl/4] >> ((jl%4)*2)) & 0x3; + const uint8_t hi1 = (x[ib].signs[jl/8] >> (jl%8)) & 0x1; + ((float *) dst)[l] = d_tq3_centroids[low2 | (hi1 << 2)] * norm; + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +} + template constexpr __device__ dequantize_V_t get_dequantize_V() { if constexpr (type_V == GGML_TYPE_F16) { @@ -613,6 +684,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_q5_1; } else if constexpr (type_V == GGML_TYPE_Q8_0) { return dequantize_V_q8_0; + } else if constexpr (type_V == GGML_TYPE_TQ3_0) { + return dequantize_V_tq3_0; } else if constexpr (type_V == GGML_TYPE_BF16) { return dequantize_V_bf16; } else { diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index f0bd42a57612..aa6a691c306b 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -566,14 +566,15 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten template void ggml_cuda_flash_attn_ext_vec_case \ (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ -#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ - extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ - extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ - extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \ - extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ - extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ - extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ - extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \ +#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_TQ3_0); \ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) @@ -582,6 +583,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_TQ3_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) @@ -590,6 +592,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_TQ3_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) @@ -598,3 +601,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_TQ3_0) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ea6607cd337e..8663c887cd13 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -4,6 +4,7 @@ #include "fattn-tile.cuh" #include "fattn-vec.cuh" #include "fattn-wmma-f16.cuh" +#include "fattn-chunked.cuh" #include "fattn.cuh" template @@ -285,10 +286,22 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TQ3_0, GGML_TYPE_TQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_TQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_TQ3_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TQ3_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TQ3_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TQ3_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TQ3_0, GGML_TYPE_BF16) #else FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TQ3_0, GGML_TYPE_TQ3_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #endif // GGML_CUDA_FA_ALL_QUANTS @@ -302,6 +315,7 @@ enum best_fattn_kernel { BEST_FATTN_KERNEL_VEC = 100, BEST_FATTN_KERNEL_WMMA_F16 = 300, BEST_FATTN_KERNEL_MMA_F16 = 400, + BEST_FATTN_KERNEL_CHUNKED = 500, // chunked long-context prefill (fattn-chunked.cu) }; static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) { @@ -390,6 +404,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ3_0: case GGML_TYPE_BF16: break; default: @@ -400,9 +415,37 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } + // Chunked long-context prefill. Routes to fattn-chunked.cu which uses + // cuBLAS SGEMM + online softmax with adaptive chunk sizing for O(CHUNK) + // temp memory. Intended for prefill (Q->ne[1] > 1) at contexts where the + // MMA kernel's O(nq_chunk * kv_len * D) memory pressure dominates. + // + // Threshold: DFLASH27B_CHUNKED_THRESHOLD (default 8192 KV tokens). + // Disable entirely: DFLASH27B_CHUNKED_THRESHOLD=0 or negative. + { + static const int64_t chunked_threshold = [] { + const char * e = getenv("DFLASH27B_CHUNKED_THRESHOLD"); + if (e) return (int64_t)atoll(e); + return (int64_t)8192; + }(); + const bool kv_supported = + (K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_BF16 || + K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0 || + K->type == GGML_TYPE_TQ3_0) && + (V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_BF16 || + K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0 || + K->type == GGML_TYPE_TQ3_0); + // TQ3_0 has no MMA kernel support, so force chunked for all prefills. + const bool tq3_prefill = (K->type == GGML_TYPE_TQ3_0 || V->type == GGML_TYPE_TQ3_0); + if ((chunked_threshold > 0 && K->ne[1] > chunked_threshold) || tq3_prefill) { + if (Q->type == GGML_TYPE_F32 && Q->ne[1] > 1 && kv_supported && mask != nullptr) { + return BEST_FATTN_KERNEL_CHUNKED; + } + } + } + // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; - // If Turing tensor cores are available, use them: if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { if (can_use_vector_kernel) { @@ -522,6 +565,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst case BEST_FATTN_KERNEL_MMA_F16: ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); break; + case BEST_FATTN_KERNEL_CHUNKED: + ggml_cuda_flash_attn_ext_chunked(ctx, dst); + break; } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 6b44bec73174..ac8c8c205a43 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,6 +1,30 @@ #include "gated_delta_net.cuh" +#include +#include + +// Tree-mode parent index sentinel: a node whose parent is the pre-block state +// (i.e. a "root" node in the DFS-flattened tree) uses this value in +// parent_ids[]. Any value < 0 triggers a reload from curr_state. +#define GGML_GDN_TREE_ROOT_PARENT (-1) + +// Intermediate-state load/store helpers. Allow the persistent intermediate +// buffer (for dflash27b_ggml tree rollback) to live in fp16 instead of fp32, +// halving its memory footprint and letting us fit larger DDTree budgets in +// the hybrid (GatedDeltaNet) target's state cache. +static __device__ __forceinline__ float load_inter_state(const float * p, int idx) { + return p[idx]; +} +static __device__ __forceinline__ float load_inter_state(const __half * p, int idx) { + return __half2float(p[idx]); +} +static __device__ __forceinline__ void store_inter_state(float * p, int idx, float v) { + p[idx] = v; +} +static __device__ __forceinline__ void store_inter_state(__half * p, int idx, float v) { + p[idx] = __float2half(v); +} -template +template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) gated_delta_net_cuda(const float * q, const float * k, @@ -9,6 +33,8 @@ gated_delta_net_cuda(const float * q, const float * beta, const float * curr_state, float * dst, + const int * parent_ids, // TREE_MODE only; else ignored + InterT * persist_inter, // optional external buffer for per-token intermediates int64_t H, int64_t n_tokens, int64_t n_seqs, @@ -34,13 +60,31 @@ gated_delta_net_cuda(const float * q, const uint32_t iq3 = fastdiv(sequence, rq3_magic); const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t final_state_elems = S_v * S_v * H * n_seqs; float * attn_data = dst; float * state = dst + attn_score_elems; + // intermediate_states region: one S_v*S_v*H*n_seqs state per token. Written + // inside the token loop below (one state per `t`) to enable spec-decode + // rollback without a replay forward pass. See ggml.c::ggml_gated_delta_net. + // + // dflash27b_ggml: if persist_inter != nullptr, the kernel writes the + // intermediate states DIRECTLY to that external buffer instead of the + // embedded region inside dst. InterT selects the storage precision (float + // or __half). f16 halves the memory footprint — enough to fit larger + // DDtree budgets on the 24 GB 3090. + // When persist_inter is null, InterT MUST be float (the embedded region + // inside dst is f32). + InterT * inter_states = persist_inter + ? persist_inter + : (InterT *)(dst + attn_score_elems + final_state_elems); const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; state += state_offset; curr_state += state_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; + // Per-sequence per-head base for this block's intermediates, token t=0. + // Advance by (H * S_v * S_v) each iteration. + InterT * inter_base = inter_states + (sequence * n_tokens * H + h_idx) * S_v * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); @@ -54,7 +98,51 @@ gated_delta_net_cuda(const float * q, s_shard[r] = curr_state[i]; } + // TREE_MODE: pointer base for parent lookups. Each sequence has its own + // parent_ids[n_tokens] slice. At branch points (parent_t != t - 1), we + // reload s_shard from the intermediate-state region instead of continuing + // the recurrence sequentially. Ports sglang's + // fused_sigmoid_gating_recurrent.py HAS_EAGLE_TREE_CUSTOM_ATTN_MASK logic + // to CUDA. + const int * parent_ids_seq = nullptr; + if constexpr (TREE_MODE) { + parent_ids_seq = parent_ids + sequence * n_tokens; + } + for (int t = 0; t < n_tokens; t++) { + // Tree branch-point reload: if this token's parent in the DFS-flattened + // tree isn't the previous token in processing order, pull its state + // back from the intermediate-state region. Same-thread read-after-write + // on global memory — no __syncthreads() needed because each lane writes + // and reads its own (col, row) slots. + if constexpr (TREE_MODE) { + if (t > 0) { + const int parent_t = parent_ids_seq[t]; + if (parent_t == GGML_GDN_TREE_ROOT_PARENT) { + // Root-level sibling: reset to the pre-block state. +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = curr_state[i]; + } + } else if (parent_t != t - 1) { + // Branch: this token's parent is somewhere earlier in the + // DFS traversal. Pull that state from the intermediate + // region. inter_states base is per-sequence, per-head; + // parent_t picks the slot, col/i picks the element. The + // load helper converts from InterT (f32 or f16) → float. + const InterT * parent_base = inter_states + + ((sequence * n_tokens + parent_t) * H + h_idx) * S_v * S_v; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = load_inter_state(parent_base, col * S_v + i); + } + } + // parent_t == t - 1: sequential, keep s_shard in registers. + } + } + const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1; @@ -134,6 +222,16 @@ gated_delta_net_cuda(const float * q, } } + // Write the intermediate state for token t (same transposed layout as the + // final-state write below). Used by dflash27b_ggml spec-decode rollback. + // store_inter_state converts float → InterT (f32 passthrough or __float2half). +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + store_inter_state(inter_base, col * S_v + i, s_shard[r]); + } + inter_base += S_v * S_v * H; + attn_data += S_v * H; } @@ -145,11 +243,13 @@ gated_delta_net_cuda(const float * q, } } -template +template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, float * dst_d, + const int * parent_ids_d, + InterT * persist_inter_d, int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, int64_t sq1, int64_t sq2, int64_t sq3, int64_t sv1, int64_t sv2, int64_t sv3, @@ -169,27 +269,27 @@ static void launch_gated_delta_net( switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + gated_delta_net_cuda<16, KDA, TREE_MODE, InterT><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, parent_ids_d, persist_inter_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; case 32: - gated_delta_net_cuda<32, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + gated_delta_net_cuda<32, KDA, TREE_MODE, InterT><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, parent_ids_d, persist_inter_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; case 64: { - gated_delta_net_cuda<64, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + gated_delta_net_cuda<64, KDA, TREE_MODE, InterT><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, parent_ids_d, persist_inter_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; } case 128: { - gated_delta_net_cuda<128, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + gated_delta_net_cuda<128, KDA, TREE_MODE, InterT><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, parent_ids_d, persist_inter_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; @@ -207,6 +307,14 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_tensor * src_g = dst->src[3]; ggml_tensor * src_beta = dst->src[4]; ggml_tensor * src_state = dst->src[5]; + // Optional 7th source = parent_ids[n_seqs, n_tokens] int32, enabling + // tree-mode recurrence (dflash27b_ggml extension). nullptr means chain mode. + ggml_tensor * src_parent = dst->src[6]; + // Optional 8th source = persistent external intermediate-state buffer + // (dflash27b_ggml extension). When non-null, the kernel writes per-token + // intermediate states directly to persist_inter->data instead of the + // embedded region inside dst, saving a downstream ggml_cpy. + ggml_tensor * src_persist_inter = dst->src[7]; GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb); @@ -236,6 +344,19 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * const float * s_d = (const float *) src_state->data; float * dst_d = (float *) dst->data; + const int * parent_ids_d = src_parent + ? (const int *) src_parent->data + : nullptr; + void * persist_inter_d = src_persist_inter + ? src_persist_inter->data + : nullptr; + const bool persist_is_f16 = + src_persist_inter && src_persist_inter->type == GGML_TYPE_F16; + if (src_persist_inter) { + GGML_ASSERT(src_persist_inter->type == GGML_TYPE_F32 || + src_persist_inter->type == GGML_TYPE_F16); + GGML_ASSERT(ggml_is_contiguous(src_persist_inter)); + } GGML_ASSERT(ggml_is_contiguous_rows(src_q)); GGML_ASSERT(ggml_is_contiguous_rows(src_k)); @@ -245,6 +366,11 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * GGML_ASSERT(ggml_is_contiguous(src_g)); GGML_ASSERT(ggml_is_contiguous(src_beta)); GGML_ASSERT(ggml_is_contiguous(src_state)); + if (src_parent) { + GGML_ASSERT(src_parent->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src_parent)); + GGML_ASSERT(ggml_nelements(src_parent) == n_tokens * n_seqs); + } // strides in floats (beta strides used for both g and beta offset computation) const int64_t sq1 = nbq1 / sizeof(float); @@ -261,13 +387,44 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); - if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + const bool tree_mode = (parent_ids_d != nullptr); + + // Macro to expand the 4 (KDA × TREE_MODE) cases for a given InterT. + // The persist_is_f16 branch picks between __half and float instantiations. + #define GDN_LAUNCH(INTER_T) \ + do { \ + INTER_T * persist_typed = (INTER_T *)persist_inter_d; \ + if (kda) { \ + if (tree_mode) { \ + launch_gated_delta_net( \ + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, parent_ids_d, persist_typed, \ + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, \ + sb1, sb2, sb3, neqk1, rq3, scale, stream); \ + } else { \ + launch_gated_delta_net( \ + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, nullptr, persist_typed, \ + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, \ + sb1, sb2, sb3, neqk1, rq3, scale, stream); \ + } \ + } else { \ + if (tree_mode) { \ + launch_gated_delta_net( \ + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, parent_ids_d, persist_typed, \ + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, \ + sb1, sb2, sb3, neqk1, rq3, scale, stream); \ + } else { \ + launch_gated_delta_net( \ + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, nullptr, persist_typed, \ + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, \ + sb1, sb2, sb3, neqk1, rq3, scale, stream); \ + } \ + } \ + } while (0) + + if (persist_is_f16) { + GDN_LAUNCH(__half); } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + GDN_LAUNCH(float); } + #undef GDN_LAUNCH } diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 2fab33243dda..1d7c6d17a525 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -199,6 +199,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_q(src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_TQ3_0: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; default: // TODO: k-quants GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type)); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3113de017f09..93bf76c51bcb 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -24,6 +24,7 @@ #include "ggml-cuda/diag.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" +#include "ggml-cuda/turbo-wht.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmf.cuh" #include "ggml-cuda/mmq.cuh" @@ -2700,6 +2701,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_NORM: ggml_cuda_op_norm(ctx, dst); break; + case GGML_OP_TURBO_WHT: + ggml_cuda_op_turbo_wht(ctx, dst); + break; case GGML_OP_GROUP_NORM: ggml_cuda_op_group_norm(ctx, dst); break; @@ -4745,6 +4749,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } break; + case GGML_OP_TURBO_WHT: + return ggml_is_contiguous(op->src[0]); case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { @@ -4825,6 +4831,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ3_0: return true; default: return false; @@ -4838,7 +4845,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g { return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 || - op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) && + op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL || + op->type == GGML_TYPE_TQ3_0) && op->src[0]->type == GGML_TYPE_F32 && (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); } break; @@ -4891,6 +4899,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_TQ3_0) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) { return true; } diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 631de7e8fa51..db61b3995a8b 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -1,5 +1,65 @@ #include "set-rows.cuh" #include "cpy-utils.cuh" +#include "tq3-quant.cuh" + +template +static __global__ void k_set_rows_tq3_0( + const float * __restrict__ src0, + const idx_t * __restrict__ src1, + block_tq3_0 * __restrict__ dst, + const int64_t ne_total, + const int64_t num_groups_per_row, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t s10, + const int64_t s11, + const int64_t s12, + const int64_t s1, + const int64_t s2, + const int64_t s3, + const uint3 ne00g_fd, + const uint3 ne01_fd, + const uint3 ne02_fd, + const uint3 ne11_fd, + const uint3 ne12_fd) { + const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; + if (i >= ne_total) return; + + uint32_t tmp = (uint32_t) i; + uint2 div_mod; + + div_mod = fast_div_modulo(tmp, ne00g_fd); + const int64_t ig = div_mod.y; + tmp = div_mod.x; + + div_mod = fast_div_modulo(tmp, ne01_fd); + const int64_t i01 = div_mod.y; + tmp = div_mod.x; + + div_mod = fast_div_modulo(tmp, ne02_fd); + const int64_t i02 = div_mod.y; + const int64_t i03 = div_mod.x; + + const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd); + const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; + block_tq3_0 * dst_row_ptr = (block_tq3_0 *)((char *)dst + dst_row*s1 + i02*s2 + i03*s3); + + const float * src_grp = src0_row + ig * QK_TQ3_0_GROUP; + block_tq3_0 * dst_grp = dst_row_ptr + ig * (QK_TQ3_0_GROUP / QK_TQ3_0); + + quantize_f32_tq3_0_group(src_grp, dst_grp); + + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); +} typedef void (*set_rows_kernel_t)(const char * src, char * dst); @@ -309,6 +369,36 @@ static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * s nb1, nb2, nb3, stream ); + } else if (dst->type == GGML_TYPE_TQ3_0) { + GGML_ASSERT(ne00 % QK_TQ3_0_GROUP == 0); + const int64_t num_groups_per_row = ne00 / QK_TQ3_0_GROUP; + const int64_t ne_total = num_groups_per_row * ne01 * ne02 * ne03; + const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; + + const int64_t s01 = nb01 / sizeof(float); + const int64_t s02 = nb02 / sizeof(float); + const int64_t s03 = nb03 / sizeof(float); + const int64_t s10 = nb10 / sizeof(idx_t); + const int64_t s11 = nb11 / sizeof(idx_t); + const int64_t s12 = nb12 / sizeof(idx_t); + const int64_t s1 = nb1; + const int64_t s2 = nb2; + const int64_t s3 = nb3; + + if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) { + const uint3 ne00g_fd = init_fastdiv_values((uint32_t) num_groups_per_row); + const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01); + const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); + const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11); + const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12); + + k_set_rows_tq3_0<<>>( + src0_d, src1_d, (block_tq3_0*)dst->data, + ne_total, num_groups_per_row, + ne10, ne11, ne12, ne13, + s01, s02, s03, s10, s11, s12, s1, s2, s3, + ne00g_fd, ne01_fd, ne02_fd, ne11_fd, ne12_fd); + } } else { GGML_ABORT("unsupported type %s", ggml_type_name(dst->type)); } diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index b77cdc1c1376..e5cefbb04964 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -108,6 +108,144 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, } } +// dflash27b_ggml: tree-mode ssm_conv kernel. For each new-token t, walks up +// the parent chain K-1 times via parent_ids[] to find the (K-1) ancestor slots +// in the conv input, then convolves with the kernel weights. Virtual-slot +// encoding: a non-negative parent index `p` maps to sx slot (K-1 + p). A +// parent index of -1 means "before the block" — i.e., the old conv state. +// Each successive walk beyond -1 decrements by 1, so virtual slot -k maps to +// sx slot (K-1 - k), which indexes into the old state region [0, K-1). This +// matches SGLang's causal_conv1d_triton HAS_EAGLE_TREE_CUSTOM_ATTN_MASK path. +// dflash27b_ggml: tree-mode + per-token persistent conv state. When +// WITH_PERSIST is true, every token writes its (K-1)-element conv "post-state" +// (the last K-1 cols of its parent-chain window) into persist_inter so the +// driver can roll the live conv state back to the accepted DFS node. +template +static __global__ void ssm_conv_tree_f32( + const float * __restrict__ src0, // sx: [K-1+n_t, d_inner, n_s] + const float * __restrict__ src1, // c: [K, d_inner] + const int * __restrict__ parent_ids, // [n_t, n_s] + const int src0_nb0, const int src0_nb1, const int src0_nb2, + const int src1_nb1, + float * __restrict__ dst, // [d_inner, n_t, n_s] + const int dst_nb0, const int dst_nb1, const int dst_nb2, + float * __restrict__ persist_inter, // [K-1, d_inner, n_t, n_s] when WITH_PERSIST, else nullptr + const int64_t d_inner_total, // full d_inner for persist row stride + const int64_t n_t) { + GGML_UNUSED(src0_nb0); + const int tid = threadIdx.x; + const int bidx = blockIdx.x; // sequence + const int bidy = blockIdx.y; // d_inner / split_d_inner + + const float * x_block = (const float *) ((const char *) src0 + + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1); + const float * w_block = (const float *) ((const char *) src1 + + bidy * split_d_inner * src1_nb1); + float * y_block = (float *) ((char *) dst + + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0); + + const int stride_x = src0_nb1 / sizeof(float); + const int stride_w = src1_nb1 / sizeof(float); + const int stride_y = dst_nb1 / sizeof(float); + + // Load kernel weights into registers. + float w[d_conv] = { 0.0f }; +#pragma unroll + for (size_t j = 0; j < d_conv; j++) { + w[j] = w_block[tid * stride_w + j]; + } + + const int * parent_ids_seq = parent_ids + bidx * n_t; + + // Channel index this thread owns within the full d_inner dimension. + // Used both for indexing persist_inter (when enabled) and as bookkeeping. + const int channel = (int)(bidy * split_d_inner) + tid; + + for (int64_t i = 0; i < n_t; i++) { + // Walk the parent chain K-1 times to fill the conv window. + // ancestor_virtual[k] gives the "virtual slot" for kernel position k, + // where the most recent slot is at k=K-1 (= token i itself) and older + // slots are at k=K-2, K-3, ..., 0. + // + // ancestor_virtual[K-1] = i + // ancestor_virtual[K-2] = parent_of(i) (or i-1 decay for negative) + // ancestor_virtual[k ] = parent_of(ancestor_virtual[k+1]) + int ancestors[d_conv]; + ancestors[d_conv - 1] = (int)i; +#pragma unroll + for (int k = (int)d_conv - 2; k >= 0; k--) { + int prev = ancestors[k + 1]; + int next; + if (prev >= 0) { + next = parent_ids_seq[prev]; // -1 if parent is before block + } else { + next = prev - 1; // keep decaying through old state slots + } + ancestors[k] = next; + } + + float sumf = 0.0f; + // Cache window values so we can both convolve and (optionally) persist + // them without re-reading from global memory. + float window[d_conv]; +#pragma unroll + for (size_t k = 0; k < d_conv; k++) { + // Map virtual slot → sx slot: sx_slot = (K-1) + ancestors[k]. + const int sx_slot = (int)(d_conv - 1) + ancestors[k]; + window[k] = x_block[tid * stride_x + sx_slot]; + sumf += window[k] * w[k]; + } + y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; + + if constexpr (WITH_PERSIST) { + // Per-token "post-state": the (K-1) most recent cols of this token's + // window — i.e. ancestors[1..K-1]. Layout matches the live conv state + // tensor (r_l): [K-1, d_inner, ...] with K-1 fastest. Persist memory + // layout: persist_inter[s][t][channel][k] flat = ((s*n_t + t)*d_inner + channel) * (K-1) + k. + float * persist_token = persist_inter + + ((bidx * n_t + i) * d_inner_total + channel) * (int64_t)(d_conv - 1); +#pragma unroll + for (size_t k = 0; k < d_conv - 1; k++) { + // ancestors[1] is the oldest col we keep; ancestors[K-1] = self. + persist_token[k] = window[k + 1]; + } + } + } +} + +template +static void ssm_conv_tree_f32_cuda(const float * src0, const float * src1, const int * parent_ids, + const int src0_nb0, const int src0_nb1, const int src0_nb2, + const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, + const int dst_nb2, float * persist_inter, + const int64_t nc, const int64_t nr, + const int64_t n_t, const int64_t n_s, cudaStream_t stream) { + const int threads = 128; + GGML_ASSERT(nr % threads == 0); + + const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); + auto launch_kernel = [&](auto NC) { + constexpr int kNC = decltype(NC)::value; + if (persist_inter != nullptr) { + ssm_conv_tree_f32<<>>( + src0, src1, parent_ids, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + dst, dst_nb0, dst_nb1, dst_nb2, persist_inter, nr, n_t); + } else { + ssm_conv_tree_f32<<>>( + src0, src1, parent_ids, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + dst, dst_nb0, dst_nb1, dst_nb2, nullptr, nr, n_t); + } + }; + + switch (nc) { + case 3: launch_kernel(std::integral_constant{}); break; + case 4: launch_kernel(std::integral_constant{}); break; + case 5: launch_kernel(std::integral_constant{}); break; + case 9: launch_kernel(std::integral_constant{}); break; + default: GGML_ABORT("Tree ssm_conv only supports kernel sizes 3, 4, 5, 9."); + } +} + template static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, @@ -143,6 +281,8 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) { const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + // dflash27b_ggml: optional src[2] = parent_ids (i32) enables tree mode + const struct ggml_tensor * parent_ids = dst->src[2]; const bool fuse_silu = silu_dst != nullptr; // When fusing, write to silu_dst (the node downstream references). @@ -165,6 +305,34 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(out->type == GGML_TYPE_F32); + + if (parent_ids != nullptr) { + GGML_ASSERT(parent_ids->type == GGML_TYPE_I32); + const int * parent_ids_d = (const int *) parent_ids->data; + // dflash27b_ggml: optional src[3] = persist_inter (F32) buffer where + // each token's [K-1, d_inner] post-state is written for SSM rollback. + const struct ggml_tensor * persist_inter = dst->src[3]; + float * persist_d = nullptr; + if (persist_inter != nullptr) { + GGML_ASSERT(persist_inter->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(persist_inter)); + GGML_ASSERT(ggml_nelements(persist_inter) >= (int64_t)(nc - 1) * nr * n_t * n_s); + persist_d = (float *) persist_inter->data; + } + if (fuse_silu) { + ssm_conv_tree_f32_cuda(src0_d, src1_d, parent_ids_d, + src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], + dst_d, out->nb[0], out->nb[1], out->nb[2], + persist_d, nc, nr, n_t, n_s, stream); + } else { + ssm_conv_tree_f32_cuda(src0_d, src1_d, parent_ids_d, + src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], + dst_d, out->nb[0], out->nb[1], out->nb[2], + persist_d, nc, nr, n_t, n_s, stream); + } + return; + } + if (fuse_silu) { ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-tq3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-tq3_0.cu new file mode 100644 index 000000000000..1c0be87549d1 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-tq3_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_TQ3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-tq3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-tq3_0.cu new file mode 100644 index 000000000000..751cd7f027d3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-tq3_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_TQ3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-tq3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-tq3_0.cu new file mode 100644 index 000000000000..c22fd642e31a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-tq3_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_TQ3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-tq3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-tq3_0.cu new file mode 100644 index 000000000000..0b82828c90b3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-tq3_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_TQ3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-tq3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-tq3_0.cu new file mode 100644 index 000000000000..b07cdbf31ee7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-tq3_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_TQ3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-tq3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-tq3_0.cu new file mode 100644 index 000000000000..bb9be30fbc1f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-tq3_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_TQ3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-tq3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-tq3_0.cu new file mode 100644 index 000000000000..3f45d8a5bbc6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-tq3_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TQ3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-bf16.cu new file mode 100644 index 000000000000..86dbf4d37b47 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TQ3_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TQ3_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TQ3_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-f16.cu new file mode 100644 index 000000000000..b6408fed2f54 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TQ3_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TQ3_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TQ3_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q4_0.cu new file mode 100644 index 000000000000..fa4eb3af2b5c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TQ3_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TQ3_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TQ3_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q4_1.cu new file mode 100644 index 000000000000..5aa214740a16 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TQ3_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TQ3_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TQ3_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q5_0.cu new file mode 100644 index 000000000000..0186d4863e18 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TQ3_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TQ3_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TQ3_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q5_1.cu new file mode 100644 index 000000000000..e83b35656846 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TQ3_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TQ3_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TQ3_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q8_0.cu new file mode 100644 index 000000000000..a29f9f0e1863 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TQ3_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TQ3_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TQ3_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-tq3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-tq3_0.cu new file mode 100644 index 000000000000..8e1353933f53 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tq3_0-tq3_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TQ3_0, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TQ3_0, GGML_TYPE_TQ3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TQ3_0, GGML_TYPE_TQ3_0); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 40d51f93fa4d..17cee853e4f3 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -5,7 +5,7 @@ HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576] -TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] +TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16", "GGML_TYPE_TQ3_0"] SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. diff --git a/ggml/src/ggml-cuda/tq3-quant.cuh b/ggml/src/ggml-cuda/tq3-quant.cuh new file mode 100644 index 000000000000..11a00aba3718 --- /dev/null +++ b/ggml/src/ggml-cuda/tq3-quant.cuh @@ -0,0 +1,80 @@ +#pragma once + +#include "ggml-common.h" + +// TQ3_0: 3-bit Lloyd-Max codebook with FWHT rotation. +// Block size = 32 elements (14 bytes). Group size = 128 elements (4 blocks share norm + rotation). + +static __constant__ float d_tq3_centroids[8] = { + -0.190685f, -0.117832f, -0.065717f, -0.021460f, + 0.021460f, 0.065717f, 0.117832f, 0.190685f +}; + +static __constant__ float d_tq3_mids[7] = { + -0.154259f, -0.091775f, -0.043589f, 0.0f, 0.043589f, 0.091775f, 0.154259f +}; + +static __constant__ float d_tq3_wht_signs1[128] = { + -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f}; + +static __constant__ float d_tq3_wht_signs2[128] = { + 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f}; + +static __device__ __forceinline__ +void tq3_fwht_128(float * x) { + for (int h = 1; h < 128; h *= 2) { + for (int i = 0; i < 128; i += h * 2) { + for (int j = i; j < i + h; j++) { + float a = x[j], b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + } + } + const float inv_sqrt_128 = 0.08838834764831845f; + for (int i = 0; i < 128; i++) x[i] *= inv_sqrt_128; +} + +static __device__ __forceinline__ +void tq3_rotate_forward(float * x) { + for (int i = 0; i < 128; i++) x[i] *= d_tq3_wht_signs1[i]; + tq3_fwht_128(x); + for (int i = 0; i < 128; i++) x[i] *= d_tq3_wht_signs2[i]; +} + +static __device__ __forceinline__ +void tq3_rotate_inverse(float * x) { + for (int i = 0; i < 128; i++) x[i] *= d_tq3_wht_signs2[i]; + tq3_fwht_128(x); + for (int i = 0; i < 128; i++) x[i] *= d_tq3_wht_signs1[i]; +} + +static __device__ __forceinline__ +uint8_t tq3_find_nearest(float val) { + if (val < d_tq3_mids[0]) return 0; + else if (val < d_tq3_mids[1]) return 1; + else if (val < d_tq3_mids[2]) return 2; + else if (val < d_tq3_mids[3]) return 3; + else if (val < d_tq3_mids[4]) return 4; + else if (val < d_tq3_mids[5]) return 5; + else if (val < d_tq3_mids[6]) return 6; + else return 7; +} + +static __device__ __forceinline__ +void dequantize_tq3_0(const void * vx, const int64_t ib, const int iqs, float2 & v) { + const block_tq3_0 * x = (const block_tq3_0 *)vx; + const float norm = __half2float(x[ib].norm); + { + const int j = iqs; + const uint8_t low2 = (x[ib].qs[j/4] >> ((j%4)*2)) & 0x3; + const uint8_t hi1 = (x[ib].signs[j/8] >> (j%8)) & 0x1; + v.x = d_tq3_centroids[low2 | (hi1 << 2)] * norm; + } + { + const int j = iqs + QK_TQ3_0/2; + const uint8_t low2 = (x[ib].qs[j/4] >> ((j%4)*2)) & 0x3; + const uint8_t hi1 = (x[ib].signs[j/8] >> (j%8)) & 0x1; + v.y = d_tq3_centroids[low2 | (hi1 << 2)] * norm; + } +} diff --git a/ggml/src/ggml-cuda/turbo-wht.cu b/ggml/src/ggml-cuda/turbo-wht.cu new file mode 100644 index 000000000000..611d23a164dc --- /dev/null +++ b/ggml/src/ggml-cuda/turbo-wht.cu @@ -0,0 +1,58 @@ +#include "turbo-wht.cuh" +#include "tq3-quant.cuh" + +static __global__ void k_turbo_wht( + const char * __restrict__ src_base, + char * __restrict__ dst_base, + const int64_t ne00, + const int64_t ne01, + const int64_t ne02, + const int64_t nb00, + const int64_t nb01, + const int64_t nb02, + const int64_t nb03, + int direction) { + const int64_t i01 = blockIdx.y; + const int64_t i02 = blockIdx.z; + const int64_t g = blockIdx.x; + if (i01 >= ne01 || i02 >= ne02 || g * QK_TQ3_0_GROUP >= ne00) return; + + const float * row = (const float *)(src_base + i01 * nb01 + i02 * nb02) + g * QK_TQ3_0_GROUP; + float * out_row = (float *)(dst_base + i01 * nb01 + i02 * nb02) + g * QK_TQ3_0_GROUP; + + float x[128]; + for (int i = 0; i < 128; i++) x[i] = row[i]; + + if (direction == 0) { + tq3_rotate_forward(x); + } else { + tq3_rotate_inverse(x); + } + + for (int i = 0; i < 128; i++) out_row[i] = x[i]; +} + +void ggml_cuda_op_turbo_wht(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + int direction; + memcpy(&direction, dst->op_params, sizeof(int)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + GGML_ASSERT(ne00 % QK_TQ3_0_GROUP == 0); + + const int64_t n_groups = ne00 / QK_TQ3_0_GROUP; + + dim3 grid(n_groups, ne01, ne02); + dim3 threads(1, 1, 1); + + k_turbo_wht<<>>( + (const char *)src0->data, (char *)dst->data, + ne00, ne01, ne02, + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + direction); +} diff --git a/ggml/src/ggml-cuda/turbo-wht.cuh b/ggml/src/ggml-cuda/turbo-wht.cuh new file mode 100644 index 000000000000..03389af96ef6 --- /dev/null +++ b/ggml/src/ggml-cuda/turbo-wht.cuh @@ -0,0 +1,5 @@ +#pragma once + +#include "common.cuh" + +void ggml_cuda_op_turbo_wht(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 15443aa554a4..0f1a679f747d 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -394,6 +394,50 @@ void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRI } } +// TQ3_0 CPU stubs — dflash uses CUDA path exclusively, these exist for +// type_traits registration only (ggml_type_traits.to_float / .from_float_ref). +static const float TQ3_CENTROIDS[8] = { + -0.190685f, -0.117832f, -0.065717f, -0.021460f, + 0.021460f, 0.065717f, 0.117832f, 0.190685f +}; + +void dequantize_row_tq3_0(const block_tq3_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK_TQ3_0; + assert(k % qk == 0); + const int nb = k / qk; + for (int i = 0; i < nb; i++) { + const float norm = GGML_FP16_TO_FP32(x[i].norm); + for (int j = 0; j < qk; j++) { + const uint8_t qs_byte = x[i].qs[j / 4]; + const int low2 = (qs_byte >> ((j % 4) * 2)) & 0x3; + const uint8_t sign_byte = x[i].signs[j / 8]; + const int hi1 = (sign_byte >> (j % 8)) & 0x1; + const int idx = low2 | (hi1 << 2); + y[i * qk + j] = TQ3_CENTROIDS[idx] * norm; + } + } +} + +void quantize_row_tq3_0_ref(const float * GGML_RESTRICT x, block_tq3_0 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_TQ3_0 == 0); + const int nb = k / QK_TQ3_0; + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < QK_TQ3_0; j++) amax = MAX(amax, fabsf(x[i * QK_TQ3_0 + j])); + const float norm = amax / 0.190685f; + y[i].norm = GGML_FP32_TO_FP16(norm); + for (int j = 0; j < QK_TQ3_0 / 4; j++) y[i].qs[j] = 0; + for (int j = 0; j < QK_TQ3_0 / 8; j++) y[i].signs[j] = 0; + for (int j = 0; j < QK_TQ3_0; j++) { + const float v = x[i * QK_TQ3_0 + j] / (norm > 1e-10f ? norm : 1.0f); + int best = 0; float bd = fabsf(v - TQ3_CENTROIDS[0]); + for (int c = 1; c < 8; c++) { float d = fabsf(v - TQ3_CENTROIDS[c]); if (d < bd) { bd = d; best = c; } } + y[i].qs[j / 4] |= (best & 0x3) << ((j % 4) * 2); + y[i].signs[j / 8] |= ((best >> 2) & 0x1) << (j % 8); + } + } +} + void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index d56c86da8909..3e57d2fd8b34 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -35,6 +35,8 @@ GGML_API void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_API void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_tq3_0_ref(const float * GGML_RESTRICT x, block_tq3_0 * GGML_RESTRICT y, int64_t k); + GGML_API void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); @@ -63,6 +65,8 @@ GGML_API void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GG GGML_API void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_tq3_0(const block_tq3_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + GGML_API void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0142498d967a..76aa235c1119 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -659,6 +659,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) dequantize_row_q1_0, .from_float_ref = (ggml_from_float_t) quantize_row_q1_0_ref, }, + [GGML_TYPE_TQ3_0] = { + .type_name = "tq3_0", + .blck_size = QK_TQ3_0, + .type_size = sizeof(block_tq3_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_tq3_0, + .from_float_ref = (ggml_from_float_t) quantize_row_tq3_0_ref, + }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", .blck_size = QK4_0, @@ -1063,9 +1071,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_SGD", "GLU", + + "TURBO_WHT", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1173,9 +1183,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "sgd(x)", "glu(x)", + + "turbo_wht(a)", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5479,6 +5491,57 @@ struct ggml_tensor * ggml_ssm_conv( return result; } +// dflash: tree-mode variant. Same op, with parent_ids plumbed into +// src[2] so the CUDA kernel gathers each token's window along its tree +// parent chain instead of the DFS-neighbour window. +struct ggml_tensor * ggml_ssm_conv_tree( + struct ggml_context * ctx, + struct ggml_tensor * sx, + struct ggml_tensor * c, + struct ggml_tensor * parent_ids) { + struct ggml_tensor * result = ggml_ssm_conv(ctx, sx, c); + + GGML_ASSERT(parent_ids != NULL); + GGML_ASSERT(parent_ids->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(parent_ids)); + + const int64_t n_t = sx->ne[0] - c->ne[0] + 1; + const int64_t n_s = sx->ne[2]; + GGML_ASSERT(ggml_nelements(parent_ids) == n_t * n_s); + + result->src[2] = parent_ids; + + return result; +} + +// dflash: tree-mode + external persistent conv post-state buffer. Same op as +// ggml_ssm_conv_tree but the kernel ALSO writes each token's (K-1)-element +// "post-state" (last K-1 cols of its parent-chain window) into persist_inter, +// matching the [K-1, d_inner, n_tokens, n_seqs] layout used by the live conv +// state in the recurrent memory. +struct ggml_tensor * ggml_ssm_conv_tree_persist( + struct ggml_context * ctx, + struct ggml_tensor * sx, + struct ggml_tensor * c, + struct ggml_tensor * parent_ids, + struct ggml_tensor * persist_inter) { + struct ggml_tensor * result = ggml_ssm_conv_tree(ctx, sx, c, parent_ids); + + GGML_ASSERT(persist_inter != NULL); + GGML_ASSERT(persist_inter->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(persist_inter)); + + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_t = sx->ne[0] - d_conv + 1; + const int64_t n_s = sx->ne[2]; + GGML_ASSERT(ggml_nelements(persist_inter) >= (d_conv - 1) * d_inner * n_t * n_s); + + result->src[3] = persist_inter; + + return result; +} + // ggml_ssm_scan struct ggml_tensor * ggml_ssm_scan( @@ -6187,9 +6250,16 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); - // concat output and new_state into a single tensor - // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs - const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + // Pack output, final new_state, and per-step intermediate states into one tensor. + // Layout (in units of `S_v * H`-wide rows): + // [ attn_output: n_tokens*n_seqs | final_state: S_v*n_seqs | intermediate_states: S_v*n_tokens*n_seqs ] + // + // The final_state slot is kept for backward compatibility with stock llama.cpp + // callers that read state at offset S_v*H*n_tokens*n_seqs. The intermediate_states + // region is a dflash extension: for each token t in [0, n_tokens), it holds the + // recurrent state after processing token t. Used by the spec decoding loop to + // roll back SSM state to the accepted prefix without a full replay forward pass. + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs + S_v * n_tokens * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_GATED_DELTA_NET; @@ -6203,6 +6273,65 @@ struct ggml_tensor * ggml_gated_delta_net( return result; } +// dflash: tree-mode variant. Same op, with parent_ids plumbed into +// src[6] so the CUDA kernel can branch-reload state at DFS transitions. +struct ggml_tensor * ggml_gated_delta_net_tree( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state, + struct ggml_tensor * parent_ids) { + struct ggml_tensor * result = ggml_gated_delta_net(ctx, q, k, v, g, beta, state); + + GGML_ASSERT(parent_ids != NULL); + GGML_ASSERT(parent_ids->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(parent_ids)); + + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + GGML_ASSERT(ggml_nelements(parent_ids) == n_tokens * n_seqs); + + result->src[6] = parent_ids; + + return result; +} + +// dflash: tree-mode + external persistent intermediate buffer. The +// kernel writes per-token intermediate states DIRECTLY into persist_inter's +// memory, skipping the cost of copying them out of the result tensor's +// internal region after graph_compute. +struct ggml_tensor * ggml_gated_delta_net_tree_persist( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state, + struct ggml_tensor * parent_ids, + struct ggml_tensor * persist_inter) { + struct ggml_tensor * result = ggml_gated_delta_net_tree( + ctx, q, k, v, g, beta, state, parent_ids); + + GGML_ASSERT(persist_inter != NULL); + GGML_ASSERT(persist_inter->type == GGML_TYPE_F32 || + persist_inter->type == GGML_TYPE_F16); + GGML_ASSERT(ggml_is_contiguous(persist_inter)); + + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + GGML_ASSERT(ggml_nelements(persist_inter) >= S_v * S_v * H * n_tokens * n_seqs); + + result->src[7] = persist_inter; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { @@ -7746,3 +7875,15 @@ bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, cons if (p0->strict_cpu != p1->strict_cpu ) return false; return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; } + +struct ggml_tensor * ggml_turbo_wht( + struct ggml_context * ctx, + struct ggml_tensor * a, + int direction) { + GGML_ASSERT(ggml_n_dims(a) >= 2); + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, a->ne); + result->op = GGML_OP_TURBO_WHT; + result->src[0] = a; + ggml_set_op_params_i32(result, 0, direction); + return result; +} diff --git a/include/llama.h b/include/llama.h index ac267b5089ab..3b6ef348e187 100644 --- a/include/llama.h +++ b/include/llama.h @@ -48,6 +48,8 @@ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 2 +#define LLAMA_MEM_SNAPSHOT_INVALID -1 + #ifdef __cplusplus extern "C" { #endif @@ -69,6 +71,9 @@ extern "C" { typedef int32_t llama_token; typedef int32_t llama_seq_id; + // opaque handle returned by llama_seq_snapshot / used by llama_seq_restore and llama_seq_release + typedef int32_t llama_mem_snapshot_id; + enum llama_vocab_type { LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback @@ -241,6 +246,10 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" + + // tree-mode parent indices: parent_id[i] is the index of token i's parent in the batch + // -1 means root (no parent). NULL means chain mode (default behavior unchanged). + int32_t * parent_id; } llama_batch; enum llama_model_kv_override_type { @@ -290,6 +299,9 @@ extern "C" { // NULL-terminated list of buffer types to use for tensors that match a pattern const struct llama_model_tensor_buft_override * tensor_buft_overrides; + // optional target model for auxiliary models that share target tensors + const struct llama_model * target_model; + int32_t n_gpu_layers; // number of layers to store in VRAM, a negative value means all layers enum llama_split_mode split_mode; // how to split the model across multiple GPUs @@ -642,6 +654,16 @@ extern "C" { // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); + // Copy one token-embedding row from model->tok_embd into caller-supplied buffer. + // out_n must be >= model->hparams.n_embd. The embedding is returned as F32 + // regardless of the on-disk storage type (conversion happens on the backend). + // Returns 0 on success, -1 if token is out of range or tok_embd is unavailable. + LLAMA_API int llama_model_token_embd_lookup( + const struct llama_model * model, + llama_token token, + float * out, + int64_t out_n); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -781,6 +803,28 @@ extern "C" { // Check if the memory supports shifting LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); + // Snapshot/restore the recurrent state (SSM + conv) for seq_id. + // snapshot() allocates per-layer backup buffers and copies the current state into them. + // restore() copies the backed-up state back; release() frees the backup buffers. + // These are no-ops on non-recurrent memory types (returns LLAMA_MEM_SNAPSHOT_INVALID). + // The caller is responsible for calling release() after each snapshot. + LLAMA_API llama_mem_snapshot_id llama_seq_snapshot(struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API bool llama_seq_restore (struct llama_context * ctx, llama_mem_snapshot_id snap_id); + LLAMA_API void llama_seq_release (struct llama_context * ctx, llama_mem_snapshot_id snap_id); + + // Compact the KV cache after a tree-verify forward pass. + // The tree was placed at slots [spine_start, spine_start+N); after this call + // the accepted spine occupies slots [spine_start, spine_start+commit_n) in + // DFS order, and prompt cells (slots < spine_start) are untouched. + // No-op on non-KV memory types (e.g. pure SSM models). + LLAMA_API void llama_kv_cache_seq_compact_tree( + struct llama_context * ctx, + llama_seq_id seq_id, + const int32_t * accepted_dfs, + int32_t n_accepted, + int32_t commit_n, + int32_t spine_start); + // // State / sessions // @@ -932,7 +976,14 @@ extern "C" { int32_t embd, int32_t n_seq_max); - // Frees a batch of tokens allocated with llama_batch_init() + // Like llama_batch_init but also allocates parent_id[n_tokens], filled with -1 (tree roots). + // Callers must free with llama_batch_free(). + LLAMA_API struct llama_batch llama_batch_init_tree( + int32_t n_tokens, + int32_t embd, + int32_t n_seq_max); + + // Frees a batch of tokens allocated with llama_batch_init() or llama_batch_init_tree() LLAMA_API void llama_batch_free(struct llama_batch batch); // Process a batch of tokens. @@ -984,6 +1035,103 @@ extern "C" { // If true, all model tensors are activated during llama_decode() to load and cache their weights. LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); + // dflash hidden capture: when enabled, qwen35 forward writes per-layer hidden states + // for the 5 dflash target capture layers into an output tensor readable via + // llama_get_hidden_capture() after llama_decode(). Toggling this triggers a graph + // reserve. If disabled (default), behavior is byte-for-byte identical to baseline. + LLAMA_API void llama_set_capture_hidden(struct llama_context * ctx, bool enable); + LLAMA_API struct ggml_tensor * llama_get_hidden_capture(struct llama_context * ctx); + + // Host-side accessor: returns a pointer into a context-owned CPU buffer. + // The device-side capture tensor is synchronized lazily only when this is called. + // Returns NULL when capture is disabled or no decode has run yet. + // out_ne0 / out_ne1 receive the tensor dimensions. + LLAMA_API const float * llama_get_hidden_capture_data(struct llama_context * ctx, + int64_t * out_ne0, + int64_t * out_ne1); + + // dflash draft target_feat injection (Task 1 Phase 4 gap fix). + LLAMA_API void llama_set_dflash_draft_top_k(struct llama_context * ctx, int32_t k); + + // Must be called on the draft context before llama_decode() when running a + // dflash-draft (LLM_ARCH_DFLASH_DRAFT) model. The driver supplies a packed + // [5*n_embd, ctx_len] F32 host buffer with per-layer hidden captures from the + // target model. committed_pos is the number of tokens already committed in the + // target context; it drives the RoPE position indices for Q and K in the draft. + // The data pointer is non-owning; it must remain valid until llama_decode() returns. + LLAMA_API void llama_set_target_feat_raw(struct llama_context * ctx, + const float * data, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos); + + LLAMA_API int llama_dflash_draft_fuse_target_feat(struct llama_context * ctx, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + float * target_feat_fused); + + LLAMA_API int llama_dflash_draft_encode_top_k(struct llama_context * ctx, + struct llama_batch batch, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k); + + LLAMA_API int llama_dflash_draft_encode_top_k_fused(struct llama_context * ctx, + struct llama_batch batch, + const float * target_feat_fused, + int64_t n_embd, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k); + + LLAMA_API int llama_dflash_draft_update_fused_cache(struct llama_context * ctx, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t n_new, + int64_t first_pos, + int64_t cap); + + LLAMA_API int llama_dflash_draft_update_fused_cache_from_capture(struct llama_context * draft_ctx, + struct llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap); + + LLAMA_API int llama_dflash_draft_encode_top_k_cached(struct llama_context * ctx, + struct llama_batch batch, + int64_t n_embd, + int64_t ctx_len, + int64_t ring_start, + int64_t cap, + int64_t committed_pos, + int32_t top_k); + + // dflash Phase 2.4: persist-based SSM rollback after tree verify. + // After llama_kv_cache_seq_compact_tree(), call this to copy the SSM state + // captured at DFS node accepted_dfs_node from the persist buffer back into + // the live recurrent cache for seq_id, replacing the snapshot/restore/replay path. + // Must be called after the tree-mode llama_decode() and before the next decode. + // Returns true on success, false if persist buffers are unavailable. + // KNOWN LIMITATION: conv state is NOT rolled back (see Phase 2.4 Task 4 — option b). + // Conv-state divergence decays within ~K_conv tokens; the chain-vs-spec test may + // diverge by a few tokens at each tree boundary before reconverging. + LLAMA_API bool llama_dflash_rollback_ssm_to_dfs( + struct llama_context * ctx, + llama_seq_id seq_id, + int32_t accepted_dfs_node); + + // After persist-based rollback, adjust the recurrent cache bookkeeping so + // seq_pos_max() reflects the accepted chain position rather than the DFS + // tree position left by the tree-mode forward. + LLAMA_API bool llama_dflash_set_recurrent_tail_pos( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos pos); + // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); @@ -1054,6 +1202,15 @@ extern "C" { LLAMA_API llama_token * llama_get_sampled_candidates_ith (struct llama_context * ctx, int32_t i); LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); + // DFlash draft graph top-K tensors. Returns false when the last eval was not a + // dflash-draft graph with top-K output. Layout is row-major [n_rows, k]. + LLAMA_API bool llama_get_dflash_draft_top_k( + struct llama_context * ctx, + const float ** logits, + const llama_token ** token_ids, + int32_t * n_rows, + int32_t * k); + // // Vocab // diff --git a/multi_prompt_probe.sh b/multi_prompt_probe.sh new file mode 100755 index 000000000000..087cf6e919fa --- /dev/null +++ b/multi_prompt_probe.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -euo pipefail +cd "$(dirname "$0")" + +prompts="/tmp/real_rendered_prompt.txt /tmp/ddtree_prompt_task0_tok17350.txt /tmp/ddtree_prompt_task129_tok18108.txt /tmp/ddtree_prompt_task297_tok18689.txt /tmp/ddtree_prompt_task528_tok18918.txt /tmp/ddtree_prompt_task629_tok19171.txt /tmp/ddtree_prompt_task761_tok21321.txt /tmp/ddtree_prompt_task995_tok56813.txt" + +for p in $prompts; do + echo "=== PROMPT $p ===" + out=$(AUTORESEARCH_PROMPT="$p" LLAMA_DDTREE_PROPOSAL_TEMP=0.7 ./autoresearch.sh 2>&1 || true) + tps=$(echo "$out" | python3 -c "import sys,re; m=re.search(r'METRIC tps=([0-9.]+)', sys.stdin.read()); print(m.group(1) if m else '0')") + same=$(echo "$out" | python3 -c "import sys,re; m=re.search(r'batched_exact_same=(\d+)', sys.stdin.read()); print(m.group(1) if m else '0')") + diff=$(echo "$out" | python3 -c "import sys,re; m=re.search(r'batched_exact_diff=(\d+)', sys.stdin.read()); print(m.group(1) if m else '0')") + steps=$(echo "$out" | python3 -c "import sys,re; m=re.search(r'steps=(\d+)', sys.stdin.read()); print(m.group(1) if m else '0')") + committed=$(echo "$out" | python3 -c "import sys,re; m=re.search(r'committed=(\d+)', sys.stdin.read()); print(m.group(1) if m else '0')") + echo "result: prompt=$p tps=$tps same=$same diff=$diff steps=$steps committed=$committed" +done +echo "METRIC tps=0" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 121c21fed957..f1be2a7ec06a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -55,6 +55,7 @@ add_library(llama models/command-r.cpp models/dbrx.cpp models/deci.cpp + models/dflash-draft.cpp models/deepseek.cpp models/deepseek2.cpp models/delta-net-base.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 6904b9c1a645..49d0f9e7796a 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -132,6 +132,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, + { LLM_ARCH_DFLASH_DRAFT, "dflash-draft" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -545,6 +546,10 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, + // dflash-draft top-level tensors + { LLM_TENSOR_DFLASH_FC, "fc" }, + { LLM_TENSOR_DFLASH_HIDDEN_NORM, "hidden_norm" }, + { LLM_TENSOR_DFLASH_OUT_NORM, "out_norm" }, }; // declare information about the model weight tensors: @@ -765,6 +770,10 @@ static const std::map LLM_TENSOR_INFOS = { // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // dflash-draft + {LLM_TENSOR_DFLASH_FC, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DFLASH_HIDDEN_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DFLASH_OUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index c4aabab7e0cf..15183f18d88b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -136,6 +136,7 @@ enum llm_arch { LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, LLM_ARCH_KIMI_LINEAR, + LLM_ARCH_DFLASH_DRAFT, LLM_ARCH_UNKNOWN, }; @@ -552,6 +553,10 @@ enum llm_tensor { LLM_TENSOR_NEXTN_HNORM, LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + // dflash-draft top-level tensors + LLM_TENSOR_DFLASH_FC, // "fc" [5*hidden, hidden] + LLM_TENSOR_DFLASH_HIDDEN_NORM, // "hidden_norm" [hidden] + LLM_TENSOR_DFLASH_OUT_NORM, // "out_norm" [hidden] }; enum llm_tensor_layer { diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 6bf76939cddc..8aac5c7cce24 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -224,6 +224,7 @@ bool llama_batch_allocr::init( /*.seq_id_unq =*/ this->seq_id_unq.data(), /*.seq_idx =*/ this->seq_idx.data(), /*.output =*/ batch.logits, + /*.parent_id =*/ batch.parent_id, /*.data =*/ {}, }; @@ -428,6 +429,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t /*.seq_id_unq =*/ udata->seq_id_unq.data(), /*.seq_idx =*/ udata->seq_idx.data(), /*.output =*/ udata->output.data(), + /*.parent_id =*/ nullptr, /*.data =*/ std::move(udata), }; @@ -683,6 +685,11 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u assert(n_tokens%n_seqs == 0); + // tree-mode parent_ids must not be split across ubatches: the ubatch must cover + // every token of the public batch in a single emission. + GGML_ASSERT((batch.parent_id == nullptr || (size_t) n_tokens == (size_t) batch.n_tokens) && + "tree-mode batch with parent_id must fit in a single ubatch"); + auto udata = std::make_shared(); const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0; @@ -722,6 +729,10 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u udata->n_seq_id[i] = batch.n_seq_id[idxs[i]]; udata->output[i] = batch.logits[idxs[i]]; + if (batch.parent_id) { + udata->parent_id.push_back(batch.parent_id[idxs[i]]); + } + for (int s = 0; s < udata->n_seq_id[i]; ++s) { const llama_seq_id seq_id = batch.seq_id[idxs[i]][s]; @@ -747,6 +758,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u } } + int32_t * parent_id_ptr = udata->parent_id.empty() ? nullptr : udata->parent_id.data(); + llama_ubatch res { /*.b_equal_seqs =*/ equal_seqs, /*.n_tokens =*/ n_tokens, @@ -763,6 +776,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u /*.seq_id_unq =*/ udata->seq_id_unq.data(), /*.seq_idx =*/ udata->seq_idx.data(), /*.output =*/ udata->output.data(), + /*.parent_id =*/ parent_id_ptr, /*.data =*/ std::move(udata), }; @@ -864,25 +878,27 @@ struct llama_batch llama_batch_get_one( llama_token * tokens, int32_t n_tokens) { return { - /*n_tokens =*/ n_tokens, - /*tokens =*/ tokens, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*parent_id =*/ nullptr, }; } struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch batch = { - /*n_tokens =*/ 0, - /*tokens =*/ nullptr, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*parent_id =*/ nullptr, }; if (embd) { @@ -904,6 +920,17 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ return batch; } +struct llama_batch llama_batch_init_tree(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch batch = llama_batch_init(n_tokens_alloc, embd, n_seq_max); + + batch.parent_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); + for (int i = 0; i < n_tokens_alloc; ++i) { + batch.parent_id[i] = -1; + } + + return batch; +} + void llama_batch_free(struct llama_batch batch) { if (batch.token) free(batch.token); if (batch.embd) free(batch.embd); @@ -916,4 +943,5 @@ void llama_batch_free(struct llama_batch batch) { free(batch.seq_id); } if (batch.logits) free(batch.logits); + if (batch.parent_id) free(batch.parent_id); } diff --git a/src/llama-batch.h b/src/llama-batch.h index f77520e86c3c..6319170db00d 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -50,6 +50,7 @@ struct llama_ubatch { llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx int8_t * output; // [n_tokens] | i | - + int32_t * parent_id; // [n_tokens] | i | parent index, -1 = root; NULL = chain mode struct data_t { std::vector token; @@ -60,6 +61,7 @@ struct llama_ubatch { std::vector seq_id_unq; std::vector seq_idx; std::vector output; + std::vector parent_id; std::vector seq_id_data; }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ee0c29235cd2..288b62f5a94a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,6 +6,9 @@ #include "llama-batch.h" #include "llama-io.h" #include "llama-memory.h" +#include "llama-kv-cache.h" +#include "llama-memory-recurrent.h" +#include "llama-memory-hybrid.h" #include "llama-mmap.h" #include "llama-model.h" #include "llama-ext.h" @@ -13,10 +16,16 @@ #include #include +#include #include #include #include +static bool llama_dflash_fast_rollback_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_FAST_ROLLBACK"); + return e != nullptr && e[0] == '1'; +} + // // llama_context // @@ -990,6 +999,27 @@ size_t llama_context::get_sampled_probs_count(int32_t idx) { } } +bool llama_context::get_dflash_draft_top_k( + const float ** top_logits, + const llama_token ** top_token_ids, + int32_t * n_rows, + int32_t * k) { + if (top_logits) { + *top_logits = dflash_draft_top_logits.empty() ? nullptr : dflash_draft_top_logits.data(); + } + if (top_token_ids) { + *top_token_ids = dflash_draft_top_token_ids.empty() ? nullptr : dflash_draft_top_token_ids.data(); + } + if (n_rows) { + *n_rows = dflash_draft_top_rows; + } + if (k) { + *k = dflash_draft_top_k; + } + return dflash_draft_top_rows > 0 && dflash_draft_top_k > 0 && + !dflash_draft_top_logits.empty() && !dflash_draft_top_token_ids.empty(); +} + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, @@ -1052,6 +1082,213 @@ void llama_context::set_causal_attn(bool value) { sched_need_reserve = true; } +void llama_context::set_capture_hidden(bool enable) { + LLAMA_LOG_DEBUG("%s: enable = %d\n", __func__, enable); + if (capture_hidden == enable) { + return; + } + capture_hidden = enable; + sched_need_reserve = true; // graph topology changes when capture is toggled +} + +ggml_tensor * llama_context::get_hidden_capture() const { + if (gf_res_prev && gf_res_prev->t_hidden_capture) { + return gf_res_prev->t_hidden_capture; + } + return nullptr; +} + +void llama_context::set_dflash_draft_top_k(int32_t k) { + k = std::max(0, k); + if (dflash_draft_top_k_req == k) { + return; + } + dflash_draft_top_k_req = k; + sched_need_reserve = true; +} + +const float * llama_context::get_hidden_capture_data(int64_t * out_ne0, int64_t * out_ne1) const { + ggml_tensor * t_cap = get_hidden_capture(); + if (t_cap == nullptr || t_cap->buffer == nullptr) { + if (out_ne0) *out_ne0 = 0; + if (out_ne1) *out_ne1 = 0; + return nullptr; + } + + const size_t cap_n = ggml_nelements(t_cap); + if (!hidden_capture_host_valid || hidden_capture_host.size() < cap_n || + hidden_capture_ne0 != t_cap->ne[0] || hidden_capture_ne1 != t_cap->ne[1]) { + if (hidden_capture_host.size() < cap_n) { + hidden_capture_host.resize(cap_n); + } + hidden_capture_ne0 = t_cap->ne[0]; + hidden_capture_ne1 = t_cap->ne[1]; + ggml_backend_tensor_get(t_cap, hidden_capture_host.data(), 0, cap_n * sizeof(float)); + hidden_capture_host_valid = true; + } + + if (out_ne0) *out_ne0 = hidden_capture_ne0; + if (out_ne1) *out_ne1 = hidden_capture_ne1; + return hidden_capture_host.data(); +} + +ggml_tensor * llama_context::dflash_get_persist_inter(int32_t il) const { + if (il < 0 || il >= (int32_t)dflash_persist_inter_l.size()) { + return nullptr; + } + return dflash_persist_inter_l[il]; +} + +ggml_tensor * llama_context::dflash_get_persist_conv(int32_t il) const { + if (il < 0 || il >= (int32_t)dflash_persist_conv_l.size()) { + return nullptr; + } + return dflash_persist_conv_l[il]; +} + +void llama_context::ensure_dflash_persist_capacity(int64_t n_tokens) { + if (model.arch != LLM_ARCH_QWEN35) { + return; // only Qwen3.5 uses delta-net recurrent layers + } + if (n_tokens <= dflash_persist_max_n_tokens) { + return; // already large enough + } + if (n_tokens <= dflash_persist_failed_n_tokens) { + return; // allocation already failed for this size in this context + } + + // Derive SSM dimensions from hparams (same as build_layer_attn_linear). + const auto & hparams = model.hparams; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t num_v_heads = hparams.ssm_dt_rank; // H_v + const int64_t head_v_dim = d_inner / num_v_heads; // S_v + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t conv_channels = d_inner + 2 * (int64_t)hparams.ssm_n_group * (int64_t)hparams.ssm_d_state; + const int32_t n_layer = (int32_t)hparams.n_layer; + + // Release existing allocation before reallocating. + dflash_persist_inter_l.clear(); + dflash_persist_conv_l.clear(); + dflash_persist_inter_buf.reset(); + dflash_persist_inter_ctx.reset(); + dflash_persist_ctxs_bufs.clear(); + + dflash_persist_inter_l.resize(n_layer, nullptr); + dflash_persist_conv_l.resize(n_layer, nullptr); + + // Persist tensors have to live next to each layer's recurrent state. With + // partial offload, CPU and CUDA recurrent layers coexist; one shared buffer + // would make CUDA layers write persist state into CPU memory or vice versa. + auto * raw_mem = memory.get(); + auto * mem_recr = dynamic_cast(raw_mem); + if (!mem_recr) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem_recr = hyb->get_mem_recr(); + } + } + + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + + // dflash Phase 2.4 fix: allocate each layer's persist tensors in a separate + // context so they get separate backend buffers. This lets small (~26 MiB) + // per-layer allocations fit into fragmented GPU memory where one large + // (~1.7 GiB) contiguous block would fail. + size_t total_bytes = 0; + for (int il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { + continue; // full-attn layer — no persist buffer needed + } + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + if (mem_recr && il < (int)mem_recr->s_l.size() && + mem_recr->s_l[il] != nullptr && mem_recr->s_l[il]->buffer != nullptr) { + buft = ggml_backend_buffer_get_type(mem_recr->s_l[il]->buffer); + } + + struct ggml_init_params init_params = { + /* mem_size = */ ggml_tensor_overhead() * 2 + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(init_params); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for persist buffers\n", __func__); + dflash_persist_inter_l.clear(); + dflash_persist_conv_l.clear(); + dflash_persist_ctxs_bufs.clear(); + dflash_persist_failed_n_tokens = std::max(dflash_persist_failed_n_tokens, n_tokens); + return; + } + + // SSM persist: [S_v, S_v, H_v, n_tokens]. CPU layers keep F32 for exact + // equivalence. CUDA layers use F16 to keep DDTree rollback feasible on + // 24 GiB cards when the draft model is also resident. + const ggml_type persist_s_type = + (buft == ggml_backend_cpu_buffer_type()) ? GGML_TYPE_F32 : GGML_TYPE_F16; + ggml_tensor * ts = ggml_new_tensor_4d(ctx, + persist_s_type, head_v_dim, head_v_dim, num_v_heads, n_tokens); + ggml_format_name(ts, "dflash_persist_il%d", il); + dflash_persist_inter_l[il] = ts; + + // Conv persist: [K-1, conv_channels, n_tokens] F32 — matches the live + // r_l[il] layout (K-1 fastest, then conv_channels) per token. + ggml_tensor * tc = ggml_new_tensor_3d(ctx, + GGML_TYPE_F32, d_conv - 1, conv_channels, n_tokens); + ggml_format_name(tc, "dflash_persist_conv_il%d", il); + dflash_persist_conv_l[il] = tc; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate persist buffer for layer %d (n_tokens=%lld)\n", + __func__, il, (long long)n_tokens); + dflash_persist_inter_l.clear(); + dflash_persist_conv_l.clear(); + dflash_persist_ctxs_bufs.clear(); + dflash_persist_failed_n_tokens = std::max(dflash_persist_failed_n_tokens, n_tokens); + return; + } + ggml_backend_buffer_clear(buf, 0); + total_bytes += ggml_backend_buffer_get_size(buf); + dflash_persist_ctxs_bufs.emplace_back(ggml_context_ptr(ctx), buf); + } + dflash_persist_max_n_tokens = n_tokens; + dflash_persist_failed_n_tokens = 0; + + LLAMA_LOG_INFO("%s: allocated dflash persist buffers: %d layers, %lld tokens, %.2f MiB across %zu backend buffers\n", + __func__, n_layer, (long long)n_tokens, + (double)total_bytes / (1024.0 * 1024.0), + dflash_persist_ctxs_bufs.size()); +} + +void llama_context::set_target_feat_raw(const float * data, int64_t n_embd_fc, int64_t ctx_len, + int64_t committed_pos) { + // Stash non-owning pointer and dims; read by llm_graph_input_target_feat::set_input(). + pending_target_feat_raw = data; + pending_target_feat_n_embd_fc = n_embd_fc; + pending_target_feat_ctx_len = ctx_len; + pending_draft_committed_pos = committed_pos; + pending_target_feat_fused = false; + pending_dflash_fuse_only = false; + pending_dflash_kv_update_only = false; + pending_target_feat_tensor = nullptr; +} + +void llama_context::set_target_feat_fused(const float * data, int64_t n_embd, int64_t ctx_len, + int64_t committed_pos) { + pending_target_feat_raw = data; + pending_target_feat_n_embd_fc = n_embd; + pending_target_feat_ctx_len = ctx_len; + pending_draft_committed_pos = committed_pos; + pending_target_feat_fused = true; + pending_dflash_fuse_only = false; + pending_dflash_kv_update_only = false; + pending_target_feat_tensor = nullptr; +} + void llama_context::set_warmup(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1168,21 +1405,87 @@ bool llama_context::set_adapter_cvec( return res; } +static std::map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset); +static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map & samplers); +static void copy_tensor_async_ints( + const std::map & tensor_map, + const buffer_view & sampled, + const std::map & seq_to_row, + ggml_backend_sched_t sched); +static void copy_tensor_async_floats( + const std::map & tensor_map, + const buffer_view & dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched); +static void copy_tensor_async_candidates( + const std::map & tensor_map, + const buffer_view & dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched); + llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { + const bool profile_dflash = + model.arch == LLM_ARCH_DFLASH_DRAFT && + std::getenv("LLAMA_DDTREE_PROFILE") != nullptr; + + const int64_t t_total_start_us = profile_dflash ? ggml_time_us() : 0; + int64_t t_apply_us = 0; + int64_t t_build_alloc_us = 0; + int64_t t_set_inputs_us = 0; + int64_t t_compute_us = 0; + bool reused_graph = false; + + int64_t t0_us = profile_dflash ? ggml_time_us() : 0; if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; return nullptr; } + if (profile_dflash) { + t_apply_us = ggml_time_us() - t0_us; + } + + auto ensure_dflash_runtime = [&](llm_graph_result_ptr & res_ptr, ggml_backend_sched_ptr & sched_ptr) { + if (!res_ptr) { + res_ptr.reset(new llm_graph_result(this->graph_max_nodes(std::min(cparams.n_ctx, cparams.n_ubatch)))); + } + if (!sched_ptr) { + sched_ptr.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), + res_ptr->get_max_nodes(), cparams.pipeline_parallel, + cparams.op_offload)); + } + }; - auto * res = gf_res_prev.get(); + llm_graph_result * res = gf_res_prev.get(); + ggml_backend_sched_t sched_use = sched.get(); + if (model.arch == LLM_ARCH_DFLASH_DRAFT) { + if (pending_dflash_kv_update_only) { + ensure_dflash_runtime(dflash_res_kv, dflash_sched_kv); + res = dflash_res_kv.get(); + sched_use = dflash_sched_kv.get(); + } else if (pending_dflash_fuse_only) { + ensure_dflash_runtime(dflash_res_fuse, dflash_sched_fuse); + res = dflash_res_fuse.get(); + sched_use = dflash_sched_fuse.get(); + } else if (dflash_draft_top_k_req > 0 && pending_target_feat_tensor != nullptr) { + ensure_dflash_runtime(dflash_res_draft, dflash_sched_draft); + res = dflash_res_draft.get(); + sched_use = dflash_sched_draft.get(); + } + } auto * gf = res->get_gf(); // the new graph parameters // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype); + const auto gparams = graph_params(res, ubatch, mctx, gtype, sched_use); - if (!graph_reuse_disable && res->can_reuse(gparams)) { + const bool force_rebuild = model.arch == LLM_ARCH_DFLASH_DRAFT && + (pending_dflash_kv_update_only || pending_dflash_fuse_only); + if (!force_rebuild && !graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); // with pipeline parallelism, the previous graph_compute_async may still be running @@ -1193,11 +1496,13 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } n_reused++; + reused_graph = true; } else { + t0_us = profile_dflash ? ggml_time_us() : 0; res->reset(); - ggml_backend_sched_reset(sched.get()); - ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + ggml_backend_sched_reset(sched_use); + ggml_backend_sched_set_eval_callback(sched_use, cparams.cb_eval, cparams.cb_eval_user_data); //const auto t_start_us = ggml_time_us(); @@ -1211,24 +1516,33 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll return nullptr; } - if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + if (!ggml_backend_sched_alloc_graph(sched_use, gf)) { LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); ret = GGML_STATUS_ALLOC_FAILED; return nullptr; } + if (profile_dflash) { + t_build_alloc_us = ggml_time_us() - t0_us; + } } // set the input data for the input tensors { - //const auto t_start_us = ggml_time_us(); + t0_us = profile_dflash ? ggml_time_us() : 0; // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated res->set_inputs(&ubatch); - //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + if (profile_dflash) { + t_set_inputs_us = ggml_time_us() - t0_us; + } } - const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); + t0_us = profile_dflash ? ggml_time_us() : 0; + const auto status = graph_compute(sched_use, res->get_gf(), ubatch.n_tokens > 1); + if (profile_dflash) { + t_compute_us = ggml_time_us() - t0_us; + } if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); ret = status; @@ -1237,6 +1551,19 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll ret = GGML_STATUS_SUCCESS; + if (profile_dflash) { + LLAMA_LOG_INFO("dflash_draft_ubatch_timing: tokens=%u outputs=%d ctx_len=%" PRId64 " reused=%d apply=%.3f build_alloc=%.3f set_inputs=%.3f compute=%.3f total=%.3f ms\n", + ubatch.n_tokens, + n_outputs, + pending_target_feat_ctx_len, + reused_graph ? 1 : 0, + t_apply_us / 1000.0, + t_build_alloc_us / 1000.0, + t_set_inputs_us / 1000.0, + t_compute_us / 1000.0, + (ggml_time_us() - t_total_start_us) / 1000.0); + } + return res; } @@ -1316,7 +1643,7 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (logits.data && t_logits) { + if (logits.data && t_logits && dflash_draft_top_k_req <= 0 && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits.data != nullptr); @@ -1324,86 +1651,1192 @@ int llama_context::encode(const llama_batch & batch_inp) { ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float)); } + // Copy backend sampling output if this ubatch produced any sampling tensors. + if (!sampling.samplers.empty() && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || + !res->t_sampled_logits.empty() || !res->t_candidates.empty())) { + const auto seq_to_output_row = build_seq_to_output_row(ubatch, 0); + const auto stride = n_vocab; + + copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get()); + copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, + seq_to_output_row, sched.get()); + copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, + seq_to_output_row, sched.get()); + copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, + seq_to_output_row, sched.get()); + } + + // dflash-draft top-K output: pull compact [K, n_tokens] tensors into host memory. + if (model.arch == LLM_ARCH_DFLASH_DRAFT && res->get_dflash_top_logits() != nullptr && + res->get_dflash_top_ids() != nullptr) { + ggml_tensor * t_top_logits = res->get_dflash_top_logits(); + ggml_tensor * t_top_ids = res->get_dflash_top_ids(); + ggml_backend_t backend_logits = ggml_backend_sched_get_tensor_backend(sched.get(), t_top_logits); + ggml_backend_t backend_ids = ggml_backend_sched_get_tensor_backend(sched.get(), t_top_ids); + GGML_ASSERT(backend_logits != nullptr); + GGML_ASSERT(backend_ids != nullptr); + + const int64_t top_k = t_top_logits->ne[0]; + const int64_t rows = t_top_logits->ne[1]; + GGML_ASSERT(t_top_ids->ne[0] == top_k && t_top_ids->ne[1] == rows); + + dflash_draft_top_k = (int32_t) top_k; + dflash_draft_top_rows = (int32_t) rows; + dflash_draft_top_logits.resize((size_t) top_k * rows); + dflash_draft_top_token_ids.resize((size_t) top_k * rows); + + ggml_backend_tensor_get_async(backend_logits, t_top_logits, dflash_draft_top_logits.data(), 0, + ggml_nbytes(t_top_logits)); + ggml_backend_tensor_get_async(backend_ids, t_top_ids, dflash_draft_top_token_ids.data(), 0, + ggml_nbytes(t_top_ids)); + } else { + dflash_draft_top_k = 0; + dflash_draft_top_rows = 0; + dflash_draft_top_logits.clear(); + dflash_draft_top_token_ids.clear(); + } + + if (capture_hidden && res->t_hidden_capture != nullptr) { + ggml_tensor * t_cap = res->t_hidden_capture; + hidden_capture_ne0 = t_cap->ne[0]; + hidden_capture_ne1 = t_cap->ne[1]; + const char * direct = std::getenv("LLAMA_DDTREE_CAPTURE_DIRECT"); + if (direct != nullptr && direct[0] == '1') { + hidden_capture_host_valid = false; + } else { + ggml_backend_t backend_cap = ggml_backend_sched_get_tensor_backend(sched.get(), t_cap); + GGML_ASSERT(backend_cap != nullptr); + const size_t cap_n = ggml_nelements(t_cap); + if (hidden_capture_host.size() < cap_n) { + hidden_capture_host.resize(cap_n); + } + ggml_backend_tensor_get_async(backend_cap, t_cap, hidden_capture_host.data(), 0, cap_n * sizeof(float)); + hidden_capture_host_valid = true; + } + } + // extract embeddings if (embd.data && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(embd.data != nullptr); - const uint32_t n_embd_out = hparams.n_embd_out(); + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd.data != nullptr); + const uint32_t n_embd_out = hparams.n_embd_out(); + + GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + // use n_embd_out (not n_embd_inp) - the pooled embedding has the model's + // output dimension, which differs from input dimension for deepstack models (e.g. qwen3vl) + const uint32_t n_embd_out = hparams.n_embd_out(); + embd_seq_out[seq_id].resize(n_embd_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - n_cls_out floats per sequence + auto & embd_seq_out = embd_seq; + + const uint32_t n_cls_out = hparams.n_cls_out; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + + // TODO: hacky solution + if (model.arch == LLM_ARCH_T5 && t_embd) { + //cross.t_embd = t_embd; + + synchronize(); + + cross.n_embd = t_embd->ne[0]; + cross.n_enc = t_embd->ne[1]; + cross.v_embd.resize(cross.n_embd*cross.n_enc); + memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); + + const auto & batch = balloc->get_batch(); + + // remember the sequence ids used during the encoding - needed for cross attention later + cross.seq_ids_enc.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + cross.seq_ids_enc[i].clear(); + + for (int s = 0; s < batch.n_seq_id[i]; s++) { + const llama_seq_id seq_id = batch.seq_id[i][s]; + + cross.seq_ids_enc[i].insert(seq_id); + } + } + } + + return 0; +} + +int llama_context::dflash_draft_fuse_target_feat( + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + float * target_feat_fused) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || target_feat_raw == nullptr || target_feat_fused == nullptr || + n_embd_fc <= 0 || ctx_len <= 0) { + return -1; + } + + set_target_feat_raw(target_feat_raw, n_embd_fc, ctx_len, 0); + pending_dflash_fuse_only = true; + set_dflash_draft_top_k(0); + + const auto & hparams = model.hparams; + const int64_t n_embd = hparams.n_embd_inp(); + std::vector dummy_embd((size_t) n_embd, 0.0f); + llama_pos pos = 0; + int32_t n_seq_id = 1; + llama_seq_id seq_id_value = 0; + llama_seq_id * seq_id = &seq_id_value; + int8_t output = 1; + + llama_batch batch{}; + batch.n_tokens = 1; + batch.token = nullptr; + batch.embd = dummy_embd.data(); + batch.pos = &pos; + batch.n_seq_id = &n_seq_id; + batch.seq_id = &seq_id; + batch.logits = &output; + + if (!balloc->init(batch, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + pending_dflash_fuse_only = false; + return -1; + } + + const uint32_t n_tokens = balloc->get_n_tokens(); + const llama_ubatch ubatch = balloc->split_simple(n_tokens); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + + embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; + n_outputs = n_tokens; + + const bool causal_attn_org = cparams.causal_attn; + cparams.causal_attn = false; + + ggml_status status; + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + + cparams.causal_attn = causal_attn_org; + pending_dflash_fuse_only = false; + + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } + + ggml_tensor * t_fused = res->get_embd(); + if (t_fused == nullptr || t_fused->ne[0] != n_embd || t_fused->ne[1] != ctx_len) { + return -3; + } + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_fused); + GGML_ASSERT(backend != nullptr); + ggml_backend_tensor_get_async(backend, t_fused, target_feat_fused, 0, ggml_nbytes(t_fused)); + synchronize(); + + return 0; +} + +static ggml_backend_t dflash_tensor_backend( + const std::vector & backends, + ggml_tensor * t) { + if (t == nullptr || t->buffer == nullptr) { + return nullptr; + } + + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(t->buffer); + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + + for (const auto & backend : backends) { + ggml_backend_t be = backend.get(); + if (be != nullptr && ggml_backend_get_device(be) == dev && ggml_backend_supports_buft(be, buft)) { + return be; + } + } + + return nullptr; +} + +static ggml_backend_buffer_type_t dflash_preferred_cache_buft( + const std::vector & backends, + ggml_backend_buffer_type_t fallback) { + if (!ggml_backend_buft_is_host(fallback)) { + return fallback; + } + + for (const auto & backend : backends) { + ggml_backend_t be = backend.get(); + if (be == nullptr) { + continue; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(be); + if (buft != nullptr && !ggml_backend_buft_is_host(buft)) { + return buft; + } + } + + return fallback; +} + +static bool dflash_graph_copy_1d( + ggml_backend_t backend, + ggml_tensor * src, + ggml_tensor * dst, + int64_t ne, + size_t src_off, + size_t dst_off) { + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() * 8 + ggml_graph_overhead_custom(8, false), + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + return false; + } + + ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), 8, false); + ggml_tensor * src_view = ggml_view_1d(ctx.get(), src, ne, src_off); + ggml_tensor * dst_view = ggml_view_1d(ctx.get(), dst, ne, dst_off); + ggml_tensor * out = ggml_cpy(ctx.get(), src_view, dst_view); + ggml_build_forward_expand(gf, out); + + return ggml_backend_graph_compute(backend, gf) == GGML_STATUS_SUCCESS; +} + +bool llama_context::dflash_draft_ensure_fused_cache_tensor( + int64_t n_embd, + int64_t cap, + ggml_backend_buffer_type_t buft) { + if (dflash_fused_cache != nullptr && dflash_fused_cache_n_embd == n_embd && + dflash_fused_cache_cap == cap && dflash_fused_cache->buffer != nullptr && + ggml_backend_buffer_get_type(dflash_fused_cache->buffer) == buft) { + return true; + } + + dflash_fused_cache = nullptr; + dflash_fused_cache_ctx.reset(); + dflash_fused_cache_buf.reset(); + + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return false; + } + + ggml_tensor * cache = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, cap); + ggml_set_name(cache, "dflash_fused_target_feat_cache"); + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + + dflash_fused_cache_ctx.reset(ctx); + dflash_fused_cache_buf.reset(buf); + dflash_fused_cache = cache; + dflash_fused_cache_n_embd = n_embd; + dflash_fused_cache_cap = cap; + return true; +} + +bool llama_context::dflash_draft_ensure_packed_target_feat_tensor( + int64_t n_embd, + int64_t ctx_len, + ggml_backend_buffer_type_t buft) { + if (dflash_packed_target_feat != nullptr && dflash_packed_target_feat_n_embd == n_embd && + dflash_packed_target_feat_ctx_len == ctx_len && dflash_packed_target_feat->buffer != nullptr && + ggml_backend_buffer_get_type(dflash_packed_target_feat->buffer) == buft) { + return true; + } + + dflash_packed_target_feat = nullptr; + dflash_packed_target_feat_ctx.reset(); + dflash_packed_target_feat_buf.reset(); + + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return false; + } + + ggml_tensor * packed = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ctx_len); + ggml_set_name(packed, "dflash_packed_target_feat"); + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + + dflash_packed_target_feat_ctx.reset(ctx); + dflash_packed_target_feat_buf.reset(buf); + dflash_packed_target_feat = packed; + dflash_packed_target_feat_n_embd = n_embd; + dflash_packed_target_feat_ctx_len = ctx_len; + return true; +} + +bool llama_context::dflash_draft_ensure_kv_cache_tensors( + int64_t n_embd_head, + int64_t n_head_kv, + int64_t cap, + ggml_backend_buffer_type_t buft) { + const int64_t n_layer = model.hparams.n_layer; + if ((int64_t)dflash_k_cache_l.size() == n_layer && dflash_kv_cache_head_dim == n_embd_head && + dflash_kv_cache_n_head_kv == n_head_kv && dflash_kv_cache_cap == cap && + !dflash_k_cache_l.empty() && dflash_k_cache_l[0] != nullptr && dflash_k_cache_l[0]->buffer != nullptr && + ggml_backend_buffer_get_type(dflash_k_cache_l[0]->buffer) == buft) { + return true; + } + + dflash_k_cache_l.clear(); + dflash_v_cache_l.clear(); + dflash_kv_cache_ctx.reset(); + dflash_kv_cache_buf.reset(); + + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() * (size_t)n_layer * 2 + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return false; + } + + dflash_k_cache_l.resize((size_t)n_layer, nullptr); + dflash_v_cache_l.resize((size_t)n_layer, nullptr); + for (int64_t il = 0; il < n_layer; ++il) { + dflash_k_cache_l[(size_t)il] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head_kv, cap); + dflash_v_cache_l[(size_t)il] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head_kv, cap); + ggml_format_name(dflash_k_cache_l[(size_t)il], "dflash_k_cache_%lld", (long long)il); + ggml_format_name(dflash_v_cache_l[(size_t)il], "dflash_v_cache_%lld", (long long)il); + } + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + ggml_free(ctx); + dflash_k_cache_l.clear(); + dflash_v_cache_l.clear(); + return false; + } + ggml_backend_buffer_clear(buf, 0); + + dflash_kv_cache_ctx.reset(ctx); + dflash_kv_cache_buf.reset(buf); + dflash_kv_cache_head_dim = n_embd_head; + dflash_kv_cache_n_head_kv = n_head_kv; + dflash_kv_cache_cap = cap; + return true; +} + +bool llama_context::dflash_draft_ensure_packed_kv_tensors( + int64_t n_embd_head, + int64_t n_head_kv, + int64_t ctx_len, + ggml_backend_buffer_type_t buft) { + const int64_t n_layer = model.hparams.n_layer; + if ((int64_t)dflash_k_packed_l.size() == n_layer && dflash_kv_packed_head_dim == n_embd_head && + dflash_kv_packed_n_head_kv == n_head_kv && dflash_kv_packed_ctx_len == ctx_len && + !dflash_k_packed_l.empty() && dflash_k_packed_l[0] != nullptr && dflash_k_packed_l[0]->buffer != nullptr && + ggml_backend_buffer_get_type(dflash_k_packed_l[0]->buffer) == buft) { + return true; + } + + dflash_k_packed_l.clear(); + dflash_v_packed_l.clear(); + dflash_kv_packed_ctx.reset(); + dflash_kv_packed_buf.reset(); + + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() * (size_t)n_layer * 2 + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return false; + } + + dflash_k_packed_l.resize((size_t)n_layer, nullptr); + dflash_v_packed_l.resize((size_t)n_layer, nullptr); + for (int64_t il = 0; il < n_layer; ++il) { + dflash_k_packed_l[(size_t)il] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head_kv, ctx_len); + dflash_v_packed_l[(size_t)il] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head_kv, ctx_len); + ggml_format_name(dflash_k_packed_l[(size_t)il], "dflash_k_packed_%lld", (long long)il); + ggml_format_name(dflash_v_packed_l[(size_t)il], "dflash_v_packed_%lld", (long long)il); + } + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + ggml_free(ctx); + dflash_k_packed_l.clear(); + dflash_v_packed_l.clear(); + return false; + } + ggml_backend_buffer_clear(buf, 0); + + dflash_kv_packed_ctx.reset(ctx); + dflash_kv_packed_buf.reset(buf); + dflash_kv_packed_head_dim = n_embd_head; + dflash_kv_packed_n_head_kv = n_head_kv; + dflash_kv_packed_ctx_len = ctx_len; + return true; +} + +bool llama_context::dflash_draft_ensure_top_output_tensors( + int64_t top_k, + int64_t rows, + ggml_backend_buffer_type_t buft) { + if (dflash_top_logits_fixed != nullptr && dflash_top_ids_fixed != nullptr && + dflash_top_output_k == top_k && dflash_top_output_rows == rows && + dflash_top_logits_fixed->buffer != nullptr && dflash_top_ids_fixed->buffer != nullptr && + ggml_backend_buffer_get_type(dflash_top_logits_fixed->buffer) == buft) { + return true; + } + + dflash_top_logits_fixed = nullptr; + dflash_top_ids_fixed = nullptr; + dflash_top_output_ctx.reset(); + dflash_top_output_buf.reset(); + + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() * 2 + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return false; + } + + ggml_tensor * top_logits = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, top_k, rows); + ggml_tensor * top_ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, top_k, rows); + ggml_set_name(top_logits, "dflash_top_logits_fixed"); + ggml_set_name(top_ids, "dflash_top_ids_fixed"); + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + + dflash_top_output_ctx.reset(ctx); + dflash_top_output_buf.reset(buf); + dflash_top_logits_fixed = top_logits; + dflash_top_ids_fixed = top_ids; + dflash_top_output_k = top_k; + dflash_top_output_rows = rows; + return true; +} + +int llama_context::dflash_draft_update_fused_cache( + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t n_new, + int64_t first_pos, + int64_t cap) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || target_feat_raw == nullptr || + n_embd_fc <= 0 || n_new <= 0 || cap <= 0 || n_embd_fc % 5 != 0) { + return -1; + } + + if (n_new > cap) { + const int64_t skip = n_new - cap; + target_feat_raw += (size_t)skip * n_embd_fc; + first_pos += skip; + n_new = cap; + } + + set_target_feat_raw(target_feat_raw, n_embd_fc, n_new, 0); + pending_dflash_fuse_only = true; + set_dflash_draft_top_k(0); + + const int64_t n_embd = n_embd_fc / 5; + std::vector dummy_embd((size_t) n_embd, 0.0f); + llama_pos pos = 0; + int32_t n_seq_id = 1; + llama_seq_id seq_id_value = 0; + llama_seq_id * seq_id = &seq_id_value; + int8_t output = 1; + + llama_batch batch{}; + batch.n_tokens = 1; + batch.token = nullptr; + batch.embd = dummy_embd.data(); + batch.pos = &pos; + batch.n_seq_id = &n_seq_id; + batch.seq_id = &seq_id; + batch.logits = &output; + + if (!balloc->init(batch, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { + pending_dflash_fuse_only = false; + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return -1; + } + + const uint32_t n_tokens = balloc->get_n_tokens(); + const llama_ubatch ubatch = balloc->split_simple(n_tokens); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + + embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; + n_outputs = n_tokens; + + const bool causal_attn_org = cparams.causal_attn; + cparams.causal_attn = false; + + ggml_status status; + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + + cparams.causal_attn = causal_attn_org; + pending_dflash_fuse_only = false; + + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } + + ggml_tensor * t_fused = res->get_embd(); + if (t_fused == nullptr || t_fused->ne[0] != n_embd || t_fused->ne[1] != n_new || t_fused->buffer == nullptr) { + return -3; + } + + ggml_backend_buffer_type_t buft = dflash_preferred_cache_buft(backends, ggml_backend_buffer_get_type(t_fused->buffer)); + if (!dflash_draft_ensure_fused_cache_tensor(n_embd, cap, buft)) { + LLAMA_LOG_ERROR("%s: failed to allocate fused target_feat cache\n", __func__); + return -2; + } + + ggml_backend_t src_backend = dflash_tensor_backend(backends, t_fused); + ggml_backend_t dst_backend = dflash_tensor_backend(backends, dflash_fused_cache); + std::vector bounce; + + int64_t copied = 0; + while (copied < n_new) { + const int64_t dst_col = (first_pos + copied) % cap; + const int64_t n_part = std::min(n_new - copied, cap - dst_col); + const size_t src_off = (size_t) copied * n_embd * sizeof(float); + const size_t dst_off = (size_t) dst_col * n_embd * sizeof(float); + const size_t nbytes = (size_t) n_embd * n_part * sizeof(float); + if (src_backend != nullptr && src_backend == dst_backend) { + if (!dflash_graph_copy_1d(src_backend, t_fused, dflash_fused_cache, n_embd * n_part, src_off, dst_off)) { + LLAMA_LOG_ERROR("%s: failed to copy fused target_feat into device cache\n", __func__); + return -3; + } + } else { + bounce.resize((size_t) n_embd * n_part); + ggml_backend_tensor_get(t_fused, bounce.data(), src_off, nbytes); + ggml_backend_tensor_set(dflash_fused_cache, bounce.data(), dst_off, nbytes); + } + copied += n_part; + } + + if (dflash_draft_update_kv_cache(nullptr, n_embd, n_new, first_pos, cap) != 0) { + LLAMA_LOG_ERROR("%s: failed to update draft K/V cache\n", __func__); + return -3; + } + + return 0; +} + +int llama_context::dflash_draft_update_fused_cache_from_capture( + llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || target_ctx == nullptr || n_dfs <= 0 || cap <= 0) { + return -1; + } + + ggml_tensor * t_cap = target_ctx->get_hidden_capture(); + if (t_cap == nullptr || t_cap->buffer == nullptr || t_cap->ne[0] <= 0 || t_cap->ne[1] <= 0 || t_cap->ne[1] % 5 != 0) { + return -1; + } + + const int64_t n_embd = t_cap->ne[0]; + const int64_t n_tokens = t_cap->ne[1] / 5; + const int64_t n_embd_fc = 5 * n_embd; + const int64_t update_chunk = std::min(16, cap); + + ggml_backend_buffer_type_t src_buft = ggml_backend_buffer_get_type(t_cap->buffer); + ggml_backend_buffer_type_t buft = dflash_preferred_cache_buft(backends, src_buft); + if (!dflash_draft_ensure_packed_target_feat_tensor(n_embd, 5*update_chunk, buft)) { + return -2; + } + if (!dflash_draft_ensure_fused_cache_tensor(n_embd, cap, buft)) { + return -2; + } + + ggml_backend_t src_backend = dflash_tensor_backend(target_ctx->backends, t_cap); + ggml_backend_t dst_backend = dflash_tensor_backend(backends, dflash_packed_target_feat); + if (dst_backend == nullptr) { + return -3; + } + + std::vector zeros((size_t)n_embd * 5 * update_chunk, 0.0f); + std::vector bounce((size_t)n_embd); + + int32_t done = 0; + while (done < n_dfs) { + const int64_t dst_col = (first_pos + done) % cap; + const int64_t valid = std::min({(int64_t)n_dfs - done, update_chunk, cap - dst_col}); + const int64_t width = std::min(update_chunk, cap - dst_col); + if (valid <= 0 || width <= 0) { + return -3; + } + if (!dflash_draft_ensure_packed_target_feat_tensor(n_embd, 5*width, buft)) { + return -2; + } + dst_backend = dflash_tensor_backend(backends, dflash_packed_target_feat); + if (dst_backend == nullptr) { + return -3; + } + + ggml_backend_tensor_set(dflash_packed_target_feat, zeros.data(), 0, + (size_t)n_embd * 5 * width * sizeof(float)); + + if (dfs_indices == nullptr) { + for (int64_t l = 0; l < 5; ++l) { + const size_t src_off = (size_t)(l * n_tokens + done) * n_embd * sizeof(float); + const size_t dst_off = (size_t)l * width * n_embd * sizeof(float); + const size_t nbytes = (size_t)n_embd * valid * sizeof(float); + if (src_backend != nullptr && src_backend == dst_backend) { + if (!dflash_graph_copy_1d(src_backend, t_cap, dflash_packed_target_feat, + n_embd * valid, src_off, dst_off)) { + return -3; + } + } else { + std::vector layer_bounce((size_t)n_embd * valid); + ggml_backend_tensor_get(t_cap, layer_bounce.data(), src_off, nbytes); + ggml_backend_tensor_set(dflash_packed_target_feat, layer_bounce.data(), dst_off, nbytes); + } + } + } else { + for (int64_t i = 0; i < valid; ++i) { + const int64_t src_col = (int64_t)dfs_indices[done + i]; + if (src_col < 0 || src_col >= n_tokens) { + return -1; + } + for (int64_t l = 0; l < 5; ++l) { + const size_t src_off = (size_t)(l * n_tokens + src_col) * n_embd * sizeof(float); + const size_t dst_off = ((size_t)l * width + (size_t)i) * n_embd * sizeof(float); + if (src_backend != nullptr && src_backend == dst_backend) { + if (!dflash_graph_copy_1d(src_backend, t_cap, dflash_packed_target_feat, + n_embd, src_off, dst_off)) { + return -3; + } + } else { + ggml_backend_tensor_get(t_cap, bounce.data(), src_off, (size_t)n_embd * sizeof(float)); + ggml_backend_tensor_set(dflash_packed_target_feat, bounce.data(), dst_off, (size_t)n_embd * sizeof(float)); + } + } + } + } + + pending_target_feat_raw = nullptr; + pending_target_feat_n_embd_fc = n_embd_fc; + pending_target_feat_ctx_len = width; + pending_draft_committed_pos = 0; + pending_target_feat_fused = false; + pending_dflash_fuse_only = true; + pending_dflash_kv_update_only = false; + pending_target_feat_tensor = dflash_packed_target_feat; + set_dflash_draft_top_k(0); + + std::vector dummy_embd((size_t)n_embd, 0.0f); + llama_pos pos = 0; + int32_t n_seq_id = 1; + llama_seq_id seq_id_value = 0; + llama_seq_id * seq_id = &seq_id_value; + int8_t output = 1; + + llama_batch batch{}; + batch.n_tokens = 1; + batch.token = nullptr; + batch.embd = dummy_embd.data(); + batch.pos = &pos; + batch.n_seq_id = &n_seq_id; + batch.seq_id = &seq_id; + batch.logits = &output; + + if (!balloc->init(batch, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { + pending_dflash_fuse_only = false; + pending_target_feat_tensor = nullptr; + return -1; + } + + const uint32_t n_batch_tokens = balloc->get_n_tokens(); + const llama_ubatch ubatch = balloc->split_simple(n_batch_tokens); + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_batch_tokens; + n_outputs = n_batch_tokens; + + const bool causal_attn_org = cparams.causal_attn; + cparams.causal_attn = false; + ggml_status status; + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + cparams.causal_attn = causal_attn_org; + pending_dflash_fuse_only = false; + pending_target_feat_tensor = nullptr; + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } + + ggml_tensor * t_fused = res->get_embd(); + if (t_fused == nullptr || t_fused->ne[0] != n_embd || t_fused->ne[1] != width || t_fused->buffer == nullptr) { + return -3; + } + + ggml_backend_t fused_backend = dflash_tensor_backend(backends, t_fused); + ggml_backend_t cache_backend = dflash_tensor_backend(backends, dflash_fused_cache); + if (fused_backend != nullptr && fused_backend == cache_backend) { + if (!dflash_graph_copy_1d(fused_backend, t_fused, dflash_fused_cache, + n_embd * width, 0, (size_t)dst_col * n_embd * sizeof(float))) { + return -3; + } + } else { + std::vector fused_bounce((size_t)n_embd * width); + ggml_backend_tensor_get(t_fused, fused_bounce.data(), 0, (size_t)n_embd * width * sizeof(float)); + ggml_backend_tensor_set(dflash_fused_cache, fused_bounce.data(), + (size_t)dst_col * n_embd * sizeof(float), + (size_t)n_embd * width * sizeof(float)); + } + + if (dflash_draft_update_kv_cache(nullptr, n_embd, width, first_pos + done, cap) != 0) { + return -3; + } + + done += (int32_t)valid; + } + + return 0; +} + +int llama_context::dflash_draft_update_kv_cache( + const float * target_feat_fused, + int64_t n_embd, + int64_t n_new, + int64_t first_pos, + int64_t cap) { + GGML_UNUSED(target_feat_fused); + + const int64_t n_embd_head = model.hparams.n_embd_head_k(); + const int64_t n_head_kv = model.hparams.n_head_kv(); + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(dflash_fused_cache->buffer); + if (!dflash_draft_ensure_kv_cache_tensors(n_embd_head, n_head_kv, cap, buft)) { + LLAMA_LOG_ERROR("%s: failed to allocate draft K/V cache\n", __func__); + return -2; + } + + const int64_t update_chunk = std::min(16, cap); + int64_t done = 0; + while (done < n_new) { + const int64_t src_col = (first_pos + done) % cap; + const int64_t valid = std::min({n_new - done, update_chunk, cap - src_col}); + const int64_t width = std::min(update_chunk, cap - src_col); + if (valid <= 0 || width <= 0) { + return -3; + } + if (!dflash_draft_ensure_packed_target_feat_tensor(n_embd, width, buft)) { + return -2; + } + + ggml_backend_t backend = dflash_tensor_backend(backends, dflash_fused_cache); + if (backend == nullptr || backend != dflash_tensor_backend(backends, dflash_packed_target_feat)) { + return -3; + } + if (!dflash_graph_copy_1d(backend, dflash_fused_cache, dflash_packed_target_feat, + n_embd * width, + (size_t)src_col * n_embd * sizeof(float), 0)) { + return -3; + } + + pending_target_feat_raw = nullptr; + pending_target_feat_n_embd_fc = n_embd; + pending_target_feat_ctx_len = width; + pending_draft_committed_pos = 0; + pending_target_feat_fused = true; + pending_dflash_fuse_only = false; + pending_dflash_kv_update_only = true; + pending_dflash_kv_update_dst_pos = src_col; + pending_target_feat_tensor = dflash_packed_target_feat; + set_dflash_draft_top_k(0); + + std::vector dummy_embd((size_t) n_embd, 0.0f); + llama_pos pos = 0; + int32_t n_seq_id = 1; + llama_seq_id seq_id_value = 0; + llama_seq_id * seq_id = &seq_id_value; + int8_t output = 1; + llama_batch batch{}; + batch.n_tokens = 1; + batch.token = nullptr; + batch.embd = dummy_embd.data(); + batch.pos = &pos; + batch.n_seq_id = &n_seq_id; + batch.seq_id = &seq_id; + batch.logits = &output; + + if (!balloc->init(batch, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { + pending_dflash_kv_update_only = false; + return -1; + } + + const uint32_t n_tokens = balloc->get_n_tokens(); + const llama_ubatch ubatch = balloc->split_simple(n_tokens); + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; + n_outputs = n_tokens; + + const bool causal_attn_org = cparams.causal_attn; + cparams.causal_attn = false; + ggml_status status; + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + cparams.causal_attn = causal_attn_org; + pending_dflash_kv_update_only = false; + pending_target_feat_tensor = nullptr; + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } + + done += valid; + } + + return 0; +} + +int llama_context::dflash_draft_pack_kv_cache( + int64_t n_embd_head, + int64_t n_head_kv, + int64_t ctx_len, + int64_t ring_start, + int64_t cap) { + if (dflash_k_cache_l.empty() || dflash_v_cache_l.empty()) { + return -1; + } + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(dflash_k_cache_l[0]->buffer); + if (!dflash_draft_ensure_packed_kv_tensors(n_embd_head, n_head_kv, ctx_len, buft)) { + return -2; + } + ggml_backend_t backend = dflash_tensor_backend(backends, dflash_k_cache_l[0]); + if (backend == nullptr || backend != dflash_tensor_backend(backends, dflash_k_packed_l[0])) { + return -3; + } + + const int64_t per_tok = n_embd_head * n_head_kv; + for (size_t il = 0; il < dflash_k_cache_l.size(); ++il) { + int64_t copied = 0; + while (copied < ctx_len) { + const int64_t src_col = (ring_start + copied) % cap; + const int64_t n_part = std::min(ctx_len - copied, cap - src_col); + const size_t src_off = (size_t) src_col * per_tok * sizeof(float); + const size_t dst_off = (size_t) copied * per_tok * sizeof(float); + const int64_t ne = per_tok * n_part; + if (!dflash_graph_copy_1d(backend, dflash_k_cache_l[il], dflash_k_packed_l[il], ne, src_off, dst_off) || + !dflash_graph_copy_1d(backend, dflash_v_cache_l[il], dflash_v_packed_l[il], ne, src_off, dst_off)) { + return -3; + } + copied += n_part; + } + } + return 0; +} + +int llama_context::dflash_draft_encode_top_k_cached( + const llama_batch & batch_inp, + int64_t n_embd, + int64_t ctx_len, + int64_t ring_start, + int64_t cap, + int64_t committed_pos, + int32_t top_k) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || dflash_fused_cache == nullptr || + n_embd <= 0 || ctx_len <= 0 || cap <= 0 || top_k <= 0 || + n_embd != dflash_fused_cache_n_embd || cap != dflash_fused_cache_cap) { + return -1; + } + + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(dflash_fused_cache->buffer); + if (!dflash_draft_ensure_packed_target_feat_tensor(n_embd, ctx_len, buft)) { + LLAMA_LOG_ERROR("%s: failed to allocate packed target_feat tensor\n", __func__); + return -2; + } + + ggml_backend_t backend = dflash_tensor_backend(backends, dflash_fused_cache); + if (backend == nullptr || backend != dflash_tensor_backend(backends, dflash_packed_target_feat)) { + LLAMA_LOG_ERROR("%s: packed target_feat backend mismatch\n", __func__); + return -3; + } + + int64_t copied = 0; + while (copied < ctx_len) { + const int64_t src_col = (ring_start + copied) % cap; + const int64_t n_part = std::min(ctx_len - copied, cap - src_col); + const size_t src_off = (size_t) src_col * n_embd * sizeof(float); + const size_t dst_off = (size_t) copied * n_embd * sizeof(float); + if (!dflash_graph_copy_1d(backend, dflash_fused_cache, dflash_packed_target_feat, + n_embd * n_part, src_off, dst_off)) { + LLAMA_LOG_ERROR("%s: failed to pack fused target_feat device window\n", __func__); + return -3; + } + copied += n_part; + } + + const int64_t n_embd_head = model.hparams.n_embd_head_k(); + const int64_t n_head_kv = model.hparams.n_head_kv(); + if (dflash_draft_pack_kv_cache(n_embd_head, n_head_kv, ctx_len, ring_start, cap) != 0) { + LLAMA_LOG_ERROR("%s: failed to pack draft K/V cache\n", __func__); + return -3; + } + + pending_target_feat_raw = nullptr; + pending_target_feat_n_embd_fc = n_embd; + pending_target_feat_ctx_len = ctx_len; + pending_draft_committed_pos = committed_pos; + pending_target_feat_fused = true; + pending_dflash_fuse_only = false; + pending_dflash_kv_update_only = false; + pending_target_feat_tensor = dflash_packed_target_feat; + + return dflash_draft_encode_top_k_pending(batch_inp, top_k); +} + +int llama_context::dflash_draft_encode_top_k( + const llama_batch & batch_inp, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || target_feat_raw == nullptr || top_k <= 0) { + return -1; + } + + set_target_feat_raw(target_feat_raw, n_embd_fc, ctx_len, committed_pos); + return dflash_draft_encode_top_k_pending(batch_inp, top_k); +} + +int llama_context::dflash_draft_encode_top_k_fused( + const llama_batch & batch_inp, + const float * target_feat_fused, + int64_t n_embd, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || target_feat_fused == nullptr || top_k <= 0) { + return -1; + } + + set_target_feat_fused(target_feat_fused, n_embd, ctx_len, committed_pos); + return dflash_draft_encode_top_k_pending(batch_inp, top_k); +} + +int llama_context::dflash_draft_encode_top_k_pending( + const llama_batch & batch_inp, + int32_t top_k) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || + (pending_target_feat_raw == nullptr && pending_target_feat_tensor == nullptr) || top_k <= 0) { + return -1; + } + + set_dflash_draft_top_k(top_k); + + const auto & hparams = model.hparams; + const int64_t n_embd_model = hparams.n_embd_inp(); + + if (pending_target_feat_fused && pending_target_feat_n_embd_fc != n_embd_model) { + LLAMA_LOG_ERROR("%s: fused target_feat width mismatch: got %lld expected %lld\n", + __func__, (long long)pending_target_feat_n_embd_fc, (long long)n_embd_model); + return -1; + } + + if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd_model, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return -1; + } + + const uint32_t n_tokens = balloc->get_n_tokens(); + const llama_ubatch ubatch = balloc->split_simple(n_tokens); + + if (cparams.n_ubatch < n_tokens) { + LLAMA_LOG_ERROR("%s: encoder requires n_ubatch >= n_tokens\n", __func__); + return -1; + } - GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings - auto & embd_seq_out = embd_seq; + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; + embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; + n_outputs = n_tokens; - // use n_embd_out (not n_embd_inp) - the pooled embedding has the model's - // output dimension, which differs from input dimension for deepstack models (e.g. qwen3vl) - const uint32_t n_embd_out = hparams.n_embd_out(); - embd_seq_out[seq_id].resize(n_embd_out); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_RANK: - { - // extract the rerank score - n_cls_out floats per sequence - auto & embd_seq_out = embd_seq; + const bool causal_attn_org = cparams.causal_attn; + cparams.causal_attn = false; - const uint32_t n_cls_out = hparams.n_cls_out; + ggml_status status; + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; + cparams.causal_attn = causal_attn_org; - embd_seq_out[seq_id].resize(n_cls_out); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); - } + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); } } - // TODO: hacky solution - if (model.arch == LLM_ARCH_T5 && t_embd) { - //cross.t_embd = t_embd; + if (res->get_dflash_top_logits() == nullptr || res->get_dflash_top_ids() == nullptr) { + dflash_draft_top_k = 0; + dflash_draft_top_rows = 0; + dflash_draft_top_logits.clear(); + dflash_draft_top_token_ids.clear(); + return -3; + } - synchronize(); + ggml_tensor * t_top_logits = res->get_dflash_top_logits(); + ggml_tensor * t_top_ids = res->get_dflash_top_ids(); + auto tensor_backend_from_any_sched = [&](ggml_tensor * t) -> ggml_backend_t { + ggml_backend_t be = ggml_backend_sched_get_tensor_backend(sched.get(), t); + if (be == nullptr && dflash_sched_draft) { + be = ggml_backend_sched_get_tensor_backend(dflash_sched_draft.get(), t); + } + if (be == nullptr && dflash_sched_fuse) { + be = ggml_backend_sched_get_tensor_backend(dflash_sched_fuse.get(), t); + } + if (be == nullptr && dflash_sched_kv) { + be = ggml_backend_sched_get_tensor_backend(dflash_sched_kv.get(), t); + } + return be; + }; + ggml_backend_t backend_logits = tensor_backend_from_any_sched(t_top_logits); + ggml_backend_t backend_ids = tensor_backend_from_any_sched(t_top_ids); + GGML_ASSERT(backend_logits != nullptr); + GGML_ASSERT(backend_ids != nullptr); - cross.n_embd = t_embd->ne[0]; - cross.n_enc = t_embd->ne[1]; - cross.v_embd.resize(cross.n_embd*cross.n_enc); - memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); + const int64_t graph_top_k = t_top_logits->ne[0]; + const int64_t rows = t_top_logits->ne[1]; + GGML_ASSERT(t_top_ids->ne[0] == graph_top_k && t_top_ids->ne[1] == rows); - const auto & batch = balloc->get_batch(); + ggml_backend_buffer_type_t out_buft = ggml_backend_buffer_get_type(t_top_logits->buffer); + if (!dflash_draft_ensure_top_output_tensors(graph_top_k, rows, out_buft)) { + return -2; + } - // remember the sequence ids used during the encoding - needed for cross attention later - cross.seq_ids_enc.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - cross.seq_ids_enc[i].clear(); + ggml_backend_t backend_fixed_logits = dflash_tensor_backend(backends, dflash_top_logits_fixed); + ggml_backend_t backend_fixed_ids = dflash_tensor_backend(backends, dflash_top_ids_fixed); + if (backend_fixed_logits == nullptr || backend_fixed_ids == nullptr) { + return -3; + } + if (backend_logits != nullptr && backend_logits == backend_fixed_logits) { + ggml_backend_tensor_copy_async(backend_logits, backend_fixed_logits, t_top_logits, dflash_top_logits_fixed); + } else { + std::vector top_bounce((size_t)graph_top_k * rows); + ggml_backend_tensor_get(t_top_logits, top_bounce.data(), 0, ggml_nbytes(t_top_logits)); + ggml_backend_tensor_set(dflash_top_logits_fixed, top_bounce.data(), 0, ggml_nbytes(t_top_logits)); + } + if (backend_ids != nullptr && backend_ids == backend_fixed_ids) { + ggml_backend_tensor_copy_async(backend_ids, backend_fixed_ids, t_top_ids, dflash_top_ids_fixed); + } else { + std::vector id_bounce((size_t)graph_top_k * rows); + ggml_backend_tensor_get(t_top_ids, id_bounce.data(), 0, ggml_nbytes(t_top_ids)); + ggml_backend_tensor_set(dflash_top_ids_fixed, id_bounce.data(), 0, ggml_nbytes(t_top_ids)); + } - for (int s = 0; s < batch.n_seq_id[i]; s++) { - const llama_seq_id seq_id = batch.seq_id[i][s]; + dflash_draft_top_k = (int32_t) graph_top_k; + dflash_draft_top_rows = (int32_t) rows; + dflash_draft_top_logits.resize((size_t) graph_top_k * rows); + dflash_draft_top_token_ids.resize((size_t) graph_top_k * rows); - cross.seq_ids_enc[i].insert(seq_id); - } - } - } + ggml_backend_tensor_get_async(backend_fixed_logits, dflash_top_logits_fixed, dflash_draft_top_logits.data(), 0, + ggml_nbytes(dflash_top_logits_fixed)); + ggml_backend_tensor_get_async(backend_fixed_ids, dflash_top_ids_fixed, dflash_draft_top_token_ids.data(), 0, + ggml_nbytes(dflash_top_ids_fixed)); return 0; } @@ -1538,6 +2971,19 @@ int llama_context::decode(const llama_batch & batch_inp) { return encode(batch_inp); } + // tree-mode batches are only supported for the Qwen3.5 hybrid architecture + if (batch_inp.parent_id != nullptr && model.arch != LLM_ARCH_QWEN35) { + LLAMA_LOG_ERROR("%s: parent_id (tree-mode batch) is only supported for LLM_ARCH_QWEN35, got arch=%d\n", + __func__, (int) model.arch); + return -1; + } + + // Phase 2.4: ensure SSM persist buffers are large enough for this tree batch. + // Must happen before graph_params() so the pointer is valid when building the graph. + if (batch_inp.parent_id != nullptr && llama_dflash_fast_rollback_enabled()) { + ensure_dflash_persist_capacity((int64_t)batch_inp.n_tokens); + } + if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; @@ -1749,6 +3195,25 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + if (capture_hidden && res->t_hidden_capture != nullptr) { + ggml_tensor * t_cap = res->t_hidden_capture; + hidden_capture_ne0 = t_cap->ne[0]; + hidden_capture_ne1 = t_cap->ne[1]; + const char * direct = std::getenv("LLAMA_DDTREE_CAPTURE_DIRECT"); + if (direct != nullptr && direct[0] == '1') { + hidden_capture_host_valid = false; + } else { + ggml_backend_t backend_cap = ggml_backend_sched_get_tensor_backend(sched.get(), t_cap); + GGML_ASSERT(backend_cap != nullptr); + const size_t cap_n = ggml_nelements(t_cap); + if (hidden_capture_host.size() < cap_n) { + hidden_capture_host.resize(cap_n); + } + ggml_backend_tensor_get_async(backend_cap, t_cap, hidden_capture_host.data(), 0, cap_n * sizeof(float)); + hidden_capture_host_valid = true; + } + } + // extract embeddings if (embd.data && t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); @@ -2149,29 +3614,61 @@ llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + llm_graph_type gtype, + ggml_backend_sched_t sched_use) const { + if (sched_use == nullptr) { + sched_use = sched.get(); + } return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, /*.cparams =*/ cparams, /*.ubatch =*/ ubatch, /*.gtype =*/ gtype, - /*.sched =*/ sched.get(), + /*.sched =*/ sched_use, /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ cvec.get(), /*.loras =*/ loras.get(), /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.samplers =*/ sampling.samplers, - /*.n_outputs =*/ n_outputs, - /*.cb =*/ graph_get_cb(), - /*.res =*/ res, + /*.n_outputs =*/ n_outputs, + /*.cb =*/ graph_get_cb(), + /*.res =*/ res, + /*.capture_hidden =*/ capture_hidden, + // Wire pending_target_feat pointers so build_inp_target_feat() can read them. + // These are non-null only when the caller invoked llama_set_target_feat_raw(). + /*.pending_target_feat_raw_ptr =*/ &pending_target_feat_raw, + /*.pending_target_feat_n_embd_fc_ptr =*/ &pending_target_feat_n_embd_fc, + /*.pending_target_feat_ctx_len_ptr =*/ &pending_target_feat_ctx_len, + /*.pending_draft_committed_pos_ptr =*/ &pending_draft_committed_pos, + /*.pending_target_feat_tensor_ptr =*/ &pending_target_feat_tensor, + /*.dflash_kv_cache_k_l =*/ pending_dflash_kv_update_only ? &dflash_k_cache_l : &dflash_k_packed_l, + /*.dflash_kv_cache_v_l =*/ pending_dflash_kv_update_only ? &dflash_v_cache_l : &dflash_v_packed_l, + /*.dflash_kv_cache_dst_pos =*/ pending_dflash_kv_update_dst_pos, + /*.dflash_target_feat_fused =*/ pending_target_feat_fused, + /*.dflash_kv_update_only =*/ pending_dflash_kv_update_only, + /*.dflash_fuse_only =*/ pending_dflash_fuse_only, + /*.dflash_draft_top_k =*/ dflash_draft_top_k_req, + // Phase 2.4: pass persist buffer vector when in tree mode (parent_id is set). + // Non-null only after ensure_dflash_persist_capacity() ran in decode(). + /*.dflash_persist_inter_l =*/ (!dflash_persist_inter_l.empty() && ubatch.parent_id != nullptr) + ? &dflash_persist_inter_l : nullptr, + /*.dflash_persist_conv_l =*/ (!dflash_persist_conv_l.empty() && ubatch.parent_id != nullptr) + ? &dflash_persist_conv_l : nullptr, }; } ggml_status llama_context::graph_compute( ggml_cgraph * gf, bool batched) { + return graph_compute(sched.get(), gf, batched); +} + +ggml_status llama_context::graph_compute( + ggml_backend_sched_t sched_use, + ggml_cgraph * gf, + bool batched) { int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; @@ -2188,7 +3685,7 @@ ggml_status llama_context::graph_compute( set_n_threads_fn.second(set_n_threads_fn.first, n_threads); } - auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); + auto status = ggml_backend_sched_graph_compute_async(sched_use, gf); if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); } @@ -3083,6 +4580,128 @@ void llama_set_warmup(llama_context * ctx, bool warmup) { ctx->set_warmup(warmup); } +void llama_set_capture_hidden(llama_context * ctx, bool enable) { + ctx->set_capture_hidden(enable); +} + +void llama_set_dflash_draft_top_k(llama_context * ctx, int32_t k) { + ctx->set_dflash_draft_top_k(k); +} + +ggml_tensor * llama_get_hidden_capture(llama_context * ctx) { + ctx->synchronize(); + return ctx->get_hidden_capture(); +} + +const float * llama_get_hidden_capture_data(llama_context * ctx, int64_t * out_ne0, int64_t * out_ne1) { + ctx->synchronize(); + return ctx->get_hidden_capture_data(out_ne0, out_ne1); +} + +void llama_set_target_feat_raw(llama_context * ctx, + const float * data, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos) { + ctx->set_target_feat_raw(data, n_embd_fc, ctx_len, committed_pos); +} + +int llama_dflash_draft_fuse_target_feat(llama_context * ctx, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + float * target_feat_fused) { + return ctx->dflash_draft_fuse_target_feat(target_feat_raw, n_embd_fc, ctx_len, target_feat_fused); +} + +int llama_dflash_draft_encode_top_k(llama_context * ctx, + llama_batch batch, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k) { + return ctx->dflash_draft_encode_top_k(batch, target_feat_raw, n_embd_fc, ctx_len, committed_pos, top_k); +} + +int llama_dflash_draft_encode_top_k_fused(llama_context * ctx, + llama_batch batch, + const float * target_feat_fused, + int64_t n_embd, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k) { + return ctx->dflash_draft_encode_top_k_fused(batch, target_feat_fused, n_embd, ctx_len, committed_pos, top_k); +} + +int llama_dflash_draft_update_fused_cache(llama_context * ctx, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t n_new, + int64_t first_pos, + int64_t cap) { + return ctx->dflash_draft_update_fused_cache(target_feat_raw, n_embd_fc, n_new, first_pos, cap); +} + +int llama_dflash_draft_update_fused_cache_from_capture(llama_context * draft_ctx, + llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap) { + return draft_ctx->dflash_draft_update_fused_cache_from_capture(target_ctx, dfs_indices, n_dfs, first_pos, cap); +} + +int llama_dflash_draft_encode_top_k_cached(llama_context * ctx, + llama_batch batch, + int64_t n_embd, + int64_t ctx_len, + int64_t ring_start, + int64_t cap, + int64_t committed_pos, + int32_t top_k) { + return ctx->dflash_draft_encode_top_k_cached(batch, n_embd, ctx_len, ring_start, cap, committed_pos, top_k); +} + +bool llama_dflash_rollback_ssm_to_dfs( + struct llama_context * ctx, + llama_seq_id seq_id, + int32_t accepted_dfs_node) { + ctx->synchronize(); // ensure the tree-mode decode kernel has completed + return ctx->dflash_rollback_ssm_to_dfs(seq_id, accepted_dfs_node); +} + +bool llama_dflash_set_recurrent_tail_pos( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos pos) { + if (ctx == nullptr) { + return false; + } + + auto * raw_mem = ctx->get_memory(); + auto * mem_recr = dynamic_cast(raw_mem); + if (!mem_recr) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem_recr = hyb->get_mem_recr(); + } + } + if (!mem_recr || seq_id < 0 || seq_id >= (llama_seq_id) mem_recr->cells.size()) { + return false; + } + + const int32_t cell_id = mem_recr->cells[seq_id].tail; + if (cell_id < 0 || cell_id >= (int32_t) mem_recr->cells.size()) { + return false; + } + if (!mem_recr->cells[cell_id].has_seq_id(seq_id)) { + return false; + } + + mem_recr->cells[cell_id].pos = pos; + return true; +} + void llama_synchronize(llama_context * ctx) { ctx->synchronize(); } @@ -3171,6 +4790,16 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { return static_cast(ctx->get_sampled_probs_count(i)); } +bool llama_get_dflash_draft_top_k( + llama_context * ctx, + const float ** logits, + const llama_token ** token_ids, + int32_t * n_rows, + int32_t * k) { + ctx->synchronize(); + return ctx->get_dflash_draft_top_k(logits, token_ids, n_rows, k); +} + struct ggml_cgraph * llama_graph_reserve( struct llama_context * ctx, uint32_t n_tokens, @@ -3317,6 +4946,345 @@ bool llama_memory_can_shift(llama_memory_t mem) { return mem->get_can_shift(); } +// snapshot/restore for recurrent (SSM + conv) state + +llama_mem_snapshot_id llama_seq_snapshot(struct llama_context * ctx, llama_seq_id seq_id) { + auto * raw_mem = ctx->get_memory(); + auto * mem = dynamic_cast(raw_mem); + if (!mem) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem = hyb->get_mem_recr(); + } + } + if (!mem) { + return LLAMA_MEM_SNAPSHOT_INVALID; + } + return mem->snapshot(seq_id); +} + +bool llama_seq_restore(struct llama_context * ctx, llama_mem_snapshot_id snap_id) { + auto * raw_mem = ctx->get_memory(); + auto * mem = dynamic_cast(raw_mem); + if (!mem) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem = hyb->get_mem_recr(); + } + } + if (!mem) { + return false; + } + return mem->restore(snap_id); +} + +void llama_seq_release(struct llama_context * ctx, llama_mem_snapshot_id snap_id) { + auto * raw_mem = ctx->get_memory(); + auto * mem = dynamic_cast(raw_mem); + if (!mem) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem = hyb->get_mem_recr(); + } + } + if (mem) { + mem->release(snap_id); + } +} + +bool llama_context::dflash_rollback_ssm_to_dfs(llama_seq_id seq_id, int32_t accepted_dfs_node) { + if (dflash_persist_inter_l.empty()) { + LLAMA_LOG_WARN("%s: persist buffers not allocated (no tree-mode decode has run)\n", __func__); + return false; + } + + // Resolve the recurrent memory module. + auto * raw_mem = memory.get(); + auto * mem_recr = dynamic_cast(raw_mem); + if (!mem_recr) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem_recr = hyb->get_mem_recr(); + } + } + if (!mem_recr) { + LLAMA_LOG_WARN("%s: no recurrent memory module; rollback is a no-op\n", __func__); + return false; + } + + const auto & hparams = model.hparams; + const int32_t n_layer = (int32_t)hparams.n_layer; + const int32_t cell_id = (seq_id >= 0 && seq_id < (int32_t)mem_recr->cells.size()) + ? mem_recr->cells[seq_id].tail : -1; + if (cell_id < 0) { + LLAMA_LOG_WARN("%s: seq_id=%d has no tail cell; rollback skipped\n", __func__, (int)seq_id); + return false; + } + + const int64_t n_embd_s = (int64_t)hparams.n_embd_s(); + const bool skip_s_rollback = []{ + const char * e = getenv("LLAMA_DDTREE_ROLLBACK_SKIP_S"); + return e && e[0] == '1'; + }(); + + // Fast path: execute the rollback as a tiny backend graph so CUDA layers do + // not bounce every persist column through host memory. Keep the host path + // below as the exact fallback for mixed/offloaded or unsupported layouts. + const bool graph_rollback_enabled = []{ + const char * e = getenv("LLAMA_DDTREE_ROLLBACK_GRAPH"); + return !e || e[0] != '0'; + }(); + + const bool skip_conv_rollback = []{ + const char * e = getenv("LLAMA_DDTREE_ROLLBACK_SKIP_CONV"); + return e && e[0] == '1'; + }(); + + struct dflash_rollback_copy { + ggml_tensor * src; + ggml_tensor * dst; + int64_t ne; + size_t src_off; + size_t dst_off; + }; + + auto tensor_backend = [&](const ggml_tensor * t) -> ggml_backend_t { + if (t == nullptr || t->buffer == nullptr) { + return nullptr; + } + + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(t->buffer); + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + + for (auto & backend : backends) { + ggml_backend_t be = backend.get(); + if (be != nullptr && ggml_backend_get_device(be) == dev && ggml_backend_supports_buft(be, buft)) { + return be; + } + } + + return nullptr; + }; + + auto try_graph_rollback = [&]() -> bool { + if (!graph_rollback_enabled) { + return false; + } + + std::vector copies; + copies.reserve((size_t)n_layer * 2); + + ggml_backend_t graph_backend = nullptr; + auto add_copy = [&](ggml_tensor * src, ggml_tensor * dst, int64_t ne, size_t src_off, size_t dst_off) -> bool { + ggml_backend_t src_backend = tensor_backend(src); + ggml_backend_t dst_backend = tensor_backend(dst); + if (src_backend == nullptr || src_backend != dst_backend) { + return false; + } + if (graph_backend == nullptr) { + graph_backend = src_backend; + } else if (graph_backend != src_backend) { + return false; + } + + copies.push_back({ src, dst, ne, src_off, dst_off }); + return true; + }; + + if (!skip_s_rollback) { + for (int il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { continue; } + ggml_tensor * persist = dflash_persist_inter_l[il]; + ggml_tensor * s_state = (il < (int32_t)mem_recr->s_l.size()) ? mem_recr->s_l[il] : nullptr; + if (!persist || !s_state) { continue; } + if (accepted_dfs_node >= dflash_persist_max_n_tokens) { return false; } + + if (!((persist->type == GGML_TYPE_F32 || persist->type == GGML_TYPE_F16) && + (s_state->type == GGML_TYPE_F32 || s_state->type == GGML_TYPE_F16))) { + return false; + } + + const size_t src_col_bytes = ggml_row_size(persist->type, n_embd_s); + const size_t dst_row_bytes = ggml_row_size(s_state->type, n_embd_s); + if (!add_copy(persist, s_state, n_embd_s, + (size_t)accepted_dfs_node * src_col_bytes, + (size_t)cell_id * dst_row_bytes)) { + return false; + } + } + } + + if (!skip_conv_rollback && !dflash_persist_conv_l.empty()) { + const int64_t n_embd_r = (int64_t)hparams.n_embd_r(); + for (int il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { continue; } + ggml_tensor * persist_conv = (il < (int32_t)dflash_persist_conv_l.size()) + ? dflash_persist_conv_l[il] : nullptr; + ggml_tensor * r_state = (il < (int32_t)mem_recr->r_l.size()) + ? mem_recr->r_l[il] : nullptr; + if (!persist_conv || !r_state) { continue; } + if (accepted_dfs_node >= dflash_persist_max_n_tokens) { return false; } + if (persist_conv->type != GGML_TYPE_F32 || r_state->type != GGML_TYPE_F32) { + return false; + } + + const size_t conv_col_bytes = (size_t)n_embd_r * sizeof(float); + const size_t r_row_bytes = ggml_row_size(r_state->type, n_embd_r); + if (!add_copy(persist_conv, r_state, n_embd_r, + (size_t)accepted_dfs_node * conv_col_bytes, + (size_t)cell_id * r_row_bytes)) { + return false; + } + } + } + + if (copies.empty()) { + return true; + } + + const size_t graph_size = copies.size() * 4 + 16; + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() * (copies.size() * 4 + 16) + + ggml_graph_overhead_custom(graph_size, false), + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + return false; + } + + ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), graph_size, false); + for (const auto & copy : copies) { + ggml_tensor * src = ggml_view_1d(ctx.get(), copy.src, copy.ne, copy.src_off); + ggml_tensor * dst = ggml_view_1d(ctx.get(), copy.dst, copy.ne, copy.dst_off); + ggml_tensor * out = ggml_cpy(ctx.get(), src, dst); + ggml_build_forward_expand(gf, out); + } + + ggml_backend_sched_synchronize(sched.get()); + const ggml_status status = ggml_backend_graph_compute(graph_backend, gf); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_WARN("%s: graph rollback failed with status %d; falling back to host copy\n", + __func__, (int)status); + return false; + } + + return true; + }; + + if (try_graph_rollback()) { + return true; + } + + // Persist tensor may be F32 (correctness baseline) or F16 (memory-saving + // variant). s_state may be F32 (Qwen3.5 hybrid stores SSM in F32). + std::vector bounce_f16((size_t)n_embd_s); + std::vector bounce_f32((size_t)n_embd_s); + + if (!skip_s_rollback) { + for (int il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { continue; } + ggml_tensor * persist = dflash_persist_inter_l[il]; + ggml_tensor * s_state = (il < (int32_t)mem_recr->s_l.size()) ? mem_recr->s_l[il] : nullptr; + if (!persist || !s_state) { continue; } + + if (accepted_dfs_node >= dflash_persist_max_n_tokens) { + LLAMA_LOG_WARN("%s: accepted_dfs_node=%d >= persist capacity=%lld at il=%d\n", + __func__, (int)accepted_dfs_node, + (long long)dflash_persist_max_n_tokens, il); + continue; + } + + const size_t state_row_bytes = ggml_row_size(s_state->type, n_embd_s); + const size_t state_offset = (size_t)cell_id * state_row_bytes; + + if (persist->type == GGML_TYPE_F32) { + const size_t persist_col_bytes = (size_t)n_embd_s * sizeof(float); + const size_t persist_offset = (size_t)accepted_dfs_node * persist_col_bytes; + ggml_backend_tensor_get(persist, bounce_f32.data(), persist_offset, persist_col_bytes); + + if (s_state->type == GGML_TYPE_F32) { + ggml_backend_tensor_set(s_state, bounce_f32.data(), state_offset, state_row_bytes); + } else if (s_state->type == GGML_TYPE_F16) { + ggml_fp32_to_fp16_row(bounce_f32.data(), bounce_f16.data(), n_embd_s); + ggml_backend_tensor_set(s_state, bounce_f16.data(), state_offset, state_row_bytes); + } else { + GGML_ABORT("dflash_rollback_ssm_to_dfs: unsupported s_state type"); + } + } else if (persist->type == GGML_TYPE_F16) { + const size_t persist_col_bytes = (size_t)n_embd_s * sizeof(ggml_fp16_t); + const size_t persist_offset = (size_t)accepted_dfs_node * persist_col_bytes; + ggml_backend_tensor_get(persist, bounce_f16.data(), persist_offset, persist_col_bytes); + + if (s_state->type == GGML_TYPE_F16) { + ggml_backend_tensor_set(s_state, bounce_f16.data(), state_offset, state_row_bytes); + } else if (s_state->type == GGML_TYPE_F32) { + ggml_fp16_to_fp32_row(bounce_f16.data(), bounce_f32.data(), n_embd_s); + ggml_backend_tensor_set(s_state, bounce_f32.data(), state_offset, state_row_bytes); + } else { + GGML_ABORT("dflash_rollback_ssm_to_dfs: unsupported s_state type"); + } + } else { + GGML_ABORT("dflash_rollback_ssm_to_dfs: unsupported persist type"); + } + } + } + + // Phase 5 fix: also roll the conv state (r_l[il]) back to accepted_dfs_node. + // Without this, the conv window stays at the DFS-last node and pollutes the + // root forward of the next spec step. + if (!skip_conv_rollback && !dflash_persist_conv_l.empty()) { + const int64_t n_embd_r = (int64_t)hparams.n_embd_r(); + std::vector bounce_conv((size_t)n_embd_r); + for (int il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { continue; } + ggml_tensor * persist_conv = (il < (int32_t)dflash_persist_conv_l.size()) + ? dflash_persist_conv_l[il] : nullptr; + ggml_tensor * r_state = (il < (int32_t)mem_recr->r_l.size()) + ? mem_recr->r_l[il] : nullptr; + if (!persist_conv || !r_state) { continue; } + if (accepted_dfs_node >= dflash_persist_max_n_tokens) { continue; } + + GGML_ASSERT(persist_conv->type == GGML_TYPE_F32); + // persist_conv layout: [K-1, conv_channels, n_tokens] F32 contiguous; + // each token col is exactly n_embd_r elements (K-1 * conv_channels). + const size_t conv_col_bytes = (size_t)n_embd_r * sizeof(float); + const size_t conv_off_src = (size_t)accepted_dfs_node * conv_col_bytes; + ggml_backend_tensor_get(persist_conv, bounce_conv.data(), conv_off_src, conv_col_bytes); + + // Live r_state row layout: ggml_row_size(type, n_embd_r) per cell. + const size_t r_row_bytes = ggml_row_size(r_state->type, n_embd_r); + const size_t r_off_dst = (size_t)cell_id * r_row_bytes; + if (r_state->type == GGML_TYPE_F32) { + ggml_backend_tensor_set(r_state, bounce_conv.data(), r_off_dst, r_row_bytes); + } else { + GGML_ABORT("dflash_rollback_ssm_to_dfs: unsupported r_state type"); + } + } + } + + return true; +} + +void llama_kv_cache_seq_compact_tree( + struct llama_context * ctx, + llama_seq_id seq_id, + const int32_t * accepted_dfs, + int32_t n_accepted, + int32_t commit_n, + int32_t spine_start) { + auto * raw_mem = ctx->get_memory(); + llama_kv_cache * kv = dynamic_cast(raw_mem); + if (!kv) { + if (auto * hyb = dynamic_cast(raw_mem)) { + kv = hyb->get_mem_attn(); + } + } + if (!kv) { + // non-KV memory (pure SSM) — no cache compaction needed + return; + } + std::vector dfs_vec(accepted_dfs, accepted_dfs + n_accepted); + kv->seq_compact_tree(seq_id, dfs_vec, commit_n, spine_start); +} + // llama state API // deprecated diff --git a/src/llama-context.h b/src/llama-context.h index e0d0085c1c3f..a308d106b16a 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -91,6 +91,12 @@ struct llama_context { const llama_token * get_sampled_candidates_ith(int32_t idx); size_t get_sampled_candidates_count(int32_t idx); + bool get_dflash_draft_top_k( + const float ** top_logits, + const llama_token ** top_token_ids, + int32_t * n_rows, + int32_t * k); + void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -105,6 +111,92 @@ struct llama_context { void set_causal_attn(bool value); void set_warmup(bool value); + // dflash hidden capture API + void set_capture_hidden(bool enable); + ggml_tensor * get_hidden_capture() const; + void set_dflash_draft_top_k(int32_t k); + + // dflash Phase 2.4: persist-based SSM rollback. + // Copies the SSM state stored in dflash_persist_inter_l[il] at DFS column + // accepted_dfs_node back into the live s_l[il] tensor at the seq's tail cell. + // Returns false if the context has no recurrent memory or buffers are unallocated. + bool dflash_rollback_ssm_to_dfs(llama_seq_id seq_id, int32_t accepted_dfs_node); + + // Host-side accessor: returns pointer into hidden_capture_host (always CPU). + // Returns nullptr if capture is disabled or no decode has run. + const float * get_hidden_capture_data(int64_t * out_ne0, int64_t * out_ne1) const; + + // dflash draft target_feat injection API (Task 1). + // Stashes host pointer + dims so the next llama_decode() on this draft context + // can copy the data into the dflash_target_feat_raw GGML input tensor. + // committed_pos is the number of tokens committed in the target context so far. + void set_target_feat_raw(const float * data, int64_t n_embd_fc, int64_t ctx_len, + int64_t committed_pos); + void set_target_feat_fused(const float * data, int64_t n_embd, int64_t ctx_len, + int64_t committed_pos); + int dflash_draft_fuse_target_feat(const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + float * target_feat_fused); + int dflash_draft_update_fused_cache(const float * target_feat_raw, + int64_t n_embd_fc, + int64_t n_new, + int64_t first_pos, + int64_t cap); + int dflash_draft_update_fused_cache_from_capture(llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap); + int dflash_draft_encode_top_k_cached(const llama_batch & batch_inp, + int64_t n_embd, + int64_t ctx_len, + int64_t ring_start, + int64_t cap, + int64_t committed_pos, + int32_t top_k); + bool dflash_draft_ensure_fused_cache_tensor(int64_t n_embd, + int64_t cap, + ggml_backend_buffer_type_t buft); + bool dflash_draft_ensure_packed_target_feat_tensor(int64_t n_embd, + int64_t ctx_len, + ggml_backend_buffer_type_t buft); + bool dflash_draft_ensure_kv_cache_tensors(int64_t n_embd_head, + int64_t n_head_kv, + int64_t cap, + ggml_backend_buffer_type_t buft); + bool dflash_draft_ensure_packed_kv_tensors(int64_t n_embd_head, + int64_t n_head_kv, + int64_t ctx_len, + ggml_backend_buffer_type_t buft); + bool dflash_draft_ensure_top_output_tensors(int64_t top_k, + int64_t rows, + ggml_backend_buffer_type_t buft); + int dflash_draft_update_kv_cache(const float * target_feat_fused, + int64_t n_embd, + int64_t n_new, + int64_t first_pos, + int64_t cap); + int dflash_draft_pack_kv_cache(int64_t n_embd_head, + int64_t n_head_kv, + int64_t ctx_len, + int64_t ring_start, + int64_t cap); + int dflash_draft_encode_top_k(const llama_batch & batch_inp, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k); + int dflash_draft_encode_top_k_fused(const llama_batch & batch_inp, + const float * target_feat_fused, + int64_t n_embd, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k); + int dflash_draft_encode_top_k_pending(const llama_batch & batch_inp, + int32_t top_k); + void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); @@ -227,6 +319,7 @@ struct llama_context { // returns the result of ggml_backend_sched_graph_compute_async execution ggml_status graph_compute(ggml_cgraph * gf, bool batched); + ggml_status graph_compute(ggml_backend_sched_t sched_use, ggml_cgraph * gf, bool batched); // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve( @@ -239,7 +332,8 @@ struct llama_context { llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const; + llm_graph_type gtype, + ggml_backend_sched_t sched_use = nullptr) const; llm_graph_cb graph_get_cb() const; @@ -291,6 +385,12 @@ struct llama_context { sampling_info sampling; + int32_t dflash_draft_top_k_req = 0; + std::vector dflash_draft_top_logits; + std::vector dflash_draft_top_token_ids; + int32_t dflash_draft_top_rows = 0; + int32_t dflash_draft_top_k = 0; + // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; @@ -334,12 +434,107 @@ struct llama_context { llm_graph_result_ptr gf_res_prev; llm_graph_result_ptr gf_res_reserve; + llm_graph_result_ptr dflash_res_fuse; + llm_graph_result_ptr dflash_res_kv; + llm_graph_result_ptr dflash_res_draft; + ggml_backend_sched_ptr dflash_sched_fuse; + ggml_backend_sched_ptr dflash_sched_kv; + ggml_backend_sched_ptr dflash_sched_draft; // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; bool has_evaluated_once = false; + // dflash hidden capture: when true, qwen35 forward writes captured hidden states + // into a graph output tensor; accessible via get_hidden_capture() after decode. + bool capture_hidden = false; + + // dflash Phase 2.4: per-layer SSM intermediate-state persist buffers. + // Allocated on first tree-mode decode; one tensor per delta-net layer (nullptr for + // full-attn layers). Shape: [S_v, S_v, H_v, n_tokens] F16, contiguous. + // After tree verify, llama_dflash_rollback_ssm_to_dfs() copies column[accepted_dfs_node] + // back into the live SSM state, replacing the snapshot/restore/replay path. + std::vector dflash_persist_inter_l; // [n_layer], nullptr for non-recurrent + ggml_context_ptr dflash_persist_inter_ctx; // ggml context owning the tensors + ggml_backend_buffer_ptr dflash_persist_inter_buf; // backend buffer owning the data + std::vector> + dflash_persist_ctxs_bufs; // per backend buffer type for mixed CPU/GPU offload + int64_t dflash_persist_max_n_tokens = 0; // current capacity + int64_t dflash_persist_failed_n_tokens = 0; // suppress repeated OOM retries + + // dflash Phase 5 fix: per-token conv post-state persist buffer used by + // ggml_ssm_conv_tree_persist. One tensor per delta-net layer; shape + // [K_conv-1, conv_channels, n_tokens] F32. Read by dflash_rollback_ssm_to_dfs + // to roll the live conv state (r_l[il]) back to the accepted DFS node. + std::vector dflash_persist_conv_l; // [n_layer], nullptr for non-recurrent + + // Ensure the persist buffers can hold n_tokens columns; reallocates if needed. + void ensure_dflash_persist_capacity(int64_t n_tokens); + + // Returns the per-layer SSM/conv persist tensors for layer il, or nullptr if not + // a recurrent layer or the buffers have not yet been allocated. + ggml_tensor * dflash_get_persist_inter(int32_t il) const; + ggml_tensor * dflash_get_persist_conv (int32_t il) const; + + // host-side mirror of t_hidden_capture, populated via ggml_backend_tensor_get_async + // after each decode. get_hidden_capture_data() returns into this buffer so callers + // don't dereference device pointers. + mutable std::vector hidden_capture_host; + mutable int64_t hidden_capture_ne0 = 0; + mutable int64_t hidden_capture_ne1 = 0; + mutable bool hidden_capture_host_valid = false; + + // dflash draft target_feat injection: stashed by llama_set_target_feat_raw() before + // llama_decode() on the draft context. The dflash-draft graph input reads from these + // fields in set_input() and copies them into the host-pinned GGML input tensors. + // Non-owning pointer — caller (speculative-tree-driver) owns the lifetime. + // Mutable so graph_params() (a const method) can take their address for the param struct. + mutable const float * pending_target_feat_raw = nullptr; + mutable int64_t pending_target_feat_n_embd_fc = 0; + mutable int64_t pending_target_feat_ctx_len = 0; + mutable int64_t pending_draft_committed_pos = 0; + mutable bool pending_target_feat_fused = false; + mutable bool pending_dflash_fuse_only = false; + mutable bool pending_dflash_kv_update_only = false; + mutable int64_t pending_dflash_kv_update_dst_pos = 0; + mutable ggml_tensor * pending_target_feat_tensor = nullptr; + + ggml_context_ptr dflash_fused_cache_ctx; + ggml_backend_buffer_ptr dflash_fused_cache_buf; + ggml_tensor * dflash_fused_cache = nullptr; + int64_t dflash_fused_cache_n_embd = 0; + int64_t dflash_fused_cache_cap = 0; + + ggml_context_ptr dflash_packed_target_feat_ctx; + ggml_backend_buffer_ptr dflash_packed_target_feat_buf; + ggml_tensor * dflash_packed_target_feat = nullptr; + int64_t dflash_packed_target_feat_n_embd = 0; + int64_t dflash_packed_target_feat_ctx_len = 0; + + ggml_context_ptr dflash_kv_cache_ctx; + ggml_backend_buffer_ptr dflash_kv_cache_buf; + std::vector dflash_k_cache_l; + std::vector dflash_v_cache_l; + int64_t dflash_kv_cache_head_dim = 0; + int64_t dflash_kv_cache_n_head_kv = 0; + int64_t dflash_kv_cache_cap = 0; + + ggml_context_ptr dflash_kv_packed_ctx; + ggml_backend_buffer_ptr dflash_kv_packed_buf; + std::vector dflash_k_packed_l; + std::vector dflash_v_packed_l; + int64_t dflash_kv_packed_head_dim = 0; + int64_t dflash_kv_packed_n_head_kv = 0; + int64_t dflash_kv_packed_ctx_len = 0; + + ggml_context_ptr dflash_top_output_ctx; + ggml_backend_buffer_ptr dflash_top_output_buf; + ggml_tensor * dflash_top_logits_fixed = nullptr; + ggml_tensor * dflash_top_ids_fixed = nullptr; + int64_t dflash_top_output_k = 0; + int64_t dflash_top_output_rows = 0; + // env: LLAMA_GRAPH_REUSE_DISABLE bool graph_reuse_disable = false; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8e2b6ab8e7e1..5a1ae0a3225e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -783,6 +783,17 @@ bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) { return true; } +void llm_graph_input_tree::set_input(const llama_ubatch * ubatch) { + GGML_ASSERT(ubatch->parent_id != nullptr); + GGML_ASSERT(inp_parent_ids != nullptr); + + const int32_t n_tokens = (int32_t) ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(inp_parent_ids->buffer)); + int32_t * data = (int32_t *) inp_parent_ids->data; + memcpy(data, ubatch->parent_id, n_tokens * sizeof(int32_t)); +} + // // llm_graph_result // @@ -801,9 +812,12 @@ int64_t llm_graph_result::get_max_nodes() const { void llm_graph_result::reset() { t_inp_tokens = nullptr; t_inp_embd = nullptr; - t_logits = nullptr; - t_embd = nullptr; - t_embd_pooled = nullptr; + t_logits = nullptr; + t_embd = nullptr; + t_embd_pooled = nullptr; + t_hidden_capture = nullptr; + t_dflash_top_logits = nullptr; + t_dflash_top_ids = nullptr; t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); @@ -842,6 +856,15 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + if (t_hidden_capture != nullptr) { + ggml_set_output(t_hidden_capture); + } + if (t_dflash_top_logits != nullptr) { + ggml_set_output(t_dflash_top_logits); + } + if (t_dflash_top_ids != nullptr) { + ggml_set_output(t_dflash_top_ids); + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); @@ -896,6 +919,81 @@ bool llm_graph_result::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_target_feat::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + // host_data, n_embd_fc, ctx_len are stashed by llama_set_target_feat_raw() before decode. + const float * data = *host_data_ptr; + const int64_t fc = *n_embd_fc_ptr; + const int64_t ctx_len = *ctx_len_ptr; + + if (inp_target_feat_raw && inp_target_feat_raw->buffer != nullptr && data != nullptr) { + GGML_ASSERT(inp_target_feat_raw->ne[0] == fc); + GGML_ASSERT(inp_target_feat_raw->ne[1] == ctx_len); + ggml_backend_tensor_set(inp_target_feat_raw, data, 0, (size_t)fc * ctx_len * sizeof(float)); + } + + // pos_q is local to the draft attention window, not the target's global + // sequence position. The draft attends over target_feat[0..ctx_len) plus + // the block's noise tokens, matching standalone DFlash's draft_ctx+i. + if (inp_pos_q && inp_pos_q->buffer != nullptr) { + const int64_t block_size = inp_pos_q->ne[0]; + std::vector pos_q(block_size); + for (int64_t i = 0; i < block_size; ++i) { + pos_q[i] = (int32_t)(ctx_len + i); + } + ggml_backend_tensor_set(inp_pos_q, pos_q.data(), 0, block_size * sizeof(int32_t)); + } + + // pos_k: [0 .. ctx_len + block_size) + if (inp_pos_k && inp_pos_k->buffer != nullptr) { + const int64_t total_k = inp_pos_k->ne[0]; + std::vector pos_k(total_k); + for (int64_t i = 0; i < total_k; ++i) { + pos_k[i] = (int32_t)i; + } + ggml_backend_tensor_set(inp_pos_k, pos_k.data(), 0, total_k * sizeof(int32_t)); + } +} + +bool llm_graph_input_target_feat::can_reuse(const llm_graph_params & params) { + if (params.pending_target_feat_raw_ptr == nullptr || + params.pending_target_feat_n_embd_fc_ptr == nullptr || + params.pending_target_feat_ctx_len_ptr == nullptr) { + return false; + } + + const int64_t fc = *params.pending_target_feat_n_embd_fc_ptr; + const int64_t ctx_len = *params.pending_target_feat_ctx_len_ptr; + const int64_t n_tokens = params.ubatch.n_tokens; + + if (fc <= 0 || ctx_len <= 0 || n_tokens <= 0) { + return false; + } + + bool res = true; + res &= inp_target_feat_raw != nullptr; + res &= inp_pos_q != nullptr; + res &= inp_pos_k != nullptr; + + if (inp_target_feat_raw) { + res &= inp_target_feat_raw->ne[0] == fc; + res &= inp_target_feat_raw->ne[1] == ctx_len; + } + if (inp_pos_q) { + res &= inp_pos_q->ne[0] == n_tokens; + } + if (inp_pos_k) { + res &= inp_pos_k->ne[0] == ctx_len + n_tokens; + } + + if (debug > 1) { + LLAMA_LOG_DEBUG("%s: can reuse dflash target_feat graph input = %d\n", __func__, res); + } + + return res; +} + llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) { inputs.emplace_back(std::move(input)); return inputs.back().get(); @@ -948,6 +1046,21 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : samplers (params.samplers), cb_func (params.cb), res (params.res), + capture_hidden (params.capture_hidden), + dflash_persist_inter_l(params.dflash_persist_inter_l), + dflash_target_feat_fused(params.dflash_target_feat_fused), + dflash_kv_update_only(params.dflash_kv_update_only), + dflash_fuse_only(params.dflash_fuse_only), + dflash_draft_top_k(params.dflash_draft_top_k), + dflash_persist_conv_l (params.dflash_persist_conv_l), + pending_target_feat_raw_ptr (params.pending_target_feat_raw_ptr), + pending_target_feat_n_embd_fc_ptr(params.pending_target_feat_n_embd_fc_ptr), + pending_target_feat_ctx_len_ptr (params.pending_target_feat_ctx_len_ptr), + pending_draft_committed_pos_ptr (params.pending_draft_committed_pos_ptr), + pending_target_feat_tensor_ptr (params.pending_target_feat_tensor_ptr), + dflash_kv_cache_k_l (params.dflash_kv_cache_k_l), + dflash_kv_cache_v_l (params.dflash_kv_cache_v_l), + dflash_kv_cache_dst_pos (params.dflash_kv_cache_dst_pos), ctx0 (res->get_ctx()), gf (res->get_gf()) { res->set_params(params); @@ -1719,6 +1832,35 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { return cur; } +llm_graph_input_target_feat * llm_graph_context::build_inp_target_feat(int64_t n_embd_fc, int64_t ctx_len) const { + // The graph context holds non-owning pointers into the llama_context's pending_target_feat + // fields, propagated via llm_graph_params. The llama_context outlives every graph invocation, + // so the pointer lifetime is safe within a single decode call. + GGML_ASSERT(pending_target_feat_raw_ptr != nullptr && + "build_inp_target_feat called without pending pointers wired in llm_graph_params"); + auto inp = std::make_unique( + pending_target_feat_raw_ptr, + pending_target_feat_n_embd_fc_ptr, + pending_target_feat_ctx_len_ptr, + pending_draft_committed_pos_ptr); + + const int64_t block_size = n_tokens; // == dflash_block_size at draft invocation time + + inp->inp_target_feat_raw = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_fc, ctx_len); + ggml_set_name(inp->inp_target_feat_raw, "dflash_target_feat_raw"); + ggml_set_input(inp->inp_target_feat_raw); + + inp->inp_pos_q = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, block_size); + ggml_set_name(inp->inp_pos_q, "dflash_pos_q"); + ggml_set_input(inp->inp_pos_q); + + inp->inp_pos_k = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len + block_size); + ggml_set_name(inp->inp_pos_k, "dflash_pos_k"); + ggml_set_input(inp->inp_pos_k); + + return (llm_graph_input_target_feat *) res->add_input(std::move(inp)); +} + ggml_tensor * llm_graph_context::build_inp_attn_scale() const { auto inp = std::make_unique(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset); @@ -1805,6 +1947,23 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { return cur; } +void llm_graph_context::build_inp_tree() const { + auto inp = std::make_unique(); + + inp->inp_parent_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->inp_parent_ids); + ggml_set_name(inp->inp_parent_ids, "parent_ids"); + + // The ancestor-only attention mask is written directly into the standard + // kq_mask buffer in llama_kv_cache::set_input_kq_mask when ubatch->parent_id + // is set, so we don't allocate a separate tree_mask graph input here. + // TODO: phase-1 leaves pos as 1D — M-RoPE 4-axis is UNKNOWN-3 in roadmap + + const_cast(this)->parent_ids = inp->inp_parent_ids; + + res->add_input(std::move(inp)); +} + ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { auto inp = std::make_unique(hparams); diff --git a/src/llama-graph.h b/src/llama-graph.h index 29e78451fbb4..c0a87e6e4ef6 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -511,6 +511,52 @@ class llm_graph_input_sampling : public llm_graph_input_i { std::map samplers; }; +// Input for tree-mode forward. Holds parent_ids; ancestor mask is written +// into the standard kq_mask by llama_kv_cache::set_input_kq_mask when +// ubatch->parent_id is non-null, so no separate mask tensor is allocated here. +class llm_graph_input_tree : public llm_graph_input_i { +public: + llm_graph_input_tree() = default; + virtual ~llm_graph_input_tree() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * inp_parent_ids = nullptr; // I32 [n_tokens] +}; + +// Graph input class for the dflash draft model's target_feat_raw, pos_q, and pos_k tensors. +// The host side stashes the data via llama_set_target_feat_raw() before calling llama_decode() +// on the draft context. set_input() memcpy's the stashed data into the GGML input tensors. +class llm_graph_input_target_feat : public llm_graph_input_i { +public: + // host_data, n_embd_fc, ctx_len are non-owning; they point into llama_context's pending fields. + // committed_pos is the number of tokens already committed before this draft step. + llm_graph_input_target_feat( + const float ** host_data_ptr, // pointer to context's pending_target_feat_raw field + const int64_t * n_embd_fc_ptr, // pointer to context's pending_target_feat_n_embd_fc + const int64_t * ctx_len_ptr, // pointer to context's pending_target_feat_ctx_len + const int64_t * committed_pos_ptr) // pointer to context's pending_draft_committed_pos + : host_data_ptr(host_data_ptr), n_embd_fc_ptr(n_embd_fc_ptr), + ctx_len_ptr(ctx_len_ptr), committed_pos_ptr(committed_pos_ptr) {} + virtual ~llm_graph_input_target_feat() = default; + + void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + // [5*n_embd, ctx_len] F32 — stacked hidden captures from target layers + ggml_tensor * inp_target_feat_raw = nullptr; + // [block_size] I32 — Q positions: [committed_pos .. committed_pos + block_size) + ggml_tensor * inp_pos_q = nullptr; + // [ctx_len + block_size] I32 — K positions: [0 .. ctx_len + block_size) + ggml_tensor * inp_pos_k = nullptr; + +private: + const float ** host_data_ptr; + const int64_t * n_embd_fc_ptr; + const int64_t * ctx_len_ptr; + const int64_t * committed_pos_ptr; +}; + // // llm_graph_result // @@ -567,6 +613,39 @@ struct llm_graph_params { llm_graph_result * res; + // If true, qwen35 forward writes hidden states at dflash_target_capture_layers + // into t_hidden_capture on the result. No-op (zero overhead) when false. + bool capture_hidden = false; + + // dflash draft target_feat injection (Task 1). + // Non-owning pointers into llama_context's pending_target_feat fields. + // Non-null only when running the dflash-draft graph; graph inputs use them in set_input(). + const float ** pending_target_feat_raw_ptr = nullptr; + const int64_t * pending_target_feat_n_embd_fc_ptr = nullptr; + const int64_t * pending_target_feat_ctx_len_ptr = nullptr; + const int64_t * pending_draft_committed_pos_ptr = nullptr; + ggml_tensor * const * pending_target_feat_tensor_ptr = nullptr; + const std::vector * dflash_kv_cache_k_l = nullptr; + const std::vector * dflash_kv_cache_v_l = nullptr; + int64_t dflash_kv_cache_dst_pos = 0; + + bool dflash_target_feat_fused = false; + bool dflash_kv_update_only = false; + bool dflash_fuse_only = false; + int32_t dflash_draft_top_k = 0; + + // dflash Phase 2.4: per-layer SSM intermediate-state persist buffers. + // Non-owning pointer into llama_context::dflash_persist_inter_l. Null when not in + // tree mode or when the buffers have not yet been allocated. Graph builder reads + // (*dflash_persist_inter_l)[il] for each recurrent layer and passes it to + // build_delta_net_tree() as the persist_inter argument. + const std::vector * dflash_persist_inter_l = nullptr; + + // dflash Phase 5: per-layer conv post-state persist buffers (paired with + // dflash_persist_inter_l). Read by ggml_ssm_conv_tree_persist; rolled back + // into r_l[il] after spec verify. + const std::vector * dflash_persist_conv_l = nullptr; + // return true if the "other" params would result in a graph with the same topology as with the current params // having the same topology allows us to reuse the graph in some cases bool allow_reuse(const llm_graph_params & other) const { @@ -625,11 +704,18 @@ struct llm_graph_params { return cparams.embeddings == other.cparams.embeddings && cparams.causal_attn == other.cparams.causal_attn && - arch == other.arch && - gtype == other.gtype && - cvec == other.cvec && - loras == other.loras && - cross == other.cross; + arch == other.arch && + gtype == other.gtype && + cvec == other.cvec && + loras == other.loras && + cross == other.cross && + capture_hidden == other.capture_hidden && + dflash_target_feat_fused == other.dflash_target_feat_fused && + dflash_kv_update_only == other.dflash_kv_update_only && + dflash_fuse_only == other.dflash_fuse_only && + dflash_draft_top_k == other.dflash_draft_top_k && + (dflash_persist_inter_l != nullptr) == (other.dflash_persist_inter_l != nullptr) && + (dflash_persist_conv_l != nullptr) == (other.dflash_persist_conv_l != nullptr); } }; @@ -640,9 +726,12 @@ class llm_graph_result { virtual ~llm_graph_result() = default; ggml_tensor * get_inp_tokens() const { return t_inp_tokens; } - ggml_tensor * get_logits() const { return t_logits; } - ggml_tensor * get_embd() const { return t_embd; } - ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_logits() const { return t_logits; } + ggml_tensor * get_embd() const { return t_embd; } + ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_hidden_capture() const { return t_hidden_capture; } + ggml_tensor * get_dflash_top_logits() const { return t_dflash_top_logits; } + ggml_tensor * get_dflash_top_ids() const { return t_dflash_top_ids; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -666,11 +755,17 @@ class llm_graph_result { void set_params(const llm_graph_params & params); // important graph nodes - ggml_tensor * t_inp_tokens = nullptr; - ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens] - ggml_tensor * t_logits = nullptr; - ggml_tensor * t_embd = nullptr; - ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_inp_tokens = nullptr; + ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens] + ggml_tensor * t_logits = nullptr; + ggml_tensor * t_embd = nullptr; + ggml_tensor * t_embd_pooled = nullptr; + // dflash hidden capture: [5*n_embd, n_tokens] F32, populated when capture_hidden=true in graph_params + ggml_tensor * t_hidden_capture = nullptr; + + // dflash-draft top-K graph outputs: [K, n_tokens] + ggml_tensor * t_dflash_top_logits = nullptr; + ggml_tensor * t_dflash_top_ids = nullptr; std::map t_sampled_logits; std::map t_candidates; @@ -758,9 +853,41 @@ struct llm_graph_context { llm_graph_result * res; + // dflash hidden capture: propagated from llm_graph_params::capture_hidden + const bool capture_hidden; + + // dflash Phase 2.4: per-layer SSM intermediate-state persist buffer pointers. + // Non-owning pointer into llama_context::dflash_persist_inter_l (via graph_params). + // Null when not in tree mode. Indexed by layer index il. + const std::vector * dflash_persist_inter_l; + + bool dflash_target_feat_fused; + bool dflash_kv_update_only; + bool dflash_fuse_only; + int32_t dflash_draft_top_k; + + // dflash Phase 5: per-layer conv post-state persist buffer pointers + // (paired with dflash_persist_inter_l). + const std::vector * dflash_persist_conv_l; + + // dflash draft target_feat injection: propagated from llm_graph_params. + // Non-owning; valid only for the dflash-draft graph builder. + const float ** pending_target_feat_raw_ptr; + const int64_t * pending_target_feat_n_embd_fc_ptr; + const int64_t * pending_target_feat_ctx_len_ptr; + const int64_t * pending_draft_committed_pos_ptr; + ggml_tensor * const * pending_target_feat_tensor_ptr; + const std::vector * dflash_kv_cache_k_l; + const std::vector * dflash_kv_cache_v_l; + int64_t dflash_kv_cache_dst_pos; + ggml_context * ctx0 = nullptr; ggml_cgraph * gf = nullptr; + // tree-mode field: non-null when batch.parent_id is set; consumed by + // ggml_ssm_conv_tree / ggml_gated_delta_net_tree on hybrid layers. + ggml_tensor * parent_ids = nullptr; // [n_tokens] i32 + llm_graph_context(const llm_graph_params & params); virtual ~llm_graph_context() = default; @@ -866,9 +993,17 @@ struct llm_graph_context { ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; + // build tree-mode input tensors (parent_ids + tree_mask); sets this->parent_ids and this->tree_mask + void build_inp_tree() const; + ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; ggml_tensor * build_inp_pos_bucket_dec() const; + + // Build input tensors for the dflash draft model (target_feat_raw, pos_q, pos_k). + // Returns the target_feat_raw tensor (already registered as graph input). + // pos_q and pos_k are accessible via the returned llm_graph_input_target_feat*. + llm_graph_input_target_feat * build_inp_target_feat(int64_t n_embd_fc, int64_t ctx_len) const; ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const; // diff --git a/src/llama-hparams.h b/src/llama-hparams.h index c2000c77c37a..965e5549c1b5 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -212,6 +212,13 @@ struct llama_hparams { // gemma4 per-layer embedding uint32_t n_embd_per_layer = 0; + // dflash-draft specific (27B-specific for now: 5-layer draft targeting Qwen3.5-27B) + // target_capture_layers: indices into target model layers to capture hidden states from + std::array dflash_target_capture_layers = {1, 16, 31, 46, 61}; + uint32_t dflash_target_n_embd = 5120; // target hidden dim (n_embd of target model) + uint32_t dflash_mask_token_id = 248070; // mask token id for noise embedding lookup + uint32_t dflash_block_size = 16; // number of noise tokens per spec step + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggml-org/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 09102f549c8e..ea7c2a7646c3 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -290,12 +290,14 @@ llama_kv_cache::llama_kv_cache( !attn_rot_disable && n_embd_head_k_all > 0 && ggml_is_quantized(type_k) && + type_k != GGML_TYPE_TQ3_0 && hparams.n_embd_head_k() % 64 == 0; attn_rot_v = !attn_rot_disable && n_embd_head_v_all > 0 && ggml_is_quantized(type_v) && + type_v != GGML_TYPE_TQ3_0 && hparams.n_embd_head_v() % 64 == 0; LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all); @@ -606,6 +608,109 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { return cells.seq_pos_max(seq_id); } +void llama_kv_cache::seq_compact_tree( + llama_seq_id seq_id, + const std::vector & accepted_dfs, + int32_t commit_n, + int32_t spine_start) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + GGML_ASSERT(commit_n >= 0 && commit_n <= (int32_t) accepted_dfs.size()); + GGML_ASSERT(spine_start >= 0); + + if (commit_n == 0) { + return; + } + + const uint32_t strm = seq_to_stream[seq_id]; + auto & cells = v_cells[strm]; + + // accepted_dfs[i] is the tree-local DFS index of the i-th accepted node. + // The tree was placed at slots [spine_start, spine_start + N), so absolute + // src/dst slot indices need spine_start added. + // + // Copy K/V rows via ggml_backend_tensor_get/set with explicit offsets + // (ggml_backend_tensor_copy doesn't follow view_src->buffer). + std::vector bounce; + + for (int32_t i = 0; i < commit_n; ++i) { + const int32_t src_slot = spine_start + accepted_dfs[i]; + const int32_t dst_slot = spine_start + i; + + if (src_slot == dst_slot) { + continue; + } + + GGML_ASSERT(src_slot >= 0 && (uint32_t) src_slot < cells.size()); + GGML_ASSERT(dst_slot >= 0 && (uint32_t) dst_slot < cells.size()); + + for (auto & layer : layers) { + if (layer.k) { + const size_t k_row_bytes = ggml_row_size(layer.k->type, layer.k->ne[0]); + const size_t k_row_stride = layer.k->nb[1]; + const size_t k_stride_stream = layer.k->nb[2]; + + if (bounce.size() < k_row_bytes) bounce.resize(k_row_bytes); + + const size_t src_off = strm * k_stride_stream + (size_t) src_slot * k_row_stride; + const size_t dst_off = strm * k_stride_stream + (size_t) dst_slot * k_row_stride; + + ggml_backend_tensor_get(layer.k, bounce.data(), src_off, k_row_bytes); + ggml_backend_tensor_set(layer.k, bounce.data(), dst_off, k_row_bytes); + } + + if (layer.v && !v_trans) { + const size_t v_row_bytes = ggml_row_size(layer.v->type, layer.v->ne[0]); + const size_t v_row_stride = layer.v->nb[1]; + const size_t v_stride_stream = layer.v->nb[2]; + + if (bounce.size() < v_row_bytes) bounce.resize(v_row_bytes); + + const size_t src_off = strm * v_stride_stream + (size_t) src_slot * v_row_stride; + const size_t dst_off = strm * v_stride_stream + (size_t) dst_slot * v_row_stride; + + ggml_backend_tensor_get(layer.v, bounce.data(), src_off, v_row_bytes); + ggml_backend_tensor_set(layer.v, bounce.data(), dst_off, v_row_bytes); + } + } + } + + // Update cell metadata: only touch the tree region [spine_start, spine_start+N). + // Past prompt cells (slots < spine_start) are left untouched. + // + // Snapshot positions from the accepted source slots first to avoid aliasing + // when src_slot < dst_slot. + std::vector accepted_pos(commit_n); + for (int32_t i = 0; i < commit_n; ++i) { + const uint32_t src = (uint32_t) (spine_start + accepted_dfs[i]); + accepted_pos[i] = cells.is_empty(src) ? -1 : cells.pos_get(src); + } + + // Clear all tree slots used by seq_id (i.e. cells with slot >= spine_start + // and pos >= the tree start position). To be conservative, scan the entire + // tree region width: assume the tree had at most max(accepted_dfs)+1 nodes, + // but we don't know N here — use the max of accepted_dfs as a lower bound + // and rely on the caller passing the correct spine_start. Clear all cells + // that belong to this seq with slot >= spine_start. + const uint32_t kv_size = cells.size(); + for (uint32_t slot = (uint32_t) spine_start; slot < kv_size; ++slot) { + if (!cells.is_empty(slot) && cells.seq_has(slot, seq_id)) { + cells.rm(slot); + } + } + + // Set the spine slots [spine_start, spine_start+commit_n) with accepted positions + for (int32_t i = 0; i < commit_n; ++i) { + if (accepted_pos[i] >= 0) { + const uint32_t slot = (uint32_t) (spine_start + i); + cells.pos_set(slot, accepted_pos[i]); + cells.seq_add(slot, seq_id); + } + } + + // Search head moves to just past the spine. + v_heads[strm] = (uint32_t) (spine_start + commit_n); +} + std::map llama_kv_cache::memory_breakdown() const { std::map ret; for (const auto & [ctx, buf] : ctxs_bufs) { @@ -1605,7 +1710,7 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * } } -void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { +void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); @@ -1619,6 +1724,60 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u // n_tps == n_tokens_per_stream const int64_t n_tps = n_tokens/n_stream; + // Tree-mode mask: each query node attends to all past (committed) KV cells + // unconditionally, plus its exact tree ancestors in the current ubatch. + // Do not match tree nodes by position: siblings share the same depth/pos. + if (ubatch->parent_id != nullptr) { + GGML_ASSERT(n_stream == 1 && "tree-mode requires n_stream == 1 in Phase 4"); + GGML_ASSERT(sinfo.n_stream() == 1 && sinfo.size() == n_tokens); + + // Find the boundary between past KV and the current tree ubatch. + llama_pos tree_min_pos = std::numeric_limits::max(); + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + tree_min_pos = std::min(tree_min_pos, ubatch->pos[i]); + } + + std::fill(data, data + n_kv * n_tps, -INFINITY); + + const llama_seq_id seq0 = ubatch->seq_id[0][0]; + const auto & cells = v_cells.at(seq_to_stream[seq0]); + + for (int64_t i = 0; i < (int64_t) n_tokens; ++i) { + // Collect exact KV cell indices for query i's tree ancestors, + // including the node itself. + uint32_t ancestor_slot[64]; + int n_anc = 0; + int32_t cur = (int32_t) i; + while (cur >= 0 && n_anc < 64) { + ancestor_slot[n_anc++] = sinfo.idxs[0][cur]; + const int32_t p = ubatch->parent_id[cur]; + if (p < 0) break; + cur = p; + } + + for (int64_t j = 0; j < n_kv; ++j) { + if (cells.is_empty(j) || !cells.seq_has(j, seq0)) { + continue; + } + const llama_pos p0 = cells.pos_get(j); + bool visible = false; + if (p0 < tree_min_pos) { + // Past KV (prompt or earlier accepted tokens): always visible. + visible = true; + } else { + // Tree region: visible only for the exact ancestor cells. + for (int k = 0; k < n_anc; ++k) { + if ((uint32_t) j == ancestor_slot[k]) { visible = true; break; } + } + } + if (visible) { + data[i * n_kv + j] = 0.0f; + } + } + } + return; + } + //const int64_t t_start = ggml_time_us(); const args_set_input_kq_mask args = { @@ -2488,7 +2647,7 @@ void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_uba } void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { - kv->set_input_kq_mask(dst, ubatch, causal_attn); + kv->set_input_kq_mask(dst, ubatch, causal_attn, sinfos[i_cur]); } void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 0b62dc7b2320..ad7311cb2829 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -179,6 +179,16 @@ class llama_kv_cache : public llama_memory_i { bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info); + // After a tree-verify forward fills KV slots [spine_start, spine_start+N), + // compact the accepted spine: K/V rows from slot (spine_start + accepted_dfs[i]) + // are moved to slot (spine_start + i) for i in [0, commit_n). Cells outside + // [spine_start, spine_start+kv_size) are untouched (preserves prompt prefill). + // Only operates on the stream assigned to seq_id. + void seq_compact_tree(llama_seq_id seq_id, + const std::vector & accepted_dfs, + int32_t commit_n, + int32_t spine_start); + // find a slot of kv cells that can hold the ubatch // if cont == true, then the slot must be continuous // return empty slot_info on failure @@ -202,7 +212,7 @@ class llama_kv_cache : public llama_memory_i { void set_input_k_shift(ggml_tensor * dst) const; - void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, const slot_info & sinfo) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_k_rot(ggml_tensor * dst) const; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 9287fe45e963..77c6bca5e5d7 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -12,6 +12,7 @@ #include #include #include +#include // // llama_memory_recurrent @@ -604,7 +605,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { const int32_t cell_id = s + min; auto & cell = cells[cell_id]; - if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens && ubatch.parent_id == nullptr) { // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", @@ -666,6 +667,147 @@ bool llama_memory_recurrent::get_can_shift() const { return true; } +llama_mem_snapshot_id llama_memory_recurrent::snapshot(llama_seq_id seq_id) { + if (seq_id < 0 || (uint32_t) seq_id >= size) { + return LLAMA_MEM_SNAPSHOT_INVALID; + } + + const int32_t n_layer = (int32_t) r_l.size(); + + snapshot_entry entry; + entry.seq_id = seq_id; + entry.r_backup.resize(n_layer, nullptr); + entry.s_backup.resize(n_layer, nullptr); + + // group backup tensors by buffer type, matching the main cache allocation pattern + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + std::map ctx_map; + + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u * n_layer * ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { return nullptr; } + ctx_map.emplace(buft, ctx); + return ctx; + } + return it->second.get(); + }; + + for (int il = 0; il < n_layer; ++il) { + if (r_l[il] == nullptr) { continue; } + + ggml_backend_buffer_t main_buf = r_l[il]->buffer; + ggml_backend_buffer_type_t buft = main_buf + ? ggml_backend_buffer_get_type(main_buf) + : ggml_backend_cpu_buffer_type(); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { return LLAMA_MEM_SNAPSHOT_INVALID; } + + // one cell worth of r and s + ggml_tensor * rb = ggml_new_tensor_1d(ctx, r_l[il]->type, hparams.n_embd_r()); + ggml_tensor * sb = ggml_new_tensor_1d(ctx, s_l[il]->type, hparams.n_embd_s()); + entry.r_backup[il] = rb; + entry.s_backup[il] = sb; + } + + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); + if (!buf) { return LLAMA_MEM_SNAPSHOT_INVALID; } + ggml_backend_buffer_clear(buf, 0); + entry.ctxs_bufs.emplace_back(std::move(ctx), buf); + } + + // copy current cell state into the backup tensors via a host bounce buffer. + // ggml_backend_tensor_copy doesn't follow view_src->buffer, but tensor_get/set do. + const int32_t cell_id = cells[seq_id].tail; + entry.cell_id = cell_id; + if (cell_id >= 0) { + entry.cell_pos = cells[cell_id].pos; + entry.cell_src = cells[cell_id].src; + } + if (cell_id >= 0) { + std::vector bounce; + for (int il = 0; il < n_layer; ++il) { + if (r_l[il] == nullptr) { continue; } + + const size_t r_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); + const size_t s_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); + + if (bounce.size() < std::max(r_row, s_row)) { + bounce.resize(std::max(r_row, s_row)); + } + + ggml_backend_tensor_get(r_l[il], bounce.data(), (size_t) cell_id * r_row, r_row); + ggml_backend_tensor_set(entry.r_backup[il], bounce.data(), 0, r_row); + + ggml_backend_tensor_get(s_l[il], bounce.data(), (size_t) cell_id * s_row, s_row); + ggml_backend_tensor_set(entry.s_backup[il], bounce.data(), 0, s_row); + } + } + + llama_mem_snapshot_id snap_id = next_snap_id++; + snapshots.emplace(snap_id, std::move(entry)); + return snap_id; +} + +bool llama_memory_recurrent::restore(llama_mem_snapshot_id snap_id) { + auto it = snapshots.find(snap_id); + if (it == snapshots.end()) { return false; } + + const snapshot_entry & entry = it->second; + const llama_seq_id seq_id = entry.seq_id; + + if (seq_id < 0 || (uint32_t) seq_id >= size) { return false; } + + // restore the seq's tail to the cell that was snapshotted + const int32_t cell_id = entry.cell_id; + if (cell_id < 0) { + return true; // no live cell at snapshot time — nothing to restore + } + + cells[seq_id].tail = cell_id; + cells[cell_id].pos = entry.cell_pos; + cells[cell_id].src = entry.cell_src; + + const int32_t n_layer = (int32_t) r_l.size(); + + std::vector bounce; + for (int il = 0; il < n_layer; ++il) { + if (r_l[il] == nullptr) { continue; } + if (entry.r_backup[il] == nullptr) { continue; } + + const size_t r_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); + const size_t s_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); + + if (bounce.size() < std::max(r_row, s_row)) { + bounce.resize(std::max(r_row, s_row)); + } + + ggml_backend_tensor_get(entry.r_backup[il], bounce.data(), 0, r_row); + ggml_backend_tensor_set(r_l[il], bounce.data(), (size_t) cell_id * r_row, r_row); + + ggml_backend_tensor_get(entry.s_backup[il], bounce.data(), 0, s_row); + ggml_backend_tensor_set(s_l[il], bounce.data(), (size_t) cell_id * s_row, s_row); + } + + return true; +} + +void llama_memory_recurrent::release(llama_mem_snapshot_id snap_id) { + snapshots.erase(snap_id); +} + size_t llama_memory_recurrent::total_size() const { size_t size = 0; for (const auto & [_, buf] : ctxs_bufs) { diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d739124..81ed749aed75 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -6,6 +6,7 @@ #include #include +#include #include // @@ -60,6 +61,12 @@ class llama_memory_recurrent : public llama_memory_i { bool get_can_shift() const override; + // snapshot/restore API for recurrent state (SSM + conv) + // snap_id < 0 indicates failure + llama_mem_snapshot_id snapshot(llama_seq_id seq_id); + bool restore (llama_mem_snapshot_id snap_id); + void release (llama_mem_snapshot_id snap_id); + // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; @@ -117,6 +124,25 @@ class llama_memory_recurrent : public llama_memory_i { size_t size_r_bytes() const; size_t size_s_bytes() const; + // per-snapshot: the seq_id that was snapshotted, plus backup tensors for each layer + // and the cell metadata at snapshot time so restore can roll back the position counter. + struct snapshot_entry { + llama_seq_id seq_id; + // backup tensors: one per layer, same type/shape as r_l[il] / s_l[il] for that seq cell + // r_backup[il] and s_backup[il] are null for filtered (null) layers + std::vector r_backup; // [n_layer] + std::vector s_backup; // [n_layer] + // ggml contexts and backend buffers that own the backup tensors + std::vector> ctxs_bufs; + // cell bookkeeping captured at snapshot time + int32_t cell_id = -1; + llama_pos cell_pos = -1; + int32_t cell_src = -1; + }; + + llama_mem_snapshot_id next_snap_id = 0; + std::unordered_map snapshots; + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4e65a45a50d8..6a0abaccf750 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -394,6 +394,9 @@ namespace GGUFMeta { template bool llama_model_loader::get_arr>(enum llm_kv kid, std::vector & result, bool required); + // Explicit instantiation for dflash-draft: target_capture_layers is a fixed 5-element array. + template bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required); + template bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { auto it = kv_overrides.find(key); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b265394ef736..234a8a681a84 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -16,6 +16,7 @@ #include "models/models.h" #include "ggml.h" +#include "ggml-backend.h" #include "ggml-cpp.h" #include @@ -2933,6 +2934,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_DFLASH_DRAFT: + { + // Standard transformer hparams are read from GGUF normally (n_layer=5, + // n_embd=5120, n_head=32, n_head_kv=8, n_embd_head=128, n_ff=2048). + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // dflash-draft specific metadata keys + ml.get_arr("dflash_draft.target_capture_layers", + hparams.dflash_target_capture_layers, false); + ml.get_key("dflash_draft.target_n_embd", hparams.dflash_target_n_embd, false); + ml.get_key("dflash_draft.mask_token_id", hparams.dflash_mask_token_id, false); + ml.get_key("dflash_draft.block_size", hparams.dflash_block_size, false); + + type = LLM_TYPE_UNKNOWN; // draft is a small auxiliary model, no standard type + } break; default: throw std::runtime_error("unsupported model architecture: " + arch_name()); } @@ -7971,6 +7987,56 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_DFLASH_DRAFT: + { + // dflash-draft: 5-layer non-causal speculative decoder. + // token_embd is NOT loaded — token embeddings are looked up from the + // target model at runtime via llama_model_token_embd_lookup. + // lm_head can be shared from the target model. The draft GGUF may + // still contain output.weight for standalone tests, but in server + // mode loading another copy costs about 1 GiB on Qwen3.5-27B. + // out_norm maps to model.output_norm; fc and hidden_norm are stored in dflash_fc / dflash_hidden_norm. + const int64_t n_draft_fc_in = (int64_t)5 * n_embd; // 5 * hidden = 25600 for 27B + + // Top-level: feature-fusion projection and norms + dflash_fc = create_tensor(tn(LLM_TENSOR_DFLASH_FC, "weight"), {n_draft_fc_in, n_embd}, 0); + dflash_hidden_norm = create_tensor(tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_DFLASH_OUT_NORM, "weight"), {n_embd}, 0); + + const llama_model * target_model = params.target_model; + if (target_model != nullptr && target_model->output != nullptr && + target_model->output->ne[0] == n_embd && + target_model->output->ne[1] == n_vocab) { + (void) create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_SKIP); + output = target_model->output; + LLAMA_LOG_INFO("%s: dflash-draft: sharing target output.weight; skipped draft lm_head allocation\n", __func__); + } else { + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // Per-layer pre-norms + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // Attention projections; Q projects to n_head * n_embd_head_k + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // Per-head Q/K norms (qwen3-style) + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + // SwiGLU FFN; intermediate = n_ff (2048 for 27B draft) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -8629,6 +8695,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: case LLM_ARCH_RND1: + case LLM_ARCH_DFLASH_DRAFT: // non-causal draft: no autoregressive KV cache needed { res = nullptr; } break; @@ -9259,6 +9326,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_DFLASH_DRAFT: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -9289,6 +9360,7 @@ llama_model_params llama_model_default_params() { llama_model_params result = { /*.devices =*/ nullptr, /*.tensor_buft_overrides =*/ nullptr, + /*.target_model =*/ nullptr, /*.n_gpu_layers =*/ -1, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, @@ -9511,6 +9583,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: + case LLM_ARCH_DFLASH_DRAFT: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -9658,6 +9731,57 @@ bool llama_model_is_diffusion(const llama_model * model) { return llm_arch_is_diffusion(model->arch); } +int llama_model_token_embd_lookup( + const llama_model * model, + llama_token token, + float * out, + int64_t out_n) { + const ggml_tensor * t = model->tok_embd; + if (!t) { + return -1; + } + + const int64_t n_embd = t->ne[0]; // rows in ggml layout = embedding dim + const int64_t n_vocab = t->ne[1]; + + if (token < 0 || (int64_t)token >= n_vocab) { + return -1; + } + if (out_n < n_embd) { + return -1; + } + + const ggml_type dtype = t->type; + + if (dtype == GGML_TYPE_F32) { + const size_t row_bytes = (size_t)n_embd * sizeof(float); + ggml_backend_tensor_get(t, out, (size_t)token * row_bytes, row_bytes); + return 0; + } + + if (dtype == GGML_TYPE_F16) { + const size_t row_bytes = (size_t)n_embd * sizeof(ggml_fp16_t); + std::vector tmp(n_embd); + ggml_backend_tensor_get(t, tmp.data(), (size_t)token * row_bytes, row_bytes); + ggml_fp16_to_fp32_row(tmp.data(), out, n_embd); + return 0; + } + + // Quantized rows: fetch the raw bytes for one row and dequantize via the + // type's to_float trait. Required for Q4_K_M / Q5_K / etc. target models + // where tok_embd is stored quantized. + const auto * traits = ggml_get_type_traits(dtype); + if (traits != nullptr && traits->to_float != nullptr) { + const size_t row_bytes = ggml_row_size(dtype, n_embd); + std::vector tmp(row_bytes); + ggml_backend_tensor_get(t, tmp.data(), (size_t)token * row_bytes, row_bytes); + traits->to_float(tmp.data(), out, n_embd); + return 0; + } + + return -1; +} + const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } diff --git a/src/llama-model.h b/src/llama-model.h index bba70012e118..5273c977e8b5 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -546,6 +546,10 @@ struct llama_model { struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; + // dflash-draft top-level tensors + struct ggml_tensor * dflash_fc = nullptr; // "fc" [5*n_embd, n_embd] + struct ggml_tensor * dflash_hidden_norm = nullptr; // "hidden_norm" [n_embd] + // gemma3n altup struct ggml_tensor * altup_proj = nullptr; struct ggml_tensor * altup_unembd_proj = nullptr; diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 6bc989c95099..cbc823bbff5d 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -428,6 +428,17 @@ std::pair llm_build_delta_net_base::build_delta_ne ggml_tensor * b, ggml_tensor * s, int il) { + // dispatch to tree variant when parent_ids are available + if (parent_ids != nullptr) { + // Phase 2.4: fetch per-layer persist buffer from graph context if allocated + ggml_tensor * persist_inter = nullptr; + if (dflash_persist_inter_l != nullptr && il >= 0 && + il < (int32_t)dflash_persist_inter_l->size()) { + persist_inter = (*dflash_persist_inter_l)[il]; + } + return build_delta_net_tree(q, k, v, g, b, s, parent_ids, persist_inter, il); + } + const int64_t n_seq_tokens = q->ne[2]; if (n_seq_tokens == 1) { @@ -443,3 +454,43 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } + +std::pair llm_build_delta_net_base::build_delta_net_tree( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + ggml_tensor * par_ids, + ggml_tensor * persist_inter, + int il) { + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + + // ggml_gated_delta_net_tree has the same packed output layout as ggml_gated_delta_net + ggml_tensor * result; + if (persist_inter != nullptr) { + result = ggml_gated_delta_net_tree_persist(ctx0, q, k, v, g, b, s, par_ids, persist_inter); + } else { + result = ggml_gated_delta_net_tree(ctx0, q, k, v, g, b, s, par_ids); + } + cb(result, "fgdn_tree", il); + + ggml_tensor * output = ggml_view_4d(ctx0, result, + S_v, H_v, n_tokens, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens), 0); + + ggml_tensor * new_state = ggml_view_4d(ctx0, result, + S_v, S_v, H_v, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * S_v), + ggml_row_size(result->type, S_v * S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs)); + + return {output, new_state}; +} diff --git a/src/models/dflash-draft.cpp b/src/models/dflash-draft.cpp new file mode 100644 index 000000000000..df4718e661ea --- /dev/null +++ b/src/models/dflash-draft.cpp @@ -0,0 +1,309 @@ +// dflash-draft.cpp — Graph builder for the z-lab/Qwen3.5-27B-DFlash speculative draft model. +// +// Architecture: 5-layer non-causal transformer. +// Inputs (host-provided, set via ggml_set_input): +// - noise_embed : [n_embd, block_size] — pre-looked-up rows from target tok_embd +// - target_feat_raw: [5*n_embd, ctx_len] — stacked hidden captures from target layers +// Forward pass (mirrors qwen3_dflash_graph.cpp:53-164): +// 1. fc(target_feat_raw) → rms_norm(hidden_norm) → target_feat [n_embd, ctx_len] +// 2. For each of 5 layers: +// - Q: from noise_embed only (attn_norm, wq, q_norm, RoPE-NEOX) +// - K/V: from concat(target_feat, noise), then wk/wv, k_norm, RoPE-NEOX +// - Non-causal FlashAttention (mask=nullptr), GQA 32:8 +// - SwiGLU FFN +// 3. out_norm + shared lm_head → logits [vocab, block_size] + +#include "models.h" +#include "llama-impl.h" // LLAMA_TENSOR_NAME_FATTN +#include "llama-graph.h" // llm_graph_input_target_feat, build_inp_target_feat + +#include + +llm_build_dflash_draft::llm_build_dflash_draft( + const llama_model & model, + const llm_graph_params & params) : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_k(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v()); + + // draft constants derived from hparams + const int64_t n_embd_fc = (int64_t)5 * n_embd; // 5*hidden for fc input + + // rope_theta = 10M for draft (matches DFLASH27B_ROPE_THETA) + const float draft_rope_theta = 10000000.0f; + const float scale = 1.0f / sqrtf((float)n_embd_head); + + // ── Draft-specific target-feature input ─────────────────────────────────── + // target_feat_raw / pos_q / pos_k: registered as graph inputs via build_inp_target_feat. + // The host stashes data with llama_set_target_feat_raw() before llama_decode(); the + // graph input class copies it into these GGML tensors at set_input() time. + // + // ctx_len is read directly from the pending context length stashed by the driver before + // llama_decode(). At reservation time (graph preheating), the pointer may be null/zero, + // in which case we fall back to n_ctx as the worst-case upper bound. + const int64_t ctx_len = (pending_target_feat_ctx_len_ptr && *pending_target_feat_ctx_len_ptr > 0) + ? *pending_target_feat_ctx_len_ptr + : n_ctx; + ggml_tensor * cached_target_feat = (pending_target_feat_tensor_ptr != nullptr) + ? *pending_target_feat_tensor_ptr + : nullptr; + const bool use_cached_target_feat = cached_target_feat != nullptr; + const int64_t target_feat_width = dflash_target_feat_fused ? n_embd : n_embd_fc; + llm_graph_input_target_feat * inp_tf = build_inp_target_feat(target_feat_width, ctx_len); + + ggml_tensor * target_feat_in = use_cached_target_feat ? cached_target_feat : inp_tf->inp_target_feat_raw; + ggml_tensor * pos_q = inp_tf->inp_pos_q; + ggml_tensor * pos_k = inp_tf->inp_pos_k; + + // ── Step 1: feature fusion ──────────────────────────────────────────────── + // target_feat = rms_norm(fc @ target_feat_raw, hidden_norm) + // The dedicated draft runtime can pass a cached fused target_feat directly. + ggml_tensor * target_feat = target_feat_in; + if (!dflash_target_feat_fused) { + if (use_cached_target_feat && target_feat_in->ne[0] == n_embd && target_feat_in->ne[1] == 5*ctx_len) { + ggml_tensor * packed = nullptr; + for (int l = 0; l < 5; ++l) { + ggml_tensor * layer = ggml_view_2d(ctx0, target_feat_in, n_embd, ctx_len, + (size_t)n_embd * ggml_element_size(target_feat_in), + (size_t)l * ctx_len * n_embd * ggml_element_size(target_feat_in)); + packed = packed == nullptr ? layer : ggml_concat(ctx0, packed, layer, 0); + } + target_feat_in = packed; + } + // fc: [n_embd_fc, n_embd] (ggml: ne[0]=n_embd_fc, ne[1]=n_embd) + // target_feat_raw: [n_embd_fc, ctx_len] + // Result: [n_embd, ctx_len] + target_feat = ggml_mul_mat(ctx0, model.dflash_fc, target_feat_in); + cb(target_feat, "dflash_fc_out", -1); + + target_feat = ggml_rms_norm(ctx0, target_feat, hparams.f_norm_rms_eps); + target_feat = ggml_mul(ctx0, target_feat, model.dflash_hidden_norm); + } + GGML_ASSERT(target_feat->ne[0] == n_embd); + GGML_ASSERT(target_feat->ne[1] == ctx_len); + cb(target_feat, "dflash_target_feat", -1); + + if (dflash_fuse_only && !dflash_kv_update_only) { + res->t_embd = target_feat; + ggml_build_forward_expand(gf, target_feat); + return; + } + + if (dflash_kv_update_only) { + GGML_ASSERT(dflash_kv_cache_k_l != nullptr); + GGML_ASSERT(dflash_kv_cache_v_l != nullptr); + GGML_ASSERT((int64_t) dflash_kv_cache_k_l->size() >= n_layer); + GGML_ASSERT((int64_t) dflash_kv_cache_v_l->size() >= n_layer); + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + + ggml_tensor * K = ggml_mul_mat(ctx0, layer.wk, target_feat); + K = ggml_reshape_3d(ctx0, K, n_embd_head, n_head_kv, ctx_len); + K = ggml_rms_norm(ctx0, K, hparams.f_norm_rms_eps); + K = ggml_mul(ctx0, K, layer.attn_k_norm); + cb(K, "dflash_k_cache_update", il); + + ggml_tensor * V = ggml_mul_mat(ctx0, layer.wv, target_feat); + V = ggml_reshape_3d(ctx0, V, n_embd_head, n_head_kv, ctx_len); + cb(V, "dflash_v_cache_update", il); + + ggml_tensor * Kdst = (*dflash_kv_cache_k_l)[il]; + ggml_tensor * Vdst = (*dflash_kv_cache_v_l)[il]; + GGML_ASSERT(Kdst != nullptr && Vdst != nullptr); + const int64_t n_el = n_embd_head * n_head_kv * ctx_len; + const size_t dst_off = (size_t) dflash_kv_cache_dst_pos * n_embd_head * n_head_kv * sizeof(float); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, K, n_el, 0), + ggml_view_1d(ctx0, Kdst, n_el, dst_off))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, V, n_el, 0), + ggml_view_1d(ctx0, Vdst, n_el, dst_off))); + } + res->t_embd = target_feat; + return; + } + + // noise_embed: pre-computed embedding rows [n_embd, block_size] — host fills this + // through ubatch.embd. It must be registered as a graph input; merely calling + // ggml_set_input() is not enough for llama_decode() to populate it. + auto inp_noise = std::make_unique(n_embd); + inp_noise->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_name(inp_noise->embd, "dflash_noise_embed"); + ggml_set_input(inp_noise->embd); + ggml_tensor * noise_embed = inp_noise->embd; + res->add_input(std::move(inp_noise)); + + // ── Step 2: position tensors ────────────────────────────────────────────── + // Q positions: [ctx_len .. ctx_len + block_size) in draft-window-local + // coordinates, matching standalone DFlash. + // K positions: [0 .. ctx_len + block_size) + // Both tensors were created and registered by build_inp_target_feat() above. + // set_input() fills them from pending_draft_committed_pos before each decode. + const int64_t total_k = ctx_len + n_tokens; + + // ── Step 3: 5-layer decoder ─────────────────────────────────────────────── + ggml_tensor * h = noise_embed; // [n_embd, block_size] + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + + // -- Attention pre-norm on noise hidden state + ggml_tensor * hn = ggml_rms_norm(ctx0, h, hparams.f_norm_rms_eps); + hn = ggml_mul(ctx0, hn, layer.attn_norm); // layer.attn_norm: [n_embd] + cb(hn, "attn_norm", il); + + // -- Q from noise only: wq [n_embd, n_head*n_embd_head], reshaped, q_norm, RoPE + ggml_tensor * Q = ggml_mul_mat(ctx0, layer.wq, hn); // [n_head*n_embd_head, block_size] + Q = ggml_reshape_3d(ctx0, Q, n_embd_head, n_head, n_tokens); // [n_embd_head, n_head, block_size] + Q = ggml_rms_norm(ctx0, Q, hparams.f_norm_rms_eps); // per-head rms_norm along n_embd_head + Q = ggml_mul(ctx0, Q, layer.attn_q_norm); // broadcast [n_embd_head] + cb(Q, "Q_normed", il); + + // Q RoPE-NEOX + Q = ggml_rope_ext(ctx0, Q, pos_q, nullptr, + (int)n_embd_head, + (int)LLAMA_ROPE_TYPE_NEOX, + /*n_ctx_orig=*/0, + draft_rope_theta, + /*freq_scale=*/1.0f, + /*ext_factor=*/0.0f, + /*attn_factor=*/1.0f, + /*beta_fast=*/0.0f, + /*beta_slow=*/0.0f); + cb(Q, "Q_rope", il); + + // -- K and V from concat(target_feat, noise) + ggml_tensor * K = nullptr; + ggml_tensor * V = nullptr; + const bool use_cached_kv = dflash_kv_cache_k_l != nullptr && dflash_kv_cache_v_l != nullptr && + (int64_t) dflash_kv_cache_k_l->size() > il && + (int64_t) dflash_kv_cache_v_l->size() > il && + (*dflash_kv_cache_k_l)[il] != nullptr && + (*dflash_kv_cache_v_l)[il] != nullptr; + if (use_cached_kv) { + ggml_tensor * Kctx = (*dflash_kv_cache_k_l)[il]; + ggml_tensor * Vctx = (*dflash_kv_cache_v_l)[il]; + GGML_ASSERT(Kctx->ne[0] == n_embd_head && Kctx->ne[1] == n_head_kv && Kctx->ne[2] == ctx_len); + GGML_ASSERT(Vctx->ne[0] == n_embd_head && Vctx->ne[1] == n_head_kv && Vctx->ne[2] == ctx_len); + + ggml_tensor * Kn = ggml_mul_mat(ctx0, layer.wk, hn); + Kn = ggml_reshape_3d(ctx0, Kn, n_embd_head, n_head_kv, n_tokens); + Kn = ggml_rms_norm(ctx0, Kn, hparams.f_norm_rms_eps); + Kn = ggml_mul(ctx0, Kn, layer.attn_k_norm); + + ggml_tensor * Vn = ggml_mul_mat(ctx0, layer.wv, hn); + Vn = ggml_reshape_3d(ctx0, Vn, n_embd_head, n_head_kv, n_tokens); + + K = ggml_concat(ctx0, Kctx, Kn, 2); + V = ggml_concat(ctx0, Vctx, Vn, 2); + cb(K, "K_normed", il); + } else { + // First compute K/V from target_feat (ctx_len tokens) + ggml_tensor * Kctx = ggml_mul_mat(ctx0, layer.wk, target_feat); // [n_head_kv*n_embd_head, ctx_len] + ggml_tensor * Vctx = ggml_mul_mat(ctx0, layer.wv, target_feat); + + // Then from noise (block_size tokens) + ggml_tensor * Kn = ggml_mul_mat(ctx0, layer.wk, hn); // [n_head_kv*n_embd_head, block_size] + ggml_tensor * Vn = ggml_mul_mat(ctx0, layer.wv, hn); + + // Concat along sequence dimension (ne[1]) + K = ggml_concat(ctx0, Kctx, Kn, 1); // [n_head_kv*n_embd_head, total_k] + V = ggml_concat(ctx0, Vctx, Vn, 1); + + // Per-head K norm + K = ggml_reshape_3d(ctx0, K, n_embd_head, n_head_kv, total_k); + K = ggml_rms_norm(ctx0, K, hparams.f_norm_rms_eps); + K = ggml_mul(ctx0, K, layer.attn_k_norm); + cb(K, "K_normed", il); + + V = ggml_reshape_3d(ctx0, V, n_embd_head, n_head_kv, total_k); + } + + // K RoPE-NEOX + K = ggml_rope_ext(ctx0, K, pos_k, nullptr, + (int)n_embd_head, + (int)LLAMA_ROPE_TYPE_NEOX, + 0, + draft_rope_theta, + 1.0f, 0.0f, 1.0f, 0.0f, 0.0f); + cb(K, "K_rope", il); + + // Permute into flash_attn_ext layout + // Q: [n_embd_head, n_head, block_size, 1] + // K: [n_embd_head, n_head_kv, total_k, 1] + // V: [n_embd_head, n_head_kv, total_k, 1] (not transposed) + Q = ggml_permute(ctx0, Q, 0, 2, 1, 3); + Q = ggml_cont(ctx0, Q); + K = ggml_permute(ctx0, K, 0, 2, 1, 3); + K = ggml_cont(ctx0, K); + V = ggml_permute(ctx0, V, 0, 2, 1, 3); + V = ggml_cont(ctx0, V); + + // Non-causal flash attention; mask=nullptr, GQA broadcast handled internally. + ggml_tensor * attn = ggml_flash_attn_ext(ctx0, Q, K, V, + /*mask=*/nullptr, + scale, + /*max_bias=*/0.0f, + /*logit_softcap=*/0.0f); + // Name the FA tensor so sched_reserve's auto_fa name-prefix assert passes. + cb(attn, LLAMA_TENSOR_NAME_FATTN, il); + // attn: [n_embd_head, n_head, block_size, 1] + attn = ggml_reshape_2d(ctx0, attn, n_embd_head * n_head, n_tokens); + cb(attn, "attn_out", il); + + // Output projection + residual + ggml_tensor * attn_proj = ggml_mul_mat(ctx0, layer.wo, attn); + h = ggml_add(ctx0, h, attn_proj); + cb(h, "attn_residual", il); + + // -- FFN pre-norm + ggml_tensor * hf = ggml_rms_norm(ctx0, h, hparams.f_norm_rms_eps); + hf = ggml_mul(ctx0, hf, layer.ffn_norm); + cb(hf, "ffn_norm", il); + + // SwiGLU: down(silu(gate(x)) * up(x)) + ggml_tensor * g = ggml_mul_mat(ctx0, layer.ffn_gate, hf); + g = ggml_silu(ctx0, g); + ggml_tensor * u = ggml_mul_mat(ctx0, layer.ffn_up, hf); + ggml_tensor * gu = ggml_mul(ctx0, g, u); + ggml_tensor * ffn_out = ggml_mul_mat(ctx0, layer.ffn_down, gu); + cb(ffn_out, "ffn_out", il); + + h = ggml_add(ctx0, h, ffn_out); + cb(h, "l_out", il); + } + + // ── Step 4: final norm + lm_head ───────────────────────────────────────── + ggml_tensor * out = ggml_rms_norm(ctx0, h, hparams.f_norm_rms_eps); + out = ggml_mul(ctx0, out, model.output_norm); // model.output_norm == dflash out_norm + cb(out, "result_norm", -1); + res->t_embd = out; + + // lm_head is shared from target model — it must be provided via model.output. + // If not yet wired (Phase 3 test mode), output is the hidden state only. + if (model.output != nullptr) { + ggml_tensor * logits = ggml_mul_mat(ctx0, model.output, out); + cb(logits, "result_output", -1); + + if (dflash_draft_top_k > 0) { + const int top_k = std::min(dflash_draft_top_k, logits->ne[0]); + + ggml_tensor * top_ids = ggml_top_k(ctx0, logits, top_k); + cb(top_ids, "dflash_top_ids", -1); + + ggml_tensor * logits_rows = ggml_reshape_3d(ctx0, logits, 1, logits->ne[0], n_tokens); + ggml_tensor * top_logits = ggml_get_rows(ctx0, logits_rows, top_ids); + top_logits = ggml_reshape_2d(ctx0, top_logits, top_k, n_tokens); + cb(top_logits, "dflash_top_logits", -1); + + res->t_dflash_top_ids = top_ids; + res->t_dflash_top_logits = top_logits; + ggml_build_forward_expand(gf, top_ids); + ggml_build_forward_expand(gf, top_logits); + } else { + res->t_logits = logits; + ggml_build_forward_expand(gf, logits); + } + } else { + ggml_build_forward_expand(gf, out); + } +} diff --git a/src/models/models.h b/src/models/models.h index a6682ebb287d..e2d4e2cc6303 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -56,6 +56,7 @@ struct llm_build_delta_net_base : public llm_graph_context { int il); // choose one of two implementations above based on the number of tokens + // if parent_ids != nullptr, dispatches to build_delta_net_tree std::pair build_delta_net( ggml_tensor * q, ggml_tensor * k, @@ -64,6 +65,20 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // tree-mode variant: uses ggml_ssm_conv_tree + ggml_gated_delta_net_tree + // When persist_inter != nullptr, calls ggml_gated_delta_net_tree_persist to write + // per-token intermediate states to the provided external buffer. + std::pair build_delta_net_tree( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + ggml_tensor * par_ids, // [n_tokens] i32 + ggml_tensor * persist_inter, // optional, may be null + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { @@ -182,6 +197,10 @@ struct llm_build_dbrx : public llm_graph_context { llm_build_dbrx(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_dflash_draft : public llm_graph_context { + llm_build_dflash_draft(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_deci : public llm_graph_context { llm_build_deci(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 28df353050b9..a52ace8c2327 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -14,10 +14,34 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa ggml_tensor * cur; ggml_tensor * inpL; + // dflash hidden capture: per-slot tensors collected during the forward pass. + // Slots are ggml_concat'd into a single [n_embd, 5*n_tokens] tensor AFTER + // the layer loop and registered as t_hidden_capture (OUTPUT). This ensures + // the concat node is a real compute graph output that gallocr / the sched + // execute and sync back to the host — avoiding the INPUT-leaf + cpy-to-view + // pattern which silently produces all-zeros on GPU (cpy dst is CPU-pinned + // while the src lives on the device backend). + ggml_tensor * cap_slots[5] = {nullptr, nullptr, nullptr, nullptr, nullptr}; + inpL = build_inp_embd(model.tok_embd); cb(inpL, "model.input_embed", -1); + // build tree-mode inputs when parent_ids are present in the ubatch. + // LLAMA_DDTREE_FORCE_CHAIN_KERNEL=1 skips the tree input wiring; downstream + // conv/delta-net dispatch then falls back to the chain kernel (parent_ids + // member stays null). Diagnostic only — sibling/cousin tokens are wrong, + // root token stays equivalent to chain. + if (ubatch.parent_id != nullptr) { + static const bool s_ddtree_force_chain_kernel = []{ + const char * e = getenv("LLAMA_DDTREE_FORCE_CHAIN_KERNEL"); + return e && e[0] == '1'; + }(); + if (!s_ddtree_force_chain_kernel) { + build_inp_tree(); + } + } + auto * inp = build_inp_mem_hybrid(); ggml_tensor * inp_pos = build_inp_pos(); @@ -64,6 +88,21 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_ffn", il); + // dflash hidden capture: stash cur in cap_slots[k] for later concat. + // Critical invariant: this block is NOT entered when capture_hidden==false, + // so the baseline qwen35 forward is byte-for-byte unchanged. + if (capture_hidden) { + const auto & cl = hparams.dflash_target_capture_layers; + for (int k = 0; k < 5; ++k) { + if ((int)cl[k] == il) { + // ggml_cont ensures the slot is a standalone contiguous node + // (cur may be a non-owning view after certain ops). + cap_slots[k] = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens)); + break; + } + } + } + cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -72,6 +111,24 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa } cur = inpL; + // dflash hidden capture: concat the 5 collected slots along dim 1 into + // [n_embd, 5*n_tokens] and register as t_hidden_capture (OUTPUT). + // The concat result is a regular compute node — gallocr schedules it on the + // device backend and the sched syncs it to host after the forward pass. + if (capture_hidden) { + for (int k = 0; k < 5; ++k) { + GGML_ASSERT(cap_slots[k] != nullptr && + "dflash_target_capture_layers must cover all 5 slots; check hparams"); + } + ggml_tensor * cap = ggml_concat(ctx0, cap_slots[0], cap_slots[1], 1); + cap = ggml_concat(ctx0, cap, cap_slots[2], 1); + cap = ggml_concat(ctx0, cap, cap_slots[3], 1); + cap = ggml_concat(ctx0, cap, cap_slots[4], 1); + ggml_set_name(cap, "dflash_hidden_capture"); + ggml_build_forward_expand(gf, cap); + res->t_hidden_capture = cap; + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -280,7 +337,22 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); cb(state, "state_predelta", il); - ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + // use tree conv when parent_ids are set; identical output shape to ggml_ssm_conv. + // Phase 5 fix: when a per-layer conv-persist buffer is allocated, use the + // _persist variant so each token writes its post-state for SSM rollback. + ggml_tensor * conv_persist = nullptr; + if (parent_ids != nullptr && dflash_persist_conv_l != nullptr && + il >= 0 && il < (int32_t)dflash_persist_conv_l->size()) { + conv_persist = (*dflash_persist_conv_l)[il]; + } + ggml_tensor * conv_output_proper; + if (parent_ids != nullptr) { + conv_output_proper = (conv_persist != nullptr) + ? ggml_ssm_conv_tree_persist(ctx0, conv_input, conv_kernel, parent_ids, conv_persist) + : ggml_ssm_conv_tree (ctx0, conv_input, conv_kernel, parent_ids); + } else { + conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + } cb(conv_output_proper, "conv_output_raw", il); ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); diff --git a/tests/.gitignore b/tests/.gitignore index 52b292b1f878..aea409dd1b9c 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1,6 +1,8 @@ * !*.* !snapshots/ +!fixtures/ +!fixtures/** *.o ggml-common.h **/*.swp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cd4bc5ef1d31..0e3dcd61c728 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -210,6 +210,8 @@ llama_build_and_test( peg-parser/tests.h ) llama_build_and_test(test-regex-partial.cpp) +llama_build_and_test(test-speculative-tree.cpp) +llama_build_and_test(test-speculative-draft-backend.cpp) if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") set(MODEL_NAME "tinyllamas/stories15M-q4_0.gguf") @@ -298,3 +300,59 @@ if (TARGET gguf-model-data) target_link_libraries(export-graph-ops PRIVATE gguf-model-data) target_compile_definitions(export-graph-ops PRIVATE LLAMA_HF_FETCH) endif() + +# DDTree Phase 1 acceptance test. +# Requires a 16+ GB Qwen3.5-27B GGUF not available in CI. +# Enable manually: cmake -DLLAMA_BUILD_TESTS_QWEN35_TREE=ON ... +option(LLAMA_BUILD_TESTS_QWEN35_TREE "Build DDTree tree-mode acceptance test (requires large GGUF)" OFF) +if (LLAMA_BUILD_TESTS_QWEN35_TREE) + llama_build(test-qwen35-tree.cpp) + # Note: NOT registered with llama_test / add_test — must be run manually. +endif() + +# DDTree diagnostic: chain forward vs single tree-mode root forward. +# Used to isolate whether the long-prompt ddtree bug is in the tree kernels +# or above (driver / KV slot / capture). Same large GGUF requirement. +option(LLAMA_BUILD_TESTS_QWEN35_ROOT_VS_CHAIN "Build DDTree root-vs-chain diagnostic (requires large GGUF)" OFF) +if (LLAMA_BUILD_TESTS_QWEN35_ROOT_VS_CHAIN) + llama_build(test-qwen35-root-vs-chain.cpp) + # Manual run only. +endif() + +# DDTree Phase 2 acceptance test: snapshot/restore symmetry. +# Requires a 16+ GB Qwen3.5-27B GGUF not available in CI. +# Requires the llama_seq_snapshot / llama_seq_restore / llama_seq_release API +# delivered by the Phase 2 implementation agent. +# Enable manually: cmake -DLLAMA_BUILD_TESTS_QWEN35_TREE_ROLLBACK=ON ... +option(LLAMA_BUILD_TESTS_QWEN35_TREE_ROLLBACK "Build DDTree snapshot/restore acceptance test (requires large GGUF)" OFF) +if (LLAMA_BUILD_TESTS_QWEN35_TREE_ROLLBACK) + llama_build(test-qwen35-tree-rollback.cpp) + # Note: NOT registered with llama_test / add_test — must be run manually. +endif() + +# DDTree Phase 3 acceptance tests: dflash-draft forward + hidden-state capture. +# Requires a 16+ GB Qwen3.5-27B GGUF and a converted dflash-draft GGUF. +# Requires Phase 3 API from the implementation agent: +# llama_model_token_embd_lookup, llama_set_capture_hidden, llama_get_hidden_capture. +# Enable manually: cmake -DLLAMA_BUILD_TESTS_DFLASH_DRAFT=ON ... +option(LLAMA_BUILD_TESTS_DFLASH_DRAFT "Build DDTree Phase 3 dflash-draft and chain-capture acceptance tests (requires large GGUFs)" OFF) +if (LLAMA_BUILD_TESTS_DFLASH_DRAFT) + llama_build(test-dflash-draft.cpp) + llama_build(test-qwen35-chain-capture.cpp) + # Note: NOT registered with llama_test / add_test — must be run manually. +endif() + +# DDTree Phase 4 end-to-end speculative decode acceptance test. +# Requires Qwen3.5-27B GGUF (~16 GB) and dflash-draft GGUF. Not available in CI. +# Requires Phase 4 API from the implementation agent: +# llama_speculative_tree_driver_init / _step / _free (common/speculative-tree-driver.h) +# llama_set_target_feat_raw (Phase 3 gap, llama.h) +# Enable manually: cmake -DLLAMA_BUILD_TESTS_SPECULATIVE_TREE_E2E=ON ... +option(LLAMA_BUILD_TESTS_SPECULATIVE_TREE_E2E + "Build DDTree Phase 4 end-to-end speculative decode acceptance test (requires large GGUFs)" + OFF) +if (LLAMA_BUILD_TESTS_SPECULATIVE_TREE_E2E) + llama_build(test-speculative-tree-e2e.cpp) + # Note: NOT registered with llama_test / add_test — must be run manually. + # Both --target-model and --draft-model are required at runtime. +endif() diff --git a/tests/fixtures/ddtree/README.md b/tests/fixtures/ddtree/README.md new file mode 100644 index 000000000000..de6a84d11a7d --- /dev/null +++ b/tests/fixtures/ddtree/README.md @@ -0,0 +1,332 @@ +# DDTree Phase 1 test fixtures + +## Files + +- `tree_5node.json` — 5-node tree fixture (1 root + 4 verify nodes). Used by `test-qwen35-tree --mode tree`. Token IDs are mid-range Qwen3.5 placeholders; regenerate from real sampled tokens as needed. + +- `short_prompt.bin` — 16 int32 LE token IDs for chain-mode warm-up (Test 1.A). Hardcoded Qwen3.5 BOS/system-prompt tokens. Regenerate with `make_short_prompt.py`. + +- `make_short_prompt.py` — standalone script; reads a text string and writes int32 LE tokens to `short_prompt.bin`. Requires `llama-cpp-python`. Falls back to the hardcoded 16-token fixture if the package is absent. + +## Build + +The test binary is gated behind a CMake option (off by default, not in ctest): + +``` +cmake -DLLAMA_BUILD_TESTS_QWEN35_TREE=ON .. +make test-qwen35-tree +``` + +Requires a Qwen3.5-27B GGUF (~16 GB). Not available in CI. + +## Running (castle only) + +Test 1.A — chain mode does not regress vs current fork master: + +```bash +./test-qwen35-tree --mode chain \ + --model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --prompt-tokens fixtures/ddtree/short_prompt.bin \ + --out-logits /tmp/chain.bin + +# Compare against a golden dump produced by an unmodified fork build: +python3 scripts/compare_logits.py /tmp/chain_golden.bin /tmp/chain.bin +``` + +Test 1.B — tree mode aligns with test_dflash (blocked, see below): + +```bash +./test-qwen35-tree --mode tree \ + --model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --tree-fixture fixtures/ddtree/tree_5node.json \ + --out-logits /tmp/tree.bin + +python3 scripts/compare_logits.py /tmp/test_dflash_golden.bin /tmp/tree.bin +``` + +## Blocker: Test 1.B prerequisite + +Test 1.B requires `test_dflash` to support `--dump-verify-logits`, which dumps the per-node logits produced by the dflash tree forward. This flag is a Phase 0 prerequisite listed in roadmap section 7.2 and is not yet implemented. Until it lands, Test 1.B cannot produce a golden reference and cannot be run end-to-end. + +--- + +# DDTree Phase 2 test fixtures + +## Build + +Gated behind a CMake option (off by default, not in ctest): + +``` +cmake -DLLAMA_BUILD_TESTS_QWEN35_TREE_ROLLBACK=ON .. +make test-qwen35-tree-rollback +``` + +Requires Phase 2 `llama_seq_snapshot` / `llama_seq_restore` / `llama_seq_release` API from the implementation agent. + +## Test 2.A — snapshot/restore symmetry + +```bash +./test-qwen35-tree-rollback \ + --model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --prompt-tokens fixtures/ddtree/short_prompt.bin \ + --gen 8 \ + --out-logits-pre /tmp/pre.bin \ + --out-logits-post /tmp/post.bin + +# Both logit dumps must be bit-equal: +python3 scripts/compare_logits.py /tmp/pre.bin /tmp/post.bin --abs-tol 0 --rel-tol 0 +``` + +## Test 2.B — BLOCKED + +Requires `test_dflash --dump-state-at-commit`, a Phase 0 prerequisite not yet implemented. Until that flag lands, Test 2.B (tree partial-accept vs sequential golden state) cannot be run. + +## Test 2.C — deferred + +Long-prompt OOM stress test deferred to Phase 5 server integration. + +--- + +# DDTree Phase 3 test fixtures + +## Files + +- `dflash_draft_metadata_smoke.json` — expected GGUF metadata fields for a + converted dflash-draft model. Used by `check_dflash_draft_gguf.py`. + +## Build + +Both Phase 3 test binaries are gated behind a single CMake option (off by +default, not in ctest): + +``` +cmake -DLLAMA_BUILD_TESTS_DFLASH_DRAFT=ON .. +make test-dflash-draft test-qwen35-chain-capture +``` + +Requires: +- A Qwen3.5-27B GGUF (~16 GB). +- A converted dflash-draft GGUF (see conversion step below). +- Phase 3 implementation API: `llama_model_token_embd_lookup`, + `llama_set_capture_hidden`, `llama_get_hidden_capture`. + +## Converting safetensors to dflash-draft GGUF + +The conversion script is written by the implementation agent in parallel. +Once it lands at `repo/dflash/scripts/convert_dflash_draft.py`: + +```bash +python repo/dflash/scripts/convert_dflash_draft.py \ + /path/to/dflash_draft/model.safetensors \ + -o /path/to/draft.gguf +``` + +Until the script lands this step is a TODO. + +## Validating the converted GGUF (Test 3.A) + +```bash +python repo/scripts/check_dflash_draft_gguf.py \ + /path/to/draft.gguf \ + tests/fixtures/ddtree/dflash_draft_metadata_smoke.json +# Exit 0: PASS. Exit 1: one line per discrepant field on stderr. +``` + +## Test 3.B — Draft forward bit-equal vs dflash reference + +BLOCKED on Phase 0 prerequisite: `test_dflash --dump-draft-output` flag is +not yet implemented. Until that flag lands, no golden reference exists and +the end-to-end comparison cannot be run. + +The driver (`test-dflash-draft`) can still be used standalone to inspect +draft logits: + +```bash +./test-dflash-draft \ + --target-model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --draft-model /path/to/draft.gguf \ + --last-tok 12345 \ + --target-feat-bin /path/to/target_feat.bin \ + --out-logits /tmp/draft_logits.bin + +# Once the Phase 0 flag lands, compare against the dflash reference dump: +python3 scripts/compare_logits.py /tmp/dflash_draft_golden.bin /tmp/draft_logits.bin +``` + +## Test 3.C — Hidden capture does not break chain mode + +Run both capture and no-capture modes in a single invocation. The driver +asserts logits are bit-equal and that the capture buffer contains valid +(non-NaN, non-zero) values. + +```bash +./test-qwen35-chain-capture \ + --model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --prompt-tokens fixtures/ddtree/short_prompt.bin \ + --out-logits /tmp/capture_logits.bin \ + --out-capture /tmp/capture_buf.bin +# Exit 0: both assertions passed. +``` + +Regression-only mode (skips Mode A, writes no-capture logits for external +comparison against a Phase 1 chain golden dump): + +```bash +./test-qwen35-chain-capture \ + --model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --prompt-tokens fixtures/ddtree/short_prompt.bin \ + --out-logits /tmp/nocapture_logits.bin \ + --out-capture /dev/null \ + --no-capture + +# Compare against Phase 1 chain baseline: +python3 scripts/compare_logits.py /tmp/chain_golden.bin /tmp/nocapture_logits.bin --abs-tol 0 --rel-tol 0 +``` + +--- + +# DDTree Phase 4 test fixtures + +## Build + +The Phase 4 end-to-end test binary is gated behind its own CMake option (off +by default, not in ctest): + +``` +cmake -DLLAMA_BUILD_TESTS_SPECULATIVE_TREE_E2E=ON .. +make test-speculative-tree-e2e +``` + +Requires: +- A Qwen3.5-27B GGUF (~16 GB). +- A converted dflash-draft GGUF. +- Phase 4 implementation API: + `llama_speculative_tree_driver_init` / `_step` / `_free` + (`common/speculative-tree-driver.h`) and `llama_set_target_feat_raw` + (Phase 3 gap; `llama.h`). + +## Test 4.A — Spec-decode token trajectory matches chain reference + +This is the canonical Phase 4 acceptance test. With `--temp 0` (greedy), +DDTree speculative decoding is lossless: the target verifies each draft token +against its own argmax before accepting it. The resulting token sequence MUST +be bit-equal to a plain greedy chain decode from the same prompt. + +```bash +./test-speculative-tree-e2e \ + --target-model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --draft-model /path/to/draft.gguf \ + --prompt-tokens fixtures/ddtree/short_prompt.bin \ + --gen 64 \ + --out-chain /tmp/chain.tokens \ + --out-spec /tmp/spec.tokens \ + --ddtree-budget 22 \ + --temp 0 + +# The driver prints: chain_n=X spec_n=Y first_divergence=none bytes_match=Z/Z +# Exit 0 = PASS. + +# Optional: offline comparison using the script: +python3 scripts/compare_tokens.py /tmp/chain.tokens /tmp/spec.tokens +# Exit 0 = all positions match AND n_a == n_b. +``` + +**Note**: `--temp 0` is required for the bit-equal guarantee. Non-zero +temperature introduces stochastic sampling, which makes the two sequences +non-deterministic relative to each other. With non-zero temp the comparison +is informational only (the driver does not assert bit-equality). + +## Test 4.B — BLOCKED + +Test 4.B (comparison of the spec-decode token sequence against the output of +the `test_dflash` daemon) is blocked on a Phase 0 prerequisite: the +`test_dflash` daemon mode interface (`--daemon` flag) is not yet implemented. +Until that flag lands, no golden `test_dflash` token stream can be produced +for comparison. + +The chain-reference comparison in Test 4.A gives strong independent functional +verification and is sufficient for Phase 4 sign-off. + +--- + +# DDTree Phase 5 test fixtures + +Phase 5 integrates the DDTree driver into `llama-server` as a selectable +speculative-decode mode (`--speculative-mode ddtree`). There are no new +binary test fixtures; validation is done via two shell scripts in the +super-repo `scripts/` directory that run against the server on Castle. + +## New CLI flags (impl agent deliverables) + +| Flag | Type | Default | Notes | +|------|------|---------|-------| +| `--speculative-mode {chain,ddtree}` | string | chain | selects speculative backend | +| `--ddtree-budget N` | int | 22 | max draft tokens per tree step | +| `--ddtree-temp F` | float | 0.0 | draft sampling temperature | +| `--ddtree-no-chain-seed` | bool flag | off | disable chain-seed warmup | + +These flags are parsed in `common/arg.cpp`. The HTTP API surface is +unchanged: same OpenAI-compatible `/v1/chat/completions` and +`/v1/messages` endpoints, SSE streaming, `tool_use`, and +`reasoning_content` all work identically to chain mode. + +Only `--parallel 1` (single slot) is supported in Phase 5. + +## Test 5.A — Smoke test (primary acceptance) + +Run from the local mac: + +```bash +# Default (port 8003, single prompt) +./repo/scripts/run_server_ddtree_castle.sh + +# Custom port and prompt +./repo/scripts/run_server_ddtree_castle.sh 8003 "Write a haiku." +``` + +What the script does: + +1. Verifies the `llama-server` binary exists on Castle. +2. Kills any leftover DDTree-mode server (idempotent). +3. Starts the server via `nohup` on Castle, logging to `/tmp/ddtree_server.log`. +4. Polls `/health` up to 60 s (2 s interval). +5. Sends one non-streaming `POST /v1/chat/completions` and validates + `choices[0].message.content` is non-empty. +6. Sends one streaming request and confirms SSE `data:` lines arrive. +7. Prints the last 50 lines of the server log. +8. Stops the server. +9. Exits 0 (SMOKE PASS) or non-zero (SMOKE FAIL). + +## Test 5.B — Mode comparison (optional / informational) + +```bash +./repo/scripts/compare_server_modes_castle.sh +# or with a custom prompt: +./repo/scripts/compare_server_modes_castle.sh "Describe the sky in exactly 32 tokens." +``` + +Starts both a chain-mode server (port 8001) and a DDTree-mode server +(port 8003) on Castle, sends the same greedy (`temperature: 0`, +`max_tokens: 32`) prompt to each, and reports the first word-level +divergence index. + +**Expected outcome**: divergence at word index >= 17. This matches the +Phase 4 finding that chain and DDTree outputs are bit-equal up to +approximately 17 tokens per speculative-step boundary, then diverge due +to KV-cache / conversation-state differences in the server slot state +machine. Divergence at or above that threshold is not a regression. +Early divergence (word index < 17) warrants investigation. + +The script always exits 0; the comparison is informational. + +## Phase 5 acceptance criteria + +Phase 5 acceptance is **smoke level only**: + +- `run_server_ddtree_castle.sh` exits 0 (SMOKE PASS). +- Non-streaming completion returns valid JSON with non-empty content. +- SSE streaming delivers at least one `data:` chunk. + +Full production replacement of `dflash/scripts/server.py` (pointing +Claude Code at `http://castle.local:8002/v1`) is the user's **manual** +validation step and is outside automated testing scope. diff --git a/tests/fixtures/ddtree/dflash_draft_metadata_smoke.json b/tests/fixtures/ddtree/dflash_draft_metadata_smoke.json new file mode 100644 index 000000000000..dc8099a2000d --- /dev/null +++ b/tests/fixtures/ddtree/dflash_draft_metadata_smoke.json @@ -0,0 +1,11 @@ +{ + "expected_arch": "dflash-draft", + "expected_n_layer": 5, + "expected_n_embd": 5120, + "expected_n_head": 32, + "expected_n_head_kv": 8, + "expected_target_n_embd": 5120, + "expected_mask_token_id": 248070, + "expected_block_size": 16, + "expected_capture_layers": [1, 16, 31, 46, 61] +} diff --git a/tests/fixtures/ddtree/make_short_prompt.py b/tests/fixtures/ddtree/make_short_prompt.py new file mode 100755 index 000000000000..e07eecf79b55 --- /dev/null +++ b/tests/fixtures/ddtree/make_short_prompt.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +""" +make_short_prompt.py + +Tokenize a text string with a Qwen3.5 GGUF and write int32 LE token IDs to +short_prompt.bin next to this script. + +Usage: + python3 make_short_prompt.py --text "You are a helpful assistant." \ + --model-path /path/to/Qwen3.5-27B-Q4_K_M.gguf + +Requires the `gguf` Python package: + pip install gguf + +If the package is unavailable, the script falls back to the hardcoded 16-token +fixture already committed in short_prompt.bin and prints a warning. +""" + +import argparse +import pathlib +import struct +import sys + +SCRIPT_DIR = pathlib.Path(__file__).parent + +FALLBACK_TOKENS = [ + 151644, 8948, 198, 2610, 525, 264, 10950, 17847, + 13, 151645, 198, 151644, 872, 198, 2610, 7291, +] + + +def tokenize_via_gguf(text: str, model_path: str): + try: + from gguf import GGUFReader # type: ignore + except ImportError: + return None + + # GGUFReader gives access to metadata but not a tokenizer runtime. + # For actual tokenization we need llama-cpp-python or similar. + try: + from llama_cpp import Llama # type: ignore + except ImportError: + return None + + llm = Llama(model_path=model_path, vocab_only=True, verbose=False) + tokens = llm.tokenize(text.encode(), add_bos=True, special=True) + return tokens + + +def main(): + parser = argparse.ArgumentParser(description="Generate short_prompt.bin from text.") + parser.add_argument("--text", default="You are a helpful assistant.") + parser.add_argument("--model-path", default="") + parser.add_argument("--out", default=str(SCRIPT_DIR / "short_prompt.bin")) + args = parser.parse_args() + + tokens = None + if args.model_path: + tokens = tokenize_via_gguf(args.text, args.model_path) + if tokens is None: + print( + "WARNING: llama_cpp Python package not available. " + "Writing hardcoded fallback fixture.", + file=sys.stderr, + ) + + if tokens is None: + tokens = FALLBACK_TOKENS + + out_path = pathlib.Path(args.out) + data = struct.pack("<" + "i" * len(tokens), *tokens) + out_path.write_bytes(data) + print(f"wrote {len(tokens)} tokens to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/fixtures/ddtree/short_prompt.bin b/tests/fixtures/ddtree/short_prompt.bin new file mode 100644 index 000000000000..bdea24d7dac1 Binary files /dev/null and b/tests/fixtures/ddtree/short_prompt.bin differ diff --git a/tests/fixtures/ddtree/tree_5node.json b/tests/fixtures/ddtree/tree_5node.json new file mode 100644 index 000000000000..acdc159e4204 --- /dev/null +++ b/tests/fixtures/ddtree/tree_5node.json @@ -0,0 +1,11 @@ +{ + "committed_offset": 0, + "comment": "Phase 1 fixture. 5 nodes: 1 root + 4 verify nodes. Mix of chain + 1 sibling for branching coverage. Token IDs are placeholders in the Qwen3.5 mid-range vocab; real test runs may regenerate via make_short_prompt.py with actual sampled tokens.", + "nodes": [ + { "flat_idx": 0, "token_id": 13374, "parent_idx": -1, "depth": 0 }, + { "flat_idx": 1, "token_id": 7922, "parent_idx": 0, "depth": 1 }, + { "flat_idx": 2, "token_id": 42073, "parent_idx": 1, "depth": 2 }, + { "flat_idx": 3, "token_id": 18435, "parent_idx": 1, "depth": 2 }, + { "flat_idx": 4, "token_id": 30722, "parent_idx": 2, "depth": 3 } + ] +} diff --git a/tests/test-dflash-draft.cpp b/tests/test-dflash-draft.cpp new file mode 100644 index 000000000000..e90c6b1b9aca --- /dev/null +++ b/tests/test-dflash-draft.cpp @@ -0,0 +1,350 @@ +// test-dflash-draft.cpp +// +// Phase 3 acceptance test (Test 3.B) for the dflash-draft model forward pass. +// Loads a Qwen3.5-27B target model and a dflash-draft GGUF, looks up token +// embeddings for a 16-token batch ([last_tok, MASK_TOKEN_ID*15]), reads a +// target-hidden-state feature binary, runs a single draft forward, and dumps +// the 16-position logits. +// +// Build: requires -DLLAMA_BUILD_TESTS_DFLASH_DRAFT=ON (not added to ctest). +// +// API assumptions (implementation agent deliverables): +// LLM_ARCH_DFLASH_DRAFT -- arch string "dflash-draft" +// llama_model_token_embd_lookup(model, tokens, n, out_buf, embd_dim) +// -- fills out_buf with n rows of embd_dim F32 values from model's token +// embedding table. out_buf must be caller-allocated (n * embd_dim floats). +// llama_set_capture_hidden(ctx, bool) -- opt-in to hidden-capture in target +// model; not needed here but shares the header. +// llama_get_hidden_capture(ctx) -- not used here; see test-qwen35-chain-capture. +// +// Draft forward with target_feat injection: +// The draft graph builder (llm_build_dflash_draft) reads a named graph +// input tensor "dflash_target_feat" of shape [ctx_len, n_embd*5] from the +// batch's embd pointer. The caller feeds it by creating a batch with +// embd != NULL, where embd points to: +// [ token_embd_row (embd_dim floats) * 16 (token rows) ] +// [ target_feat (5*embd_dim floats) * ctx_len (feature rows) ] +// The exact layout is defined by the implementation agent. If that layout +// differs from the above, the user will reconcile before compiling. +// +// Output binary format (--out-logits): +// int32_t n_tokens (= 16) +// int32_t vocab_size +// float logits[16 * vocab_size] (row-major, little-endian) + +#include "llama.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// constants +// --------------------------------------------------------------------------- + +static constexpr int32_t DRAFT_BATCH_SIZE = 16; +static constexpr int32_t DEFAULT_MASK_TOK_ID = 248070; + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --target-model PATH (Qwen3.5-27B GGUF; source of token_embd lookup; required)\n" + " --draft-model PATH (dflash-draft GGUF; required)\n" + " --last-tok N (int32 token id; required)\n" + " --target-feat-bin PATH (F32 binary [ctx_len * 5 * embd_dim]; required)\n" + " --ctx-len N (number of positions in target-feat-bin; 0 = derive from file)\n" + " --out-logits PATH (F32 binary output for 16 positions; required)\n" + " --mask-token-id N (override mask token id; default %d)\n" + " --n-gpu-layers N (default 99)\n", + prog, DEFAULT_MASK_TOK_ID); +} + +// Read a raw F32 binary file into a host buffer. +static std::vector read_f32_bin(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open binary file: " + path); + } + f.seekg(0, std::ios::end); + auto sz = f.tellg(); + f.seekg(0, std::ios::beg); + if (sz % sizeof(float) != 0) { + throw std::runtime_error("binary file size not a multiple of 4: " + path); + } + std::vector buf(sz / sizeof(float)); + f.read(reinterpret_cast(buf.data()), sz); + return buf; +} + +static void write_logits(const std::string & path, + const std::vector & data, + int32_t n_tokens, + int32_t vocab_size) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open out-logits for writing: " + path); + } + f.write(reinterpret_cast(&n_tokens), sizeof(int32_t)); + f.write(reinterpret_cast(&vocab_size), sizeof(int32_t)); + f.write(reinterpret_cast(data.data()), + (std::streamsize)(data.size() * sizeof(float))); +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char ** argv) { + std::string target_model_path; + std::string draft_model_path; + std::string target_feat_path; + std::string out_logits_path; + int32_t last_tok = -1; + int32_t ctx_len = 0; + int32_t mask_tok_id = DEFAULT_MASK_TOK_ID; + int32_t n_gpu_layers = 99; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--target-model" && i + 1 < argc) { + target_model_path = argv[++i]; + } else if (arg == "--draft-model" && i + 1 < argc) { + draft_model_path = argv[++i]; + } else if (arg == "--last-tok" && i + 1 < argc) { + last_tok = std::atoi(argv[++i]); + } else if (arg == "--target-feat-bin" && i + 1 < argc) { + target_feat_path = argv[++i]; + } else if (arg == "--ctx-len" && i + 1 < argc) { + ctx_len = std::atoi(argv[++i]); + } else if (arg == "--out-logits" && i + 1 < argc) { + out_logits_path = argv[++i]; + } else if (arg == "--mask-token-id" && i + 1 < argc) { + mask_tok_id = std::atoi(argv[++i]); + } else if (arg == "--n-gpu-layers" && i + 1 < argc) { + n_gpu_layers = std::atoi(argv[++i]); + } else if (arg == "-h" || arg == "--help") { + usage(argv[0]); + return 0; + } else { + fprintf(stderr, "unknown argument: %s\n", arg.c_str()); + usage(argv[0]); + return 1; + } + } + + if (target_model_path.empty()) { fprintf(stderr, "--target-model is required\n"); return 1; } + if (draft_model_path.empty()) { fprintf(stderr, "--draft-model is required\n"); return 1; } + if (last_tok < 0) { fprintf(stderr, "--last-tok is required\n"); return 1; } + if (target_feat_path.empty()) { fprintf(stderr, "--target-feat-bin is required\n"); return 1; } + if (out_logits_path.empty()) { fprintf(stderr, "--out-logits is required\n"); return 1; } + + llama_backend_init(); + + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + + // ------------------------------------------------------------------ + // Step 1: load target model (for token_embd lookup only) + // ------------------------------------------------------------------ + llama_model * target_model = llama_model_load_from_file(target_model_path.c_str(), mparams); + if (!target_model) { + LOG_ERR("failed to load target model: %s\n", target_model_path.c_str()); + llama_backend_free(); + return 1; + } + + // ------------------------------------------------------------------ + // Step 2: load draft model + // ------------------------------------------------------------------ + llama_model * draft_model = llama_model_load_from_file(draft_model_path.c_str(), mparams); + if (!draft_model) { + LOG_ERR("failed to load draft model: %s\n", draft_model_path.c_str()); + llama_model_free(target_model); + llama_backend_free(); + return 1; + } + + int ret = 1; + try { + // ------------------------------------------------------------------ + // Step 3: derive embd_dim from target model + // ------------------------------------------------------------------ + // llama_model_n_embd returns the embedding dimension of the model. + const int32_t embd_dim = llama_model_n_embd(target_model); + LOG_INF("embd_dim = %d\n", embd_dim); + + // ------------------------------------------------------------------ + // Step 4: look up token embeddings for the 16-token batch + // tokens: [last_tok, mask_tok_id, mask_tok_id, ..., mask_tok_id] + // (1 + 15 = 16 tokens) + // ------------------------------------------------------------------ + std::vector batch_tokens(DRAFT_BATCH_SIZE); + batch_tokens[0] = (llama_token)last_tok; + for (int i = 1; i < DRAFT_BATCH_SIZE; ++i) { + batch_tokens[i] = (llama_token)mask_tok_id; + } + + // out_embd: [DRAFT_BATCH_SIZE * embd_dim] floats + // Public API is one token per call; loop over the batch. + std::vector token_embd((size_t)DRAFT_BATCH_SIZE * embd_dim, 0.0f); + for (int i = 0; i < DRAFT_BATCH_SIZE; ++i) { + const int rc = llama_model_token_embd_lookup( + target_model, + batch_tokens[i], + token_embd.data() + (size_t)i * embd_dim, + embd_dim); + if (rc != 0) { + throw std::runtime_error( + "llama_model_token_embd_lookup failed for token " + + std::to_string(batch_tokens[i])); + } + } + LOG_INF("token_embd lookup done (%d tokens x %d dim)\n", DRAFT_BATCH_SIZE, embd_dim); + + // ------------------------------------------------------------------ + // Step 5: read target_feat binary + // ------------------------------------------------------------------ + std::vector target_feat = read_f32_bin(target_feat_path); + + // Derive or validate ctx_len. + // Expected layout: [ctx_len * 5 * embd_dim] floats + const int32_t feat_width = 5 * embd_dim; + if (ctx_len == 0) { + if ((int32_t)target_feat.size() % feat_width != 0) { + throw std::runtime_error( + "target-feat-bin size not divisible by 5*embd_dim=" + + std::to_string(feat_width)); + } + ctx_len = (int32_t)(target_feat.size() / feat_width); + LOG_INF("derived ctx_len = %d from target-feat-bin\n", ctx_len); + } else { + const size_t expected = (size_t)ctx_len * feat_width; + if (target_feat.size() != expected) { + throw std::runtime_error( + "target-feat-bin has " + std::to_string(target_feat.size()) + + " floats, expected " + std::to_string(expected) + + " (ctx_len=" + std::to_string(ctx_len) + + " * feat_width=" + std::to_string(feat_width) + ")"); + } + } + LOG_INF("target_feat: %d positions x %d floats\n", ctx_len, feat_width); + + // ------------------------------------------------------------------ + // Step 6: init draft context + // + // n_ctx must cover both the draft batch (16) and the target feat + // positions (ctx_len). Use the larger of the two. + // ------------------------------------------------------------------ + const int32_t n_ctx_draft = std::max(ctx_len, DRAFT_BATCH_SIZE) + 64; + auto cparams = llama_context_default_params(); + cparams.n_ctx = (uint32_t)n_ctx_draft; + cparams.n_batch = (uint32_t)DRAFT_BATCH_SIZE; + + llama_context * draft_ctx = llama_init_from_model(draft_model, cparams); + if (!draft_ctx) { + throw std::runtime_error("failed to create draft context"); + } + + // ------------------------------------------------------------------ + // Step 7: build the embedding input buffer for the draft forward. + // + // The draft graph builder expects an embd batch where the embd pointer + // contains the concatenation of: + // [token_embd rows: DRAFT_BATCH_SIZE * embd_dim floats] + // [target_feat : ctx_len * 5 * embd_dim floats ] + // + // The batch is created with embd != 0 so llama_decode dispatches the + // embd path. Token IDs are left unset (embd takes precedence). + // + // NOTE: This layout is the current best guess from the roadmap. If + // the implementation agent uses a different mechanism (e.g., a separate + // set_target_feat() call), the user will reconcile and update this + // driver before compiling. + // ------------------------------------------------------------------ + const size_t embd_buf_floats = + (size_t)DRAFT_BATCH_SIZE * embd_dim + + (size_t)ctx_len * feat_width; + + std::vector embd_buf(embd_buf_floats); + // Copy token embeddings first + memcpy(embd_buf.data(), + token_embd.data(), + (size_t)DRAFT_BATCH_SIZE * embd_dim * sizeof(float)); + // Then target_feat + memcpy(embd_buf.data() + (size_t)DRAFT_BATCH_SIZE * embd_dim, + target_feat.data(), + (size_t)ctx_len * feat_width * sizeof(float)); + + // Build a batch that feeds embeddings directly. + // embd = 1 tells llama_batch_init to allocate an embd array; however + // we want to point at our own buffer, so we create the struct manually. + llama_batch batch; + memset(&batch, 0, sizeof(batch)); + batch.n_tokens = DRAFT_BATCH_SIZE; + // embd points to our concatenated buffer + batch.embd = embd_buf.data(); + + // Allocate ancillary arrays on the stack/heap. + std::vector pos_arr(DRAFT_BATCH_SIZE); + std::vector n_seq_id_arr(DRAFT_BATCH_SIZE, 1); + std::vector seq_id_val(DRAFT_BATCH_SIZE, 0); + std::vector seq_id_arr(DRAFT_BATCH_SIZE); + std::vector logits_arr(DRAFT_BATCH_SIZE, 1); + + for (int i = 0; i < DRAFT_BATCH_SIZE; ++i) { + pos_arr[i] = (llama_pos)i; + seq_id_arr[i] = &seq_id_val[i]; + } + batch.pos = pos_arr.data(); + batch.n_seq_id = n_seq_id_arr.data(); + batch.seq_id = seq_id_arr.data(); + batch.logits = logits_arr.data(); + + // ------------------------------------------------------------------ + // Step 8: run draft forward + // ------------------------------------------------------------------ + if (llama_decode(draft_ctx, batch) != 0) { + llama_free(draft_ctx); + throw std::runtime_error("llama_decode (draft) failed"); + } + + // ------------------------------------------------------------------ + // Step 9: collect logits for all 16 positions and dump + // ------------------------------------------------------------------ + const auto * vocab = llama_model_get_vocab(draft_model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + + std::vector logits_out((size_t)DRAFT_BATCH_SIZE * vocab_size); + for (int i = 0; i < DRAFT_BATCH_SIZE; ++i) { + const float * row = llama_get_logits_ith(draft_ctx, i); + memcpy(&logits_out[(size_t)i * vocab_size], row, + vocab_size * sizeof(float)); + } + + llama_free(draft_ctx); + write_logits(out_logits_path, logits_out, DRAFT_BATCH_SIZE, vocab_size); + LOG_INF("draft forward done: wrote %d x %d logits to %s\n", + DRAFT_BATCH_SIZE, vocab_size, out_logits_path.c_str()); + ret = 0; + + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + ret = 1; + } + + llama_model_free(draft_model); + llama_model_free(target_model); + llama_backend_free(); + return ret; +} diff --git a/tests/test-qwen35-chain-capture.cpp b/tests/test-qwen35-chain-capture.cpp new file mode 100644 index 000000000000..78fdf6910c9b --- /dev/null +++ b/tests/test-qwen35-chain-capture.cpp @@ -0,0 +1,354 @@ +// test-qwen35-chain-capture.cpp +// +// Phase 3 acceptance test (Test 3.C) for hidden-state capture. +// +// Two modes are run in a single invocation using the same model/context: +// +// Mode A (capture): +// - Calls llama_set_capture_hidden(ctx, true). +// - Decodes the prompt as a chain batch. +// - Dumps logits for the last token to --out-logits. +// - Reads the hidden capture buffer via llama_get_hidden_capture(). +// - Dumps the capture buffer to --out-capture. +// - Asserts: shape is [5 * hidden_dim, n_tokens]; no NaN/Inf; not all-zero. +// +// Mode B (regression, --no-capture): +// - Calls llama_set_capture_hidden(ctx, false) then re-decodes same prompt. +// - Asserts logits are BIT-EQUAL to Mode A output (same values, not just close). +// - Skips capture dump. +// +// Both modes run inside a single process so logits can be compared in memory. +// The --out-logits file is written once (from Mode A). If Mode B differs, +// the driver exits with code 1 and prints the first discrepant index. +// +// Build: requires -DLLAMA_BUILD_TESTS_DFLASH_DRAFT=ON (not in ctest). +// +// API assumptions (implementation agent deliverables): +// void llama_set_capture_hidden(llama_context * ctx, bool enable) +// -- opt-in to hidden-state capture for the target model. +// ggml_tensor * llama_get_hidden_capture(llama_context * ctx) +// -- returns a pointer to the capture tensor after decode. +// tensor shape: [5 * hidden_dim, n_tokens] (F32, host-accessible). +// Returns NULL if capture was not enabled or graph not yet run. +// +// Output binary format (--out-logits): +// int32_t n_tokens (= 1, last-position logit row) +// int32_t vocab_size +// float logits[vocab_size] +// +// Output binary format (--out-capture): +// int32_t feat_dim (= 5 * hidden_dim) +// int32_t n_tokens (number of prompt tokens) +// float buf[feat_dim * n_tokens] (row-major: row i = token i's features) + +#include "llama.h" +#include "ggml.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --model PATH (Qwen3.5-27B GGUF; required)\n" + " --prompt-tokens PATH (binary int32 LE token IDs; required)\n" + " --out-logits PATH (F32 binary; required)\n" + " --out-capture PATH (F32 binary capture dump; required)\n" + " --no-capture (skip Mode A, only run Mode B regression check)\n" + " --n-gpu-layers N (default 99)\n" + " --n-ctx N (default 4096)\n" + "\n" + "Both capture and no-capture modes run in sequence within one invocation.\n" + "Logits from both modes are compared in memory and must be bit-equal.\n", + prog); +} + +static std::vector read_prompt_tokens(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open prompt-tokens file: " + path); + } + f.seekg(0, std::ios::end); + auto sz = f.tellg(); + f.seekg(0, std::ios::beg); + if (sz % sizeof(int32_t) != 0) { + throw std::runtime_error("prompt-tokens file size not a multiple of 4: " + path); + } + std::vector tokens(sz / sizeof(int32_t)); + f.read(reinterpret_cast(tokens.data()), sz); + return tokens; +} + +static void write_logits(const std::string & path, + const float * data, + int32_t n_tokens, + int32_t vocab_size) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open out-logits for writing: " + path); + } + f.write(reinterpret_cast(&n_tokens), sizeof(int32_t)); + f.write(reinterpret_cast(&vocab_size), sizeof(int32_t)); + f.write(reinterpret_cast(data), + (std::streamsize)((size_t)n_tokens * vocab_size * sizeof(float))); +} + +static void write_capture(const std::string & path, + const float * data, + int32_t feat_dim, + int32_t n_tokens) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open out-capture for writing: " + path); + } + f.write(reinterpret_cast(&feat_dim), sizeof(int32_t)); + f.write(reinterpret_cast(&n_tokens), sizeof(int32_t)); + f.write(reinterpret_cast(data), + (std::streamsize)((size_t)feat_dim * n_tokens * sizeof(float))); +} + +// Decode a prompt batch and return the logits for the last position. +// The returned vector is a copy (safe across re-use of the context). +static std::vector decode_chain(llama_context * ctx, + const std::vector & prompt, + int32_t vocab_size) { + const int32_t n_tokens = (int32_t)prompt.size(); + llama_batch batch = llama_batch_init(n_tokens, /*embd=*/0, /*n_seq_max=*/1); + + for (int32_t i = 0; i < n_tokens; ++i) { + batch.token[i] = (llama_token)prompt[i]; + batch.pos[i] = (llama_pos)i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + // Only request logits for the last token to match Phase 1 chain mode. + batch.logits[i] = (i == n_tokens - 1) ? 1 : 0; + } + batch.n_tokens = n_tokens; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + throw std::runtime_error("llama_decode failed"); + } + + // Copy last-token logits before freeing batch. + const float * row = llama_get_logits_ith(ctx, n_tokens - 1); + std::vector logits(row, row + vocab_size); + + llama_batch_free(batch); + return logits; +} + +// Clear context memory (KV cache + recurrent state) between the two decode runs. +static void clear_kv(llama_context * ctx) { + llama_memory_clear(llama_get_memory(ctx), /*data=*/true); +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char ** argv) { + std::string model_path; + std::string prompt_tokens_path; + std::string out_logits_path; + std::string out_capture_path; + bool no_capture = false; + int32_t n_gpu_layers = 99; + int32_t n_ctx = 4096; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--model" && i + 1 < argc) { + model_path = argv[++i]; + } else if (arg == "--prompt-tokens" && i + 1 < argc) { + prompt_tokens_path = argv[++i]; + } else if (arg == "--out-logits" && i + 1 < argc) { + out_logits_path = argv[++i]; + } else if (arg == "--out-capture" && i + 1 < argc) { + out_capture_path = argv[++i]; + } else if (arg == "--no-capture") { + no_capture = true; + } else if (arg == "--n-gpu-layers" && i + 1 < argc) { + n_gpu_layers = std::atoi(argv[++i]); + } else if (arg == "--n-ctx" && i + 1 < argc) { + n_ctx = std::atoi(argv[++i]); + } else if (arg == "-h" || arg == "--help") { + usage(argv[0]); + return 0; + } else { + fprintf(stderr, "unknown argument: %s\n", arg.c_str()); + usage(argv[0]); + return 1; + } + } + + if (model_path.empty()) { fprintf(stderr, "--model is required\n"); return 1; } + if (prompt_tokens_path.empty()) { fprintf(stderr, "--prompt-tokens is required\n"); return 1; } + if (out_logits_path.empty()) { fprintf(stderr, "--out-logits is required\n"); return 1; } + if (out_capture_path.empty() && !no_capture) { + fprintf(stderr, "--out-capture is required (or pass --no-capture to skip Mode A)\n"); + return 1; + } + + llama_backend_init(); + + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), mparams); + if (!model) { + LOG_ERR("failed to load model: %s\n", model_path.c_str()); + llama_backend_free(); + return 1; + } + + auto cparams = llama_context_default_params(); + cparams.n_ctx = (uint32_t)n_ctx; + cparams.n_batch = (uint32_t)n_ctx; + + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + LOG_ERR("failed to create context\n"); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + int ret = 1; + try { + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + const int32_t hidden_dim = llama_model_n_embd(model); + const int32_t feat_dim = 5 * hidden_dim; + + std::vector prompt = read_prompt_tokens(prompt_tokens_path); + const int32_t n_prompt = (int32_t)prompt.size(); + + std::vector logits_capture; + std::vector logits_nocapture; + + // ================================================================== + // Mode A: capture enabled + // ================================================================== + if (!no_capture) { + LOG_INF("--- Mode A: capture enabled ---\n"); + llama_set_capture_hidden(ctx, true); + + logits_capture = decode_chain(ctx, prompt, vocab_size); + + // Write logits (last token only). + write_logits(out_logits_path, logits_capture.data(), 1, vocab_size); + LOG_INF("Mode A: logits written to %s\n", out_logits_path.c_str()); + + // Read hidden capture data (host-side, populated via ggml_backend_tensor_get_async). + int64_t ne0 = 0, ne1 = 0; + const float * cap_data = llama_get_hidden_capture_data(ctx, &ne0, &ne1); + if (!cap_data) { + throw std::runtime_error( + "llama_get_hidden_capture_data returned NULL after capture decode; " + "check that llama_set_capture_hidden is wired in the graph builder"); + } + + // Validate shape. qwen35.cpp allocates [n_embd, 5*n_tokens] (slots stacked + // along ne[1]); accept both layouts so the assertion stays portable. + const bool layout_stacked_ne1 = + ne0 == (int64_t)hidden_dim && ne1 == (int64_t)5 * n_prompt; + const bool layout_stacked_ne0 = + ne0 == (int64_t)feat_dim && ne1 == (int64_t)n_prompt; + if (!layout_stacked_ne1 && !layout_stacked_ne0) { + throw std::runtime_error( + "hidden capture tensor shape mismatch: got [" + + std::to_string(ne0) + ", " + std::to_string(ne1) + + "], expected [" + std::to_string(hidden_dim) + ", " + + std::to_string(5 * n_prompt) + "] or [" + + std::to_string(feat_dim) + ", " + std::to_string(n_prompt) + "]"); + } + LOG_INF("capture shape: [%lld, %lld] — OK (%s)\n", + (long long)ne0, (long long)ne1, + layout_stacked_ne1 ? "stacked along ne[1]" : "stacked along ne[0]"); + + // Validate: no NaN/Inf and not all-zero. + const size_t cap_n = (size_t) ne0 * (size_t) ne1; + bool any_nonzero = false; + for (size_t k = 0; k < cap_n; ++k) { + float v = cap_data[k]; + if (!std::isfinite(v)) { + throw std::runtime_error( + "hidden capture contains non-finite value at index " + + std::to_string(k)); + } + if (v != 0.0f) { + any_nonzero = true; + } + } + if (!any_nonzero) { + throw std::runtime_error( + "hidden capture is all-zero; capture hook is likely not wired"); + } + LOG_INF("capture: no NaN/Inf, at least one non-zero value — OK\n"); + + write_capture(out_capture_path, cap_data, feat_dim, n_prompt); + LOG_INF("Mode A: capture written to %s\n", out_capture_path.c_str()); + + clear_kv(ctx); + } + + // ================================================================== + // Mode B: capture disabled — must produce bit-equal logits + // ================================================================== + LOG_INF("--- Mode B: capture disabled ---\n"); + llama_set_capture_hidden(ctx, false); + + logits_nocapture = decode_chain(ctx, prompt, vocab_size); + + if (!no_capture) { + // Compare bit-for-bit against Mode A. + bool mismatch = false; + for (int32_t v = 0; v < vocab_size; ++v) { + if (logits_capture[v] != logits_nocapture[v]) { + fprintf(stderr, + "FAIL: logit mismatch at vocab index %d: " + "capture=%.8e no-capture=%.8e\n", + v, logits_capture[v], logits_nocapture[v]); + mismatch = true; + break; // report first discrepancy only + } + } + if (mismatch) { + throw std::runtime_error( + "Mode A and Mode B logits are not bit-equal; " + "hidden capture hook may be altering the compute graph"); + } + LOG_INF("Mode B: logits bit-equal to Mode A — OK\n"); + } else { + // no-capture-only run: write logits so the caller can compare + // against a Phase 1 golden dump externally. + write_logits(out_logits_path, logits_nocapture.data(), 1, vocab_size); + LOG_INF("Mode B only: logits written to %s\n", out_logits_path.c_str()); + } + + LOG_INF("all assertions passed\n"); + ret = 0; + + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + ret = 1; + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return ret; +} diff --git a/tests/test-qwen35-root-vs-chain.cpp b/tests/test-qwen35-root-vs-chain.cpp new file mode 100644 index 000000000000..26466c754668 --- /dev/null +++ b/tests/test-qwen35-root-vs-chain.cpp @@ -0,0 +1,450 @@ +// test-qwen35-root-vs-chain.cpp +// +// DDTree diagnostic: verify that a single tree-mode forward at the root node +// (parent_id = -1) is equivalent to a chain forward of the same token at the +// same position. +// +// Two passes inside the same process (model loaded once): +// pass A (chain): chain prefill tokens[0 .. N-1], record logits at index N-1 +// pass B (tree-root): chain prefill tokens[0 .. N-2], then a single +// tree-mode batch with one node {token = tokens[N-1], parent_id = -1, +// pos = N-1}, record logits at index 0 +// +// If the tree kernel + tree input wiring are correct, A and B should match +// within numerical tolerance for that one position. +// +// Build: -DLLAMA_BUILD_TESTS_QWEN35_ROOT_VS_CHAIN=ON +// ./build-server/bin/test-qwen35-root-vs-chain \ +// --model PATH --prompt-tokens tokens.bin --out-summary diff.txt + +#include "llama.h" +#include "common.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --model PATH (GGUF; required)\n" + " --prompt-tokens PATH (binary int32 LE token IDs)\n" + " --prompt-text STR (alternative to --prompt-tokens; tokenized in-process)\n" + " --prompt-text-file PATH (alternative to --prompt-text; UTF-8 text file)\n" + " --out-summary PATH (text summary; required)\n" + " --n-siblings N (extra sibling nodes at depth 1; default 0)\n" + " --n-spec-steps N (1 or 2; default 1; 2 chains step1 -> compact/rollback -> step2)\n" + " --skip-rollback (diagnostic: compact accepted root but do not rollback SSM)\n" + " --n-gpu-layers N (default 99)\n" + " --n-ctx N (default 4096)\n", + prog); +} + +static std::vector read_prompt_tokens(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open prompt-tokens file: " + path); + } + f.seekg(0, std::ios::end); + auto size = f.tellg(); + f.seekg(0, std::ios::beg); + if (size % sizeof(int32_t) != 0) { + throw std::runtime_error("prompt-tokens file size not a multiple of 4: " + path); + } + std::vector tokens(size / sizeof(int32_t)); + f.read(reinterpret_cast(tokens.data()), size); + return tokens; +} + +static std::vector run_chain_capture_last(llama_model * model, + const llama_context_params & cparams, + const std::vector & tokens) { + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + throw std::runtime_error("failed to create chain context"); + } + + const int32_t n_tokens = (int32_t)tokens.size(); + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + + llama_batch batch = llama_batch_init(n_tokens, /*embd=*/0, /*n_seq_max=*/1); + for (int32_t i = 0; i < n_tokens; ++i) { + batch.token[i] = (llama_token)tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = (i == n_tokens - 1) ? 1 : 0; + } + batch.n_tokens = n_tokens; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + llama_free(ctx); + throw std::runtime_error("chain llama_decode failed"); + } + + const float * row = llama_get_logits_ith(ctx, n_tokens - 1); + std::vector out(vocab_size); + memcpy(out.data(), row, (size_t)vocab_size * sizeof(float)); + + llama_batch_free(batch); + llama_free(ctx); + return out; +} + +// Build a tree batch of (1 + n_siblings) nodes at the given root pos. +// Sibling tokens come from the prompt history (cyclic) so they are distinct. +static llama_batch build_tree_batch(const std::vector & tokens, + int32_t root_pos, + llama_token root_token, + int n_siblings) { + const int n_nodes = 1 + n_siblings; + llama_batch tb = llama_batch_init_tree(n_nodes, /*embd=*/0, /*n_seq_max=*/1); + tb.token[0] = root_token; + tb.pos[0] = root_pos; + tb.n_seq_id[0] = 1; + tb.seq_id[0][0] = 0; + tb.parent_id[0] = -1; + tb.logits[0] = 1; + for (int i = 1; i < n_nodes; ++i) { + const int n = (int)tokens.size(); + const int src = ((root_pos - i) % n + n) % n; + tb.token[i] = (llama_token)tokens[src]; + tb.pos[i] = root_pos + 1; + tb.n_seq_id[i] = 1; + tb.seq_id[i][0] = 0; + tb.parent_id[i] = 0; + tb.logits[i] = 0; + } + tb.n_tokens = n_nodes; + return tb; +} + +// Two-spec-step variant: chain prefill [0..N-3], step1 (root=tokens[N-2]), accept +// root only -> compact+rollback, step2 (root=tokens[N-1]); return logits at step2 root. +static std::vector run_chain_then_tree_two_step(llama_model * model, + const llama_context_params & cparams, + const std::vector & tokens, + int n_siblings, + bool skip_rollback) { + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + throw std::runtime_error("failed to create two-step context"); + } + const int32_t N = (int32_t)tokens.size(); + const int32_t n_prefix = N - 2; + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_sz = llama_vocab_n_tokens(vocab); + + // chain prefill + if (n_prefix > 0) { + llama_batch b = llama_batch_init(n_prefix, /*embd=*/0, /*n_seq_max=*/1); + for (int32_t i = 0; i < n_prefix; ++i) { + b.token[i] = (llama_token)tokens[i]; b.pos[i] = i; + b.n_seq_id[i] = 1; b.seq_id[i][0] = 0; b.logits[i] = 0; + } + b.n_tokens = n_prefix; + if (llama_decode(ctx, b) != 0) { + llama_batch_free(b); llama_free(ctx); + throw std::runtime_error("two-step: prefill failed"); + } + llama_batch_free(b); + } + + // spec step 1: root @ pos n_prefix (token = tokens[n_prefix]) + { + llama_batch t = build_tree_batch(tokens, n_prefix, + (llama_token)tokens[n_prefix], n_siblings); + if (llama_decode(ctx, t) != 0) { + llama_batch_free(t); llama_free(ctx); + throw std::runtime_error("two-step: spec step 1 failed"); + } + llama_batch_free(t); + } + // accept root only -> compact tree + SSM rollback + int32_t accepted_dfs[1] = {0}; + llama_kv_cache_seq_compact_tree(ctx, /*seq_id=*/0, accepted_dfs, + /*n_accepted=*/1, /*commit_n=*/1, + /*spine_start=*/n_prefix); + if (!skip_rollback) { + llama_dflash_rollback_ssm_to_dfs(ctx, /*seq_id=*/0, /*accepted_dfs_node=*/0); + } + + // spec step 2: root @ pos n_prefix+1 (token = tokens[n_prefix+1]) + std::vector out(vocab_sz); + { + llama_batch t = build_tree_batch(tokens, n_prefix + 1, + (llama_token)tokens[n_prefix + 1], n_siblings); + if (llama_decode(ctx, t) != 0) { + llama_batch_free(t); llama_free(ctx); + throw std::runtime_error("two-step: spec step 2 failed"); + } + memcpy(out.data(), llama_get_logits_ith(ctx, 0), + (size_t)vocab_sz * sizeof(float)); + llama_batch_free(t); + } + + llama_free(ctx); + return out; +} + +static std::vector run_chain_then_tree_root(llama_model * model, + const llama_context_params & cparams, + const std::vector & tokens, + int n_siblings) { + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + throw std::runtime_error("failed to create tree-root context"); + } + + const int32_t n_tokens = (int32_t)tokens.size(); + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + const int32_t n_prefix = n_tokens - 1; + + // chain prefill tokens[0 .. n_prefix - 1] + if (n_prefix > 0) { + llama_batch batch = llama_batch_init(n_prefix, /*embd=*/0, /*n_seq_max=*/1); + for (int32_t i = 0; i < n_prefix; ++i) { + batch.token[i] = (llama_token)tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = 0; + } + batch.n_tokens = n_prefix; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + llama_free(ctx); + throw std::runtime_error("chain prefill llama_decode failed"); + } + llama_batch_free(batch); + } + + // tree batch: root at index 0 + n_siblings nodes at depth 1 (parent = root) + const int n_nodes = 1 + n_siblings; + llama_batch tbatch = llama_batch_init_tree(/*n_tokens=*/n_nodes, /*embd=*/0, /*n_seq_max=*/1); + // root + tbatch.token[0] = (llama_token)tokens[n_prefix]; + tbatch.pos[0] = n_prefix; + tbatch.n_seq_id[0] = 1; + tbatch.seq_id[0][0] = 0; + tbatch.parent_id[0] = -1; + tbatch.logits[0] = 1; + // siblings: depth 1, parent = root, token taken from prompt history (cyclic) + // so the tree batch has *distinct* tokens like a real spec verify ubatch. + for (int i = 1; i < n_nodes; ++i) { + const int src = (n_prefix - 1 - i + (int)tokens.size()) % (int)tokens.size(); + tbatch.token[i] = (llama_token)tokens[(src < 0 ? src + (int)tokens.size() : src)]; + tbatch.pos[i] = n_prefix + 1; + tbatch.n_seq_id[i] = 1; + tbatch.seq_id[i][0] = 0; + tbatch.parent_id[i] = 0; + tbatch.logits[i] = 0; + } + tbatch.n_tokens = n_nodes; + + if (llama_decode(ctx, tbatch) != 0) { + llama_batch_free(tbatch); + llama_free(ctx); + throw std::runtime_error("tree-root llama_decode failed"); + } + + const float * row = llama_get_logits_ith(ctx, 0); + std::vector out(vocab_size); + memcpy(out.data(), row, (size_t)vocab_size * sizeof(float)); + + llama_batch_free(tbatch); + llama_free(ctx); + return out; +} + +struct DiffStats { + double max_abs_diff; + double mean_abs_diff; + int argmax_a; + int argmax_b; + std::vector top5_a; + std::vector top5_b; +}; + +static std::vector top_k_indices(const std::vector & v, int k) { + std::vector idx(v.size()); + std::iota(idx.begin(), idx.end(), 0); + std::partial_sort(idx.begin(), idx.begin() + k, idx.end(), + [&](int a, int b) { return v[a] > v[b]; }); + idx.resize(k); + return idx; +} + +static DiffStats diff_logits(const std::vector & a, const std::vector & b) { + DiffStats s = {}; + if (a.size() != b.size() || a.empty()) { + throw std::runtime_error("logits size mismatch"); + } + double sum_abs = 0.0; + double max_abs = 0.0; + for (size_t i = 0; i < a.size(); ++i) { + const double d = std::fabs((double)a[i] - (double)b[i]); + sum_abs += d; + if (d > max_abs) max_abs = d; + } + s.max_abs_diff = max_abs; + s.mean_abs_diff = sum_abs / (double)a.size(); + s.argmax_a = (int)(std::max_element(a.begin(), a.end()) - a.begin()); + s.argmax_b = (int)(std::max_element(b.begin(), b.end()) - b.begin()); + s.top5_a = top_k_indices(a, 5); + s.top5_b = top_k_indices(b, 5); + return s; +} + +static std::string read_text_file(const std::string & path) { + std::ifstream f(path); + if (!f) { + throw std::runtime_error("cannot open prompt-text-file: " + path); + } + std::string s((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + return s; +} + +static std::vector tokenize_text(llama_model * model, const std::string & text) { + const auto * vocab = llama_model_get_vocab(model); + int32_t n = -llama_tokenize(vocab, text.data(), (int32_t)text.size(), + nullptr, 0, /*add_special=*/true, /*parse_special=*/false); + if (n <= 0) { + throw std::runtime_error("llama_tokenize sizing failed"); + } + std::vector tmp(n); + int32_t got = llama_tokenize(vocab, text.data(), (int32_t)text.size(), + tmp.data(), n, true, false); + if (got != n) { + throw std::runtime_error("llama_tokenize result mismatch"); + } + std::vector out(tmp.begin(), tmp.end()); + return out; +} + +int main(int argc, char ** argv) { + std::string model_path; + std::string prompt_tokens_path; + std::string prompt_text; + std::string prompt_text_file; + std::string out_summary; + int32_t n_gpu_layers = 99; + int32_t n_ctx = 4096; + int32_t n_siblings = 0; + int32_t n_spec_steps = 1; + bool skip_rollback = false; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--model" && i + 1 < argc) model_path = argv[++i]; + else if (arg == "--prompt-tokens" && i + 1 < argc) prompt_tokens_path = argv[++i]; + else if (arg == "--prompt-text" && i + 1 < argc) prompt_text = argv[++i]; + else if (arg == "--prompt-text-file" && i + 1 < argc) prompt_text_file = argv[++i]; + else if (arg == "--out-summary" && i + 1 < argc) out_summary = argv[++i]; + else if (arg == "--n-siblings" && i + 1 < argc) n_siblings = std::atoi(argv[++i]); + else if (arg == "--n-spec-steps" && i + 1 < argc) n_spec_steps = std::atoi(argv[++i]); + else if (arg == "--skip-rollback") skip_rollback = true; + else if (arg == "--n-gpu-layers" && i + 1 < argc) n_gpu_layers = std::atoi(argv[++i]); + else if (arg == "--n-ctx" && i + 1 < argc) n_ctx = std::atoi(argv[++i]); + else if (arg == "-h" || arg == "--help") { usage(argv[0]); return 0; } + else { fprintf(stderr, "unknown argument: %s\n", arg.c_str()); usage(argv[0]); return 1; } + } + + int input_modes = (!prompt_tokens_path.empty()) + (!prompt_text.empty()) + (!prompt_text_file.empty()); + if (model_path.empty() || out_summary.empty() || input_modes != 1) { + fprintf(stderr, "must provide exactly one of --prompt-tokens, --prompt-text, --prompt-text-file\n"); + usage(argv[0]); return 1; + } + + llama_backend_init(); + + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), mparams); + if (!model) { + LOG_ERR("failed to load model: %s\n", model_path.c_str()); + llama_backend_free(); + return 1; + } + + auto cparams = llama_context_default_params(); + cparams.n_ctx = (uint32_t)n_ctx; + cparams.n_batch = (uint32_t)n_ctx; + + int rc = 1; + try { + std::vector tokens; + if (!prompt_tokens_path.empty()) { + tokens = read_prompt_tokens(prompt_tokens_path); + } else { + std::string text = !prompt_text.empty() + ? prompt_text + : read_text_file(prompt_text_file); + tokens = tokenize_text(model, text); + LOG_INF("tokenized %zu tokens from text\n", tokens.size()); + } + if (tokens.size() < 2) { + throw std::runtime_error("need at least 2 tokens"); + } + + LOG_INF("loaded %zu tokens; running chain pass...\n", tokens.size()); + std::vector A = run_chain_capture_last(model, cparams, tokens); + + LOG_INF("running chain-prefill + tree-root pass (n_siblings=%d, n_spec_steps=%d)...\n", + n_siblings, n_spec_steps); + std::vector B; + if (n_spec_steps == 1) { + B = run_chain_then_tree_root(model, cparams, tokens, n_siblings); + } else if (n_spec_steps == 2) { + B = run_chain_then_tree_two_step(model, cparams, tokens, n_siblings, skip_rollback); + } else { + throw std::runtime_error("--n-spec-steps must be 1 or 2"); + } + + DiffStats s = diff_logits(A, B); + + std::ofstream f(out_summary); + f << "n_tokens=" << tokens.size() << "\n"; + f << "n_siblings=" << n_siblings << "\n"; + f << "n_spec_steps=" << n_spec_steps << "\n"; + f << "skip_rollback=" << (skip_rollback ? 1 : 0) << "\n"; + f << "vocab_size=" << A.size() << "\n"; + f << "max_abs_diff=" << s.max_abs_diff << "\n"; + f << "mean_abs_diff=" << s.mean_abs_diff << "\n"; + f << "argmax_chain=" << s.argmax_a << "\n"; + f << "argmax_tree_root=" << s.argmax_b << "\n"; + f << "top5_chain="; + for (int x : s.top5_a) f << x << " "; + f << "\ntop5_tree_root="; + for (int x : s.top5_b) f << x << " "; + f << "\n"; + + fprintf(stderr, "max_abs_diff = %.6g\n", s.max_abs_diff); + fprintf(stderr, "mean_abs_diff = %.6g\n", s.mean_abs_diff); + fprintf(stderr, "argmax: chain=%d tree_root=%d %s\n", + s.argmax_a, s.argmax_b, + s.argmax_a == s.argmax_b ? "MATCH" : "DIFF"); + + rc = 0; + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + rc = 1; + } + + llama_model_free(model); + llama_backend_free(); + return rc; +} diff --git a/tests/test-qwen35-tree-rollback.cpp b/tests/test-qwen35-tree-rollback.cpp new file mode 100644 index 000000000000..014771eb5426 --- /dev/null +++ b/tests/test-qwen35-tree-rollback.cpp @@ -0,0 +1,299 @@ +// test-qwen35-tree-rollback.cpp +// +// Phase 2 acceptance test for DDTree snapshot/restore symmetry (Test 2.A). +// Loads a Qwen3.5-27B GGUF, decodes a prompt chain, takes a recurrent-state +// snapshot, runs N decode steps, restores the snapshot, runs the same N steps +// again from the same starting token, and dumps both runs' final logits. +// The two logit files must be bit-equal (--abs-tol 0 with compare_logits.py). +// +// Build: requires -DLLAMA_BUILD_TESTS_QWEN35_TREE_ROLLBACK=ON (not in ctest). +// +// API assumptions (implementation agent deliverables): +// typedef int32_t llama_mem_snapshot_id; +// llama_mem_snapshot_id llama_seq_snapshot(llama_context *, llama_seq_id); +// bool llama_seq_restore (llama_context *, llama_mem_snapshot_id); +// void llama_seq_release (llama_context *, llama_mem_snapshot_id); +// +// Output binary format (--out-logits-pre / --out-logits-post): +// int32_t n_tokens (= 1, the single last-step logit row) +// int32_t vocab_size +// float logits[vocab_size] + +#include "llama.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --model PATH (Qwen3.5-27B GGUF; required)\n" + " --prompt-tokens PATH (binary int32 LE token IDs; required)\n" + " --gen N (chain decode steps per run; default 8)\n" + " --out-logits-pre PATH (logit dump from first run; required)\n" + " --out-logits-post PATH (logit dump from second run after restore; required)\n" + " --n-gpu-layers N (default 99)\n" + " --n-ctx N (default 4096)\n", + prog); +} + +static std::vector read_prompt_tokens(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open prompt-tokens file: " + path); + } + f.seekg(0, std::ios::end); + auto sz = f.tellg(); + f.seekg(0, std::ios::beg); + if (sz % sizeof(int32_t) != 0) { + throw std::runtime_error("prompt-tokens file size not a multiple of 4: " + path); + } + std::vector tokens(sz / sizeof(int32_t)); + f.read(reinterpret_cast(tokens.data()), sz); + return tokens; +} + +static void write_logits(const std::string & path, + const float * data, + int32_t n_tokens, + int32_t vocab_size) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open out-logits for writing: " + path); + } + f.write(reinterpret_cast(&n_tokens), sizeof(int32_t)); + f.write(reinterpret_cast(&vocab_size), sizeof(int32_t)); + f.write(reinterpret_cast(data), (std::streamsize)(n_tokens * vocab_size * sizeof(float))); +} + +// Return the argmax token id from a logits row. +static llama_token argmax(const float * logits, int32_t vocab_size) { + return (llama_token)(std::max_element(logits, logits + vocab_size) - logits); +} + +// Decode a single token at the given position and return the logits pointer. +// The returned pointer is valid until the next llama_decode call. +static const float * decode_single(llama_context * ctx, + llama_token tok, + llama_pos pos, + int32_t vocab_size) { + llama_batch batch = llama_batch_init(1, /*embd=*/0, /*n_seq_max=*/1); + batch.token[0] = tok; + batch.pos[0] = pos; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = 0; + batch.logits[0] = 1; + batch.n_tokens = 1; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + throw std::runtime_error("llama_decode failed for single token"); + } + + const float * row = llama_get_logits_ith(ctx, 0); + // Copy before freeing the batch (logits buffer owned by context, not batch) + llama_batch_free(batch); + (void)vocab_size; // size used by caller + return row; +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char ** argv) { + std::string model_path; + std::string prompt_tokens_path; + std::string out_logits_pre_path; + std::string out_logits_post_path; + int32_t gen = 8; + int32_t n_gpu_layers = 99; + int32_t n_ctx = 4096; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--model" && i + 1 < argc) { + model_path = argv[++i]; + } else if (arg == "--prompt-tokens" && i + 1 < argc) { + prompt_tokens_path = argv[++i]; + } else if (arg == "--gen" && i + 1 < argc) { + gen = std::atoi(argv[++i]); + } else if (arg == "--out-logits-pre" && i + 1 < argc) { + out_logits_pre_path = argv[++i]; + } else if (arg == "--out-logits-post" && i + 1 < argc) { + out_logits_post_path = argv[++i]; + } else if (arg == "--n-gpu-layers" && i + 1 < argc) { + n_gpu_layers = std::atoi(argv[++i]); + } else if (arg == "--n-ctx" && i + 1 < argc) { + n_ctx = std::atoi(argv[++i]); + } else if (arg == "-h" || arg == "--help") { + usage(argv[0]); + return 0; + } else { + fprintf(stderr, "unknown argument: %s\n", arg.c_str()); + usage(argv[0]); + return 1; + } + } + + if (model_path.empty()) { + fprintf(stderr, "--model is required\n"); + return 1; + } + if (prompt_tokens_path.empty()) { + fprintf(stderr, "--prompt-tokens is required\n"); + return 1; + } + if (out_logits_pre_path.empty()) { + fprintf(stderr, "--out-logits-pre is required\n"); + return 1; + } + if (out_logits_post_path.empty()) { + fprintf(stderr, "--out-logits-post is required\n"); + return 1; + } + if (gen < 1) { + fprintf(stderr, "--gen must be >= 1\n"); + return 1; + } + + llama_backend_init(); + + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), mparams); + if (!model) { + LOG_ERR("failed to load model: %s\n", model_path.c_str()); + llama_backend_free(); + return 1; + } + + auto cparams = llama_context_default_params(); + cparams.n_ctx = (uint32_t)n_ctx; + cparams.n_batch = (uint32_t)n_ctx; + + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + LOG_ERR("failed to create context\n"); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + int ret = 1; + try { + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + + // ------------------------------------------------------------------ + // Step 1: decode the prompt as a single chain batch to prime state + // ------------------------------------------------------------------ + std::vector prompt = read_prompt_tokens(prompt_tokens_path); + const int32_t n_prompt = (int32_t)prompt.size(); + + { + llama_batch batch = llama_batch_init(n_prompt, /*embd=*/0, /*n_seq_max=*/1); + for (int32_t i = 0; i < n_prompt; ++i) { + batch.token[i] = (llama_token)prompt[i]; + batch.pos[i] = (llama_pos)i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = (i == n_prompt - 1) ? 1 : 0; + } + batch.n_tokens = n_prompt; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + throw std::runtime_error("prompt decode failed"); + } + llama_batch_free(batch); + } + LOG_INF("prompt decoded (%d tokens)\n", n_prompt); + + // ------------------------------------------------------------------ + // Step 2: snapshot the recurrent state BEFORE the decode loop + // The first token for both runs is the argmax of the prompt's last + // position logits, captured now so both runs start identically. + // ------------------------------------------------------------------ + const float * prompt_logits = llama_get_logits_ith(ctx, n_prompt - 1); + llama_token tok_first = argmax(prompt_logits, vocab_size); + LOG_INF("first token after prompt: %d\n", (int)tok_first); + + llama_mem_snapshot_id snap = llama_seq_snapshot(ctx, /*seq_id=*/0); + if (snap < 0) { + throw std::runtime_error("llama_seq_snapshot returned negative id"); + } + LOG_INF("snapshot id: %d\n", (int)snap); + + // ------------------------------------------------------------------ + // Step 3: run K decode steps (run 1), save logits of last step + // ------------------------------------------------------------------ + std::vector last_logits_pre((size_t)vocab_size); + { + llama_token cur = tok_first; + llama_pos pos = (llama_pos)n_prompt; + for (int step = 0; step < gen; ++step) { + const float * row = decode_single(ctx, cur, pos, vocab_size); + if (step == gen - 1) { + memcpy(last_logits_pre.data(), row, vocab_size * sizeof(float)); + } + cur = argmax(row, vocab_size); + ++pos; + } + } + write_logits(out_logits_pre_path, last_logits_pre.data(), 1, vocab_size); + LOG_INF("run 1 complete, logits written to %s\n", out_logits_pre_path.c_str()); + + // ------------------------------------------------------------------ + // Step 4: restore snapshot and run K steps again with same first token + // ------------------------------------------------------------------ + if (!llama_seq_restore(ctx, snap)) { + throw std::runtime_error("llama_seq_restore failed"); + } + LOG_INF("snapshot restored\n"); + + std::vector last_logits_post((size_t)vocab_size); + { + llama_token cur = tok_first; // same first token as run 1 + llama_pos pos = (llama_pos)n_prompt; + for (int step = 0; step < gen; ++step) { + const float * row = decode_single(ctx, cur, pos, vocab_size); + if (step == gen - 1) { + memcpy(last_logits_post.data(), row, vocab_size * sizeof(float)); + } + cur = argmax(row, vocab_size); + ++pos; + } + } + write_logits(out_logits_post_path, last_logits_post.data(), 1, vocab_size); + LOG_INF("run 2 complete, logits written to %s\n", out_logits_post_path.c_str()); + + // ------------------------------------------------------------------ + // Step 5: release snapshot + // ------------------------------------------------------------------ + llama_seq_release(ctx, snap); + LOG_INF("snapshot released\n"); + + ret = 0; + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + ret = 1; + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return ret; +} diff --git a/tests/test-qwen35-tree.cpp b/tests/test-qwen35-tree.cpp new file mode 100644 index 000000000000..2516577c2847 --- /dev/null +++ b/tests/test-qwen35-tree.cpp @@ -0,0 +1,311 @@ +// test-qwen35-tree.cpp +// +// Phase 1 acceptance test for DDTree tree-mode forward pass. +// Loads a Qwen3.5-27B GGUF, runs either a plain chain forward or a tree forward, +// and dumps raw F32 logits for offline comparison. +// +// Build: requires -DLLAMA_BUILD_TESTS_QWEN35_TREE=ON (not added to ctest by default). +// +// API assumptions (implementation agent deliverables): +// - llama_batch.parent_id : int32_t *, NULL in chain mode; -1 = root, else flat parent index +// - llama_batch_init_tree(n_tokens, embd, n_seq_max) : like llama_batch_init but also +// allocates parent_id array of size n_tokens +// - llama_batch_free() frees parent_id when non-NULL +// - llama_decode() reads batch.parent_id and dispatches tree forward when non-NULL +// +// Output binary format (--out-logits): +// int32_t n_tokens +// int32_t vocab_size +// float logits[n_tokens * vocab_size] (row-major, little-endian) + +#include "llama.h" +#include "common.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include + +// nlohmann/json is vendored at vendor/nlohmann/json.hpp +#include + +using json = nlohmann::json; + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --mode {chain,tree} (required)\n" + " --model PATH (GGUF; required)\n" + " --prompt-tokens PATH (binary int32 LE token IDs; required)\n" + " --tree-fixture PATH (JSON; required in tree mode)\n" + " --out-logits PATH (binary F32 output; required)\n" + " --n-gpu-layers N (default 99)\n" + " --n-ctx N (default 4096)\n", + prog); +} + +static std::vector read_prompt_tokens(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open prompt-tokens file: " + path); + } + f.seekg(0, std::ios::end); + auto size = f.tellg(); + f.seekg(0, std::ios::beg); + if (size % sizeof(int32_t) != 0) { + throw std::runtime_error("prompt-tokens file size not a multiple of 4: " + path); + } + std::vector tokens(size / sizeof(int32_t)); + f.read(reinterpret_cast(tokens.data()), size); + return tokens; +} + +struct TreeNode { + int32_t flat_idx; + int32_t token_id; + int32_t parent_idx; // -1 = root + int32_t depth; +}; + +static std::vector parse_tree_fixture(const std::string & path, int32_t & committed_offset) { + std::ifstream f(path); + if (!f) { + throw std::runtime_error("cannot open tree-fixture: " + path); + } + json j; + f >> j; + + committed_offset = j.value("committed_offset", 0); + + std::vector nodes; + for (const auto & n : j["nodes"]) { + TreeNode node; + node.flat_idx = n["flat_idx"].get(); + node.token_id = n["token_id"].get(); + node.parent_idx = n["parent_idx"].get(); + node.depth = n["depth"].get(); + nodes.push_back(node); + } + return nodes; +} + +static void write_logits(const std::string & path, + const std::vector & data, + int32_t n_tokens, + int32_t vocab_size) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open out-logits for writing: " + path); + } + f.write(reinterpret_cast(&n_tokens), sizeof(int32_t)); + f.write(reinterpret_cast(&vocab_size), sizeof(int32_t)); + f.write(reinterpret_cast(data.data()), (std::streamsize)(data.size() * sizeof(float))); +} + +// --------------------------------------------------------------------------- +// chain mode: plain llama_batch forward +// --------------------------------------------------------------------------- + +static int run_chain(llama_model * model, + llama_context * ctx, + const std::vector & prompt, + const std::string & out_path) { + const int32_t n_tokens = (int32_t)prompt.size(); + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + + llama_batch batch = llama_batch_init(n_tokens, /*embd=*/0, /*n_seq_max=*/1); + + for (int32_t i = 0; i < n_tokens; ++i) { + batch.token[i] = (llama_token)prompt[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = 1; // request logits for every position + } + batch.n_tokens = n_tokens; + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("%s: llama_decode failed\n", __func__); + llama_batch_free(batch); + return 1; + } + + std::vector logits_out((size_t)n_tokens * vocab_size); + for (int32_t i = 0; i < n_tokens; ++i) { + const float * row = llama_get_logits_ith(ctx, i); + memcpy(&logits_out[(size_t)i * vocab_size], row, vocab_size * sizeof(float)); + } + + llama_batch_free(batch); + write_logits(out_path, logits_out, n_tokens, vocab_size); + LOG_INF("chain: wrote %d x %d logits to %s\n", n_tokens, vocab_size, out_path.c_str()); + return 0; +} + +// --------------------------------------------------------------------------- +// tree mode: tree-batch forward +// --------------------------------------------------------------------------- +// +// Position assignment (Phase 1, 1D only): +// pos[i] = committed_offset + node.depth +// +// M-RoPE 4-axis positions are deferred to Phase 3 (UNKNOWN-3 in the roadmap). +// When that work lands, pos[] will need to be a 4-tuple per token and +// llama_batch will need a corresponding multi-axis pos field. + +static int run_tree(llama_model * model, + llama_context * ctx, + const std::string & fixture_path, + const std::string & out_path) { + int32_t committed_offset = 0; + std::vector nodes = parse_tree_fixture(fixture_path, committed_offset); + + const int32_t n_tokens = (int32_t)nodes.size(); + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + + // llama_batch_init_tree is the new API added by the implementation agent. + // It behaves like llama_batch_init but additionally allocates batch.parent_id. + llama_batch batch = llama_batch_init_tree(n_tokens, /*embd=*/0, /*n_seq_max=*/1); + + for (int32_t i = 0; i < n_tokens; ++i) { + const TreeNode & node = nodes[i]; + batch.token[i] = (llama_token)node.token_id; + batch.pos[i] = (llama_pos)(committed_offset + node.depth); + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.parent_id[i] = node.parent_idx; // -1 = root + batch.logits[i] = 1; + } + batch.n_tokens = n_tokens; + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("%s: llama_decode (tree) failed\n", __func__); + llama_batch_free(batch); + return 1; + } + + std::vector logits_out((size_t)n_tokens * vocab_size); + for (int32_t i = 0; i < n_tokens; ++i) { + const float * row = llama_get_logits_ith(ctx, i); + memcpy(&logits_out[(size_t)i * vocab_size], row, vocab_size * sizeof(float)); + } + + llama_batch_free(batch); + write_logits(out_path, logits_out, n_tokens, vocab_size); + LOG_INF("tree: wrote %d x %d logits to %s\n", n_tokens, vocab_size, out_path.c_str()); + return 0; +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char ** argv) { + std::string mode; + std::string model_path; + std::string prompt_tokens_path; + std::string tree_fixture_path; + std::string out_logits_path; + int32_t n_gpu_layers = 99; + int32_t n_ctx = 4096; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--mode" && i + 1 < argc) { + mode = argv[++i]; + } else if (arg == "--model" && i + 1 < argc) { + model_path = argv[++i]; + } else if (arg == "--prompt-tokens" && i + 1 < argc) { + prompt_tokens_path = argv[++i]; + } else if (arg == "--tree-fixture" && i + 1 < argc) { + tree_fixture_path = argv[++i]; + } else if (arg == "--out-logits" && i + 1 < argc) { + out_logits_path = argv[++i]; + } else if (arg == "--n-gpu-layers" && i + 1 < argc) { + n_gpu_layers = std::atoi(argv[++i]); + } else if (arg == "--n-ctx" && i + 1 < argc) { + n_ctx = std::atoi(argv[++i]); + } else if (arg == "-h" || arg == "--help") { + usage(argv[0]); + return 0; + } else { + fprintf(stderr, "unknown argument: %s\n", arg.c_str()); + usage(argv[0]); + return 1; + } + } + + if (mode != "chain" && mode != "tree") { + fprintf(stderr, "--mode must be 'chain' or 'tree'\n"); + usage(argv[0]); + return 1; + } + if (model_path.empty()) { + fprintf(stderr, "--model is required\n"); + return 1; + } + if (prompt_tokens_path.empty() && mode == "chain") { + fprintf(stderr, "--prompt-tokens is required in chain mode\n"); + return 1; + } + if (tree_fixture_path.empty() && mode == "tree") { + fprintf(stderr, "--tree-fixture is required in tree mode\n"); + return 1; + } + if (out_logits_path.empty()) { + fprintf(stderr, "--out-logits is required\n"); + return 1; + } + + llama_backend_init(); + + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), mparams); + if (!model) { + LOG_ERR("failed to load model: %s\n", model_path.c_str()); + llama_backend_free(); + return 1; + } + + auto cparams = llama_context_default_params(); + cparams.n_ctx = (uint32_t)n_ctx; + cparams.n_batch = (uint32_t)n_ctx; + + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + LOG_ERR("failed to create context\n"); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + int ret = 1; + try { + if (mode == "chain") { + std::vector prompt = read_prompt_tokens(prompt_tokens_path); + ret = run_chain(model, ctx, prompt, out_logits_path); + } else { + ret = run_tree(model, ctx, tree_fixture_path, out_logits_path); + } + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + ret = 1; + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return ret; +} diff --git a/tests/test-speculative-draft-backend.cpp b/tests/test-speculative-draft-backend.cpp new file mode 100644 index 000000000000..2a7b049d4e19 --- /dev/null +++ b/tests/test-speculative-draft-backend.cpp @@ -0,0 +1,99 @@ +#include "speculative-draft-backend.h" + +#include +#include +#include +#include +#include + +static void require(bool ok, const char * expr, int line) { + if (!ok) { + std::fprintf(stderr, "test-speculative-draft-backend:%d: check failed: %s\n", line, expr); + std::abort(); + } +} + +#define REQUIRE(expr) require((expr), #expr, __LINE__) + +static void test_top_k_width() { + llama_ddtree_params p; + p.block_size = 16; + p.budget = 40; + p.top_k = 0; + REQUIRE(llama_speculative_draft_top_k_width(p.block_size, p) == 8); + + p.budget = 8; + REQUIRE(llama_speculative_draft_top_k_width(p.block_size, p) == 1); + + p.top_k = 4; + REQUIRE(llama_speculative_draft_top_k_width(p.block_size, p) == 4); +} + +static void test_pack_target_feat_no_wrap() { + const int64_t fc = 3; + const int64_t cap = 4; + std::vector ring((size_t) fc * cap); + for (int64_t col = 0; col < cap; ++col) { + for (int64_t row = 0; row < fc; ++row) { + ring[(size_t) col * fc + row] = (float) (10 * col + row); + } + } + + llama_speculative_draft_target_feat_view view{ ring.data(), 3, cap, fc }; + std::vector out; + int64_t ctx_len = 0; + REQUIRE(llama_speculative_draft_pack_target_feat(view, out, ctx_len)); + REQUIRE(ctx_len == 3); + REQUIRE(out.size() == 9); + + for (int64_t col = 0; col < ctx_len; ++col) { + for (int64_t row = 0; row < fc; ++row) { + REQUIRE(out[(size_t) col * fc + row] == ring[(size_t) col * fc + row]); + } + } +} + +static void test_pack_target_feat_wrap() { + const int64_t fc = 3; + const int64_t cap = 4; + std::vector ring((size_t) fc * cap); + + // logical columns 2, 3, 4, 5 live in ring slots 2, 3, 0, 1. + const int64_t logical_by_slot[4] = { 4, 5, 2, 3 }; + for (int64_t slot = 0; slot < cap; ++slot) { + const int64_t logical = logical_by_slot[slot]; + for (int64_t row = 0; row < fc; ++row) { + ring[(size_t) slot * fc + row] = (float) (100 * logical + row); + } + } + + llama_speculative_draft_target_feat_view view{ ring.data(), 6, cap, fc }; + std::vector out; + int64_t ctx_len = 0; + REQUIRE(llama_speculative_draft_pack_target_feat(view, out, ctx_len)); + REQUIRE(ctx_len == cap); + REQUIRE(out.size() == (size_t) fc * cap); + + for (int64_t col = 0; col < ctx_len; ++col) { + const int64_t logical = 2 + col; + for (int64_t row = 0; row < fc; ++row) { + REQUIRE(out[(size_t) col * fc + row] == (float) (100 * logical + row)); + } + } +} + +static void test_pack_target_feat_empty() { + std::vector out{ 1.0f }; + int64_t ctx_len = 123; + llama_speculative_draft_target_feat_view view{}; + REQUIRE(!llama_speculative_draft_pack_target_feat(view, out, ctx_len)); + REQUIRE(ctx_len == 0); +} + +int main() { + test_top_k_width(); + test_pack_target_feat_no_wrap(); + test_pack_target_feat_wrap(); + test_pack_target_feat_empty(); + return 0; +} diff --git a/tests/test-speculative-tree-e2e.cpp b/tests/test-speculative-tree-e2e.cpp new file mode 100644 index 000000000000..41a222764899 --- /dev/null +++ b/tests/test-speculative-tree-e2e.cpp @@ -0,0 +1,759 @@ +// test-speculative-tree-e2e.cpp +// +// Phase 4 end-to-end acceptance test (Test 4.A) for DDTree speculative decoding. +// +// Two decode runs are performed back-to-back: +// +// Run 1 (chain reference): +// Load target model. Decode the prompt as a plain chain batch. Then loop +// greedy-argmax N times to collect N reference tokens. Write to --out-chain. +// +// Run 2 (spec decode): +// Reload a fresh target context with capture_hidden=true. Load draft model. +// Decode prompt as chain batch (primes hidden capture). Init the DDTree +// speculative driver. Loop spec steps until N tokens are collected. +// Write (first N) to --out-spec. +// +// Acceptance criterion (--temp 0 / greedy): +// The first min(chain_n, spec_n) tokens MUST be bit-equal. +// DDTree is lossless speculative decoding: each accepted draft token has +// been verified as matching target-argmax at that position. +// +// Build: requires -DLLAMA_BUILD_TESTS_SPECULATIVE_TREE_E2E=ON (not in ctest). +// +// API assumptions (implementation agent deliverables): +// -- From Phase 3 gap: +// void llama_set_target_feat_raw(llama_context * ctx, +// const float * data, +// int64_t n_embd_fc, +// int64_t ctx_len); +// +// -- Phase 4 driver (common/speculative-tree-driver.h): +// struct llama_speculative_tree_driver; +// llama_speculative_tree_driver * llama_speculative_tree_driver_init( +// llama_context * target_ctx, +// llama_context * draft_ctx, +// const llama_ddtree_params & params); +// void llama_speculative_tree_driver_free(llama_speculative_tree_driver * d); +// std::vector llama_speculative_tree_driver_step( +// llama_speculative_tree_driver * d, +// llama_token root_token, +// llama_pos committed_pos); +// +// -- Phase 3 (already landed): +// void llama_set_capture_hidden(llama_context * ctx, bool enable); +// +// Output binary format (--out-chain / --out-spec): +// int32_t n_tokens (number of generated tokens written) +// int32_t tokens[n_tokens] (little-endian int32, one per generated token) +// +// Note: --temp 0 is required for the bit-equal trajectory guarantee. +// Non-zero temperature introduces stochastic sampling and invalidates the +// comparison. + +#include "llama.h" +#include "common.h" +#include "log.h" +#include "speculative-tree.h" +#include "speculative-tree-driver.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// constants +// --------------------------------------------------------------------------- + +// Qwen3.5 EOS token id. Accept this token but stop further generation. +static constexpr llama_token QWEN35_EOS = 248045; + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --target-model PATH (Qwen3.5-27B GGUF; required)\n" + " --draft-model PATH (dflash-draft GGUF; required)\n" + " --prompt-tokens PATH (binary int32 LE token IDs; required unless --prompt-text)\n" + " --prompt-text PATH (raw rendered prompt text; tokenized with target vocab)\n" + " --prompt-add-special (with --prompt-text, request tokenizer special BOS/EOS insertion)\n" + " --no-prompt-parse-special\n" + " (with --prompt-text, do not parse <|...|> as special tokens)\n" + " --gen N (tokens to generate; default 32)\n" + " --out-spec PATH (spec-decode output tokens, int32 LE; required)\n" + " --out-chain PATH (chain-decode reference tokens, int32 LE; required)\n" + " --ddtree-budget N (DDTree node budget; default 22)\n" + " --ddtree-no-chain-seed (disable chain-seed heuristic; default: on)\n" + " --require-ddtree (fail unless multi-node DDTree verify ran)\n" + " --require-replay (fail unless snapshot+replay fallback ran)\n" + " --require-full-prompt-ingest\n" + " (fail unless DDTree ingested every prompt token capture)\n" + " --temp F (sampling temperature; default 0.0 = greedy)\n" + " --n-gpu-layers N (default 99)\n" + " --n-ctx N (default 4096)\n" + " --n-batch N (logical prompt batch; default min(n_ctx, 2048))\n" + " --n-ubatch N (physical prompt batch; default 512)\n" + " --no-flash-attn (disable Flash Attention)\n" + "\n" + "Pass --temp 0 (greedy) to enable token-trajectory bit-equal assertion.\n", + prog); +} + +static std::vector read_int32_file(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open file: " + path); + } + f.seekg(0, std::ios::end); + auto sz = f.tellg(); + f.seekg(0, std::ios::beg); + if (sz % sizeof(int32_t) != 0) { + throw std::runtime_error("file size not a multiple of 4: " + path); + } + std::vector buf(sz / sizeof(int32_t)); + f.read(reinterpret_cast(buf.data()), sz); + return buf; +} + +static std::string read_text_file(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open file: " + path); + } + return std::string((std::istreambuf_iterator(f)), std::istreambuf_iterator()); +} + +static std::vector tokenize_text( + const llama_vocab * vocab, + const std::string & text, + bool add_special, + bool parse_special) { + int32_t n = -llama_tokenize(vocab, text.data(), (int32_t)text.size(), + nullptr, 0, add_special, parse_special); + if (n <= 0) { + throw std::runtime_error("llama_tokenize sizing failed"); + } + std::vector tmp(n); + int32_t got = llama_tokenize(vocab, text.data(), (int32_t)text.size(), + tmp.data(), n, add_special, parse_special); + if (got != n) { + throw std::runtime_error("llama_tokenize result mismatch"); + } + return std::vector(tmp.begin(), tmp.end()); +} + +static void write_token_file(const std::string & path, + const std::vector & tokens) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open for writing: " + path); + } + int32_t n = (int32_t)tokens.size(); + f.write(reinterpret_cast(&n), sizeof(int32_t)); + f.write(reinterpret_cast(tokens.data()), + (std::streamsize)(tokens.size() * sizeof(llama_token))); +} + +// Decode prompt as plain chain batches, return last-token logits (copy). +// The optional per_chunk callback runs after every llama_decode() and is used +// by the DDTree run to ingest exactly the hidden capture columns produced by +// that physical prompt chunk. +static std::vector decode_chain_prompt(llama_context * ctx, + const std::vector & prompt, + int32_t vocab_size, + int32_t prompt_chunk, + const std::function & per_chunk = {}) { + const int32_t n = (int32_t)prompt.size(); + if (prompt_chunk <= 0) { + throw std::runtime_error("prompt_chunk must be > 0"); + } + + std::vector logits; + for (int32_t off = 0; off < n; off += prompt_chunk) { + const int32_t n_cur = std::min(prompt_chunk, n - off); + llama_batch batch = llama_batch_init(n_cur, /*embd=*/0, /*n_seq_max=*/1); + + for (int32_t i = 0; i < n_cur; ++i) { + const int32_t pos = off + i; + batch.token[i] = (llama_token)prompt[pos]; + batch.pos[i] = (llama_pos)pos; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = (pos == n - 1) ? 1 : 0; + } + batch.n_tokens = n_cur; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + throw std::runtime_error("llama_decode failed on prompt chunk"); + } + + if (per_chunk) { + per_chunk(n_cur); + } + + if (off + n_cur == n) { + const float * row = llama_get_logits_ith(ctx, n_cur - 1); + if (!row) { + llama_batch_free(batch); + throw std::runtime_error("prompt final logits unavailable"); + } + logits.assign(row, row + vocab_size); + } + + llama_batch_free(batch); + } + + if (logits.empty()) { + throw std::runtime_error("prompt decode produced no logits"); + } + return logits; +} + +// Greedy argmax over a logit row. +static llama_token argmax(const float * logits, int32_t vocab_size) { + llama_token best = 0; + float best_val = logits[0]; + for (int32_t v = 1; v < vocab_size; ++v) { + if (logits[v] > best_val) { + best_val = logits[v]; + best = v; + } + } + return best; +} + +// Decode a single token at position pos, return logits for that position. +static std::vector decode_single(llama_context * ctx, + llama_token tok, + llama_pos pos, + int32_t vocab_size) { + llama_batch batch = llama_batch_init(1, /*embd=*/0, /*n_seq_max=*/1); + batch.token[0] = tok; + batch.pos[0] = pos; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = 0; + batch.logits[0] = 1; + batch.n_tokens = 1; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + throw std::runtime_error("llama_decode failed on single token"); + } + + const float * row = llama_get_logits_ith(ctx, 0); + std::vector logits(row, row + vocab_size); + llama_batch_free(batch); + return logits; +} + +// --------------------------------------------------------------------------- +// Run 1: chain reference decode +// --------------------------------------------------------------------------- + +static std::vector run_chain( + llama_model * model, + const llama_context_params & cparams, + const std::vector & prompt, + int32_t gen, + int32_t vocab_size, + int32_t prompt_chunk) { + + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + throw std::runtime_error("chain: failed to create target context"); + } + + std::vector out; + out.reserve(gen); + + // Decode prompt; logits for last prompt token give the first generated token. + const auto prompt_t0 = std::chrono::steady_clock::now(); + std::vector logits = decode_chain_prompt(ctx, prompt, vocab_size, prompt_chunk); + const auto prompt_t1 = std::chrono::steady_clock::now(); + + llama_pos pos = (llama_pos)prompt.size(); // next decode position + + double decode_ms = 0.0; + int32_t decode_steps = 0; + for (int32_t i = 0; i < gen; ++i) { + llama_token tok = argmax(logits.data(), vocab_size); + out.push_back(tok); + if (tok == QWEN35_EOS) { + LOG_INF("chain: EOS at step %d\n", i); + break; + } + const auto decode_t0 = std::chrono::steady_clock::now(); + logits = decode_single(ctx, tok, pos, vocab_size); + decode_ms += std::chrono::duration(std::chrono::steady_clock::now() - decode_t0).count(); + decode_steps++; + pos++; + } + + llama_free(ctx); + LOG_INF("chain: generated %d tokens\n", (int)out.size()); + LOG_INF("chain timing detail: prompt=%.2f ms decode_steps=%d decode_avg=%.2f ms decode_total=%.2f ms\n", + std::chrono::duration(prompt_t1 - prompt_t0).count(), + (int)decode_steps, + decode_steps > 0 ? decode_ms / (double)decode_steps : 0.0, + decode_ms); + return out; +} + +// --------------------------------------------------------------------------- +// Run 2: speculative decode +// --------------------------------------------------------------------------- + +static std::vector run_spec( + llama_model * target_model, + llama_model * draft_model, + const llama_context_params & target_cparams, + const llama_context_params & draft_cparams, + const std::vector & prompt, + int32_t gen, + int32_t vocab_size, + const llama_ddtree_params & ddparams, + int32_t prompt_chunk, + llama_speculative_tree_driver_stats * out_stats) { + + // Target context with hidden capture enabled (required by the driver). + llama_context * target_ctx = llama_init_from_model(target_model, target_cparams); + if (!target_ctx) { + throw std::runtime_error("spec: failed to create target context"); + } + llama_set_capture_hidden(target_ctx, true); + + llama_context * draft_ctx = llama_init_from_model(draft_model, draft_cparams); + if (!draft_ctx) { + llama_free(target_ctx); + throw std::runtime_error("spec: failed to create draft context"); + } + + // Init spec driver. + llama_speculative_tree_driver * driver = + llama_speculative_tree_driver_init(target_ctx, draft_ctx, ddparams); + if (!driver) { + llama_free(draft_ctx); + llama_free(target_ctx); + throw std::runtime_error("spec: llama_speculative_tree_driver_init returned NULL"); + } + + // Prime hidden capture in physical prompt chunks and ingest each chunk + // immediately. A single logical 16k decode only leaves the last ubatch in + // the capture tensor, which is not a valid DDTree/DFlash prompt state. + std::vector prompt_logits = + decode_chain_prompt(target_ctx, prompt, vocab_size, prompt_chunk, + [&](int32_t n_cur) { + llama_speculative_tree_driver_ingest_prompt_capture(driver, n_cur); + }); + + // Root token = argmax of last prompt position. + llama_token root_token = argmax(prompt_logits.data(), vocab_size); + llama_pos committed_pos = (llama_pos)prompt.size(); + + std::vector out; + out.reserve(gen); + + bool hit_eos = false; + while ((int32_t)out.size() < gen && !hit_eos) { + std::vector accepted = + llama_speculative_tree_driver_step(driver, root_token, committed_pos); + + if (accepted.empty()) { + // Driver signals terminal condition (e.g. EOS from target). + LOG_INF("spec: driver returned empty accepted list at out_n=%d\n", + (int)out.size()); + break; + } + + // Driver returns [committed_tokens..., bonus]. The bonus is the next + // step's root_token and is NOT yet in the KV cache, so it's not part + // of the committed output and doesn't advance committed_pos. + const int32_t n_committed = (int32_t)accepted.size() - 1; + for (int32_t i = 0; i < n_committed; ++i) { + llama_token t = accepted[i]; + out.push_back(t); + if (t == QWEN35_EOS) { + hit_eos = true; + break; + } + if ((int32_t)out.size() >= gen) { + break; + } + } + + root_token = accepted.back(); // bonus, fed as next step's tree[0] + committed_pos += (llama_pos)n_committed; + } + + if (out_stats != nullptr) { + *out_stats = llama_speculative_tree_driver_get_stats(driver); + } + + llama_speculative_tree_driver_free(driver); + llama_free(draft_ctx); + llama_free(target_ctx); + + LOG_INF("spec: generated %d tokens\n", (int)out.size()); + return out; +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char ** argv) { + std::string target_model_path; + std::string draft_model_path; + std::string prompt_tokens_path; + std::string prompt_text_path; + std::string out_spec_path; + std::string out_chain_path; + int32_t gen = 32; + int32_t n_gpu_layers = 99; + int32_t n_gpu_layers_draft = -1; + int32_t n_ctx = 4096; + int32_t n_batch_arg = 0; + int32_t n_ubatch_arg = 512; + float temp = 0.0f; + std::string kv_type_str = "f16"; // "f16", "q8_0", or "q4_0" + bool require_ddtree = false; + bool require_replay = false; + bool require_full_prompt_ingest = false; + bool prompt_add_special = false; + bool prompt_parse_special = true; + bool no_flash_attn = false; + + llama_ddtree_params ddparams; // defaults: budget=22, chain_seed=true + // temp is set separately below after arg parsing + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--target-model" && i + 1 < argc) { + target_model_path = argv[++i]; + } else if (arg == "--draft-model" && i + 1 < argc) { + draft_model_path = argv[++i]; + } else if (arg == "--prompt-tokens" && i + 1 < argc) { + prompt_tokens_path = argv[++i]; + } else if (arg == "--prompt-text" && i + 1 < argc) { + prompt_text_path = argv[++i]; + } else if (arg == "--prompt-add-special") { + prompt_add_special = true; + } else if (arg == "--no-prompt-parse-special") { + prompt_parse_special = false; + } else if (arg == "--gen" && i + 1 < argc) { + gen = std::atoi(argv[++i]); + } else if (arg == "--out-spec" && i + 1 < argc) { + out_spec_path = argv[++i]; + } else if (arg == "--out-chain" && i + 1 < argc) { + out_chain_path = argv[++i]; + } else if (arg == "--ddtree-budget" && i + 1 < argc) { + ddparams.budget = std::atoi(argv[++i]); + } else if (arg == "--ddtree-top-k" && i + 1 < argc) { + ddparams.top_k = std::atoi(argv[++i]); + } else if (arg == "--ddtree-no-chain-seed") { + ddparams.chain_seed = false; + } else if (arg == "--require-ddtree") { + require_ddtree = true; + } else if (arg == "--require-replay") { + require_replay = true; + } else if (arg == "--require-full-prompt-ingest") { + require_full_prompt_ingest = true; + } else if (arg == "--temp" && i + 1 < argc) { + temp = std::stof(argv[++i]); + } else if (arg == "--n-gpu-layers" && i + 1 < argc) { + n_gpu_layers = std::atoi(argv[++i]); + } else if (arg == "--draft-gpu-layers" && i + 1 < argc) { + n_gpu_layers_draft = std::atoi(argv[++i]); + } else if (arg == "--n-ctx" && i + 1 < argc) { + n_ctx = std::atoi(argv[++i]); + } else if (arg == "--n-batch" && i + 1 < argc) { + n_batch_arg = std::atoi(argv[++i]); + } else if (arg == "--n-ubatch" && i + 1 < argc) { + n_ubatch_arg = std::atoi(argv[++i]); + } else if (arg == "--no-flash-attn") { + no_flash_attn = true; + } else if (arg == "--kv-type" && i + 1 < argc) { + kv_type_str = argv[++i]; + } else if (arg == "-h" || arg == "--help") { + usage(argv[0]); + return 0; + } else { + fprintf(stderr, "unknown argument: %s\n", arg.c_str()); + usage(argv[0]); + return 1; + } + } + + if (target_model_path.empty()) { fprintf(stderr, "--target-model is required\n"); return 1; } + if (draft_model_path.empty()) { fprintf(stderr, "--draft-model is required\n"); return 1; } + if (prompt_tokens_path.empty() && prompt_text_path.empty()) { + fprintf(stderr, "one of --prompt-tokens or --prompt-text is required\n"); + return 1; + } + if (!prompt_tokens_path.empty() && !prompt_text_path.empty()) { + fprintf(stderr, "use only one of --prompt-tokens or --prompt-text\n"); + return 1; + } + if (out_spec_path.empty()) { fprintf(stderr, "--out-spec is required\n"); return 1; } + if (out_chain_path.empty()) { fprintf(stderr, "--out-chain is required\n"); return 1; } + if (gen <= 0) { fprintf(stderr, "--gen must be > 0\n"); return 1; } + + ddparams.temp = temp; + + const bool greedy = (temp == 0.0f); + if (!greedy) { + fprintf(stderr, + "warning: --temp %.4f is non-zero; token-trajectory bit-equal assertion " + "is DISABLED (stochastic sampling makes sequences non-deterministic)\n", + (double)temp); + } + + llama_backend_init(); + + int ret = 1; + + llama_model * target_model = nullptr; + llama_model * draft_model = nullptr; + + try { + // Load target model. + { + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + target_model = llama_model_load_from_file(target_model_path.c_str(), mparams); + if (!target_model) { + throw std::runtime_error("failed to load target model: " + target_model_path); + } + } + + // Load draft model. + { + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers_draft >= 0 ? n_gpu_layers_draft : n_gpu_layers; + mparams.target_model = target_model; + draft_model = llama_model_load_from_file(draft_model_path.c_str(), mparams); + if (!draft_model) { + throw std::runtime_error("failed to load draft model: " + draft_model_path); + } + } + + const auto * vocab = llama_model_get_vocab(target_model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + std::vector prompt; + if (!prompt_tokens_path.empty()) { + prompt = read_int32_file(prompt_tokens_path); + } else { + const std::string prompt_text = read_text_file(prompt_text_path); + prompt = tokenize_text(vocab, prompt_text, prompt_add_special, prompt_parse_special); + } + if (prompt.empty()) { + throw std::runtime_error("prompt is empty after loading/tokenization"); + } + LOG_INF("prompt: %d tokens\n", (int)prompt.size()); + + // Context params shared by both target contexts (chain and spec runs). + const uint32_t n_batch = (uint32_t)(n_batch_arg > 0 ? n_batch_arg : std::min(n_ctx, 2048)); + const uint32_t n_ubatch = (uint32_t)(n_ubatch_arg > 0 ? n_ubatch_arg : 512); + ggml_type kv_type = GGML_TYPE_F16; + if (kv_type_str == "f16") kv_type = GGML_TYPE_F16; + else if (kv_type_str == "q8_0") kv_type = GGML_TYPE_Q8_0; + else if (kv_type_str == "q4_0") kv_type = GGML_TYPE_Q4_0; + else { fprintf(stderr, "unknown --kv-type: %s\n", kv_type_str.c_str()); return 1; } + auto target_cparams = llama_context_default_params(); + target_cparams.n_ctx = (uint32_t)n_ctx; + target_cparams.n_batch = n_batch; + target_cparams.n_ubatch = std::min(n_batch, n_ubatch); + target_cparams.type_k = kv_type; + target_cparams.type_v = kv_type; + if (no_flash_attn) { + target_cparams.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } + const int32_t prompt_chunk = (int32_t)target_cparams.n_ubatch; + + // Draft context: dflash-draft doesn't keep a prompt KV cache; it consumes + // KV slots only for spec block decode (pos = committed_pos+i). A short + // ctx sized to prompt+gen+budget margin is sufficient and avoids the + // compute-buffer blow-up that target n_ctx would otherwise impose. + const uint32_t draft_n_ctx = (uint32_t)std::min( + (int32_t)4096, + std::max((int32_t)prompt.size() + gen + ddparams.budget + 64, (int32_t)1024)); + auto draft_cparams = llama_context_default_params(); + draft_cparams.n_ctx = draft_n_ctx; + draft_cparams.n_batch = std::min(draft_n_ctx, (uint32_t)64); + draft_cparams.n_ubatch = std::min(draft_cparams.n_batch, n_ubatch); + if (no_flash_attn) { + draft_cparams.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } + + // --------------------------------------------------------------- + // Run 1: chain reference + // --------------------------------------------------------------- + LOG_INF("=== Run 1: chain reference decode ===\n"); + const auto chain_t0 = std::chrono::steady_clock::now(); + std::vector chain_tokens = + run_chain(target_model, target_cparams, prompt, gen, vocab_size, prompt_chunk); + const auto chain_t1 = std::chrono::steady_clock::now(); + + write_token_file(out_chain_path, chain_tokens); + LOG_INF("chain: wrote %d tokens to %s\n", + (int)chain_tokens.size(), out_chain_path.c_str()); + LOG_INF("chain timing: %.3f sec\n", + std::chrono::duration(chain_t1 - chain_t0).count()); + + // --------------------------------------------------------------- + // Run 2: speculative decode + // --------------------------------------------------------------- + LOG_INF("=== Run 2: speculative decode ===\n"); + const auto spec_t0 = std::chrono::steady_clock::now(); + llama_speculative_tree_driver_stats spec_stats; + std::vector spec_tokens = + run_spec(target_model, draft_model, + target_cparams, draft_cparams, + prompt, gen, vocab_size, ddparams, prompt_chunk, &spec_stats); + const auto spec_t1 = std::chrono::steady_clock::now(); + + LOG_INF("spec stats: steps=%lld tree_verifies=%lld tree_nodes_total=%lld max_tree_nodes=%d dfs_last=%lld snapshot_replays=%lld fast_batched_replays=%lld fast_batched_cb=%lld fast_rollback=%lld committed=%lld max_commit=%d batched_committed=%lld batched_max_commit=%d batched_exact_same=%lld batched_exact_diff=%lld batched_longer=%lld batched_shorter=%lld prompt_ingests=%lld prompt_tokens=%lld tree_tokens=%lld replay_tokens=%lld capture_clamps=%lld\n", + (long long)spec_stats.n_steps, + (long long)spec_stats.n_tree_verifies, + (long long)spec_stats.n_tree_nodes_total, + (int)spec_stats.max_tree_nodes, + (long long)spec_stats.n_dfs_last_commits, + (long long)spec_stats.n_snapshot_replays, + (long long)spec_stats.n_fast_batched_replays, + (long long)spec_stats.n_fast_batched_callback_steps, + (long long)spec_stats.n_fast_rollback_steps, + (long long)spec_stats.n_committed_tokens, + (int)spec_stats.max_committed_tokens_per_step, + (long long)spec_stats.n_batched_posterior_committed_tokens, + (int)spec_stats.max_batched_posterior_committed_tokens_per_step, + (long long)spec_stats.n_batched_exact_same, + (long long)spec_stats.n_batched_exact_diff, + (long long)spec_stats.n_batched_exact_longer, + (long long)spec_stats.n_batched_exact_shorter, + (long long)spec_stats.n_prompt_ingest_calls, + (long long)spec_stats.n_prompt_ingested_tokens, + (long long)spec_stats.n_tree_ingested_tokens, + (long long)spec_stats.n_replay_ingested_tokens, + (long long)spec_stats.n_capture_clamps); + if (spec_stats.n_steps > 0) { + LOG_INF("spec acceptance: exact_avg_commit_per_step=%.3f batched_avg_commit_per_step=%.3f\n", + (double)spec_stats.n_committed_tokens / (double)spec_stats.n_steps, + (double)spec_stats.n_batched_posterior_committed_tokens / (double)spec_stats.n_steps); + const double inv_steps = 1.0 / (double)spec_stats.n_steps; + LOG_INF("spec timing avg: step=%.2f ms pack=%.2f draft=%.2f topk=%.2f build=%.2f snap=%.2f target_tree=%.2f posterior=%.2f accept=%.2f compact=%.2f rollback=%.2f ingest=%.2f tree_ingest=%.2f replay_ingest=%.2f replay=%.2f exact=%.2f\n", + spec_stats.t_step_ms * inv_steps, + spec_stats.t_target_feat_pack_ms * inv_steps, + spec_stats.t_draft_decode_ms * inv_steps, + spec_stats.t_topk_ms * inv_steps, + spec_stats.t_build_tree_ms * inv_steps, + spec_stats.t_snapshot_ms * inv_steps, + spec_stats.t_target_tree_decode_ms * inv_steps, + spec_stats.t_posterior_scan_ms * inv_steps, + spec_stats.t_accept_path_ms * inv_steps, + spec_stats.t_kv_compact_ms * inv_steps, + spec_stats.t_ssm_rollback_ms * inv_steps, + spec_stats.t_ingest_capture_ms * inv_steps, + spec_stats.t_tree_ingest_ms * inv_steps, + spec_stats.t_replay_ingest_ms * inv_steps, + spec_stats.t_replay_ms * inv_steps, + spec_stats.t_exact_validate_ms * inv_steps); + LOG_INF("spec timing total: prompt_ingest=%.2f ms tree_ingest=%.2f ms replay_ingest=%.2f ms\n", + spec_stats.t_prompt_ingest_ms, + spec_stats.t_tree_ingest_ms, + spec_stats.t_replay_ingest_ms); + } + LOG_INF("spec timing: %.3f sec\n", + std::chrono::duration(spec_t1 - spec_t0).count()); + + if (require_ddtree && (spec_stats.n_tree_verifies <= 0 || spec_stats.max_tree_nodes <= 1)) { + throw std::runtime_error("--require-ddtree failed: no multi-node DDTree verify observed"); + } + if (require_replay && spec_stats.n_snapshot_replays <= 0) { + throw std::runtime_error("--require-replay failed: snapshot+replay fallback was not exercised"); + } + if (require_full_prompt_ingest && + (spec_stats.n_capture_clamps != 0 || + spec_stats.n_prompt_ingested_tokens != (int64_t)prompt.size())) { + throw std::runtime_error("--require-full-prompt-ingest failed: prompt hidden capture was incomplete"); + } + + // Truncate to gen if the driver produced more tokens than requested. + if ((int32_t)spec_tokens.size() > gen) { + spec_tokens.resize(gen); + } + + write_token_file(out_spec_path, spec_tokens); + LOG_INF("spec: wrote %d tokens to %s\n", + (int)spec_tokens.size(), out_spec_path.c_str()); + + // --------------------------------------------------------------- + // Compare trajectories + // --------------------------------------------------------------- + const int32_t chain_n = (int32_t)chain_tokens.size(); + const int32_t spec_n = (int32_t)spec_tokens.size(); + const int32_t cmp_n = std::min(chain_n, spec_n); + + int32_t first_divergence = -1; + int32_t match_count = 0; + for (int32_t k = 0; k < cmp_n; ++k) { + if (chain_tokens[k] == spec_tokens[k]) { + match_count++; + } else if (first_divergence < 0) { + first_divergence = k; + break; + } + } + + if (first_divergence < 0 && match_count == cmp_n) { + // All positions matched. + printf("chain_n=%d spec_n=%d first_divergence=none bytes_match=%d/%d\n", + chain_n, spec_n, match_count, cmp_n); + } else { + printf("chain_n=%d spec_n=%d first_divergence=%d bytes_match=%d/%d\n", + chain_n, spec_n, first_divergence, match_count, cmp_n); + } + + if (greedy) { + if (first_divergence >= 0) { + fprintf(stderr, + "FAIL: token-trajectory divergence at position %d " + "(greedy decoding MUST produce bit-equal sequences)\n" + " chain[%d] = %d\n" + " spec[%d] = %d\n", + first_divergence, + first_divergence, (int)chain_tokens[first_divergence], + first_divergence, (int)spec_tokens[first_divergence]); + ret = 1; + } else { + LOG_INF("PASS: all %d token positions are bit-equal\n", cmp_n); + ret = 0; + } + } else { + // Non-greedy: no hard assertion, just report. + LOG_INF("non-greedy mode: token-trajectory comparison is informational only\n"); + ret = 0; + } + + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + ret = 1; + } + + if (draft_model) { llama_model_free(draft_model); } + if (target_model) { llama_model_free(target_model); } + llama_backend_free(); + return ret; +} diff --git a/tests/test-speculative-tree.cpp b/tests/test-speculative-tree.cpp new file mode 100644 index 000000000000..61644447cc15 --- /dev/null +++ b/tests/test-speculative-tree.cpp @@ -0,0 +1,336 @@ +// Standalone unit tests for speculative-tree.{h,cpp}. +// No model or GPU required. Uses hand-computed fixtures. +// Exits non-zero on first failure. + +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include "speculative-tree.h" + +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Test 1: build_ddtree small case +// +// L=2, K=2, budget=4 (total nodes including root), chain_seed=true, root=10. +// +// top_log_probs[L*K]: +// position 0 (depth 1): [-0.1, -1.5] tokens [20, 21] +// position 1 (depth 2): [-0.2, -1.6] tokens [30, 31] +// +// Chain seeding (chain_depth = min(2, budget-1=3) = 2): +// d=1: insert tok=20 as node 1 (parent=0, depth=1), cum_logw=-0.1 +// push sibling: logw=-1.5, parent=0, depth=1, rank=1, tok=21 +// d=2: insert tok=30 as node 2 (parent=1, depth=2), cum_logw=-0.3 +// push sibling: logw=-0.3-(-0.2)+(-1.6)=-1.7, parent=1, depth=2, rank=1, tok=31 +// +// Heap after chain: {logw=-1.5,tok=21} and {logw=-1.7,tok=31} +// Pop best: logw=-1.5 → tok=21 inserted as node 3 (parent=0, depth=1). Done. +// +// Expected tree nodes: +// [0] root(10), parent=-1, depth=0 +// [1] tok=20, parent=0, depth=1 +// [2] tok=30, parent=1, depth=2 +// [3] tok=21, parent=0, depth=1 +// --------------------------------------------------------------------------- +static void test_build_ddtree_small() { + const int L = 2, K = 2; + const float top_log_probs[] = { + -0.1f, -1.5f, // position 0 + -0.2f, -1.6f, // position 1 + }; + const int32_t top_token_ids[] = { + 20, 21, // position 0 + 30, 31, // position 1 + }; + + llama_ddtree_params p; + p.budget = 4; // total node cap including root + p.chain_seed = true; + p.temp = 1.0f; + + const llama_ddtree tree = build_ddtree( + top_log_probs, top_token_ids, L, K, /*root_token*/ 10, p); + + assert(tree.nodes.size() == 4); + + // Node 0: root + assert(tree.nodes[0].token_id == 10); + assert(tree.nodes[0].parent_idx == -1); + assert(tree.nodes[0].depth == 0); + + // Node 1: tok=20, depth-1 chain top-1 + assert(tree.nodes[1].token_id == 20); + assert(tree.nodes[1].parent_idx == 0); + assert(tree.nodes[1].depth == 1); + + // Node 2: tok=30, depth-2 chain top-1 + assert(tree.nodes[2].token_id == 30); + assert(tree.nodes[2].parent_idx == 1); + assert(tree.nodes[2].depth == 2); + + // Node 3: tok=21, best heap candidate (sibling of 20 at depth 1) + assert(tree.nodes[3].token_id == 21); + assert(tree.nodes[3].parent_idx == 0); + assert(tree.nodes[3].depth == 1); + + // Visibility: 4x4 mask, row i has 1 at all ancestors of i (inclusive). + // node 0 ancestors: {0} + // node 1 ancestors: {0, 1} + // node 2 ancestors: {0, 1, 2} + // node 3 ancestors: {0, 3} + const int N = 4; + assert(tree.visibility.size() == (size_t)(N * N)); + // row 0 + assert(tree.visibility[0*N+0] == 1); + assert(tree.visibility[0*N+1] == 0); + assert(tree.visibility[0*N+2] == 0); + assert(tree.visibility[0*N+3] == 0); + // row 1 + assert(tree.visibility[1*N+0] == 1); + assert(tree.visibility[1*N+1] == 1); + assert(tree.visibility[1*N+2] == 0); + assert(tree.visibility[1*N+3] == 0); + // row 2 + assert(tree.visibility[2*N+0] == 1); + assert(tree.visibility[2*N+1] == 1); + assert(tree.visibility[2*N+2] == 1); + assert(tree.visibility[2*N+3] == 0); + // row 3 + assert(tree.visibility[3*N+0] == 1); + assert(tree.visibility[3*N+1] == 0); + assert(tree.visibility[3*N+2] == 0); + assert(tree.visibility[3*N+3] == 1); + + printf("PASS: test_build_ddtree_small\n"); +} + +// --------------------------------------------------------------------------- +// Test 2: follow_verified_tree +// +// Uses the same 4-node tree from test 1. +// Tree structure: +// node 0: root(10), no parent +// node 1: tok=20, parent=0 +// node 2: tok=30, parent=1 +// node 3: tok=21, parent=0 +// +// child_maps derived from parent_idx: +// node 0 children: {20→1, 21→3} +// node 1 children: {30→2} +// node 2 children: {} +// node 3 children: {} +// +// Semantic: posterior[i] is the target model's argmax prediction at node i. +// The walk starts at node 0; at each step it looks for a child whose +// token_id matches posterior[current]. If found, advance; otherwise stop. +// accepted = [visited indices]; next_token = posterior[deepest accepted]. +// +// Case A: posterior = [20, 30, 99, 99] +// Start at 0. posterior[0]=20 → child 1 exists. Move to 1. +// posterior[1]=30 → child 2 exists. Move to 2. +// posterior[2]=99 → no child. Stop. +// accepted=[0,1,2], next_token=99. +// +// Case B: posterior = [21, 99, 99, 99] +// Start at 0. posterior[0]=21 → child 3 exists. Move to 3. +// posterior[3]=99 → no child. Stop. +// accepted=[0,3], next_token=99. +// +// Case C: posterior = [5, 99, 99, 99] +// Start at 0. posterior[0]=5 → no child. Stop immediately. +// accepted=[0], next_token=5. +// --------------------------------------------------------------------------- +static void test_follow_verified_tree() { + // Build the same 4-node tree via build_ddtree. + const int L = 2, K = 2; + const float top_log_probs[] = { -0.1f, -1.5f, -0.2f, -1.6f }; + const int32_t top_token_ids[] = { 20, 21, 30, 31 }; + + llama_ddtree_params p; + p.budget = 4; + p.chain_seed = true; + p.temp = 1.0f; + + const llama_ddtree tree = build_ddtree( + top_log_probs, top_token_ids, L, K, 10, p); + + std::vector accepted; + llama_token next_tok = -1; + + // Case A: greedy chain match + { + const int32_t posterior[] = { 20, 30, 99, 99 }; + follow_verified_tree(tree, posterior, accepted, next_tok); + assert(accepted.size() == 3); + assert(accepted[0] == 0); + assert(accepted[1] == 1); + assert(accepted[2] == 2); + assert(next_tok == 99); + } + + // Case B: branch match (tok=21 path) + { + const int32_t posterior[] = { 21, 99, 99, 99 }; + follow_verified_tree(tree, posterior, accepted, next_tok); + assert(accepted.size() == 2); + assert(accepted[0] == 0); + assert(accepted[1] == 3); + assert(next_tok == 99); + } + + // Case C: no match at root level — only root accepted + { + const int32_t posterior[] = { 5, 99, 99, 99 }; + follow_verified_tree(tree, posterior, accepted, next_tok); + assert(accepted.size() == 1); + assert(accepted[0] == 0); + assert(next_tok == 5); + } + + printf("PASS: test_follow_verified_tree\n"); +} + +// --------------------------------------------------------------------------- +// Test 3: build_tree_visibility — hand-constructed 5-node tree +// +// Manually define a tree: +// node 0: root, parent=-1 +// node 1: child of 0, parent=0 +// node 2: child of 1, parent=1 +// node 3: child of 0, parent=0 +// node 4: child of 3, parent=3 +// +// Expected visibility (5x5): +// row 0: {0} → [1,0,0,0,0] +// row 1: {0,1} → [1,1,0,0,0] +// row 2: {0,1,2} → [1,1,1,0,0] +// row 3: {0,3} → [1,0,0,1,0] +// row 4: {0,3,4} → [1,0,0,1,1] +// --------------------------------------------------------------------------- +static void test_build_tree_visibility() { + std::vector nodes = { + { 10, -1, 0 }, // 0: root + { 20, 0, 1 }, // 1: child of 0 + { 30, 1, 2 }, // 2: child of 1 + { 40, 0, 1 }, // 3: child of 0 + { 50, 3, 2 }, // 4: child of 3 + }; + + const int N = (int)nodes.size(); + std::vector vis(N * N, 0); + build_tree_visibility(nodes, vis.data()); + + // Row 0 + assert(vis[0*N+0] == 1); assert(vis[0*N+1] == 0); + assert(vis[0*N+2] == 0); assert(vis[0*N+3] == 0); assert(vis[0*N+4] == 0); + // Row 1 + assert(vis[1*N+0] == 1); assert(vis[1*N+1] == 1); + assert(vis[1*N+2] == 0); assert(vis[1*N+3] == 0); assert(vis[1*N+4] == 0); + // Row 2 + assert(vis[2*N+0] == 1); assert(vis[2*N+1] == 1); + assert(vis[2*N+2] == 1); assert(vis[2*N+3] == 0); assert(vis[2*N+4] == 0); + // Row 3 + assert(vis[3*N+0] == 1); assert(vis[3*N+1] == 0); + assert(vis[3*N+2] == 0); assert(vis[3*N+3] == 1); assert(vis[3*N+4] == 0); + // Row 4 + assert(vis[4*N+0] == 1); assert(vis[4*N+1] == 0); + assert(vis[4*N+2] == 0); assert(vis[4*N+3] == 1); assert(vis[4*N+4] == 1); + + printf("PASS: test_build_tree_visibility\n"); +} + +// --------------------------------------------------------------------------- +// Test 4: extract_top_k_logprobs +// +// Feed a [3, 8] logits matrix with known values at temp=1.0. +// K=3. Verify output ordering (descending log-prob) and values to 1e-5. +// +// Row 0: logits = [0,1,2,3,4,5,6,7] (argmax = id=7) +// Row 1: logits = [7,6,5,4,3,2,1,0] (argmax = id=0) +// Row 2: logits = [0,0,0,0,10,0,0,0] (argmax = id=4, dominant) +// +// For row 0: log_z = logsumexp([0,1,2,3,4,5,6,7]) +// log_z = 7 + log(sum of exp(k-7) for k=0..7) = 7 + log(exp(-7)+...+exp(0)) +// Top 3 by logit: ids [7,6,5], logprobs = [7-log_z, 6-log_z, 5-log_z] +// +// We compute expected values in the test itself using std::log and std::exp. +// --------------------------------------------------------------------------- +static float logsumexp_vec(const float * v, int n) { + float mx = v[0]; + for (int i = 1; i < n; i++) if (v[i] > mx) mx = v[i]; + float s = 0.0f; + for (int i = 0; i < n; i++) s += std::exp(v[i] - mx); + return mx + std::log(s); +} + +static void test_extract_top_k_logprobs() { + const int L = 3, V = 8, K = 3; + + const float logits[L * V] = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, // row 0 + 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f, // row 1 + 0.0f, 0.0f, 0.0f, 0.0f, 10.0f, 0.0f, 0.0f, 0.0f, // row 2 + }; + + std::vector out_lp(L * K); + std::vector out_id(L * K); + + extract_top_k_logprobs(logits, L, V, K, 1.0f, + out_lp.data(), out_id.data()); + + // Check row 0: top-3 tokens by logit are ids 7, 6, 5. + assert(out_id[0*K+0] == 7); + assert(out_id[0*K+1] == 6); + assert(out_id[0*K+2] == 5); + // Check log-prob values against manual logsumexp. + { + const float log_z = logsumexp_vec(logits + 0*V, V); + assert(std::fabs(out_lp[0*K+0] - (7.0f - log_z)) < 1e-5f); + assert(std::fabs(out_lp[0*K+1] - (6.0f - log_z)) < 1e-5f); + assert(std::fabs(out_lp[0*K+2] - (5.0f - log_z)) < 1e-5f); + } + + // Check row 1: top-3 tokens are ids 0, 1, 2. + assert(out_id[1*K+0] == 0); + assert(out_id[1*K+1] == 1); + assert(out_id[1*K+2] == 2); + { + const float log_z = logsumexp_vec(logits + 1*V, V); + assert(std::fabs(out_lp[1*K+0] - (7.0f - log_z)) < 1e-5f); + assert(std::fabs(out_lp[1*K+1] - (6.0f - log_z)) < 1e-5f); + assert(std::fabs(out_lp[1*K+2] - (5.0f - log_z)) < 1e-5f); + } + + // Check row 2: id=4 dominates with logit=10. + assert(out_id[2*K+0] == 4); + { + const float log_z = logsumexp_vec(logits + 2*V, V); + assert(std::fabs(out_lp[2*K+0] - (10.0f - log_z)) < 1e-5f); + } + + // Verify descending order within each row. + for (int row = 0; row < L; row++) { + for (int k = 0; k < K - 1; k++) { + assert(out_lp[row*K+k] >= out_lp[row*K+k+1]); + } + } + + printf("PASS: test_extract_top_k_logprobs\n"); +} + +// --------------------------------------------------------------------------- + +int main() { + test_build_ddtree_small(); + test_follow_verified_tree(); + test_build_tree_visibility(); + test_extract_top_k_logprobs(); + printf("All tests passed.\n"); + return 0; +} diff --git a/tools/server/README.md b/tools/server/README.md index b30309bf3b06..ee5aaa6b5e0b 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1825,3 +1825,39 @@ You can use html formatting if needed. ``` + +## Speculative decoding: DDTree (dflash) + +DDTree is a tree-structured speculative decoding method that uses a compact dflash-draft companion model to propose multiple token paths in parallel, then verifies them against the target model in a single forward pass. It typically yields 5-10x accepted tokens per target forward pass compared to 1 for autoregressive decoding. + +**Required flags:** + +- `-m ` — target model (Qwen3.5-27B Q4_K_M or similar) +- `-md ` — dflash-draft GGUF (arch `LLM_ARCH_DFLASH_DRAFT`) +- `--speculative-mode ddtree` + +**Optional flags:** + +- `--ddtree-budget N` — tree node budget per spec step (default: 22) +- `--ddtree-temp F` — temperature for draft log-prob extraction (default: 1.0) +- `--ddtree-no-chain-seed` — disable greedy chain seed for the tree heap + +**Constraints (Phase 5):** + +- `--parallel 1` only — multi-slot DDTree is out of scope for Phase 5 +- Target must be Qwen3.5-27B; draft must be the matching `dflash-draft` GGUF +- Greedy verification only — DDTree's accept decision is argmax-based; temperature and top-p affect only draft log-prob extraction, not acceptance +- Known limitation: SSM conv state may diverge after ~17 tokens per spec step boundary; full bit-equal awaits `ggml_ssm_conv_tree_persist` op (follow-up) + +**Example command:** + +```bash +llama-server \ + -m models/Qwen3.5-27B-Q4_K_M.gguf \ + -md models/draft/model.gguf \ + --speculative-mode ddtree --ddtree-budget 22 \ + -ctk tq3_0 -ctv tq3_0 \ + --port 8002 -ngl 99 -c 16384 +``` + +Compatible with `--api-key`, `--chat-template`, `--jinja`, and all standard server flags. diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index e134b3cfb260..9881981b8afa 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -9,6 +9,7 @@ #include "log.h" #include "sampling.h" #include "speculative.h" +#include "speculative-tree-driver.h" #include "mtmd.h" #include "mtmd-helper.h" @@ -58,6 +59,11 @@ struct server_slot { common_speculative * spec = nullptr; + // DDTree speculative decoding state (Phase 5); null when ddtree_mode is off + llama_speculative_tree_driver * spec_driver = nullptr; + llama_token ddtree_root_tok = LLAMA_TOKEN_NULL; // bonus token from prev step / first sampled + llama_pos ddtree_committed_pos = 0; // KV positions committed so far + // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 std::unique_ptr task; @@ -187,6 +193,14 @@ struct server_slot { n_draft_total = 0; n_draft_accepted = 0; + // free DDTree driver if one was created for this request + if (spec_driver) { + llama_speculative_tree_driver_free(spec_driver); + spec_driver = nullptr; + } + ddtree_root_tok = LLAMA_TOKEN_NULL; + ddtree_committed_pos = 0; + task_prev = std::move(task); task.reset(); @@ -399,6 +413,56 @@ struct server_slot { ); } + if (spec_driver) { + const llama_speculative_tree_driver_stats st = + llama_speculative_tree_driver_get_stats(spec_driver); + if (st.n_steps > 0) { + SLT_CNT(*this, + "ddtree stats: steps=%lld exact_avg_commit=%0.3f batched_avg_commit=%0.3f exact_max=%d batched_max=%d snapshot_replays=%lld fast_batched_replays=%lld fast_batched_cb=%lld fast_rollback=%lld batched_exact_diff=%lld batched_longer=%lld batched_shorter=%lld capture_clamps=%lld\n", + (long long)st.n_steps, + (double)st.n_committed_tokens / (double)st.n_steps, + (double)st.n_batched_posterior_committed_tokens / (double)st.n_steps, + (int)st.max_committed_tokens_per_step, + (int)st.max_batched_posterior_committed_tokens_per_step, + (long long)st.n_snapshot_replays, + (long long)st.n_fast_batched_replays, + (long long)st.n_fast_batched_callback_steps, + (long long)st.n_fast_rollback_steps, + (long long)st.n_batched_exact_diff, + (long long)st.n_batched_exact_longer, + (long long)st.n_batched_exact_shorter, + (long long)st.n_capture_clamps); + const double inv_steps = 1.0 / (double)st.n_steps; + SLT_CNT(*this, + "ddtree timing avg: step=%0.2f ms pack=%0.2f draft=%0.2f topk=%0.2f build=%0.2f snap=%0.2f target_tree=%0.2f posterior=%0.2f accept=%0.2f compact=%0.2f rollback=%0.2f ingest=%0.2f tree_ingest=%0.2f replay_ingest=%0.2f replay=%0.2f exact=%0.2f exact_decode=%0.2f exact_sample=%0.2f exact_advance=%0.2f exact_nodes=%0.2f\n", + st.t_step_ms * inv_steps, + st.t_target_feat_pack_ms * inv_steps, + st.t_draft_decode_ms * inv_steps, + st.t_topk_ms * inv_steps, + st.t_build_tree_ms * inv_steps, + st.t_snapshot_ms * inv_steps, + st.t_target_tree_decode_ms * inv_steps, + st.t_posterior_scan_ms * inv_steps, + st.t_accept_path_ms * inv_steps, + st.t_kv_compact_ms * inv_steps, + st.t_ssm_rollback_ms * inv_steps, + st.t_ingest_capture_ms * inv_steps, + st.t_tree_ingest_ms * inv_steps, + st.t_replay_ingest_ms * inv_steps, + st.t_replay_ms * inv_steps, + st.t_exact_validate_ms * inv_steps, + st.t_exact_decode_ms * inv_steps, + st.t_exact_sample_ms * inv_steps, + st.t_exact_advance_ms * inv_steps, + (double)st.n_exact_validate_nodes * inv_steps); + SLT_CNT(*this, + "ddtree timing total: prompt_ingest=%0.2f ms tree_ingest=%0.2f ms replay_ingest=%0.2f ms\n", + st.t_prompt_ingest_ms, + st.t_tree_ingest_ms, + st.t_replay_ingest_ms); + } + } + common_speculative_print_stats(spec); } @@ -563,6 +627,11 @@ struct server_context_impl { llama_model_ptr model_dft; + // DDTree draft context — separate from the chain-mode draft since it needs + // different n_ctx / n_batch sizing (small, fixed to draft block_size). + // Null when ddtree_mode is off. + llama_context * ctx_ddtree_dft = nullptr; + bool add_bos_token = true; int32_t n_ctx; // total context for all clients / slots @@ -600,6 +669,16 @@ struct server_context_impl { for (server_slot & slot : slots) { common_speculative_free(slot.spec); slot.spec = nullptr; + + if (slot.spec_driver) { + llama_speculative_tree_driver_free(slot.spec_driver); + slot.spec_driver = nullptr; + } + } + + if (ctx_ddtree_dft) { + llama_free(ctx_ddtree_dft); + ctx_ddtree_dft = nullptr; } llama_batch_free(batch); @@ -682,6 +761,9 @@ struct server_context_impl { params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides; auto mparams_dft = common_model_params_to_llama(params_dft); + if (params_base.speculative.ddtree_mode) { + mparams_dft.target_model = model; + } model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); if (model_dft == nullptr) { @@ -691,6 +773,32 @@ struct server_context_impl { params_base.speculative.model_dft = model_dft.get(); params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft); + + // DDTree mode: create a dedicated draft context with the sizing the + // dflash-draft model expects (small n_ctx, small n_batch = block_size). + if (params_base.speculative.ddtree_mode) { + // Phase 5: single-slot only — enforce this up front. + if (params_base.n_parallel > 1) { + fprintf(stderr, "DDTree mode supports only --parallel 1 in Phase 5\n"); + return false; + } + + llama_context_params cparams_ddft = llama_context_default_params(); + cparams_ddft.n_ctx = 2048 + 16; // DRAFT_CTX_MAX + block_size + cparams_ddft.n_batch = 16; // one block per decode + + ctx_ddtree_dft = llama_init_from_model(model_dft.get(), cparams_ddft); + if (!ctx_ddtree_dft) { + SRV_ERR("%s", "failed to create DDTree draft context\n"); + return false; + } + SRV_INF("%s", "DDTree draft context initialized\n"); + + // Enable hidden capture on the target context so the driver can + // read intermediate layer features for tree scoring. + llama_set_capture_hidden(ctx, true); + SRV_INF("%s", "DDTree: hidden capture enabled on target context\n"); + } } std::string & mmproj_path = params_base.mmproj.path; @@ -1214,6 +1322,14 @@ struct server_context_impl { slot.task = std::make_unique(std::move(task)); + if (params_base.speculative.ddtree_mode && !slot.task->is_child()) { + const std::string prompt_text = slot.task->tokens.detokenize(ctx, true); + SLT_INF(slot, "DDTree request prompt: tokens = %d, chars = %zu\n", + slot.task->n_tokens(), prompt_text.size()); + SLT_INF(slot, "DDTree request prompt begin\n%s\nDDTree request prompt end\n", + prompt_text.c_str()); + } + slot.state = slot.task->is_child() ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt : SLOT_STATE_STARTED; @@ -2091,6 +2207,15 @@ struct server_context_impl { continue; } + // DDTree slots drive their own batch submissions internally via the driver. + // Skip the normal token-addition and batch-decode path for them. + if (params_base.speculative.ddtree_mode && slot.spec_driver) { + if (!slot_batched) { + slot_batched = &slot; + } + continue; + } + // check if we can batch this slot with the previous one if (!slot_batched) { slot_batched = &slot; @@ -2155,9 +2280,14 @@ struct server_context_impl { } } - // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); + // process in chunks of params.n_batch. In DDTree mode hidden capture is + // only retained for the physical ubatch produced by llama_decode(), so + // prompt prefill must be submitted in n_ubatch-sized chunks to keep the + // driver's target feature ring complete. int32_t n_ubatch = llama_n_ubatch(ctx); + const int32_t n_batch_default = llama_n_batch(ctx); + const int32_t n_batch_prompt = params_base.speculative.ddtree_mode ? n_ubatch : n_batch_default; + int32_t n_batch = n_batch_prompt; float alora_scale = -1.0f; size_t alora_disabled_id = 0; @@ -2333,6 +2463,19 @@ struct server_context_impl { SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past); } + + if (params_base.speculative.ddtree_mode && n_past > 0) { + const int32_t ddtree_rebuild_nt = llama_speculative_tree_driver_context_window(); + const int32_t n_rebuild = std::min(ddtree_rebuild_nt, slot.task->n_tokens()); + const int32_t n_past_max = std::max(0, slot.task->n_tokens() - n_rebuild); + + if (n_past > n_past_max) { + SLT_WRN(slot, + "DDTree prompt cache reuse capped from %d to %d to rebuild the last %d target-feature tokens\n", + n_past, n_past_max, n_rebuild); + n_past = n_past_max; + } + } } else { // if we don't cache the prompt, we have to remove all previous tokens n_past = 0; @@ -2590,6 +2733,12 @@ struct server_context_impl { break; } } + if (params_base.speculative.ddtree_mode) { + const int32_t ddtree_rebuild_nt = llama_speculative_tree_driver_context_window(); + if (slot.task->n_tokens() == slot.prompt.n_tokens() + ddtree_rebuild_nt) { + should_break = true; + } + } if (should_break) { break; } @@ -2598,6 +2747,9 @@ struct server_context_impl { // the number of tokens added to the batch for the current slot const auto n_tokens_cur = batch.n_tokens - n_tokens_prev; + const bool ddtree_rebuild_checkpoint = + params_base.speculative.ddtree_mode && + slot.task->n_tokens() == slot.prompt.n_tokens() + llama_speculative_tree_driver_context_window(); // entire prompt has been processed if (slot.prompt.n_tokens() == slot.task->n_tokens()) { @@ -2614,7 +2766,11 @@ struct server_context_impl { slot.init_sampler(); SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); } else { - if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) { + if (ddtree_rebuild_checkpoint) { + do_checkpoint = do_checkpoint && true; + SLT_INF(slot, "creating DDTree rebuild checkpoint before the last %d prompt tokens at position %d\n", + llama_speculative_tree_driver_context_window(), slot.prompt.n_tokens()); + } else if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) { // near the end of the prompt do_checkpoint = do_checkpoint && true; } else { @@ -2713,10 +2869,31 @@ struct server_context_impl { } if (batch.n_tokens == 0) { - SRV_WRN("%s", "no tokens to decode\n"); + // DDTree slots don't put tokens in the main batch (the driver handles its + // own tree-mode decodes after the main loop). When ddtree_mode is on, the + // main batch can legitimately be empty for several consecutive ticks while + // slots transition through DONE_PROMPT → GENERATING or wait for the next + // request — don't treat that as a hung scheduler. + bool ddtree_active = false; + if (params_base.speculative.ddtree_mode) { + for (const auto & slot : slots) { + if (slot.spec_driver != nullptr || + slot.state == SLOT_STATE_PROCESSING_PROMPT || + slot.state == SLOT_STATE_DONE_PROMPT || + slot.state == SLOT_STATE_STARTED || + slot.state == SLOT_STATE_GENERATING) { + ddtree_active = true; + break; + } + } + } + + if (!ddtree_active) { + SRV_WRN("%s", "no tokens to decode\n"); - if (++n_empty_consecutive > 3) { - GGML_ABORT("fatal error - please provide logs and repro in %s\n", "https://github.com/ggml-org/llama.cpp/pull/20277"); + if (++n_empty_consecutive > 3) { + GGML_ABORT("fatal error - please provide logs and repro in %s\n", "https://github.com/ggml-org/llama.cpp/pull/20277"); + } } } else { n_empty_consecutive = 0; @@ -2795,7 +2972,36 @@ struct server_context_impl { i_next = i + n_tokens; // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx); + n_batch = n_batch_prompt; + + // DDTree: incrementally ingest the just-decoded ubatch's hidden capture + // into each prompt-processing slot's ring buffer. The capture buffer is + // overwritten on every llama_decode, so we MUST consume it before the + // next inner-loop iteration. Phase 5 is single-slot, so the entire + // batch_view belongs to one slot. + if (params_base.speculative.ddtree_mode && ctx_ddtree_dft) { + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_PROCESSING_PROMPT && + slot.state != SLOT_STATE_DONE_PROMPT) { + continue; + } + if (slot.spec_driver == nullptr) { + llama_ddtree_params dp; + dp.budget = params_base.speculative.ddtree_budget; + dp.temp = params_base.speculative.ddtree_temp; + dp.chain_seed = params_base.speculative.ddtree_chain_seed; + dp.block_size = 16; + slot.spec_driver = llama_speculative_tree_driver_init(ctx, ctx_ddtree_dft, dp); + if (!slot.spec_driver) { + SLT_ERR(slot, "%s", "failed to allocate DDTree driver during prompt processing\n"); + continue; + } + } + // Append n_tokens columns from this decode's capture buffer to the ring. + llama_speculative_tree_driver_ingest_prompt_capture( + slot.spec_driver, (int32_t)n_tokens); + } + } // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too for (auto & slot : slots) { @@ -2853,6 +3059,46 @@ struct server_context_impl { // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; + if (params_base.speculative.ddtree_mode && ctx_ddtree_dft) { + // DDTree: the driver was lazy-allocated and the ring was filled + // incrementally during prompt prefill (one ingest per inner-loop + // ubatch decode). If something went wrong upstream we may not + // have a driver here — fall back to EOS. + if (!slot.spec_driver) { + SLT_ERR(slot, "%s", "DDTree driver missing at GENERATING transition\n"); + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + continue; + } + + // Greedy-sample the first generated token from the last prompt logit. + const int tok_idx = slot.i_batch - i; + const float * logits = llama_get_logits_ith(ctx, tok_idx); + const int n_vocab = llama_vocab_n_tokens(vocab); + llama_token first_tok = 0; + float best = logits[0]; + for (int v = 1; v < n_vocab; ++v) { + if (logits[v] > best) { best = logits[v]; first_tok = (llama_token)v; } + } + + slot.ddtree_root_tok = first_tok; + slot.ddtree_committed_pos = (llama_pos)slot.prompt.tokens.size(); + slot.i_batch = -1; + // slot.reset() doesn't touch has_next_token; if the previous + // request ended on a stop condition the flag is still false, + // and the DDTree gen block would skip this slot forever. + slot.has_next_token = true; + + slot.t_start_generation = ggml_time_us(); + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + continue; // will be handled in the DDTree generation loop below + } + if (slot.can_speculate()) { common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens()); } @@ -2860,6 +3106,11 @@ struct server_context_impl { continue; // continue loop of slots } + // DDTree slots run their driver step outside this loop (after llama_decode) + if (params_base.speculative.ddtree_mode && slot.spec_driver) { + continue; + } + if (slot.i_batch_dft.size() > 0) { continue; // sample using speculative decoding } @@ -2962,6 +3213,120 @@ struct server_context_impl { } } + // DDTree generation: run one spec-decode step per slot, after the main llama_decode. + // The driver calls llama_decode on ctx internally (tree-mode batch), so it must run + // outside the main decode loop. + if (params_base.speculative.ddtree_mode) { + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_GENERATING || !slot.spec_driver) { + continue; + } + if (slot.ddtree_root_tok == LLAMA_TOKEN_NULL || !slot.has_next_token) { + continue; + } + + // Grammar-aware verify: clone the slot sampler and let the + // driver pick each chain step via the cloned sampler+grammar. + // The clone gets root_token accepted up front so the first + // sample at row 0 sees the grammar state "after root". + // LLAMA_DDTREE_NO_GRAMMAR_VERIFY=1 disables this and falls back + // to internal argmax (diagnostic; baseline for accept-rate + // comparison). + static const bool s_no_grammar_verify = []{ + const char * e = getenv("LLAMA_DDTREE_NO_GRAMMAR_VERIFY"); + return e && e[0] == '1'; + }(); + struct ddtree_verify_state { + common_sampler * smpl; + llama_context * ctx; + }; + ddtree_verify_state vstate{ + /*smpl=*/ (!s_no_grammar_verify && slot.smpl) ? common_sampler_clone(slot.smpl.get()) : nullptr, + /*ctx =*/ ctx, + }; + if (vstate.smpl) { + common_sampler_accept(vstate.smpl, slot.ddtree_root_tok, true); + } + llama_speculative_tree_verify_cbs vcbs{}; + vcbs.user_data = &vstate; + vcbs.sample_cb = [](void * ud, int32_t logits_row_idx) -> int32_t { + auto * s = (ddtree_verify_state *)ud; + if (!s->smpl) { + return 0; // shouldn't happen; driver falls back if cb null + } + return (int32_t)common_sampler_sample(s->smpl, s->ctx, logits_row_idx, /*grammar_first=*/true); + }; + vcbs.advance_cb = [](void * ud, llama_token tok) { + auto * s = (ddtree_verify_state *)ud; + if (s->smpl) { + common_sampler_accept(s->smpl, tok, true); + } + }; + + auto accepted = llama_speculative_tree_driver_step( + slot.spec_driver, slot.ddtree_root_tok, slot.ddtree_committed_pos, + vstate.smpl ? &vcbs : nullptr); + + if (vstate.smpl) { + common_sampler_free(vstate.smpl); + } + + if (accepted.empty()) { + SLT_ERR(slot, "%s", "DDTree driver step returned empty result, treating as EOS\n"); + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + continue; + } + + // accepted: [root_echo, draft_accepted..., bonus] + // commit everything except the bonus token. The driver used a + // grammar-aware verify (via vcbs above) so all accepted tokens + // are guaranteed to be in the sampler+grammar's allowed set. + const int n_committed = (int)accepted.size() - 1; + + const int64_t t_current = ggml_time_us(); + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; + + bool slot_done = false; + for (int ai = 0; ai < n_committed; ++ai) { + const llama_token tok = accepted[ai]; + + slot.n_decoded += 1; + + completion_token_output result; + result.tok = tok; + result.text_to_send = common_token_to_piece(ctx, tok, accept_special_token(slot, tok)); + result.prob = 1.0f; + + // update sampler history so repetition penalties remain correct + common_sampler_accept(slot.smpl.get(), tok, true); + + // track position in prompt token list + slot.prompt.tokens.push_back(tok); + + if (!process_token(result, slot)) { + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + slot_done = true; + break; + } + } + + if (!slot_done) { + slot.ddtree_root_tok = accepted.back(); // bonus = next root + slot.ddtree_committed_pos += (llama_pos)n_committed; + slot.n_draft_total += params_base.speculative.ddtree_budget; + slot.n_draft_accepted += n_committed - 1; // root was not a draft, rest were + } + } + } + SRV_DBG("%s", "run slots completed\n"); }