From f064a58b5118c0f47386aa1a66fa142d7a717947 Mon Sep 17 00:00:00 2001 From: hlin99 Date: Mon, 6 Apr 2026 17:09:58 +0800 Subject: [PATCH] feat(M88): Framework-Level Inference Hooks - InferenceHook protocol: on_prefill, on_kv_transfer, on_decode_step - HookCapture, StageComparison, TraceResult dataclasses - MockInferenceHook for testing with configurable noise - compare_captures() for field-by-field comparison (max/mean diff, cosine sim) - run_trace() orchestrates hook-based comparison across stages - format_trace() for rich terminal output with per-stage table - xpyd-acc trace CLI subcommand with --mock, --json, --hooks, --threshold - 34 tests covering protocol, captures, comparison, trace, formatting, CLI Closes #188 --- docs/iterations/current.md | 3 +- src/xpyd_acc/cli/__init__.py | 2 + src/xpyd_acc/cli/analysis.py | 55 +++++ src/xpyd_acc/cli/parsers.py | 37 +++ src/xpyd_acc/inference_hooks.py | 393 ++++++++++++++++++++++++++++++++ tests/test_inference_hooks.py | 345 ++++++++++++++++++++++++++++ 6 files changed, 834 insertions(+), 1 deletion(-) create mode 100644 src/xpyd_acc/inference_hooks.py create mode 100644 tests/test_inference_hooks.py diff --git a/docs/iterations/current.md b/docs/iterations/current.md index a8a6f32..b7530ff 100644 --- a/docs/iterations/current.md +++ b/docs/iterations/current.md @@ -51,4 +51,5 @@ shell for exploratory comparison of two endpoints. | M83 | 2026-04-06 | Divergence Heatmap by Token Position | ✅ merged | Both approved | | M84 | 2026-04-06 | Endpoint Response Time Regression Detection | ✅ merged | Both approved | | M85 | 2026-04-06 | Offline Mode — File-Based Comparison | ✅ merged | Both approved | -| M87 | 2026-04-06 | Automatic KV Cache Export from vLLM | ⏳ pending review | — | +| M87 | 2026-04-06 | Automatic KV Cache Export from vLLM | ✅ merged | Both approved | +| M88 | 2026-04-06 | Framework-Level Inference Hooks | ⏳ pending review | — | diff --git a/src/xpyd_acc/cli/__init__.py b/src/xpyd_acc/cli/__init__.py index 925c342..958b9e1 100644 --- a/src/xpyd_acc/cli/__init__.py +++ b/src/xpyd_acc/cli/__init__.py @@ -23,6 +23,7 @@ handle_heatmap, handle_root_cause, handle_token_diff, + handle_trace, ) from .batch import _run_batch_compare from .benchmark import ( @@ -129,6 +130,7 @@ def main(argv: list[str] | None = None) -> None: "token-diff": lambda: handle_token_diff(args), "heatmap": lambda: handle_heatmap(args), "capture-kv": lambda: handle_capture_kv(args), + "trace": lambda: handle_trace(args), "filter": lambda: _run_filter(args), "serve": lambda: _run_serve(args), "grafana-dashboard": lambda: _run_grafana_dashboard(args), diff --git a/src/xpyd_acc/cli/analysis.py b/src/xpyd_acc/cli/analysis.py index d94f3a1..53461e6 100644 --- a/src/xpyd_acc/cli/analysis.py +++ b/src/xpyd_acc/cli/analysis.py @@ -416,3 +416,58 @@ def handle_capture_kv(args: argparse.Namespace) -> None: with open(args.json, "w") as f: _json.dump(result.to_dict(), f, indent=2) print(f" Metadata exported to {args.json}") + + +def handle_trace(args: argparse.Namespace) -> None: + """Handle the trace CLI subcommand.""" + import json as _json + + from xpyd_acc.inference_hooks import ( + HookPoint, + MockInferenceHook, + format_trace, + run_trace, + ) + + hooks = [HookPoint(h.strip()) for h in args.hooks.split(",")] + + if args.mock: + baseline_hook = MockInferenceHook( + num_layers=args.num_layers, + noise_scale=0.0, + seed=42, + ) + target_hook = MockInferenceHook( + num_layers=args.num_layers, + noise_scale=args.noise_scale, + seed=42, + ) + else: + print( + "Live inference tracing requires framework-specific hooks.\n" + "Use --mock for testing, or see docs for vLLM/SGLang integration.", + file=sys.stderr, + ) + raise SystemExit(1) + + result = run_trace( + baseline_hook=baseline_hook, + target_hook=target_hook, + prompt=args.prompt, + baseline_url=args.baseline, + target_url=args.target, + hooks=hooks, + num_layers=args.num_layers, + decode_steps=args.decode_steps, + threshold=args.threshold, + ) + + print(format_trace(result)) + + if getattr(args, "json", None): + with open(args.json, "w") as f: + _json.dump(result.to_dict(), f, indent=2) + print(f"Trace exported to {args.json}") + + if result.overall_diverged: + raise SystemExit(1) diff --git a/src/xpyd_acc/cli/parsers.py b/src/xpyd_acc/cli/parsers.py index ec8bf3e..13a95d0 100644 --- a/src/xpyd_acc/cli/parsers.py +++ b/src/xpyd_acc/cli/parsers.py @@ -54,6 +54,7 @@ def register_all(sub: argparse._SubParsersAction) -> None: _register_heatmap(sub) _register_capture_kv(sub) _register_file_compare(sub) + _register_trace(sub) def _register_compare(sub): lp = sub.add_parser("compare-logprobs", help="Compare logprobs between two endpoints") lp.add_argument("--baseline", required=True, help="Baseline endpoint URL") @@ -695,3 +696,39 @@ def _register_file_compare(sub): "--numeric-tolerance", type=float, default=None, help="Numeric tolerance for matching", ) + + +def _register_trace(sub): + tr = sub.add_parser( + "trace", + help="Trace intermediate inference states between baseline and target", + ) + tr.add_argument("--baseline", required=True, help="Baseline endpoint URL") + tr.add_argument("--target", required=True, help="Target endpoint URL") + tr.add_argument("--prompt", required=True, help="Prompt text") + tr.add_argument( + "--hooks", + default="prefill,kv_transfer,decode_step", + help="Comma-separated hooks: prefill,kv_transfer,decode_step", + ) + tr.add_argument( + "--num-layers", type=int, default=4, + help="Number of layers to trace (default: 4)", + ) + tr.add_argument( + "--decode-steps", type=int, default=1, + help="Number of decode steps to trace (default: 1)", + ) + tr.add_argument( + "--threshold", type=float, default=1e-5, + help="Divergence threshold (default: 1e-5)", + ) + tr.add_argument( + "--mock", action="store_true", default=False, + help="Use mock hooks for testing", + ) + tr.add_argument( + "--noise-scale", type=float, default=0.0, + help="Noise scale for mock target hook (default: 0.0)", + ) + tr.add_argument("--json", default=None, help="Export trace result as JSON") diff --git a/src/xpyd_acc/inference_hooks.py b/src/xpyd_acc/inference_hooks.py new file mode 100644 index 0000000..8ba4cf2 --- /dev/null +++ b/src/xpyd_acc/inference_hooks.py @@ -0,0 +1,393 @@ +"""Framework-level inference hooks for capturing intermediate states. + +Provides a hook protocol for injecting into inference engines (vLLM, SGLang) +to capture hidden states, attention weights, KV cache, and logits at each +stage of the inference pipeline. Enables root-cause analysis of PD divergence +by comparing intermediate representations between aggregated and PD modes. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Protocol, runtime_checkable + +import numpy as np + + +class HookPoint(str, Enum): + """Points in the inference pipeline where hooks can fire.""" + + PREFILL = "prefill" + KV_TRANSFER = "kv_transfer" + DECODE_STEP = "decode_step" + + +@dataclass +class HookCapture: + """Data captured at a single hook point.""" + + hook_point: HookPoint + layer: int + step: int | None = None # decode step index (None for prefill/transfer) + hidden_states: np.ndarray | None = None + attention_weights: np.ndarray | None = None + logits: np.ndarray | None = None + kv_cache: np.ndarray | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict (arrays become shape + dtype info, not values).""" + d: dict[str, Any] = { + "hook_point": self.hook_point.value, + "layer": self.layer, + "step": self.step, + "metadata": self.metadata, + } + for name in ("hidden_states", "attention_weights", "logits", "kv_cache"): + arr = getattr(self, name) + if arr is not None: + d[name] = {"shape": list(arr.shape), "dtype": str(arr.dtype)} + else: + d[name] = None + return d + + +@dataclass +class StageComparison: + """Comparison result for a single stage between baseline and target.""" + + hook_point: HookPoint + layer: int + step: int | None = None + max_abs_diff: float = 0.0 + mean_abs_diff: float = 0.0 + cosine_similarity: float = 1.0 + field_name: str = "hidden_states" # which field was compared + diverged: bool = False + threshold: float = 1e-5 + + def to_dict(self) -> dict[str, Any]: + return { + "hook_point": self.hook_point.value, + "layer": self.layer, + "step": self.step, + "max_abs_diff": self.max_abs_diff, + "mean_abs_diff": self.mean_abs_diff, + "cosine_similarity": self.cosine_similarity, + "field_name": self.field_name, + "diverged": self.diverged, + "threshold": self.threshold, + } + + +@dataclass +class TraceResult: + """Full trace result with per-stage comparisons.""" + + prompt: str + baseline_url: str + target_url: str + hooks: list[HookPoint] + comparisons: list[StageComparison] = field(default_factory=list) + baseline_captures: list[HookCapture] = field(default_factory=list) + target_captures: list[HookCapture] = field(default_factory=list) + first_divergence: StageComparison | None = None + overall_diverged: bool = False + errors: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + d: dict[str, Any] = { + "prompt": self.prompt, + "baseline_url": self.baseline_url, + "target_url": self.target_url, + "hooks": [h.value for h in self.hooks], + "comparisons": [c.to_dict() for c in self.comparisons], + "baseline_captures": [c.to_dict() for c in self.baseline_captures], + "target_captures": [c.to_dict() for c in self.target_captures], + "first_divergence": ( + self.first_divergence.to_dict() + if self.first_divergence + else None + ), + "overall_diverged": self.overall_diverged, + "errors": self.errors, + } + return d + + +@runtime_checkable +class InferenceHook(Protocol): + """Protocol for inference hooks that capture intermediate states.""" + + def on_prefill(self, layer: int) -> HookCapture | None: + """Called after prefill for each layer. Return captured data or None.""" + ... + + def on_kv_transfer(self, layer: int) -> HookCapture | None: + """Called after KV transfer for each layer. Return captured data or None.""" + ... + + def on_decode_step(self, layer: int, step: int) -> HookCapture | None: + """Called at each decode step for each layer. Return captured data or None.""" + ... + + +class MockInferenceHook: + """Mock hook for testing — generates synthetic intermediate states.""" + + def __init__( + self, + num_layers: int = 4, + hidden_dim: int = 64, + seq_len: int = 8, + noise_scale: float = 0.0, + seed: int = 42, + ) -> None: + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.seq_len = seq_len + self.noise_scale = noise_scale + self._rng = np.random.default_rng(seed) + # Pre-generate base states for consistency + self._base_hidden = { + layer: self._rng.standard_normal((seq_len, hidden_dim)).astype(np.float32) + for layer in range(num_layers) + } + self._base_kv = { + layer: self._rng.standard_normal((2, seq_len, hidden_dim)).astype( + np.float32 + ) + for layer in range(num_layers) + } + + def _add_noise(self, arr: np.ndarray) -> np.ndarray: + if self.noise_scale == 0.0: + return arr.copy() + noise = self._rng.standard_normal(arr.shape).astype(arr.dtype) * self.noise_scale + return arr + noise + + def on_prefill(self, layer: int) -> HookCapture | None: + if layer >= self.num_layers: + return None + hs = self._add_noise(self._base_hidden[layer]) + kv = self._add_noise(self._base_kv[layer]) + return HookCapture( + hook_point=HookPoint.PREFILL, + layer=layer, + hidden_states=hs, + kv_cache=kv, + metadata={"num_layers": self.num_layers}, + ) + + def on_kv_transfer(self, layer: int) -> HookCapture | None: + if layer >= self.num_layers: + return None + kv = self._add_noise(self._base_kv[layer]) + return HookCapture( + hook_point=HookPoint.KV_TRANSFER, + layer=layer, + kv_cache=kv, + metadata={"num_layers": self.num_layers}, + ) + + def on_decode_step(self, layer: int, step: int) -> HookCapture | None: + if layer >= self.num_layers: + return None + hs = self._add_noise(self._base_hidden[layer]) + logits = self._add_noise( + self._rng.standard_normal((1, self.hidden_dim)).astype(np.float32) + ) + return HookCapture( + hook_point=HookPoint.DECODE_STEP, + layer=layer, + step=step, + hidden_states=hs, + logits=logits, + metadata={"num_layers": self.num_layers}, + ) + + +def _cosine_sim(a: np.ndarray, b: np.ndarray) -> float: + """Compute cosine similarity between two arrays (flattened).""" + a_flat = a.flatten().astype(np.float64) + b_flat = b.flatten().astype(np.float64) + dot = float(np.dot(a_flat, b_flat)) + norm_a = float(np.linalg.norm(a_flat)) + norm_b = float(np.linalg.norm(b_flat)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +def compare_captures( + baseline: HookCapture, + target: HookCapture, + threshold: float = 1e-5, +) -> list[StageComparison]: + """Compare two hook captures field by field.""" + results: list[StageComparison] = [] + for field_name in ("hidden_states", "attention_weights", "logits", "kv_cache"): + b_arr = getattr(baseline, field_name) + t_arr = getattr(target, field_name) + if b_arr is None or t_arr is None: + continue + if b_arr.shape != t_arr.shape: + results.append( + StageComparison( + hook_point=baseline.hook_point, + layer=baseline.layer, + step=baseline.step, + field_name=field_name, + max_abs_diff=float("inf"), + mean_abs_diff=float("inf"), + cosine_similarity=0.0, + diverged=True, + threshold=threshold, + ) + ) + continue + diff = np.abs(b_arr.astype(np.float64) - t_arr.astype(np.float64)) + max_diff = float(np.max(diff)) + mean_diff = float(np.mean(diff)) + cos_sim = _cosine_sim(b_arr, t_arr) + diverged = max_diff > threshold + results.append( + StageComparison( + hook_point=baseline.hook_point, + layer=baseline.layer, + step=baseline.step, + field_name=field_name, + max_abs_diff=max_diff, + mean_abs_diff=mean_diff, + cosine_similarity=cos_sim, + diverged=diverged, + threshold=threshold, + ) + ) + return results + + +def run_trace( + baseline_hook: InferenceHook, + target_hook: InferenceHook, + prompt: str, + baseline_url: str = "", + target_url: str = "", + hooks: list[HookPoint] | None = None, + num_layers: int = 4, + decode_steps: int = 1, + threshold: float = 1e-5, +) -> TraceResult: + """Run a full trace comparing baseline and target hooks. + + This is the synchronous version — real async version would call endpoints. + """ + if hooks is None: + hooks = [HookPoint.PREFILL, HookPoint.KV_TRANSFER, HookPoint.DECODE_STEP] + + result = TraceResult( + prompt=prompt, + baseline_url=baseline_url, + target_url=target_url, + hooks=hooks, + ) + + for hook_point in hooks: + for layer in range(num_layers): + b_capture: HookCapture | None = None + t_capture: HookCapture | None = None + + if hook_point == HookPoint.PREFILL: + b_capture = baseline_hook.on_prefill(layer) + t_capture = target_hook.on_prefill(layer) + elif hook_point == HookPoint.KV_TRANSFER: + b_capture = baseline_hook.on_kv_transfer(layer) + t_capture = target_hook.on_kv_transfer(layer) + elif hook_point == HookPoint.DECODE_STEP: + for step in range(decode_steps): + b_cap = baseline_hook.on_decode_step(layer, step) + t_cap = target_hook.on_decode_step(layer, step) + if b_cap: + result.baseline_captures.append(b_cap) + if t_cap: + result.target_captures.append(t_cap) + if b_cap and t_cap: + comps = compare_captures(b_cap, t_cap, threshold) + result.comparisons.extend(comps) + continue # decode handled in inner loop + + if b_capture: + result.baseline_captures.append(b_capture) + if t_capture: + result.target_captures.append(t_capture) + if b_capture and t_capture: + comps = compare_captures(b_capture, t_capture, threshold) + result.comparisons.extend(comps) + + # Find first divergence + for comp in result.comparisons: + if comp.diverged: + result.first_divergence = comp + result.overall_diverged = True + break + + return result + + +def format_trace(result: TraceResult) -> str: + """Format a trace result for terminal output.""" + lines: list[str] = [] + lines.append("=" * 60) + lines.append("Inference Trace Report") + lines.append("=" * 60) + lines.append(f"Prompt: {result.prompt[:80]}{'...' if len(result.prompt) > 80 else ''}") + if result.baseline_url: + lines.append(f"Baseline: {result.baseline_url}") + if result.target_url: + lines.append(f"Target: {result.target_url}") + lines.append(f"Hooks: {', '.join(h.value for h in result.hooks)}") + lines.append(f"Stages compared: {len(result.comparisons)}") + lines.append("") + + if result.errors: + lines.append("Errors:") + for e in result.errors: + lines.append(f" ❌ {e}") + lines.append("") + + verdict = "❌ DIVERGED" if result.overall_diverged else "✅ MATCH" + lines.append(f"Overall: {verdict}") + lines.append("") + + if result.first_divergence: + fd = result.first_divergence + lines.append("First divergence:") + lines.append(f" Stage: {fd.hook_point.value}") + lines.append(f" Layer: {fd.layer}") + if fd.step is not None: + lines.append(f" Step: {fd.step}") + lines.append(f" Field: {fd.field_name}") + lines.append(f" Max diff: {fd.max_abs_diff:.6e}") + lines.append(f" Cos sim: {fd.cosine_similarity:.6f}") + lines.append("") + + # Per-stage table + if result.comparisons: + lines.append("Stage Details:") + lines.append( + f" {'Stage':<14} {'Layer':>5} {'Step':>5} {'Field':<18} " + f"{'MaxDiff':>12} {'MeanDiff':>12} {'CosSim':>8} {'Status':>8}" + ) + lines.append(" " + "-" * 88) + for c in result.comparisons: + step_str = str(c.step) if c.step is not None else "-" + status = "❌ FAIL" if c.diverged else "✅ OK" + lines.append( + f" {c.hook_point.value:<14} {c.layer:>5} {step_str:>5} " + f"{c.field_name:<18} {c.max_abs_diff:>12.6e} " + f"{c.mean_abs_diff:>12.6e} {c.cosine_similarity:>8.4f} {status:>8}" + ) + + lines.append("=" * 60) + return "\n".join(lines) diff --git a/tests/test_inference_hooks.py b/tests/test_inference_hooks.py new file mode 100644 index 0000000..166f474 --- /dev/null +++ b/tests/test_inference_hooks.py @@ -0,0 +1,345 @@ +"""Tests for inference_hooks module.""" + +from __future__ import annotations + +import json + +import numpy as np + +from xpyd_acc.inference_hooks import ( + HookCapture, + HookPoint, + InferenceHook, + MockInferenceHook, + StageComparison, + TraceResult, + _cosine_sim, + compare_captures, + format_trace, + run_trace, +) + + +class TestHookPoint: + def test_enum_values(self): + assert HookPoint.PREFILL.value == "prefill" + assert HookPoint.KV_TRANSFER.value == "kv_transfer" + assert HookPoint.DECODE_STEP.value == "decode_step" + + +class TestHookCapture: + def test_to_dict_with_arrays(self): + cap = HookCapture( + hook_point=HookPoint.PREFILL, + layer=0, + hidden_states=np.zeros((4, 8), dtype=np.float32), + ) + d = cap.to_dict() + assert d["hook_point"] == "prefill" + assert d["layer"] == 0 + assert d["hidden_states"]["shape"] == [4, 8] + assert d["hidden_states"]["dtype"] == "float32" + assert d["logits"] is None + + def test_to_dict_no_arrays(self): + cap = HookCapture(hook_point=HookPoint.KV_TRANSFER, layer=2) + d = cap.to_dict() + assert d["hidden_states"] is None + assert d["kv_cache"] is None + + def test_step_field(self): + cap = HookCapture(hook_point=HookPoint.DECODE_STEP, layer=1, step=3) + assert cap.step == 3 + assert cap.to_dict()["step"] == 3 + + +class TestStageComparison: + def test_to_dict(self): + sc = StageComparison( + hook_point=HookPoint.PREFILL, + layer=0, + max_abs_diff=1e-6, + mean_abs_diff=5e-7, + cosine_similarity=0.9999, + field_name="hidden_states", + diverged=False, + ) + d = sc.to_dict() + assert d["hook_point"] == "prefill" + assert d["diverged"] is False + + def test_diverged_flag(self): + sc = StageComparison( + hook_point=HookPoint.PREFILL, layer=0, diverged=True + ) + assert sc.diverged is True + + +class TestTraceResult: + def test_to_dict_empty(self): + tr = TraceResult( + prompt="test", + baseline_url="http://a", + target_url="http://b", + hooks=[HookPoint.PREFILL], + ) + d = tr.to_dict() + assert d["prompt"] == "test" + assert d["hooks"] == ["prefill"] + assert d["comparisons"] == [] + assert d["first_divergence"] is None + assert d["overall_diverged"] is False + + def test_to_dict_with_comparisons(self): + sc = StageComparison( + hook_point=HookPoint.PREFILL, layer=0, diverged=True + ) + tr = TraceResult( + prompt="test", + baseline_url="", + target_url="", + hooks=[HookPoint.PREFILL], + comparisons=[sc], + first_divergence=sc, + overall_diverged=True, + ) + d = tr.to_dict() + assert d["overall_diverged"] is True + assert d["first_divergence"]["diverged"] is True + assert len(d["comparisons"]) == 1 + + +class TestMockInferenceHook: + def test_implements_protocol(self): + hook = MockInferenceHook() + assert isinstance(hook, InferenceHook) + + def test_prefill_returns_capture(self): + hook = MockInferenceHook(num_layers=4, hidden_dim=16, seq_len=4) + cap = hook.on_prefill(0) + assert cap is not None + assert cap.hook_point == HookPoint.PREFILL + assert cap.layer == 0 + assert cap.hidden_states is not None + assert cap.hidden_states.shape == (4, 16) + assert cap.kv_cache is not None + + def test_prefill_out_of_range(self): + hook = MockInferenceHook(num_layers=2) + assert hook.on_prefill(5) is None + + def test_kv_transfer(self): + hook = MockInferenceHook(num_layers=4, hidden_dim=16, seq_len=4) + cap = hook.on_kv_transfer(1) + assert cap is not None + assert cap.hook_point == HookPoint.KV_TRANSFER + assert cap.kv_cache is not None + + def test_decode_step(self): + hook = MockInferenceHook(num_layers=4, hidden_dim=16, seq_len=4) + cap = hook.on_decode_step(0, 0) + assert cap is not None + assert cap.hook_point == HookPoint.DECODE_STEP + assert cap.step == 0 + assert cap.logits is not None + + def test_noise_scale_zero_is_deterministic(self): + h1 = MockInferenceHook(seed=42, noise_scale=0.0) + h2 = MockInferenceHook(seed=42, noise_scale=0.0) + c1 = h1.on_prefill(0) + c2 = h2.on_prefill(0) + np.testing.assert_array_equal(c1.hidden_states, c2.hidden_states) + + def test_noise_scale_adds_variation(self): + h1 = MockInferenceHook(seed=42, noise_scale=0.0) + h2 = MockInferenceHook(seed=42, noise_scale=0.1) + c1 = h1.on_prefill(0) + c2 = h2.on_prefill(0) + # They share the same base but noise differs + assert not np.array_equal(c1.hidden_states, c2.hidden_states) + + +class TestCosineSim: + def test_identical(self): + a = np.array([1.0, 2.0, 3.0]) + assert abs(_cosine_sim(a, a) - 1.0) < 1e-10 + + def test_orthogonal(self): + a = np.array([1.0, 0.0]) + b = np.array([0.0, 1.0]) + assert abs(_cosine_sim(a, b)) < 1e-10 + + def test_zero_vector(self): + a = np.zeros(3) + b = np.array([1.0, 2.0, 3.0]) + assert _cosine_sim(a, b) == 0.0 + + +class TestCompareCaptures: + def test_identical_captures(self): + hs = np.ones((4, 8), dtype=np.float32) + c1 = HookCapture(hook_point=HookPoint.PREFILL, layer=0, hidden_states=hs) + c2 = HookCapture(hook_point=HookPoint.PREFILL, layer=0, hidden_states=hs.copy()) + results = compare_captures(c1, c2) + assert len(results) == 1 + assert not results[0].diverged + assert results[0].max_abs_diff == 0.0 + + def test_divergent_captures(self): + hs1 = np.zeros((4, 8), dtype=np.float32) + hs2 = np.ones((4, 8), dtype=np.float32) + c1 = HookCapture(hook_point=HookPoint.PREFILL, layer=0, hidden_states=hs1) + c2 = HookCapture(hook_point=HookPoint.PREFILL, layer=0, hidden_states=hs2) + results = compare_captures(c1, c2, threshold=0.5) + assert len(results) == 1 + assert results[0].diverged + assert results[0].max_abs_diff == 1.0 + + def test_shape_mismatch(self): + c1 = HookCapture( + hook_point=HookPoint.PREFILL, layer=0, + hidden_states=np.zeros((4, 8), dtype=np.float32), + ) + c2 = HookCapture( + hook_point=HookPoint.PREFILL, layer=0, + hidden_states=np.zeros((4, 16), dtype=np.float32), + ) + results = compare_captures(c1, c2) + assert len(results) == 1 + assert results[0].diverged + assert results[0].max_abs_diff == float("inf") + + def test_multiple_fields(self): + hs = np.ones((4, 8), dtype=np.float32) + kv = np.ones((2, 4, 8), dtype=np.float32) + c1 = HookCapture(hook_point=HookPoint.PREFILL, layer=0, hidden_states=hs, kv_cache=kv) + c2 = HookCapture( + hook_point=HookPoint.PREFILL, layer=0, + hidden_states=hs.copy(), kv_cache=kv.copy(), + ) + results = compare_captures(c1, c2) + assert len(results) == 2 # hidden_states + kv_cache + + def test_none_fields_skipped(self): + c1 = HookCapture(hook_point=HookPoint.PREFILL, layer=0) + c2 = HookCapture(hook_point=HookPoint.PREFILL, layer=0) + results = compare_captures(c1, c2) + assert len(results) == 0 + + +class TestRunTrace: + def test_no_divergence(self): + h1 = MockInferenceHook(seed=42, noise_scale=0.0) + h2 = MockInferenceHook(seed=42, noise_scale=0.0) + result = run_trace(h1, h2, prompt="test", num_layers=2, decode_steps=1) + assert not result.overall_diverged + assert result.first_divergence is None + assert len(result.comparisons) > 0 + + def test_with_divergence(self): + h1 = MockInferenceHook(seed=42, noise_scale=0.0) + h2 = MockInferenceHook(seed=42, noise_scale=1.0) + result = run_trace(h1, h2, prompt="test", num_layers=2, threshold=0.01) + assert result.overall_diverged + assert result.first_divergence is not None + + def test_selective_hooks(self): + h1 = MockInferenceHook(seed=42, noise_scale=0.0) + h2 = MockInferenceHook(seed=42, noise_scale=0.0) + result = run_trace( + h1, h2, prompt="test", + hooks=[HookPoint.PREFILL], + num_layers=2, + ) + # Only prefill comparisons + for c in result.comparisons: + assert c.hook_point == HookPoint.PREFILL + + def test_captures_stored(self): + h1 = MockInferenceHook(seed=42, num_layers=2) + h2 = MockInferenceHook(seed=42, num_layers=2) + result = run_trace(h1, h2, prompt="test", num_layers=2, decode_steps=1) + assert len(result.baseline_captures) > 0 + assert len(result.target_captures) > 0 + + def test_urls_stored(self): + h1 = MockInferenceHook() + h2 = MockInferenceHook() + result = run_trace( + h1, h2, prompt="test", + baseline_url="http://base", target_url="http://target", + ) + assert result.baseline_url == "http://base" + assert result.target_url == "http://target" + + +class TestFormatTrace: + def test_format_match(self): + h1 = MockInferenceHook(seed=42, noise_scale=0.0) + h2 = MockInferenceHook(seed=42, noise_scale=0.0) + result = run_trace(h1, h2, prompt="Hello world", num_layers=2) + output = format_trace(result) + assert "MATCH" in output + assert "Hello world" in output + assert "Stage Details:" in output + + def test_format_diverged(self): + h1 = MockInferenceHook(seed=42, noise_scale=0.0) + h2 = MockInferenceHook(seed=42, noise_scale=1.0) + result = run_trace(h1, h2, prompt="test", num_layers=2, threshold=0.01) + output = format_trace(result) + assert "DIVERGED" in output + assert "First divergence:" in output + + def test_format_with_errors(self): + result = TraceResult( + prompt="test", + baseline_url="", + target_url="", + hooks=[], + errors=["Connection failed"], + ) + output = format_trace(result) + assert "Connection failed" in output + + def test_format_long_prompt_truncated(self): + long_prompt = "x" * 200 + result = TraceResult( + prompt=long_prompt, + baseline_url="", + target_url="", + hooks=[], + ) + output = format_trace(result) + assert "..." in output + + +class TestJsonExport: + def test_trace_result_json_serializable(self): + h1 = MockInferenceHook(seed=42, noise_scale=0.0, num_layers=2) + h2 = MockInferenceHook(seed=42, noise_scale=0.1, num_layers=2) + result = run_trace(h1, h2, prompt="test", num_layers=2) + d = result.to_dict() + # Should be JSON-serializable + s = json.dumps(d) + loaded = json.loads(s) + assert loaded["prompt"] == "test" + assert isinstance(loaded["comparisons"], list) + + +class TestCLIIntegration: + def test_trace_parser_registered(self): + """Verify trace subcommand is registered in CLI parsers.""" + import argparse + + from xpyd_acc.cli.parsers import register_all + + parser = argparse.ArgumentParser() + sub = parser.add_subparsers(dest="command") + register_all(sub) + args = parser.parse_args([ + "trace", "--baseline", "http://a", "--target", "http://b", + "--prompt", "test", + ]) + assert args.command == "trace" + assert args.baseline == "http://a"