Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ below shift focus to **framework-level deep diagnostics** — capabilities that
beyond API-level black-box comparison and provide the kind of insight you can't
get from a Python script calling `/v1/chat/completions`.

## M87: Automatic KV Cache Export from vLLM
## M87: Automatic KV Cache Export from vLLM
- The biggest gap in the current toolchain: `check-kv` requires pre-existing `.npz` dumps, but extracting KV cache from a running vLLM instance is the hardest part of the workflow
- Provide a vLLM plugin / monkey-patch that intercepts the KV cache at configurable points:
- After prefill completes (before KV transfer)
Expand All @@ -854,7 +854,7 @@ get from a Python script calling `/v1/chat/completions`.
- Initial target: vLLM (most common PD disaggregation framework)
- Stretch: SGLang, TensorRT-LLM

## M89: PD Topology-Aware Testing
## M89: PD Topology-Aware Testing
- Current tools treat the endpoint as a black box — send request, get response
- In real PD deployments behind xPyD-proxy, there are multiple prefill and decode nodes
- Topology-aware mode:
Expand All @@ -878,3 +878,16 @@ get from a Python script calling `/v1/chat/completions`.
- `xpyd-acc diagnose --hw-profile a100-bf16-tp4` uses the matching baseline for comparison
- Transforms raw numbers into actionable verdicts: "expected hardware variance" vs "likely software bug"
- Community-contributed baselines: users can submit anonymized precision profiles to build the database

## M91: Smart Retry for Divergent Samples
- After initial batch comparison, automatically retry divergent samples with deterministic settings (temperature=0, seed=42)
- Classifies divergence as: `deterministic` (reproduces under greedy decoding) or `stochastic` (disappears with greedy)
- `batch-compare --smart-retry` flag triggers automatic rerun of divergent samples
- `xpyd-acc smart-retry --report <path> --baseline <url> --target <url>` standalone command
- `SmartRetryResult` dataclass: original_divergent, deterministic_count, stochastic_count, results per sample
- Stochastic divergences downgraded in severity (likely sampling noise, not a bug)
- JSON export with per-sample retry details
- Rich terminal output with deterministic vs stochastic breakdown
- Integrates with existing `--fail-threshold`: only deterministic divergences count toward threshold
- `smart_retry.py` module: `run_smart_retry()`, `SmartRetryResult`, `format_smart_retry()`
- Tests covering retry logic, classification, integration, JSON export, CLI
3 changes: 2 additions & 1 deletion bot/iterations/current.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ shell for exploratory comparison of two endpoints.
| M87 | 2026-04-06 | Automatic KV Cache Export from vLLM | ✅ merged | Both approved |
| M88 | 2026-04-06 | Framework-Level Inference Hooks | ✅ merged | Both approved |
| M89 | 2026-04-06 | PD Topology-Aware Testing | ✅ merged | Both approved |
| M90 | 2026-04-06 | Hardware Precision Baseline Library | ⏳ pending review | — |
| M90 | 2026-04-06 | Hardware Precision Baseline Library | ✅ merged | Both approved |
| M91 | 2026-04-06 | Smart Retry for Divergent Samples | ⏳ pending review | — |
2 changes: 2 additions & 0 deletions src/xpyd_acc/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_run_latency_regression,
_run_length_bias,
_run_sensitivity,
_run_smart_retry,
_run_watch,
handle_baseline_db,
handle_capture_kv,
Expand Down Expand Up @@ -147,6 +148,7 @@ def main(argv: list[str] | None = None) -> None:
"compare-files": lambda: _run_file_compare(args),
"topology-scan": lambda: handle_topology_scan(args),
"baseline-db": lambda: handle_baseline_db(args),
"smart-retry": lambda: _run_smart_retry(args),
}

if args.command in _early:
Expand Down
33 changes: 33 additions & 0 deletions src/xpyd_acc/cli/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,36 @@ def handle_baseline_db(args: argparse.Namespace) -> None:
file=sys.stderr,
)
raise SystemExit(1)


def _run_smart_retry(args: argparse.Namespace) -> None:
"""Handle the smart-retry subcommand."""
from pathlib import Path

from xpyd_acc.batch_compare import load_report
from xpyd_acc.smart_retry import format_smart_retry, run_smart_retry

report = load_report(args.report)
result = asyncio.run(
run_smart_retry(
report,
args.baseline,
args.target,
model=args.model,
max_tokens=args.max_tokens,
api_key=args.api_key,
retries=args.retries,
retry_delay=args.retry_delay,
timeout=args.timeout,
skip_validation=args.skip_validation,
)
)

print(format_smart_retry(result))

if args.json_path:
Path(args.json_path).write_text(result.to_json() + "\n")
print(f"\nResults exported to {args.json_path}")

if result.deterministic_count > 0:
raise SystemExit(1)
16 changes: 16 additions & 0 deletions src/xpyd_acc/cli/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def register_all(sub: argparse._SubParsersAction) -> None:
_register_trace(sub)
_register_topology_scan(sub)
_register_baseline_db(sub)
_register_smart_retry(sub)
def _register_compare(sub):
lp = sub.add_parser("compare-logprobs", help="Compare logprobs between two endpoints")
lp.add_argument("--baseline", required=True, help="Baseline endpoint URL")
Expand Down Expand Up @@ -799,3 +800,18 @@ def _register_baseline_db(sub):
help="Observed cosine similarity",
)
classify_p.add_argument("--json", default=None, help="Export classification as JSON")


def _register_smart_retry(sub):
p = sub.add_parser("smart-retry", help="Retry divergent samples with greedy settings")
p.add_argument("--report", required=True, help="Path to batch report JSON")
p.add_argument("--baseline", required=True, help="Baseline endpoint URL")
p.add_argument("--target", required=True, help="Target endpoint URL")
p.add_argument("--model", default="default", help="Model identifier")
p.add_argument("--max-tokens", type=int, default=64, help="Max tokens per request")
p.add_argument("--api-key", default="no-key", help="API key")
p.add_argument("--retries", type=int, default=3, help="HTTP retry count")
p.add_argument("--retry-delay", type=float, default=1.0, help="Base retry delay")
p.add_argument("--timeout", type=float, default=120.0, help="HTTP timeout seconds")
p.add_argument("--json", default=None, dest="json_path", help="Export results as JSON")
p.add_argument("--skip-validation", action="store_true", help="Skip response validation")
213 changes: 213 additions & 0 deletions src/xpyd_acc/smart_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Smart retry for divergent samples — classify as deterministic vs stochastic."""

from __future__ import annotations

import json
from dataclasses import asdict, dataclass, field
from typing import Any, Callable

from .batch_compare import (
BatchReport,
DatasetSample,
SampleResult,
run_batch,
)
from .log import get_logger

logger = get_logger(__name__)


@dataclass
class SampleRetryResult:
"""Result of retrying a single divergent sample with greedy settings."""

sample_id: str
original_classification: str
retry_match: bool
retry_classification: str # "deterministic" or "stochastic"

def to_dict(self) -> dict[str, Any]:
"""Serialize to dict."""
return asdict(self)


@dataclass
class SmartRetryResult:
"""Aggregate result of smart retry for all divergent samples."""

original_divergent: int
deterministic_count: int
stochastic_count: int
deterministic_rate: float
stochastic_rate: float
per_sample: list[SampleRetryResult] = field(default_factory=list)

def to_dict(self) -> dict[str, Any]:
"""Serialize to dict."""
return {
"original_divergent": self.original_divergent,
"deterministic_count": self.deterministic_count,
"stochastic_count": self.stochastic_count,
"deterministic_rate": self.deterministic_rate,
"stochastic_rate": self.stochastic_rate,
"per_sample": [s.to_dict() for s in self.per_sample],
}

def to_json(self) -> str:
"""Serialize to JSON string."""
return json.dumps(self.to_dict(), indent=2)


async def run_smart_retry(
report: BatchReport,
baseline_url: str,
target_url: str,
*,
model: str = "default",
max_tokens: int = 64,
api_key: str = "no-key",
retries: int = 3,
retry_delay: float = 1.0,
timeout: float = 120.0,
skip_validation: bool = False,
custom_headers: dict[str, str] | None = None,
on_progress: Callable[[int, int], None] | None = None,
) -> SmartRetryResult:
"""Retry divergent samples with deterministic settings (temperature=0, seed=42).

Classifies each divergence as:
- ``deterministic``: still diverges under greedy decoding → likely a real bug
- ``stochastic``: matches under greedy decoding → likely sampling noise

Args:
report: The original batch report containing divergent results.
baseline_url: Baseline endpoint URL.
target_url: Target endpoint URL.
model: Model identifier.
max_tokens: Maximum tokens per request.
api_key: API key for authentication.
retries: Number of HTTP retries.
retry_delay: Base delay between retries.
timeout: HTTP request timeout in seconds.
skip_validation: Skip response schema validation.
custom_headers: Optional custom HTTP headers.
on_progress: Progress callback (completed, total).

Returns:
SmartRetryResult with per-sample classification.
"""
divergent = [r for r in report.results if r.is_divergent()]
if not divergent:
return SmartRetryResult(
original_divergent=0,
deterministic_count=0,
stochastic_count=0,
deterministic_rate=0.0,
stochastic_rate=0.0,
)

# Build dataset samples from divergent results
samples = [
DatasetSample(id=r.sample_id, prompt=r.prompt)
for r in divergent
]

logger.info(
"Smart retry: re-running %d divergent samples with greedy settings",
len(samples),
)

# Use a simple SamplingParams-like object for greedy decoding
from dataclasses import dataclass as _dc

@_dc
class _GreedyParams:
temperature: float = 0.0
top_p: float | None = None
seed: int = 42

retry_report = await run_batch(
samples,
baseline_url,
target_url,
model=model,
max_tokens=max_tokens,
api_key=api_key,
retries=retries,
retry_delay=retry_delay,
timeout=timeout,
skip_validation=skip_validation,
custom_headers=custom_headers,
on_progress=on_progress,
sampling_params=_GreedyParams(),
concurrency=5,
)

# Build result map
retry_map: dict[str, SampleResult] = {
r.sample_id: r for r in retry_report.results
}

per_sample: list[SampleRetryResult] = []
deterministic = 0
stochastic = 0

for orig in divergent:
retry_result = retry_map.get(orig.sample_id)
if retry_result is None:
# Should not happen, but treat as deterministic (safe default)
classification = "deterministic"
retry_match = False
else:
retry_match = retry_result.exact_match
classification = "stochastic" if retry_match else "deterministic"

if classification == "deterministic":
deterministic += 1
else:
stochastic += 1

per_sample.append(
SampleRetryResult(
sample_id=orig.sample_id,
original_classification=orig.classification,
retry_match=retry_match,
retry_classification=classification,
)
)

total = len(divergent)
return SmartRetryResult(
original_divergent=total,
deterministic_count=deterministic,
stochastic_count=stochastic,
deterministic_rate=deterministic / total if total else 0.0,
stochastic_rate=stochastic / total if total else 0.0,
per_sample=per_sample,
)


def format_smart_retry(result: SmartRetryResult) -> str:
"""Format smart retry result for terminal display."""
lines = [
"Smart Retry Results",
"=" * 40,
f"Original divergent samples: {result.original_divergent}",
f"Deterministic (real bugs): {result.deterministic_count}"
f" ({result.deterministic_rate:.1%})",
f"Stochastic (sampling noise): {result.stochastic_count}"
f" ({result.stochastic_rate:.1%})",
"",
]

if result.per_sample:
lines.append("Per-Sample Breakdown:")
lines.append(f"{'Sample ID':<30} {'Original':<20} {'Retry':<15}")
lines.append("-" * 65)
for s in result.per_sample:
lines.append(
f"{s.sample_id:<30} {s.original_classification:<20} "
f"{s.retry_classification:<15}"
)

return "\n".join(lines)
Loading
Loading