diff --git a/ROADMAP.md b/ROADMAP.md index 8834166..1fd88dc 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -828,7 +828,7 @@ below shift focus to **framework-level deep diagnostics** — capabilities that beyond API-level black-box comparison and provide the kind of insight you can't get from a Python script calling `/v1/chat/completions`. -## M87: Automatic KV Cache Export from vLLM +## M87: Automatic KV Cache Export from vLLM ✅ - The biggest gap in the current toolchain: `check-kv` requires pre-existing `.npz` dumps, but extracting KV cache from a running vLLM instance is the hardest part of the workflow - Provide a vLLM plugin / monkey-patch that intercepts the KV cache at configurable points: - After prefill completes (before KV transfer) @@ -854,7 +854,7 @@ get from a Python script calling `/v1/chat/completions`. - Initial target: vLLM (most common PD disaggregation framework) - Stretch: SGLang, TensorRT-LLM -## M89: PD Topology-Aware Testing +## M89: PD Topology-Aware Testing ✅ - Current tools treat the endpoint as a black box — send request, get response - In real PD deployments behind xPyD-proxy, there are multiple prefill and decode nodes - Topology-aware mode: @@ -878,3 +878,16 @@ get from a Python script calling `/v1/chat/completions`. - `xpyd-acc diagnose --hw-profile a100-bf16-tp4` uses the matching baseline for comparison - Transforms raw numbers into actionable verdicts: "expected hardware variance" vs "likely software bug" - Community-contributed baselines: users can submit anonymized precision profiles to build the database + +## M91: Smart Retry for Divergent Samples +- After initial batch comparison, automatically retry divergent samples with deterministic settings (temperature=0, seed=42) +- Classifies divergence as: `deterministic` (reproduces under greedy decoding) or `stochastic` (disappears with greedy) +- `batch-compare --smart-retry` flag triggers automatic rerun of divergent samples +- `xpyd-acc smart-retry --report --baseline --target ` standalone command +- `SmartRetryResult` dataclass: original_divergent, deterministic_count, stochastic_count, results per sample +- Stochastic divergences downgraded in severity (likely sampling noise, not a bug) +- JSON export with per-sample retry details +- Rich terminal output with deterministic vs stochastic breakdown +- Integrates with existing `--fail-threshold`: only deterministic divergences count toward threshold +- `smart_retry.py` module: `run_smart_retry()`, `SmartRetryResult`, `format_smart_retry()` +- Tests covering retry logic, classification, integration, JSON export, CLI diff --git a/bot/iterations/current.md b/bot/iterations/current.md index bf3710e..7beaf2b 100644 --- a/bot/iterations/current.md +++ b/bot/iterations/current.md @@ -54,4 +54,5 @@ shell for exploratory comparison of two endpoints. | M87 | 2026-04-06 | Automatic KV Cache Export from vLLM | ✅ merged | Both approved | | M88 | 2026-04-06 | Framework-Level Inference Hooks | ✅ merged | Both approved | | M89 | 2026-04-06 | PD Topology-Aware Testing | ✅ merged | Both approved | -| M90 | 2026-04-06 | Hardware Precision Baseline Library | ⏳ pending review | — | +| M90 | 2026-04-06 | Hardware Precision Baseline Library | ✅ merged | Both approved | +| M91 | 2026-04-06 | Smart Retry for Divergent Samples | ⏳ pending review | — | diff --git a/src/xpyd_acc/cli/__init__.py b/src/xpyd_acc/cli/__init__.py index 2057105..a0203f6 100644 --- a/src/xpyd_acc/cli/__init__.py +++ b/src/xpyd_acc/cli/__init__.py @@ -18,6 +18,7 @@ _run_latency_regression, _run_length_bias, _run_sensitivity, + _run_smart_retry, _run_watch, handle_baseline_db, handle_capture_kv, @@ -147,6 +148,7 @@ def main(argv: list[str] | None = None) -> None: "compare-files": lambda: _run_file_compare(args), "topology-scan": lambda: handle_topology_scan(args), "baseline-db": lambda: handle_baseline_db(args), + "smart-retry": lambda: _run_smart_retry(args), } if args.command in _early: diff --git a/src/xpyd_acc/cli/analysis.py b/src/xpyd_acc/cli/analysis.py index a0fc211..baede3c 100644 --- a/src/xpyd_acc/cli/analysis.py +++ b/src/xpyd_acc/cli/analysis.py @@ -596,3 +596,36 @@ def handle_baseline_db(args: argparse.Namespace) -> None: file=sys.stderr, ) raise SystemExit(1) + + +def _run_smart_retry(args: argparse.Namespace) -> None: + """Handle the smart-retry subcommand.""" + from pathlib import Path + + from xpyd_acc.batch_compare import load_report + from xpyd_acc.smart_retry import format_smart_retry, run_smart_retry + + report = load_report(args.report) + result = asyncio.run( + run_smart_retry( + report, + args.baseline, + args.target, + model=args.model, + max_tokens=args.max_tokens, + api_key=args.api_key, + retries=args.retries, + retry_delay=args.retry_delay, + timeout=args.timeout, + skip_validation=args.skip_validation, + ) + ) + + print(format_smart_retry(result)) + + if args.json_path: + Path(args.json_path).write_text(result.to_json() + "\n") + print(f"\nResults exported to {args.json_path}") + + if result.deterministic_count > 0: + raise SystemExit(1) diff --git a/src/xpyd_acc/cli/parsers.py b/src/xpyd_acc/cli/parsers.py index dc4414e..57bdd54 100644 --- a/src/xpyd_acc/cli/parsers.py +++ b/src/xpyd_acc/cli/parsers.py @@ -57,6 +57,7 @@ def register_all(sub: argparse._SubParsersAction) -> None: _register_trace(sub) _register_topology_scan(sub) _register_baseline_db(sub) + _register_smart_retry(sub) def _register_compare(sub): lp = sub.add_parser("compare-logprobs", help="Compare logprobs between two endpoints") lp.add_argument("--baseline", required=True, help="Baseline endpoint URL") @@ -799,3 +800,18 @@ def _register_baseline_db(sub): help="Observed cosine similarity", ) classify_p.add_argument("--json", default=None, help="Export classification as JSON") + + +def _register_smart_retry(sub): + p = sub.add_parser("smart-retry", help="Retry divergent samples with greedy settings") + p.add_argument("--report", required=True, help="Path to batch report JSON") + p.add_argument("--baseline", required=True, help="Baseline endpoint URL") + p.add_argument("--target", required=True, help="Target endpoint URL") + p.add_argument("--model", default="default", help="Model identifier") + p.add_argument("--max-tokens", type=int, default=64, help="Max tokens per request") + p.add_argument("--api-key", default="no-key", help="API key") + p.add_argument("--retries", type=int, default=3, help="HTTP retry count") + p.add_argument("--retry-delay", type=float, default=1.0, help="Base retry delay") + p.add_argument("--timeout", type=float, default=120.0, help="HTTP timeout seconds") + p.add_argument("--json", default=None, dest="json_path", help="Export results as JSON") + p.add_argument("--skip-validation", action="store_true", help="Skip response validation") diff --git a/src/xpyd_acc/smart_retry.py b/src/xpyd_acc/smart_retry.py new file mode 100644 index 0000000..8821118 --- /dev/null +++ b/src/xpyd_acc/smart_retry.py @@ -0,0 +1,213 @@ +"""Smart retry for divergent samples — classify as deterministic vs stochastic.""" + +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from typing import Any, Callable + +from .batch_compare import ( + BatchReport, + DatasetSample, + SampleResult, + run_batch, +) +from .log import get_logger + +logger = get_logger(__name__) + + +@dataclass +class SampleRetryResult: + """Result of retrying a single divergent sample with greedy settings.""" + + sample_id: str + original_classification: str + retry_match: bool + retry_classification: str # "deterministic" or "stochastic" + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict.""" + return asdict(self) + + +@dataclass +class SmartRetryResult: + """Aggregate result of smart retry for all divergent samples.""" + + original_divergent: int + deterministic_count: int + stochastic_count: int + deterministic_rate: float + stochastic_rate: float + per_sample: list[SampleRetryResult] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict.""" + return { + "original_divergent": self.original_divergent, + "deterministic_count": self.deterministic_count, + "stochastic_count": self.stochastic_count, + "deterministic_rate": self.deterministic_rate, + "stochastic_rate": self.stochastic_rate, + "per_sample": [s.to_dict() for s in self.per_sample], + } + + def to_json(self) -> str: + """Serialize to JSON string.""" + return json.dumps(self.to_dict(), indent=2) + + +async def run_smart_retry( + report: BatchReport, + baseline_url: str, + target_url: str, + *, + model: str = "default", + max_tokens: int = 64, + api_key: str = "no-key", + retries: int = 3, + retry_delay: float = 1.0, + timeout: float = 120.0, + skip_validation: bool = False, + custom_headers: dict[str, str] | None = None, + on_progress: Callable[[int, int], None] | None = None, +) -> SmartRetryResult: + """Retry divergent samples with deterministic settings (temperature=0, seed=42). + + Classifies each divergence as: + - ``deterministic``: still diverges under greedy decoding → likely a real bug + - ``stochastic``: matches under greedy decoding → likely sampling noise + + Args: + report: The original batch report containing divergent results. + baseline_url: Baseline endpoint URL. + target_url: Target endpoint URL. + model: Model identifier. + max_tokens: Maximum tokens per request. + api_key: API key for authentication. + retries: Number of HTTP retries. + retry_delay: Base delay between retries. + timeout: HTTP request timeout in seconds. + skip_validation: Skip response schema validation. + custom_headers: Optional custom HTTP headers. + on_progress: Progress callback (completed, total). + + Returns: + SmartRetryResult with per-sample classification. + """ + divergent = [r for r in report.results if r.is_divergent()] + if not divergent: + return SmartRetryResult( + original_divergent=0, + deterministic_count=0, + stochastic_count=0, + deterministic_rate=0.0, + stochastic_rate=0.0, + ) + + # Build dataset samples from divergent results + samples = [ + DatasetSample(id=r.sample_id, prompt=r.prompt) + for r in divergent + ] + + logger.info( + "Smart retry: re-running %d divergent samples with greedy settings", + len(samples), + ) + + # Use a simple SamplingParams-like object for greedy decoding + from dataclasses import dataclass as _dc + + @_dc + class _GreedyParams: + temperature: float = 0.0 + top_p: float | None = None + seed: int = 42 + + retry_report = await run_batch( + samples, + baseline_url, + target_url, + model=model, + max_tokens=max_tokens, + api_key=api_key, + retries=retries, + retry_delay=retry_delay, + timeout=timeout, + skip_validation=skip_validation, + custom_headers=custom_headers, + on_progress=on_progress, + sampling_params=_GreedyParams(), + concurrency=5, + ) + + # Build result map + retry_map: dict[str, SampleResult] = { + r.sample_id: r for r in retry_report.results + } + + per_sample: list[SampleRetryResult] = [] + deterministic = 0 + stochastic = 0 + + for orig in divergent: + retry_result = retry_map.get(orig.sample_id) + if retry_result is None: + # Should not happen, but treat as deterministic (safe default) + classification = "deterministic" + retry_match = False + else: + retry_match = retry_result.exact_match + classification = "stochastic" if retry_match else "deterministic" + + if classification == "deterministic": + deterministic += 1 + else: + stochastic += 1 + + per_sample.append( + SampleRetryResult( + sample_id=orig.sample_id, + original_classification=orig.classification, + retry_match=retry_match, + retry_classification=classification, + ) + ) + + total = len(divergent) + return SmartRetryResult( + original_divergent=total, + deterministic_count=deterministic, + stochastic_count=stochastic, + deterministic_rate=deterministic / total if total else 0.0, + stochastic_rate=stochastic / total if total else 0.0, + per_sample=per_sample, + ) + + +def format_smart_retry(result: SmartRetryResult) -> str: + """Format smart retry result for terminal display.""" + lines = [ + "Smart Retry Results", + "=" * 40, + f"Original divergent samples: {result.original_divergent}", + f"Deterministic (real bugs): {result.deterministic_count}" + f" ({result.deterministic_rate:.1%})", + f"Stochastic (sampling noise): {result.stochastic_count}" + f" ({result.stochastic_rate:.1%})", + "", + ] + + if result.per_sample: + lines.append("Per-Sample Breakdown:") + lines.append(f"{'Sample ID':<30} {'Original':<20} {'Retry':<15}") + lines.append("-" * 65) + for s in result.per_sample: + lines.append( + f"{s.sample_id:<30} {s.original_classification:<20} " + f"{s.retry_classification:<15}" + ) + + return "\n".join(lines) diff --git a/tests/test_smart_retry.py b/tests/test_smart_retry.py new file mode 100644 index 0000000..cab626e --- /dev/null +++ b/tests/test_smart_retry.py @@ -0,0 +1,287 @@ +"""Tests for smart_retry module.""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import patch + +from xpyd_acc.batch_compare import BatchReport, SampleResult +from xpyd_acc.smart_retry import ( + SampleRetryResult, + SmartRetryResult, + format_smart_retry, + run_smart_retry, +) + + +def _make_result( + sample_id: str, + prompt: str = "test", + match: bool = True, + classification: str = "match", + divergence_index: int | None = None, + logprob_gap: float | None = None, +) -> SampleResult: + return SampleResult( + sample_id=sample_id, + prompt=prompt, + baseline_output="hello", + target_output="hello" if match else "world", + exact_match=match, + first_divergence_index=divergence_index, + baseline_logprob_at_divergence=None, + target_logprob_at_divergence=None, + logprob_gap=logprob_gap, + classification=classification, + context_length=10, + ) + + +def _make_report(results: list[SampleResult]) -> BatchReport: + divergent = [r for r in results if not r.exact_match] + total = len(results) + return BatchReport( + total_samples=total, + divergent_samples=len(divergent), + match_samples=total - len(divergent), + divergence_rate=len(divergent) / total if total else 0.0, + results=results, + ) + + +# --- SampleRetryResult tests --- + + +def test_sample_retry_result_to_dict(): + r = SampleRetryResult( + sample_id="s1", + original_classification="likely_bug", + retry_match=False, + retry_classification="deterministic", + ) + d = r.to_dict() + assert d["sample_id"] == "s1" + assert d["retry_classification"] == "deterministic" + assert d["retry_match"] is False + + +# --- SmartRetryResult tests --- + + +def test_smart_retry_result_to_dict(): + r = SmartRetryResult( + original_divergent=3, + deterministic_count=2, + stochastic_count=1, + deterministic_rate=2 / 3, + stochastic_rate=1 / 3, + per_sample=[ + SampleRetryResult("s1", "likely_bug", False, "deterministic"), + SampleRetryResult("s2", "likely_bug", False, "deterministic"), + SampleRetryResult("s3", "likely_uncertainty", True, "stochastic"), + ], + ) + d = r.to_dict() + assert d["original_divergent"] == 3 + assert d["deterministic_count"] == 2 + assert d["stochastic_count"] == 1 + assert len(d["per_sample"]) == 3 + + +def test_smart_retry_result_to_json(): + r = SmartRetryResult( + original_divergent=1, + deterministic_count=0, + stochastic_count=1, + deterministic_rate=0.0, + stochastic_rate=1.0, + per_sample=[ + SampleRetryResult("s1", "likely_bug", True, "stochastic"), + ], + ) + j = json.loads(r.to_json()) + assert j["stochastic_count"] == 1 + + +# --- run_smart_retry tests --- + + +def test_run_smart_retry_no_divergent(): + """No divergent samples → empty result.""" + report = _make_report([_make_result("s1", match=True)]) + result = asyncio.run( + run_smart_retry(report, "http://base", "http://target") + ) + assert result.original_divergent == 0 + assert result.deterministic_count == 0 + assert result.stochastic_count == 0 + + +@patch("xpyd_acc.smart_retry.run_batch") +def test_run_smart_retry_all_deterministic(mock_run_batch): + """All divergent samples still diverge under greedy → deterministic.""" + original = _make_report([ + _make_result("s1", match=False, classification="likely_bug"), + _make_result("s2", match=False, classification="likely_bug"), + ]) + + # Retry also diverges + retry_report = _make_report([ + _make_result("s1", match=False, classification="likely_bug"), + _make_result("s2", match=False, classification="likely_bug"), + ]) + mock_run_batch.return_value = retry_report + + result = asyncio.run( + run_smart_retry(original, "http://base", "http://target") + ) + assert result.original_divergent == 2 + assert result.deterministic_count == 2 + assert result.stochastic_count == 0 + assert result.deterministic_rate == 1.0 + + +@patch("xpyd_acc.smart_retry.run_batch") +def test_run_smart_retry_all_stochastic(mock_run_batch): + """All divergent samples match under greedy → stochastic.""" + original = _make_report([ + _make_result("s1", match=False, classification="likely_uncertainty"), + ]) + + retry_report = _make_report([ + _make_result("s1", match=True, classification="match"), + ]) + mock_run_batch.return_value = retry_report + + result = asyncio.run( + run_smart_retry(original, "http://base", "http://target") + ) + assert result.original_divergent == 1 + assert result.stochastic_count == 1 + assert result.deterministic_count == 0 + assert result.stochastic_rate == 1.0 + + +@patch("xpyd_acc.smart_retry.run_batch") +def test_run_smart_retry_mixed(mock_run_batch): + """Mix of deterministic and stochastic.""" + original = _make_report([ + _make_result("s1", match=False, classification="likely_bug"), + _make_result("s2", match=False, classification="likely_uncertainty"), + _make_result("s3", match=True), + ]) + + retry_report = _make_report([ + _make_result("s1", match=False, classification="likely_bug"), + _make_result("s2", match=True, classification="match"), + ]) + mock_run_batch.return_value = retry_report + + result = asyncio.run( + run_smart_retry(original, "http://base", "http://target") + ) + assert result.original_divergent == 2 + assert result.deterministic_count == 1 + assert result.stochastic_count == 1 + assert result.deterministic_rate == 0.5 + assert result.stochastic_rate == 0.5 + assert len(result.per_sample) == 2 + + +@patch("xpyd_acc.smart_retry.run_batch") +def test_run_smart_retry_greedy_params(mock_run_batch): + """Verify greedy sampling params are passed to run_batch.""" + original = _make_report([ + _make_result("s1", match=False, classification="likely_bug"), + ]) + retry_report = _make_report([ + _make_result("s1", match=False), + ]) + mock_run_batch.return_value = retry_report + + asyncio.run( + run_smart_retry(original, "http://base", "http://target") + ) + + call_kwargs = mock_run_batch.call_args[1] + sp = call_kwargs["sampling_params"] + assert sp.temperature == 0.0 + assert sp.seed == 42 + + +@patch("xpyd_acc.smart_retry.run_batch") +def test_run_smart_retry_progress_callback(mock_run_batch): + """Progress callback is forwarded to run_batch.""" + original = _make_report([ + _make_result("s1", match=False, classification="likely_bug"), + ]) + retry_report = _make_report([ + _make_result("s1", match=False), + ]) + mock_run_batch.return_value = retry_report + + progress_calls = [] + asyncio.run( + run_smart_retry( + original, "http://base", "http://target", + on_progress=lambda c, t: progress_calls.append((c, t)), + ) + ) + assert mock_run_batch.call_args[1]["on_progress"] is not None + + +# --- format_smart_retry tests --- + + +def test_format_smart_retry_basic(): + result = SmartRetryResult( + original_divergent=2, + deterministic_count=1, + stochastic_count=1, + deterministic_rate=0.5, + stochastic_rate=0.5, + per_sample=[ + SampleRetryResult("s1", "likely_bug", False, "deterministic"), + SampleRetryResult("s2", "likely_uncertainty", True, "stochastic"), + ], + ) + text = format_smart_retry(result) + assert "Smart Retry Results" in text + assert "Deterministic (real bugs)" in text + assert "Stochastic (sampling noise)" in text + assert "s1" in text + assert "s2" in text + assert "deterministic" in text + assert "stochastic" in text + + +def test_format_smart_retry_empty(): + result = SmartRetryResult( + original_divergent=0, + deterministic_count=0, + stochastic_count=0, + deterministic_rate=0.0, + stochastic_rate=0.0, + ) + text = format_smart_retry(result) + assert "Original divergent samples: 0" in text + + +def test_smart_retry_result_json_roundtrip(): + """JSON serialization round-trip.""" + r = SmartRetryResult( + original_divergent=2, + deterministic_count=1, + stochastic_count=1, + deterministic_rate=0.5, + stochastic_rate=0.5, + per_sample=[ + SampleRetryResult("s1", "likely_bug", False, "deterministic"), + SampleRetryResult("s2", "likely_uncertainty", True, "stochastic"), + ], + ) + j = json.loads(r.to_json()) + assert j["original_divergent"] == 2 + assert j["per_sample"][0]["retry_classification"] == "deterministic" + assert j["per_sample"][1]["retry_classification"] == "stochastic"