-
Notifications
You must be signed in to change notification settings - Fork 221
feat(pflash): ee7 early-exit drafter + anchor-transitive cascade + regime router #338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # anchor transitive scan | ||
|
|
||
| `scan_and_force_transitive` (anchor_scan.cpp) expands the query pool with | ||
| tokens from newly-forced chunks and re-runs `scan_and_force` until fixed | ||
| point or max_iters (default 3) is reached. | ||
|
|
||
| Improves multi-hop retrieval: enables discovery of intermediate context | ||
| chunks whose tokens do not appear in the original query but connect | ||
| query-to-needle via shared rare tokens. | ||
|
|
||
| Empirical result: F1=0.628 on LongBench HotpotQA at ee7 + keep=0.15 | ||
| (vs uncompressed F1=0.697). This is the ceiling for attention-score-based | ||
| prefill compression on this task; see bench/2026-05-25_longbench_hotpotqa/. | ||
|
|
||
| On by default. Disable via PFLASH_COMPRESS_ANCHOR_TRANSITIVE=0. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| # pflash compression knobs | ||
|
|
||
| All PFLASH_COMPRESS_* and DFLASH_COMPRESS_* env vars are read once per | ||
| request in `compress_cfg_from_env(n_chunks, n_keep)` in qwen3_drafter.cpp. | ||
|
|
||
| ## anchor_radius adaptive ladder | ||
|
|
||
| Prevents the 64K NIAH cliff: at long context the needle text is more likely | ||
| to straddle multiple chunks, and a fixed radius=2 window (5 chunks / ~160 | ||
| tokens) loses the back half of the needle. | ||
|
|
||
| Default ladder (override via PFLASH_COMPRESS_ANCHOR_RADIUS): | ||
|
|
||
| | n_chunks | anchor_radius | | ||
| |------------|---------------| | ||
| | < 1024 | 2 | | ||
| | 1024-2047 | 4 | | ||
| | >= 2048 | 8 | | ||
|
|
||
| ## max_anchor_hits adaptive ladder | ||
|
|
||
| Same breakpoints as anchor_radius. At long context anchors are sparser, so | ||
| more hits per query token are affordable. | ||
|
|
||
| | n_chunks | max_anchor_hits | | ||
| |------------|-----------------| | ||
| | < 1024 | 8 | | ||
| | 1024-2047 | 16 | | ||
| | >= 2048 | 32 | | ||
|
|
||
| ## anchor_transitive | ||
|
|
||
| On by default. Gated rare-token bridge expands the query pool with tokens | ||
| from newly-forced chunks and re-runs anchor scan to fixed point. | ||
| Improves multi-hop F1 on LongBench HotpotQA (empirically; F1=0.628 ceiling | ||
| at ee7+anchor-transitive on RTX 3090 — see bench/2026-05-25_longbench_hotpotqa/). | ||
| Control via PFLASH_COMPRESS_ANCHOR_TRANSITIVE=0 to disable. | ||
|
|
||
| ## head/tail chunk forcing | ||
|
|
||
| Head and tail chunks are force-included before top-K scoring fills the | ||
| remainder. The counts scale with n_keep so top-K always gets at least one | ||
| slot even when head_raw + tail_raw >= n_keep. | ||
|
|
||
| Defaults: head=8, tail=24 (override via DFLASH_COMPRESS_HEAD_CHUNKS / | ||
| DFLASH_COMPRESS_TAIL_CHUNKS). |
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,169 @@ | ||||||||
| #include "anchor_scan.h" | ||||||||
|
|
||||||||
| #include <algorithm> | ||||||||
| #include <cstdint> | ||||||||
| #include <unordered_map> | ||||||||
| #include <vector> | ||||||||
|
|
||||||||
| namespace dflash::qwen3 { | ||||||||
|
|
||||||||
| // Force chunk and its radius-neighborhood into `forced`. | ||||||||
| static void force_neighborhood(std::vector<uint8_t>& forced, int n_chunks, | ||||||||
| int chunk, int radius) { | ||||||||
| int lo = std::max(0, chunk - radius); | ||||||||
| int hi = std::min(n_chunks - 1, chunk + radius); | ||||||||
| for (int c = lo; c <= hi; ++c) forced[(size_t)c] = 1; | ||||||||
| } | ||||||||
|
|
||||||||
| void scan_and_force( | ||||||||
| const std::vector<int32_t>& ids, | ||||||||
| int body_end, | ||||||||
| const std::vector<int32_t>& query_pool, | ||||||||
| const AnchorScanCfg& cfg, | ||||||||
| std::vector<uint8_t>& forced) | ||||||||
| { | ||||||||
| const int n_chunks = (int)forced.size(); | ||||||||
| const int ngram = cfg.ngram; | ||||||||
| const int search_end = std::max(0, body_end - ngram); | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P2: Anchor scan incorrectly searches one position when Prompt for AI agents
Suggested change
|
||||||||
|
|
||||||||
| for (int qi = 0; qi + ngram <= (int)query_pool.size(); ++qi) { | ||||||||
| int hits = 0; | ||||||||
| int hit_pos[8]; | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P1: Prompt for AI agents |
||||||||
| for (int p = 0; p <= search_end && hits <= cfg.max_anchor_hits; ++p) { | ||||||||
| bool same = true; | ||||||||
| for (int k = 0; k < ngram; ++k) { | ||||||||
| if (ids[(size_t)p + k] != query_pool[(size_t)qi + k]) { | ||||||||
| same = false; | ||||||||
| break; | ||||||||
| } | ||||||||
| } | ||||||||
| if (same) { | ||||||||
| if (hits < 8) hit_pos[hits] = p; | ||||||||
| ++hits; | ||||||||
| } | ||||||||
| } | ||||||||
| if (hits > 0 && hits <= cfg.max_anchor_hits) { | ||||||||
| for (int i = 0; i < hits && i < 8; ++i) { | ||||||||
| force_neighborhood(forced, n_chunks, | ||||||||
| hit_pos[i] / cfg.chunk_size, | ||||||||
| cfg.anchor_radius); | ||||||||
| } | ||||||||
| } | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| // Helper: count set entries in forced. | ||||||||
| static int count_set(const std::vector<uint8_t>& forced) { | ||||||||
| int n = 0; | ||||||||
| for (uint8_t v : forced) n += (v != 0); | ||||||||
| return n; | ||||||||
| } | ||||||||
|
|
||||||||
| void scan_and_force_transitive( | ||||||||
| const std::vector<int32_t>& ids, | ||||||||
| int body_end, | ||||||||
| const std::vector<int32_t>& initial_query_pool, | ||||||||
| const AnchorScanCfg& cfg, | ||||||||
| int max_iters, | ||||||||
| std::vector<uint8_t>& forced) | ||||||||
| { | ||||||||
| auto pool = initial_query_pool; | ||||||||
| const int n_chunks = (int)forced.size(); | ||||||||
|
|
||||||||
| // Precompute token frequencies in body once. | ||||||||
| std::unordered_map<int32_t, int> body_freq; | ||||||||
| body_freq.reserve((size_t)body_end); | ||||||||
| for (int j = 0; j < body_end; ++j) ++body_freq[ids[(size_t)j]]; | ||||||||
|
|
||||||||
| // Build inverted index: token -> list of body positions (for rare tokens only). | ||||||||
| std::unordered_map<int32_t, std::vector<int>> rare_positions; | ||||||||
| if (cfg.rare_token_max_freq > 0) { | ||||||||
| for (auto& kv : body_freq) { | ||||||||
| if (kv.second <= cfg.rare_token_max_freq) { | ||||||||
| rare_positions[kv.first] = {}; | ||||||||
| } | ||||||||
| } | ||||||||
| for (int p = 0; p < body_end; ++p) { | ||||||||
| auto it = rare_positions.find(ids[(size_t)p]); | ||||||||
| if (it != rare_positions.end()) it->second.push_back(p); | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| // Pass-1: run the initial scan. | ||||||||
| const int count_before_pass1 = count_set(forced); | ||||||||
| scan_and_force(ids, body_end, pool, cfg, forced); | ||||||||
| const int gained_pass1 = count_set(forced) - count_before_pass1; | ||||||||
|
|
||||||||
| // Gating: if pass-1 already found many anchors, skip the cascade entirely. | ||||||||
| if (cfg.cascade_min_anchor_count > 0 && gained_pass1 >= cfg.cascade_min_anchor_count) { | ||||||||
| return; | ||||||||
| } | ||||||||
|
|
||||||||
| // Cascade loop: expand pool with newly-forced tokens and re-scan. | ||||||||
| std::vector<uint8_t> prev_forced; | ||||||||
| for (int it = 0; it < max_iters; ++it) { | ||||||||
| prev_forced = forced; | ||||||||
|
|
||||||||
| // Rare-token single-match: worklist-driven so cascades within a pass are | ||||||||
| // caught (e.g. hop3 forces hop2 which forces hop1 in one outer iteration). | ||||||||
| if (cfg.rare_token_max_freq > 0) { | ||||||||
| std::vector<int> worklist; | ||||||||
| for (int c = 0; c < n_chunks; ++c) { | ||||||||
| if (forced[c] && !prev_forced[c]) worklist.push_back(c); | ||||||||
| } | ||||||||
| // On first iteration, seed from everything forced so far (pass-1 results). | ||||||||
| if (it == 0) { | ||||||||
| worklist.clear(); | ||||||||
| for (int c = 0; c < n_chunks; ++c) { | ||||||||
| if (forced[c]) worklist.push_back(c); | ||||||||
| } | ||||||||
| } | ||||||||
| for (int wi = 0; wi < (int)worklist.size(); ++wi) { | ||||||||
| int c = worklist[wi]; | ||||||||
| int s = c * cfg.chunk_size; | ||||||||
| int e = std::min(body_end, (c + 1) * cfg.chunk_size); | ||||||||
| for (int j = s; j < e; ++j) { | ||||||||
| auto it2 = rare_positions.find(ids[(size_t)j]); | ||||||||
| if (it2 == rare_positions.end()) continue; | ||||||||
| for (int p : it2->second) { | ||||||||
| int target_c = p / cfg.chunk_size; | ||||||||
| if (!forced[(size_t)target_c]) { | ||||||||
| force_neighborhood(forced, n_chunks, | ||||||||
| target_c, cfg.anchor_radius); | ||||||||
| worklist.push_back(target_c); | ||||||||
| } | ||||||||
| } | ||||||||
| } | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| // Hard cap: if we exceeded max_forced_count, revert this iteration and stop. | ||||||||
| if (count_set(forced) > cfg.max_forced_count) { | ||||||||
| forced = prev_forced; | ||||||||
| break; | ||||||||
| } | ||||||||
|
|
||||||||
| if (forced == prev_forced) break; | ||||||||
|
|
||||||||
| // Expand pool with tokens from newly-forced chunks (feeds next 4-gram pass). | ||||||||
| for (int c = 0; c < n_chunks; ++c) { | ||||||||
| if (forced[c] && !prev_forced[c]) { | ||||||||
| int s = c * cfg.chunk_size; | ||||||||
| int e = std::min((int)ids.size(), (c + 1) * cfg.chunk_size); | ||||||||
| for (int j = s; j < e; ++j) pool.push_back(ids[j]); | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| // 4-gram scan with expanded pool for next iteration. | ||||||||
| prev_forced = forced; | ||||||||
| scan_and_force(ids, body_end, pool, cfg, forced); | ||||||||
|
|
||||||||
| // Hard cap check after 4-gram expansion too. | ||||||||
| if (count_set(forced) > cfg.max_forced_count) { | ||||||||
| forced = prev_forced; | ||||||||
| break; | ||||||||
| } | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| } // namespace dflash::qwen3 | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| // N-gram anchor scan: mark chunks forced by token-match between a query pool | ||
| // and the body of an ids sequence. Pure CPU, no GPU, no model required. | ||
| #pragma once | ||
|
|
||
| #include <climits> | ||
| #include <cstdint> | ||
| #include <vector> | ||
|
|
||
| namespace dflash::qwen3 { | ||
|
|
||
| struct AnchorScanCfg { | ||
| int chunk_size; | ||
| int anchor_radius; | ||
| int max_anchor_hits; | ||
| int ngram = 4; | ||
| int rare_token_max_freq = 8; // tokens appearing <= this many times in body count as rare | ||
| int cascade_min_anchor_count = 0; // skip cascade if pass-1 forced >= this many chunks (0 = always cascade) | ||
| int max_forced_count = INT_MAX; // hard cap on total forced chunks | ||
| }; | ||
|
|
||
| // Marks chunks forced by ngram-matches between query_pool and ids[0..body_end). | ||
| // `forced` is in-out; new hits are OR-merged. Idempotent. | ||
| void scan_and_force( | ||
| const std::vector<int32_t>& ids, | ||
| int body_end, | ||
| const std::vector<int32_t>& query_pool, | ||
| const AnchorScanCfg& cfg, | ||
| std::vector<uint8_t>& forced | ||
| ); | ||
|
|
||
| // Transitive variant: expands the query pool with tokens from newly-forced | ||
| // chunks and re-runs scan_and_force until a fixed point or max_iters reached. | ||
| void scan_and_force_transitive( | ||
| const std::vector<int32_t>& ids, | ||
| int body_end, | ||
| const std::vector<int32_t>& initial_query_pool, | ||
| const AnchorScanCfg& cfg, | ||
| int max_iters, | ||
| std::vector<uint8_t>& forced | ||
| ); | ||
|
|
||
| } // namespace dflash::qwen3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P3: Documentation references
bench/2026-05-25_longbench_hotpotqa/but thebench/directory does not exist anywhere in this codebase, so the reference is unresolvable.Prompt for AI agents