Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions dflash/scripts/prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,9 @@ def __init__(self, *, daemon_stdin, await_reply, daemon_lock,

self.entries: OrderedDict[bytes, int] = OrderedDict() # hash → slot_id
self.next_slot = 0
# Cumulative hit counter, never decremented. Survives eviction —
# unlike a sum over surviving entries, which trends down under churn.
self._lifetime_hits = 0
try:
self.markers = _resolve_chat_markers(tokenizer)
except ValueError as e:
Expand Down Expand Up @@ -537,10 +540,22 @@ def lookup(self, prompt_ids: list[int]) -> tuple[int, int] | None:
best = (self.entries[key], cut)
self.entries.move_to_end(key) # mark fresh
if best is not None:
self._lifetime_hits += 1
print(f"{self.log_prefix} lookup hit slot={best[0]} prefix_len={best[1]} "
f"(of {len(prompt_ids)} total)", flush=True)
return best

def stats(self) -> dict:
"""Snapshot for /props. Lockless: a mutation under daemon_lock can
tear in_use vs lifetime_hits; acceptable for an introspection report."""
if self.disabled:
return {"capacity": 0, "in_use": 0, "lifetime_hits": 0}
return {
"capacity": self.cap,
"in_use": len(self.entries),
"lifetime_hits": self._lifetime_hits,
}

def mark_all_cleared(self) -> None:
"""Drop every LRU entry after the daemon emits ``[snap] all-cleared``.

Expand Down Expand Up @@ -715,6 +730,11 @@ def init_full_cache(self, full_cap: int,
# Pending eviction: the LRU entry reserved for the next confirm.
self._full_pending_evict_key: bytes | None = None
self._full_pending_evict_path: str | None = None
# Cumulative hit + disk-usage snapshots for /props. Recomputed on
# every cache mutation so the introspection endpoint never has to
# walk the filesystem.
self._full_lifetime_hits = 0
self._full_disk_bytes_snapshot = 0

cache_dir_path = Path(cache_dir) if cache_dir else Path("/tmp/prefix")
cache_dir_path.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -776,6 +796,36 @@ def _full_entry_artifact_size(self, key: bytes, entry: FullCacheEntry) -> int:
continue
return total

def _recompute_full_disk_bytes_snapshot(self) -> None:
"""Refresh the disk-usage snapshot used by /props. Called from every
full-cache mutation site so reads don't have to stat the filesystem."""
if getattr(self, "_full_disabled", True):
self._full_disk_bytes_snapshot = 0
return
self._full_disk_bytes_snapshot = sum(
self._full_entry_artifact_size(k, e)
for k, e in self.full_entries.items()
)

def full_stats(self) -> dict:
"""Snapshot for /props. Reads cached disk-usage; never walks the
filesystem on a /props request."""
if getattr(self, "_full_disabled", True):
return {
"enabled": False,
"capacity": 0,
"in_use": 0,
"disk_bytes": 0,
"lifetime_hits": 0,
}
return {
"enabled": True,
"capacity": self._full_cap,
"in_use": len(self.full_entries),
"disk_bytes": self._full_disk_bytes_snapshot,
"lifetime_hits": self._full_lifetime_hits,
}

@staticmethod
def _read_full_meta_int(meta: dict, key: str, *, default: int | None = None) -> int | None:
value = meta.get(key, default)
Expand Down Expand Up @@ -858,6 +908,7 @@ def _retire_full_entry(self, key: bytes, entry: FullCacheEntry,
pass
self._drop_full_metadata(key)
self._recompute_full_next_slot()
self._recompute_full_disk_bytes_snapshot()

def _enforce_full_budget(self, live_prompt_ids: list[int] | None = None) -> None:
budget = int(getattr(self, "_full_budget_bytes", 0) or 0)
Expand Down Expand Up @@ -992,6 +1043,7 @@ async def rehydrate_full_cache(self, replay_entry) -> int:
break

self._recompute_full_next_slot()
self._recompute_full_disk_bytes_snapshot()
if restored:
print(f"{self.log_prefix} full-cache restored {restored} entries "
f"from disk", flush=True)
Expand All @@ -1017,8 +1069,10 @@ def lookup_full(self, prompt_ids: list[int]) -> tuple[int, str, int] | None:
if not Path(cur_bin_path).exists():
self.full_entries.pop(key, None)
self._drop_full_metadata(key)
self._recompute_full_disk_bytes_snapshot()
return None
entry.hits += 1
self._full_lifetime_hits += 1
entry.last_used_ns = time.time_ns()
self.full_entries.move_to_end(key) # mark fresh in LRU
self._persist_full_metadata(key, entry)
Expand Down Expand Up @@ -1108,6 +1162,10 @@ def confirm_full_snap(self, slot: int, prompt_ids: list[int],
self.full_entries[key] = entry
self._persist_full_metadata(key, entry)
self._enforce_full_budget(prompt_ids)
# _enforce_full_budget may call _retire_full_entry which refreshes
# the snapshot, but it bails early when budget==0. Refresh here
# unconditionally so the new entry's bytes are always reflected.
self._recompute_full_disk_bytes_snapshot()
print(f"{self.log_prefix} full-cache committed slot={slot} "
f"cur_ids_len={cur_ids_len} key={key.hex()[:8]}", flush=True)

Expand Down
Loading
Loading