Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ get from a Python script calling `/v1/chat/completions`.
- Target vLLM ≥ 0.6.x with `--enable-disaggregated-prefill`
- This is the single highest-value feature for making xPyD-acc a real diagnostic tool vs. a glorified diff script

## M88: Framework-Level Inference Hooks
## M88: Framework-Level Inference Hooks
- Go beyond API-level logprobs comparison — hook into the inference engine to capture intermediate states
- Provide a hook interface that can be injected into vLLM / SGLang inference loops:
- Post-prefill hook: capture hidden states, attention weights, KV cache state
Expand Down
3 changes: 2 additions & 1 deletion docs/iterations/current.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ shell for exploratory comparison of two endpoints.
| M84 | 2026-04-06 | Endpoint Response Time Regression Detection | ✅ merged | Both approved |
| M85 | 2026-04-06 | Offline Mode — File-Based Comparison | ✅ merged | Both approved |
| M87 | 2026-04-06 | Automatic KV Cache Export from vLLM | ✅ merged | Both approved |
| M88 | 2026-04-06 | Framework-Level Inference Hooks | ⏳ pending review | — |
| M88 | 2026-04-06 | Framework-Level Inference Hooks | ✅ merged | Both approved |
| M89 | 2026-04-06 | PD Topology-Aware Testing | ⏳ pending review | — |
2 changes: 2 additions & 0 deletions src/xpyd_acc/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
handle_heatmap,
handle_root_cause,
handle_token_diff,
handle_topology_scan,
handle_trace,
)
from .batch import _run_batch_compare
Expand Down Expand Up @@ -143,6 +144,7 @@ def main(argv: list[str] | None = None) -> None:
"repl": lambda: _run_repl(args),
"latency-regression": lambda: _run_latency_regression(args),
"compare-files": lambda: _run_file_compare(args),
"topology-scan": lambda: handle_topology_scan(args),
}

if args.command in _early:
Expand Down
44 changes: 44 additions & 0 deletions src/xpyd_acc/cli/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,47 @@ def handle_trace(args: argparse.Namespace) -> None:

if result.overall_diverged:
raise SystemExit(1)


def handle_topology_scan(args: argparse.Namespace) -> None:
"""Handle topology-scan subcommand."""
from xpyd_acc.topology import (
NodePairResult,
TopologyNode,
format_topology,
scan_topology,
)

if args.mock:
# Mock topology for testing
prefill_nodes = [
TopologyNode("p1", "http://prefill-1:8000", "prefill"),
TopologyNode("p2", "http://prefill-2:8000", "prefill"),
]
decode_nodes = [
TopologyNode("d1", "http://decode-1:8000", "decode"),
TopologyNode("d2", "http://decode-2:8000", "decode"),
]

def mock_test(p_node: TopologyNode, d_node: TopologyNode) -> NodePairResult:
return NodePairResult(
prefill_node=p_node.node_id,
decode_node=d_node.node_id,
samples_tested=args.samples,
divergent_count=0,
)

report = scan_topology(prefill_nodes, decode_nodes, mock_test, proxy_url=args.proxy)
else:
print(
"Live topology scanning requires a running xPyD-proxy.\n"
"Use --mock for testing.",
file=sys.stderr,
)
raise SystemExit(1)

print(format_topology(report))

if getattr(args, "json", None):
report.to_json(args.json)
print(f"Topology report exported to {args.json}")
27 changes: 27 additions & 0 deletions src/xpyd_acc/cli/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def register_all(sub: argparse._SubParsersAction) -> None:
_register_capture_kv(sub)
_register_file_compare(sub)
_register_trace(sub)
_register_topology_scan(sub)
def _register_compare(sub):
lp = sub.add_parser("compare-logprobs", help="Compare logprobs between two endpoints")
lp.add_argument("--baseline", required=True, help="Baseline endpoint URL")
Expand All @@ -77,6 +78,14 @@ def _register_compare(sub):
"--kl-threshold", type=float, default=None, dest="kl_threshold",
help="KL divergence threshold for flagging positions (default: 0.1)",
)
lp.add_argument(
"--prefill-node", default=None,
help="Direct prefill node URL (topology-aware comparison)",
)
lp.add_argument(
"--decode-node", default=None,
help="Direct decode node URL (topology-aware comparison)",
)

oc = sub.add_parser("compare-output", help="Compare text outputs from two endpoints")
oc_input = oc.add_mutually_exclusive_group(required=True)
Expand Down Expand Up @@ -732,3 +741,21 @@ def _register_trace(sub):
help="Noise scale for mock target hook (default: 0.0)",
)
tr.add_argument("--json", default=None, help="Export trace result as JSON")


def _register_topology_scan(sub):
ts = sub.add_parser(
"topology-scan",
help="Discover and test all prefill/decode node pairs via xPyD-proxy",
)
ts.add_argument("--proxy", required=True, help="xPyD-proxy URL")
ts.add_argument("--prompt", default="Hello", help="Test prompt (default: Hello)")
ts.add_argument(
"--samples", type=int, default=10,
help="Number of samples per node pair (default: 10)",
)
ts.add_argument("--json", default=None, help="Export topology report as JSON")
ts.add_argument(
"--mock", action="store_true", default=False,
help="Use mock topology for testing (no real endpoints)",
)
228 changes: 228 additions & 0 deletions src/xpyd_acc/topology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""PD Topology-Aware Testing — discover and test prefill/decode node pairs.

Auto-discovers PD topology from xPyD-proxy and tests all prefill/decode node
pairs to identify which specific combination shows divergence. Critical for
production clusters where overall divergence may be caused by a single bad node.
"""

from __future__ import annotations

import json
from dataclasses import dataclass, field
from pathlib import Path


@dataclass
class TopologyNode:
"""A single node in the PD topology."""

node_id: str
url: str
role: str # "prefill" or "decode"
model: str | None = None
metadata: dict = field(default_factory=dict)

def to_dict(self) -> dict:
return {
"node_id": self.node_id,
"url": self.url,
"role": self.role,
"model": self.model,
"metadata": self.metadata,
}


@dataclass
class NodePairResult:
"""Result of testing a specific prefill/decode node pair."""

prefill_node: str
decode_node: str
samples_tested: int
divergent_count: int
avg_logprob_gap: float | None = None
first_divergence_indices: list[int] = field(default_factory=list)

@property
def divergence_rate(self) -> float:
if self.samples_tested == 0:
return 0.0
return self.divergent_count / self.samples_tested

@property
def verdict(self) -> str:
rate = self.divergence_rate
if rate == 0.0:
return "clean"
if rate < 0.05:
return "low"
if rate < 0.2:
return "moderate"
return "high"

def to_dict(self) -> dict:
return {
"prefill_node": self.prefill_node,
"decode_node": self.decode_node,
"samples_tested": self.samples_tested,
"divergent_count": self.divergent_count,
"divergence_rate": round(self.divergence_rate, 4),
"avg_logprob_gap": (
round(self.avg_logprob_gap, 6) if self.avg_logprob_gap is not None else None
),
"first_divergence_indices": self.first_divergence_indices,
"verdict": self.verdict,
}


@dataclass
class TopologyReport:
"""Full topology scan result."""

proxy_url: str
prefill_nodes: list[TopologyNode]
decode_nodes: list[TopologyNode]
pair_results: list[NodePairResult]
total_pairs: int

@property
def clean_pairs(self) -> int:
return sum(1 for r in self.pair_results if r.verdict == "clean")

@property
def problematic_pairs(self) -> list[NodePairResult]:
return [r for r in self.pair_results if r.verdict != "clean"]

def to_dict(self) -> dict:
return {
"proxy_url": self.proxy_url,
"prefill_nodes": [n.to_dict() for n in self.prefill_nodes],
"decode_nodes": [n.to_dict() for n in self.decode_nodes],
"total_pairs": self.total_pairs,
"clean_pairs": self.clean_pairs,
"problematic_pairs_count": len(self.problematic_pairs),
"pair_results": [r.to_dict() for r in self.pair_results],
}

def to_json(self, path: str | Path) -> None:
Path(path).write_text(json.dumps(self.to_dict(), indent=2) + "\n")


def parse_topology(
data: dict, proxy_url: str = "",
) -> tuple[list[TopologyNode], list[TopologyNode]]:
"""Parse topology response from proxy into node lists.

Expects a dict with 'instances' or 'nodes' key containing a list
of objects with at least 'id'/'node_id', 'url', and 'role' fields.
"""
nodes_data = data.get("instances") or data.get("nodes") or []
prefill: list[TopologyNode] = []
decode: list[TopologyNode] = []

for item in nodes_data:
node_id = item.get("id") or item.get("node_id") or ""
url = item.get("url", "")
role = item.get("role", "").lower()
model = item.get("model")
meta = {k: v for k, v in item.items() if k not in ("id", "node_id", "url", "role", "model")}

node = TopologyNode(
node_id=str(node_id),
url=url,
role=role,
model=model,
metadata=meta,
)

if role == "prefill":
prefill.append(node)
elif role == "decode":
decode.append(node)

return prefill, decode


def build_pair_matrix(
prefill_nodes: list[TopologyNode],
decode_nodes: list[TopologyNode],
) -> list[tuple[TopologyNode, TopologyNode]]:
"""Generate all prefill × decode node pairs."""
return [(p, d) for p in prefill_nodes for d in decode_nodes]


def scan_topology(
prefill_nodes: list[TopologyNode],
decode_nodes: list[TopologyNode],
test_fn: callable,
proxy_url: str = "",
) -> TopologyReport:
"""Scan all node pairs using the provided test function.

Args:
prefill_nodes: discovered prefill nodes
decode_nodes: discovered decode nodes
test_fn: callable(prefill_node, decode_node) -> NodePairResult
proxy_url: original proxy URL for reporting
"""
pairs = build_pair_matrix(prefill_nodes, decode_nodes)
results: list[NodePairResult] = []

for p_node, d_node in pairs:
result = test_fn(p_node, d_node)
results.append(result)

return TopologyReport(
proxy_url=proxy_url,
prefill_nodes=prefill_nodes,
decode_nodes=decode_nodes,
pair_results=results,
total_pairs=len(pairs),
)


def format_topology(report: TopologyReport) -> str:
"""Format topology scan result as terminal-friendly text."""
lines: list[str] = []
lines.append(f"Topology Scan: {report.proxy_url}")
lines.append(
f"Nodes: {len(report.prefill_nodes)} prefill, "
f"{len(report.decode_nodes)} decode"
)
lines.append(
f"Pairs tested: {report.total_pairs} | "
f"Clean: {report.clean_pairs} | "
f"Problematic: {len(report.problematic_pairs)}"
)
lines.append("")

if not report.pair_results:
lines.append("No node pairs to test.")
return "\n".join(lines)

# Node pair matrix
hdr = f"{'Prefill → Decode':<30} {'Samples':>8} {'Divergent':>10} {'Rate':>8} {'Verdict':>10}"
lines.append(hdr)
lines.append("-" * len(hdr))

for r in report.pair_results:
pair_label = f"{r.prefill_node} → {r.decode_node}"
rate_str = f"{r.divergence_rate:.1%}"
icon = "✅" if r.verdict == "clean" else "⚠️" if r.verdict == "low" else "❌"
lines.append(
f"{pair_label:<30} {r.samples_tested:>8} {r.divergent_count:>10} "
f"{rate_str:>8} {icon} {r.verdict:>7}"
)

if report.problematic_pairs:
lines.append("")
lines.append("⚠ Problematic pairs:")
for r in report.problematic_pairs:
gap_str = f", avg gap={r.avg_logprob_gap:.4f}" if r.avg_logprob_gap is not None else ""
lines.append(
f" {r.prefill_node} → {r.decode_node}: "
f"{r.divergence_rate:.1%} divergence "
f"({r.divergent_count}/{r.samples_tested}){gap_str}"
)

return "\n".join(lines)
Loading
Loading