diff --git a/ROADMAP.md b/ROADMAP.md index da00a4a..956e916 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -841,7 +841,7 @@ get from a Python script calling `/v1/chat/completions`. - Target vLLM ≥ 0.6.x with `--enable-disaggregated-prefill` - This is the single highest-value feature for making xPyD-acc a real diagnostic tool vs. a glorified diff script -## M88: Framework-Level Inference Hooks +## M88: Framework-Level Inference Hooks ✅ - Go beyond API-level logprobs comparison — hook into the inference engine to capture intermediate states - Provide a hook interface that can be injected into vLLM / SGLang inference loops: - Post-prefill hook: capture hidden states, attention weights, KV cache state diff --git a/docs/iterations/current.md b/docs/iterations/current.md index b7530ff..627d48d 100644 --- a/docs/iterations/current.md +++ b/docs/iterations/current.md @@ -52,4 +52,5 @@ shell for exploratory comparison of two endpoints. | M84 | 2026-04-06 | Endpoint Response Time Regression Detection | ✅ merged | Both approved | | M85 | 2026-04-06 | Offline Mode — File-Based Comparison | ✅ merged | Both approved | | M87 | 2026-04-06 | Automatic KV Cache Export from vLLM | ✅ merged | Both approved | -| M88 | 2026-04-06 | Framework-Level Inference Hooks | ⏳ pending review | — | +| M88 | 2026-04-06 | Framework-Level Inference Hooks | ✅ merged | Both approved | +| M89 | 2026-04-06 | PD Topology-Aware Testing | ⏳ pending review | — | diff --git a/src/xpyd_acc/cli/__init__.py b/src/xpyd_acc/cli/__init__.py index 958b9e1..15c6bb8 100644 --- a/src/xpyd_acc/cli/__init__.py +++ b/src/xpyd_acc/cli/__init__.py @@ -23,6 +23,7 @@ handle_heatmap, handle_root_cause, handle_token_diff, + handle_topology_scan, handle_trace, ) from .batch import _run_batch_compare @@ -143,6 +144,7 @@ def main(argv: list[str] | None = None) -> None: "repl": lambda: _run_repl(args), "latency-regression": lambda: _run_latency_regression(args), "compare-files": lambda: _run_file_compare(args), + "topology-scan": lambda: handle_topology_scan(args), } if args.command in _early: diff --git a/src/xpyd_acc/cli/analysis.py b/src/xpyd_acc/cli/analysis.py index 53461e6..51fa376 100644 --- a/src/xpyd_acc/cli/analysis.py +++ b/src/xpyd_acc/cli/analysis.py @@ -471,3 +471,47 @@ def handle_trace(args: argparse.Namespace) -> None: if result.overall_diverged: raise SystemExit(1) + + +def handle_topology_scan(args: argparse.Namespace) -> None: + """Handle topology-scan subcommand.""" + from xpyd_acc.topology import ( + NodePairResult, + TopologyNode, + format_topology, + scan_topology, + ) + + if args.mock: + # Mock topology for testing + prefill_nodes = [ + TopologyNode("p1", "http://prefill-1:8000", "prefill"), + TopologyNode("p2", "http://prefill-2:8000", "prefill"), + ] + decode_nodes = [ + TopologyNode("d1", "http://decode-1:8000", "decode"), + TopologyNode("d2", "http://decode-2:8000", "decode"), + ] + + def mock_test(p_node: TopologyNode, d_node: TopologyNode) -> NodePairResult: + return NodePairResult( + prefill_node=p_node.node_id, + decode_node=d_node.node_id, + samples_tested=args.samples, + divergent_count=0, + ) + + report = scan_topology(prefill_nodes, decode_nodes, mock_test, proxy_url=args.proxy) + else: + print( + "Live topology scanning requires a running xPyD-proxy.\n" + "Use --mock for testing.", + file=sys.stderr, + ) + raise SystemExit(1) + + print(format_topology(report)) + + if getattr(args, "json", None): + report.to_json(args.json) + print(f"Topology report exported to {args.json}") diff --git a/src/xpyd_acc/cli/parsers.py b/src/xpyd_acc/cli/parsers.py index 13a95d0..f52b112 100644 --- a/src/xpyd_acc/cli/parsers.py +++ b/src/xpyd_acc/cli/parsers.py @@ -55,6 +55,7 @@ def register_all(sub: argparse._SubParsersAction) -> None: _register_capture_kv(sub) _register_file_compare(sub) _register_trace(sub) + _register_topology_scan(sub) def _register_compare(sub): lp = sub.add_parser("compare-logprobs", help="Compare logprobs between two endpoints") lp.add_argument("--baseline", required=True, help="Baseline endpoint URL") @@ -77,6 +78,14 @@ def _register_compare(sub): "--kl-threshold", type=float, default=None, dest="kl_threshold", help="KL divergence threshold for flagging positions (default: 0.1)", ) + lp.add_argument( + "--prefill-node", default=None, + help="Direct prefill node URL (topology-aware comparison)", + ) + lp.add_argument( + "--decode-node", default=None, + help="Direct decode node URL (topology-aware comparison)", + ) oc = sub.add_parser("compare-output", help="Compare text outputs from two endpoints") oc_input = oc.add_mutually_exclusive_group(required=True) @@ -732,3 +741,21 @@ def _register_trace(sub): help="Noise scale for mock target hook (default: 0.0)", ) tr.add_argument("--json", default=None, help="Export trace result as JSON") + + +def _register_topology_scan(sub): + ts = sub.add_parser( + "topology-scan", + help="Discover and test all prefill/decode node pairs via xPyD-proxy", + ) + ts.add_argument("--proxy", required=True, help="xPyD-proxy URL") + ts.add_argument("--prompt", default="Hello", help="Test prompt (default: Hello)") + ts.add_argument( + "--samples", type=int, default=10, + help="Number of samples per node pair (default: 10)", + ) + ts.add_argument("--json", default=None, help="Export topology report as JSON") + ts.add_argument( + "--mock", action="store_true", default=False, + help="Use mock topology for testing (no real endpoints)", + ) diff --git a/src/xpyd_acc/topology.py b/src/xpyd_acc/topology.py new file mode 100644 index 0000000..52e76ac --- /dev/null +++ b/src/xpyd_acc/topology.py @@ -0,0 +1,228 @@ +"""PD Topology-Aware Testing — discover and test prefill/decode node pairs. + +Auto-discovers PD topology from xPyD-proxy and tests all prefill/decode node +pairs to identify which specific combination shows divergence. Critical for +production clusters where overall divergence may be caused by a single bad node. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path + + +@dataclass +class TopologyNode: + """A single node in the PD topology.""" + + node_id: str + url: str + role: str # "prefill" or "decode" + model: str | None = None + metadata: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + return { + "node_id": self.node_id, + "url": self.url, + "role": self.role, + "model": self.model, + "metadata": self.metadata, + } + + +@dataclass +class NodePairResult: + """Result of testing a specific prefill/decode node pair.""" + + prefill_node: str + decode_node: str + samples_tested: int + divergent_count: int + avg_logprob_gap: float | None = None + first_divergence_indices: list[int] = field(default_factory=list) + + @property + def divergence_rate(self) -> float: + if self.samples_tested == 0: + return 0.0 + return self.divergent_count / self.samples_tested + + @property + def verdict(self) -> str: + rate = self.divergence_rate + if rate == 0.0: + return "clean" + if rate < 0.05: + return "low" + if rate < 0.2: + return "moderate" + return "high" + + def to_dict(self) -> dict: + return { + "prefill_node": self.prefill_node, + "decode_node": self.decode_node, + "samples_tested": self.samples_tested, + "divergent_count": self.divergent_count, + "divergence_rate": round(self.divergence_rate, 4), + "avg_logprob_gap": ( + round(self.avg_logprob_gap, 6) if self.avg_logprob_gap is not None else None + ), + "first_divergence_indices": self.first_divergence_indices, + "verdict": self.verdict, + } + + +@dataclass +class TopologyReport: + """Full topology scan result.""" + + proxy_url: str + prefill_nodes: list[TopologyNode] + decode_nodes: list[TopologyNode] + pair_results: list[NodePairResult] + total_pairs: int + + @property + def clean_pairs(self) -> int: + return sum(1 for r in self.pair_results if r.verdict == "clean") + + @property + def problematic_pairs(self) -> list[NodePairResult]: + return [r for r in self.pair_results if r.verdict != "clean"] + + def to_dict(self) -> dict: + return { + "proxy_url": self.proxy_url, + "prefill_nodes": [n.to_dict() for n in self.prefill_nodes], + "decode_nodes": [n.to_dict() for n in self.decode_nodes], + "total_pairs": self.total_pairs, + "clean_pairs": self.clean_pairs, + "problematic_pairs_count": len(self.problematic_pairs), + "pair_results": [r.to_dict() for r in self.pair_results], + } + + def to_json(self, path: str | Path) -> None: + Path(path).write_text(json.dumps(self.to_dict(), indent=2) + "\n") + + +def parse_topology( + data: dict, proxy_url: str = "", +) -> tuple[list[TopologyNode], list[TopologyNode]]: + """Parse topology response from proxy into node lists. + + Expects a dict with 'instances' or 'nodes' key containing a list + of objects with at least 'id'/'node_id', 'url', and 'role' fields. + """ + nodes_data = data.get("instances") or data.get("nodes") or [] + prefill: list[TopologyNode] = [] + decode: list[TopologyNode] = [] + + for item in nodes_data: + node_id = item.get("id") or item.get("node_id") or "" + url = item.get("url", "") + role = item.get("role", "").lower() + model = item.get("model") + meta = {k: v for k, v in item.items() if k not in ("id", "node_id", "url", "role", "model")} + + node = TopologyNode( + node_id=str(node_id), + url=url, + role=role, + model=model, + metadata=meta, + ) + + if role == "prefill": + prefill.append(node) + elif role == "decode": + decode.append(node) + + return prefill, decode + + +def build_pair_matrix( + prefill_nodes: list[TopologyNode], + decode_nodes: list[TopologyNode], +) -> list[tuple[TopologyNode, TopologyNode]]: + """Generate all prefill × decode node pairs.""" + return [(p, d) for p in prefill_nodes for d in decode_nodes] + + +def scan_topology( + prefill_nodes: list[TopologyNode], + decode_nodes: list[TopologyNode], + test_fn: callable, + proxy_url: str = "", +) -> TopologyReport: + """Scan all node pairs using the provided test function. + + Args: + prefill_nodes: discovered prefill nodes + decode_nodes: discovered decode nodes + test_fn: callable(prefill_node, decode_node) -> NodePairResult + proxy_url: original proxy URL for reporting + """ + pairs = build_pair_matrix(prefill_nodes, decode_nodes) + results: list[NodePairResult] = [] + + for p_node, d_node in pairs: + result = test_fn(p_node, d_node) + results.append(result) + + return TopologyReport( + proxy_url=proxy_url, + prefill_nodes=prefill_nodes, + decode_nodes=decode_nodes, + pair_results=results, + total_pairs=len(pairs), + ) + + +def format_topology(report: TopologyReport) -> str: + """Format topology scan result as terminal-friendly text.""" + lines: list[str] = [] + lines.append(f"Topology Scan: {report.proxy_url}") + lines.append( + f"Nodes: {len(report.prefill_nodes)} prefill, " + f"{len(report.decode_nodes)} decode" + ) + lines.append( + f"Pairs tested: {report.total_pairs} | " + f"Clean: {report.clean_pairs} | " + f"Problematic: {len(report.problematic_pairs)}" + ) + lines.append("") + + if not report.pair_results: + lines.append("No node pairs to test.") + return "\n".join(lines) + + # Node pair matrix + hdr = f"{'Prefill → Decode':<30} {'Samples':>8} {'Divergent':>10} {'Rate':>8} {'Verdict':>10}" + lines.append(hdr) + lines.append("-" * len(hdr)) + + for r in report.pair_results: + pair_label = f"{r.prefill_node} → {r.decode_node}" + rate_str = f"{r.divergence_rate:.1%}" + icon = "✅" if r.verdict == "clean" else "⚠️" if r.verdict == "low" else "❌" + lines.append( + f"{pair_label:<30} {r.samples_tested:>8} {r.divergent_count:>10} " + f"{rate_str:>8} {icon} {r.verdict:>7}" + ) + + if report.problematic_pairs: + lines.append("") + lines.append("⚠ Problematic pairs:") + for r in report.problematic_pairs: + gap_str = f", avg gap={r.avg_logprob_gap:.4f}" if r.avg_logprob_gap is not None else "" + lines.append( + f" {r.prefill_node} → {r.decode_node}: " + f"{r.divergence_rate:.1%} divergence " + f"({r.divergent_count}/{r.samples_tested}){gap_str}" + ) + + return "\n".join(lines) diff --git a/tests/test_topology.py b/tests/test_topology.py new file mode 100644 index 0000000..82b5f83 --- /dev/null +++ b/tests/test_topology.py @@ -0,0 +1,287 @@ +"""Tests for PD topology-aware testing.""" + +from __future__ import annotations + +import json +import subprocess +import sys +import tempfile +from pathlib import Path + +import pytest + +from xpyd_acc.topology import ( + NodePairResult, + TopologyNode, + TopologyReport, + build_pair_matrix, + format_topology, + parse_topology, + scan_topology, +) + +# --- TopologyNode --- + + +class TestTopologyNode: + def test_to_dict(self): + node = TopologyNode(node_id="p1", url="http://p1:8000", role="prefill", model="llama-7b") + d = node.to_dict() + assert d["node_id"] == "p1" + assert d["role"] == "prefill" + assert d["model"] == "llama-7b" + + def test_default_metadata(self): + node = TopologyNode(node_id="d1", url="http://d1:8000", role="decode") + assert node.metadata == {} + assert node.model is None + + +# --- NodePairResult --- + + +class TestNodePairResult: + def test_divergence_rate_zero(self): + r = NodePairResult( + prefill_node="p1", decode_node="d1", + samples_tested=10, divergent_count=0, + ) + assert r.divergence_rate == 0.0 + assert r.verdict == "clean" + + def test_divergence_rate_low(self): + r = NodePairResult( + prefill_node="p1", decode_node="d1", + samples_tested=100, divergent_count=3, + ) + assert r.divergence_rate == pytest.approx(0.03) + assert r.verdict == "low" + + def test_divergence_rate_moderate(self): + r = NodePairResult( + prefill_node="p1", decode_node="d1", + samples_tested=100, divergent_count=10, + ) + assert r.divergence_rate == pytest.approx(0.1) + assert r.verdict == "moderate" + + def test_divergence_rate_high(self): + r = NodePairResult( + prefill_node="p1", decode_node="d1", + samples_tested=10, divergent_count=5, + ) + assert r.divergence_rate == pytest.approx(0.5) + assert r.verdict == "high" + + def test_zero_samples(self): + r = NodePairResult(prefill_node="p1", decode_node="d1", samples_tested=0, divergent_count=0) + assert r.divergence_rate == 0.0 + + def test_to_dict(self): + r = NodePairResult( + prefill_node="p1", decode_node="d1", + samples_tested=20, divergent_count=4, + avg_logprob_gap=0.123456, first_divergence_indices=[3, 7], + ) + d = r.to_dict() + assert d["divergence_rate"] == 0.2 + assert d["avg_logprob_gap"] == 0.123456 + assert d["verdict"] == "high" + assert d["first_divergence_indices"] == [3, 7] + + +# --- parse_topology --- + + +class TestParseTopology: + def test_parse_instances(self): + data = { + "instances": [ + {"id": "p1", "url": "http://p1:8000", "role": "prefill", "model": "llama"}, + {"id": "d1", "url": "http://d1:8000", "role": "decode"}, + {"id": "d2", "url": "http://d2:8000", "role": "decode"}, + ] + } + prefill, decode = parse_topology(data) + assert len(prefill) == 1 + assert len(decode) == 2 + assert prefill[0].node_id == "p1" + assert prefill[0].model == "llama" + + def test_parse_nodes_key(self): + data = { + "nodes": [ + {"node_id": "px", "url": "http://px:8000", "role": "prefill"}, + ] + } + prefill, decode = parse_topology(data) + assert len(prefill) == 1 + assert prefill[0].node_id == "px" + + def test_empty(self): + prefill, decode = parse_topology({}) + assert prefill == [] + assert decode == [] + + def test_extra_metadata(self): + data = { + "instances": [ + {"id": "p1", "url": "http://p1:8000", "role": "prefill", "gpu": "A100"}, + ] + } + prefill, _ = parse_topology(data) + assert prefill[0].metadata == {"gpu": "A100"} + + +# --- build_pair_matrix --- + + +class TestBuildPairMatrix: + def test_cartesian_product(self): + p = [TopologyNode("p1", "http://p1", "prefill"), TopologyNode("p2", "http://p2", "prefill")] + d = [TopologyNode("d1", "http://d1", "decode"), TopologyNode("d2", "http://d2", "decode")] + pairs = build_pair_matrix(p, d) + assert len(pairs) == 4 + + def test_empty(self): + assert build_pair_matrix([], []) == [] + + +# --- scan_topology --- + + +class TestScanTopology: + def _make_test_fn(self, divergent_pairs: set[tuple[str, str]] | None = None): + divergent_pairs = divergent_pairs or set() + + def test_fn(p_node: TopologyNode, d_node: TopologyNode) -> NodePairResult: + is_div = (p_node.node_id, d_node.node_id) in divergent_pairs + return NodePairResult( + prefill_node=p_node.node_id, + decode_node=d_node.node_id, + samples_tested=10, + divergent_count=5 if is_div else 0, + avg_logprob_gap=0.05 if is_div else None, + ) + + return test_fn + + def test_all_clean(self): + p = [TopologyNode("p1", "http://p1", "prefill")] + d = [TopologyNode("d1", "http://d1", "decode")] + report = scan_topology(p, d, self._make_test_fn(), proxy_url="http://proxy") + assert report.total_pairs == 1 + assert report.clean_pairs == 1 + assert len(report.problematic_pairs) == 0 + + def test_one_bad_pair(self): + p = [TopologyNode("p1", "http://p1", "prefill")] + d = [TopologyNode("d1", "http://d1", "decode"), TopologyNode("d2", "http://d2", "decode")] + report = scan_topology( + p, d, + self._make_test_fn(divergent_pairs={("p1", "d2")}), + proxy_url="http://proxy", + ) + assert report.total_pairs == 2 + assert report.clean_pairs == 1 + assert len(report.problematic_pairs) == 1 + assert report.problematic_pairs[0].decode_node == "d2" + + +# --- TopologyReport --- + + +class TestTopologyReport: + def test_to_dict(self): + report = TopologyReport( + proxy_url="http://proxy", + prefill_nodes=[TopologyNode("p1", "http://p1", "prefill")], + decode_nodes=[TopologyNode("d1", "http://d1", "decode")], + pair_results=[ + NodePairResult("p1", "d1", samples_tested=10, divergent_count=0), + ], + total_pairs=1, + ) + d = report.to_dict() + assert d["total_pairs"] == 1 + assert d["clean_pairs"] == 1 + assert d["problematic_pairs_count"] == 0 + + def test_to_json(self): + report = TopologyReport( + proxy_url="http://proxy", + prefill_nodes=[], + decode_nodes=[], + pair_results=[], + total_pairs=0, + ) + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: + path = f.name + report.to_json(path) + data = json.loads(Path(path).read_text()) + assert data["proxy_url"] == "http://proxy" + Path(path).unlink() + + +# --- format_topology --- + + +class TestFormatTopology: + def test_empty(self): + report = TopologyReport("http://proxy", [], [], [], 0) + text = format_topology(report) + assert "No node pairs" in text + + def test_with_results(self): + report = TopologyReport( + proxy_url="http://proxy", + prefill_nodes=[TopologyNode("p1", "http://p1", "prefill")], + decode_nodes=[TopologyNode("d1", "http://d1", "decode")], + pair_results=[ + NodePairResult( + "p1", "d1", samples_tested=10, + divergent_count=3, avg_logprob_gap=0.05, + ), + ], + total_pairs=1, + ) + text = format_topology(report) + assert "p1" in text + assert "d1" in text + assert "Problematic" in text + + def test_clean_no_warning(self): + report = TopologyReport( + proxy_url="http://proxy", + prefill_nodes=[TopologyNode("p1", "http://p1", "prefill")], + decode_nodes=[TopologyNode("d1", "http://d1", "decode")], + pair_results=[ + NodePairResult("p1", "d1", samples_tested=10, divergent_count=0), + ], + total_pairs=1, + ) + text = format_topology(report) + assert "Problematic pairs" not in text + + +# --- CLI integration --- + + +class TestTopologyCLI: + def test_topology_scan_help(self): + result = subprocess.run( + [sys.executable, "-m", "xpyd_acc.cli", "topology-scan", "--help"], + capture_output=True, text=True, + ) + assert result.returncode == 0 + assert "--proxy" in result.stdout + + def test_compare_prefill_decode_help(self): + """Verify --prefill-node and --decode-node flags exist on compare-logprobs.""" + result = subprocess.run( + [sys.executable, "-m", "xpyd_acc.cli", "compare-logprobs", "--help"], + capture_output=True, text=True, + ) + assert result.returncode == 0 + assert "--prefill-node" in result.stdout + assert "--decode-node" in result.stdout