From fb86ea04b20655a7a855c9015872b6cd6942100a Mon Sep 17 00:00:00 2001 From: hlin99 Date: Mon, 6 Apr 2026 11:48:06 +0800 Subject: [PATCH] feat: Multi-Backend Comparison Report (M114) - Wire BackendComparator into CLI as 'compare-backends' subcommand - Register compare-backends parser and dispatch in _main.py - Export BackendComparator and related models from __init__.py - Auto-detect benchmark format (native, vLLM, SGLang, TensorRT-LLM) - Per-backend latency P50/P95/P99, throughput, SLA compliance - Rank backends by configurable criteria - Rich table and JSON output formats - 31 tests passing Closes #251 --- ROADMAP.md | 17 +- docs/iterations/current.md | 3 +- src/xpyd_plan/__init__.py | 24 ++ src/xpyd_plan/backend_compare.py | 363 ++++++++++++++++++++ src/xpyd_plan/cli/_compare_backends.py | 186 +++++++++++ src/xpyd_plan/cli/_main.py | 6 + tests/test_backend_compare.py | 436 +++++++++++++++++++++++++ 7 files changed, 1032 insertions(+), 3 deletions(-) create mode 100644 src/xpyd_plan/backend_compare.py create mode 100644 src/xpyd_plan/cli/_compare_backends.py create mode 100644 tests/test_backend_compare.py diff --git a/ROADMAP.md b/ROADMAP.md index 898ef4f..c00d293 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1518,9 +1518,9 @@ Help users find the **optimal Prefill:Decode instance ratio** based on **real be - Programmatic `import_trtllm()` and `import_trtllm_data()` API - 25+ new tests -### M113 🔄 TensorRT-LLM Benchmark Command Generator +### M113 ✅ TensorRT-LLM Benchmark Command Generator -*In progress* +*Completed — PR #250* - `TRTLLMCommandGenerator` class in `trtllm_commands.py` - `TRTLLMCommandConfig`, `TRTLLMServerCommand`, `TRTLLMBenchmarkCommand`, `TRTLLMCommandSet` Pydantic models @@ -1530,3 +1530,16 @@ Help users find the **optimal Prefill:Decode instance ratio** based on **real be - CLI `trtllm-commands` subcommand with table + JSON output - Programmatic `generate_trtllm_commands()` API - 29 new tests + +### M114 🔄 Multi-Backend Comparison Report + +*In progress* + +- `BackendComparator` class in `backend_compare.py` +- `BackendComparisonConfig`, `BackendMetrics`, `BackendComparisonReport`, `BackendRanking`, `SLAResult` Pydantic models +- Auto-detect input format (native, vLLM, SGLang, TensorRT-LLM) +- Per-backend latency percentiles (P50/P95/P99), throughput, SLA compliance +- Rank backends by configurable criteria (ttft_p99, tpot_p99, total_latency_p99, throughput) +- CLI `compare-backends` subcommand with `--benchmark`, `--labels`, `--formats`, `--rank-by`, table + JSON output +- Programmatic `compare_backends()` API +- ~25 new tests diff --git a/docs/iterations/current.md b/docs/iterations/current.md index 8eccdd2..d2a8853 100644 --- a/docs/iterations/current.md +++ b/docs/iterations/current.md @@ -65,4 +65,5 @@ The project has completed **110 milestones**, covering the full feature chain fr | 4 | 2026-04-06 | M110 SGLang Benchmark Format Importer | ✅ merged | PR #244 | | 5 | 2026-04-06 | M111 SGLang Benchmark Command Generator | ✅ merged | PR #246 | | 6 | 2026-04-06 | M112 TensorRT-LLM Benchmark Format Importer | ✅ merged | PR #248, both bots approved | -| 7 | 2026-04-06 | M113 TensorRT-LLM Benchmark Command Generator | ⏳ pending review | Issue #249 | +| 7 | 2026-04-06 | M113 TensorRT-LLM Benchmark Command Generator | ✅ merged | PR #250, both bots approved | +| 8 | 2026-04-06 | M114 Multi-Backend Comparison Report | ⏳ pending review | Issue #251 | diff --git a/src/xpyd_plan/__init__.py b/src/xpyd_plan/__init__.py index c0f8a6d..fc96eb5 100644 --- a/src/xpyd_plan/__init__.py +++ b/src/xpyd_plan/__init__.py @@ -1493,3 +1493,27 @@ "import_trtllm", "import_trtllm_data", ] + +from xpyd_plan.backend_compare import ( # noqa: E402 + BackendComparator, + BackendComparisonConfig, + BackendComparisonReport, + BackendFormat, + BackendMetrics, + BackendRanking, + RankCriteria, + SLAResult, + compare_backends, +) + +__all__ += [ + "BackendComparator", + "BackendComparisonConfig", + "BackendComparisonReport", + "BackendFormat", + "BackendMetrics", + "BackendRanking", + "RankCriteria", + "SLAResult", + "compare_backends", +] diff --git a/src/xpyd_plan/backend_compare.py b/src/xpyd_plan/backend_compare.py new file mode 100644 index 0000000..f9fbcd8 --- /dev/null +++ b/src/xpyd_plan/backend_compare.py @@ -0,0 +1,363 @@ +"""Multi-backend comparison report. + +Compare benchmark results across different serving backends (vLLM, SGLang, +TensorRT-LLM, native) to identify the best backend for a given workload +based on latency, throughput, and SLA compliance. +""" + +from __future__ import annotations + +import json +from enum import Enum +from pathlib import Path +from typing import Optional + +import numpy as np +from pydantic import BaseModel, Field + +from xpyd_plan.benchmark_models import BenchmarkData +from xpyd_plan.sglang_import import _detect_sglang_format, import_sglang_data +from xpyd_plan.trtllm_import import _detect_trtllm_format, import_trtllm_data +from xpyd_plan.vllm_import import VLLMImporter + + +class BackendFormat(str, Enum): + """Supported benchmark data formats.""" + + AUTO = "auto" + NATIVE = "native" + VLLM = "vllm" + SGLANG = "sglang" + TRTLLM = "trtllm" + + +class RankCriteria(str, Enum): + """Criteria for ranking backends.""" + + TTFT_P99 = "ttft_p99" + TPOT_P99 = "tpot_p99" + TOTAL_LATENCY_P99 = "total_latency_p99" + THROUGHPUT = "throughput" + + +class BackendMetrics(BaseModel): + """Aggregated metrics for a single backend benchmark run.""" + + backend_label: str = Field(..., description="User-provided label for this backend") + format_detected: str = Field(..., description="Detected input format") + num_prefill_instances: int = Field(..., ge=1) + num_decode_instances: int = Field(..., ge=1) + total_instances: int = Field(..., ge=2) + measured_qps: float = Field(..., gt=0) + request_count: int = Field(..., ge=1) + ttft_p50_ms: float = Field(..., ge=0) + ttft_p95_ms: float = Field(..., ge=0) + ttft_p99_ms: float = Field(..., ge=0) + tpot_p50_ms: float = Field(..., ge=0) + tpot_p95_ms: float = Field(..., ge=0) + tpot_p99_ms: float = Field(..., ge=0) + total_latency_p50_ms: float = Field(..., ge=0) + total_latency_p95_ms: float = Field(..., ge=0) + total_latency_p99_ms: float = Field(..., ge=0) + throughput_rps: float = Field(..., ge=0, description="Effective throughput (requests/sec)") + avg_prompt_tokens: float = Field(..., ge=0) + avg_output_tokens: float = Field(..., ge=0) + + +class BackendRanking(BaseModel): + """Ranking entry for a backend.""" + + backend_label: str = Field(..., description="Backend label") + rank: int = Field(..., ge=1, description="Rank (1 = best)") + criteria: str = Field(..., description="Ranking criteria used") + score: float = Field(..., description="Score value (lower is better for latency)") + recommendation: str = Field(..., description="Human-readable recommendation") + + +class SLAResult(BaseModel): + """SLA compliance result for a backend.""" + + backend_label: str + ttft_p99_pass: bool = Field(..., description="TTFT P99 meets SLA") + tpot_p99_pass: bool = Field(..., description="TPOT P99 meets SLA") + total_latency_p99_pass: bool = Field(..., description="Total latency P99 meets SLA") + meets_all: bool = Field(..., description="All SLA constraints met") + + +class BackendComparisonConfig(BaseModel): + """Configuration for backend comparison.""" + + rank_by: RankCriteria = Field( + RankCriteria.TTFT_P99, description="Criteria for ranking backends" + ) + sla_ttft_p99_ms: Optional[float] = Field( + None, ge=0, description="SLA threshold for TTFT P99 (ms)" + ) + sla_tpot_p99_ms: Optional[float] = Field( + None, ge=0, description="SLA threshold for TPOT P99 (ms)" + ) + sla_total_latency_p99_ms: Optional[float] = Field( + None, ge=0, description="SLA threshold for total latency P99 (ms)" + ) + + +class BackendComparisonReport(BaseModel): + """Complete comparison report across backends.""" + + metrics: list[BackendMetrics] = Field(default_factory=list) + rankings: list[BackendRanking] = Field(default_factory=list) + sla_results: list[SLAResult] = Field(default_factory=list) + best_backend: str = Field(..., description="Best backend by ranking criteria") + rank_criteria: str = Field(..., description="Criteria used for ranking") + + +def _detect_format(data: object) -> BackendFormat: + """Auto-detect the benchmark data format.""" + if isinstance(data, dict): + # Native format has 'metadata' and 'requests' at top level + if "metadata" in data and "requests" in data: + return BackendFormat.NATIVE + if _detect_trtllm_format(data): + return BackendFormat.TRTLLM + if _detect_sglang_format(data): + return BackendFormat.SGLANG + # vLLM detection: list of dicts with 'request_latency' field + if isinstance(data, list) and len(data) > 0: + first = data[0] + if isinstance(first, dict) and "request_latency" in first: + return BackendFormat.VLLM + return BackendFormat.NATIVE + + +def _load_benchmark( + path: str | Path, + fmt: BackendFormat = BackendFormat.AUTO, + num_prefill: int = 1, + num_decode: int = 1, + total_instances: int = 2, + measured_qps: float = 1.0, +) -> tuple[BenchmarkData, str]: + """Load benchmark data from file, auto-detecting format. + + Returns (BenchmarkData, detected_format_name). + """ + raw = json.loads(Path(path).read_text()) + + if fmt == BackendFormat.AUTO: + fmt = _detect_format(raw) + + if fmt == BackendFormat.NATIVE: + bd = BenchmarkData.model_validate(raw) + return bd, "native" + elif fmt == BackendFormat.VLLM: + importer = VLLMImporter() + result = importer.import_data(raw) + return result.benchmark_data, "vllm" + elif fmt == BackendFormat.SGLANG: + result = import_sglang_data( + raw, + num_prefill_instances=num_prefill, + num_decode_instances=num_decode, + total_instances=total_instances, + measured_qps=measured_qps, + ) + return result.benchmark_data, "sglang" + elif fmt == BackendFormat.TRTLLM: + result = import_trtllm_data( + raw, + num_prefill_instances=num_prefill, + num_decode_instances=num_decode, + total_instances=total_instances, + measured_qps=measured_qps, + ) + return result.benchmark_data, "trtllm" + else: + raise ValueError(f"Unsupported format: {fmt}") + + +def _compute_metrics( + bd: BenchmarkData, label: str, format_name: str +) -> BackendMetrics: + """Compute aggregated metrics from benchmark data.""" + ttfts = np.array([r.ttft_ms for r in bd.requests]) + tpots = np.array([r.tpot_ms for r in bd.requests]) + totals = np.array([r.total_latency_ms for r in bd.requests]) + prompts = np.array([r.prompt_tokens for r in bd.requests]) + outputs = np.array([r.output_tokens for r in bd.requests]) + + return BackendMetrics( + backend_label=label, + format_detected=format_name, + num_prefill_instances=bd.metadata.num_prefill_instances, + num_decode_instances=bd.metadata.num_decode_instances, + total_instances=bd.metadata.total_instances, + measured_qps=bd.metadata.measured_qps, + request_count=len(bd.requests), + ttft_p50_ms=float(np.percentile(ttfts, 50)), + ttft_p95_ms=float(np.percentile(ttfts, 95)), + ttft_p99_ms=float(np.percentile(ttfts, 99)), + tpot_p50_ms=float(np.percentile(tpots, 50)), + tpot_p95_ms=float(np.percentile(tpots, 95)), + tpot_p99_ms=float(np.percentile(tpots, 99)), + total_latency_p50_ms=float(np.percentile(totals, 50)), + total_latency_p95_ms=float(np.percentile(totals, 95)), + total_latency_p99_ms=float(np.percentile(totals, 99)), + throughput_rps=bd.metadata.measured_qps, + avg_prompt_tokens=float(np.mean(prompts)), + avg_output_tokens=float(np.mean(outputs)), + ) + + +def _check_sla( + metrics: BackendMetrics, config: BackendComparisonConfig +) -> SLAResult: + """Check SLA compliance for a backend.""" + ttft_pass = True + tpot_pass = True + total_pass = True + + if config.sla_ttft_p99_ms is not None: + ttft_pass = metrics.ttft_p99_ms <= config.sla_ttft_p99_ms + if config.sla_tpot_p99_ms is not None: + tpot_pass = metrics.tpot_p99_ms <= config.sla_tpot_p99_ms + if config.sla_total_latency_p99_ms is not None: + total_pass = metrics.total_latency_p99_ms <= config.sla_total_latency_p99_ms + + return SLAResult( + backend_label=metrics.backend_label, + ttft_p99_pass=ttft_pass, + tpot_p99_pass=tpot_pass, + total_latency_p99_pass=total_pass, + meets_all=ttft_pass and tpot_pass and total_pass, + ) + + +def _rank_backends( + metrics_list: list[BackendMetrics], + criteria: RankCriteria, +) -> list[BackendRanking]: + """Rank backends by the chosen criteria.""" + score_map = { + RankCriteria.TTFT_P99: lambda m: m.ttft_p99_ms, + RankCriteria.TPOT_P99: lambda m: m.tpot_p99_ms, + RankCriteria.TOTAL_LATENCY_P99: lambda m: m.total_latency_p99_ms, + RankCriteria.THROUGHPUT: lambda m: -m.throughput_rps, # negative: higher is better + } + + key_fn = score_map[criteria] + scored = [(m, key_fn(m)) for m in metrics_list] + scored.sort(key=lambda x: x[1]) + + rankings = [] + for rank_idx, (m, score) in enumerate(scored): + if rank_idx == 0: + rec = f"Best: {m.backend_label} wins on {criteria.value}" + else: + rec = f"{m.backend_label}: rank {rank_idx + 1} by {criteria.value}" + rankings.append( + BackendRanking( + backend_label=m.backend_label, + rank=rank_idx + 1, + criteria=criteria.value, + score=abs(score), + recommendation=rec, + ) + ) + + return rankings + + +class BackendComparator: + """Compare benchmark results across different serving backends.""" + + def __init__(self, config: Optional[BackendComparisonConfig] = None): + self._config = config or BackendComparisonConfig() + + def compare( + self, + benchmark_paths: list[str], + backend_labels: list[str], + formats: Optional[list[BackendFormat]] = None, + ) -> BackendComparisonReport: + """Compare multiple backends and produce a comparison report. + + Args: + benchmark_paths: Paths to benchmark JSON files. + backend_labels: Labels for each backend (e.g., "vllm", "sglang"). + formats: Optional format hints per file (default: auto-detect all). + + Returns: + BackendComparisonReport with metrics, rankings, and SLA results. + """ + if len(benchmark_paths) != len(backend_labels): + raise ValueError( + f"Number of benchmarks ({len(benchmark_paths)}) must match " + f"number of labels ({len(backend_labels)})" + ) + if len(benchmark_paths) < 2: + raise ValueError("At least 2 benchmarks required for comparison") + + if formats is None: + formats = [BackendFormat.AUTO] * len(benchmark_paths) + elif len(formats) != len(benchmark_paths): + raise ValueError( + f"Number of formats ({len(formats)}) must match " + f"number of benchmarks ({len(benchmark_paths)})" + ) + + metrics_list: list[BackendMetrics] = [] + for path, label, fmt in zip(benchmark_paths, backend_labels, formats): + bd, detected = _load_benchmark(path, fmt) + m = _compute_metrics(bd, label, detected) + metrics_list.append(m) + + rankings = _rank_backends(metrics_list, self._config.rank_by) + + sla_results = [_check_sla(m, self._config) for m in metrics_list] + + best = rankings[0].backend_label if rankings else backend_labels[0] + + return BackendComparisonReport( + metrics=metrics_list, + rankings=rankings, + sla_results=sla_results, + best_backend=best, + rank_criteria=self._config.rank_by.value, + ) + + +def compare_backends( + benchmark_paths: list[str], + backend_labels: list[str], + rank_by: str = "ttft_p99", + formats: Optional[list[str]] = None, + sla_ttft_p99_ms: Optional[float] = None, + sla_tpot_p99_ms: Optional[float] = None, + sla_total_latency_p99_ms: Optional[float] = None, +) -> BackendComparisonReport: + """Programmatic API: compare backends across benchmark files. + + Args: + benchmark_paths: Paths to benchmark JSON files. + backend_labels: Labels for each backend. + rank_by: Ranking criteria (ttft_p99, tpot_p99, total_latency_p99, throughput). + formats: Optional format per file (auto, native, vllm, sglang, trtllm). + sla_ttft_p99_ms: SLA threshold for TTFT P99. + sla_tpot_p99_ms: SLA threshold for TPOT P99. + sla_total_latency_p99_ms: SLA threshold for total latency P99. + + Returns: + BackendComparisonReport with metrics, rankings, and SLA results. + """ + config = BackendComparisonConfig( + rank_by=RankCriteria(rank_by), + sla_ttft_p99_ms=sla_ttft_p99_ms, + sla_tpot_p99_ms=sla_tpot_p99_ms, + sla_total_latency_p99_ms=sla_total_latency_p99_ms, + ) + fmt_enums = None + if formats: + fmt_enums = [BackendFormat(f) for f in formats] + + comparator = BackendComparator(config) + return comparator.compare(benchmark_paths, backend_labels, fmt_enums) diff --git a/src/xpyd_plan/cli/_compare_backends.py b/src/xpyd_plan/cli/_compare_backends.py new file mode 100644 index 0000000..e461735 --- /dev/null +++ b/src/xpyd_plan/cli/_compare_backends.py @@ -0,0 +1,186 @@ +"""CLI compare-backends command.""" + +from __future__ import annotations + +import argparse + +from rich.console import Console +from rich.table import Table + + +def _cmd_compare_backends(args: argparse.Namespace) -> None: + """Execute the compare-backends subcommand.""" + + from xpyd_plan.backend_compare import compare_backends + + console = Console() + + benchmarks = args.benchmarks + labels = args.labels.split(",") + + if len(benchmarks) != len(labels): + console.print( + f"[red]Error: {len(benchmarks)} benchmark file(s) but {len(labels)} label(s). " + "Counts must match.[/red]" + ) + raise SystemExit(1) + + formats_list = None + if args.formats: + formats_list = args.formats.split(",") + if len(formats_list) != len(benchmarks): + console.print( + f"[red]Error: {len(formats_list)} format(s) but {len(benchmarks)} benchmark(s). " + "Counts must match.[/red]" + ) + raise SystemExit(1) + + result = compare_backends( + benchmark_paths=benchmarks, + backend_labels=labels, + rank_by=args.rank_by, + formats=formats_list, + sla_ttft_p99_ms=getattr(args, "sla_ttft_p99", None), + sla_tpot_p99_ms=getattr(args, "sla_tpot_p99", None), + sla_total_latency_p99_ms=getattr(args, "sla_total_latency_p99", None), + ) + + if args.output_format == "json": + console.print_json(result.model_dump_json(indent=2)) + return + + # --- Table output --- + console.print("\n[bold]🔬 Multi-Backend Comparison Report[/bold]") + + # Metrics table + metrics_table = Table(title="Backend Metrics") + metrics_table.add_column("Backend", style="cyan") + metrics_table.add_column("Format", style="dim") + metrics_table.add_column("P:D Ratio", justify="center") + metrics_table.add_column("QPS", justify="right") + metrics_table.add_column("TTFT P99", justify="right") + metrics_table.add_column("TPOT P99", justify="right") + metrics_table.add_column("Total P99", justify="right") + metrics_table.add_column("Requests", justify="right") + + for m in result.metrics: + metrics_table.add_row( + m.backend_label, + m.format_detected, + f"{m.num_prefill_instances}:{m.num_decode_instances}", + f"{m.measured_qps:.1f}", + f"{m.ttft_p99_ms:.1f}", + f"{m.tpot_p99_ms:.1f}", + f"{m.total_latency_p99_ms:.1f}", + str(m.request_count), + ) + + console.print(metrics_table) + + # SLA results (if any SLA configured) + if any(not s.meets_all for s in result.sla_results) or any( + s.meets_all for s in result.sla_results + ): + has_sla = any( + getattr(args, a, None) is not None + for a in ("sla_ttft_p99", "sla_tpot_p99", "sla_total_latency_p99") + ) + if has_sla: + sla_table = Table(title="\nSLA Compliance") + sla_table.add_column("Backend", style="cyan") + sla_table.add_column("TTFT P99", justify="center") + sla_table.add_column("TPOT P99", justify="center") + sla_table.add_column("Total P99", justify="center") + sla_table.add_column("Overall", justify="center") + + for s in result.sla_results: + sla_table.add_row( + s.backend_label, + "✅" if s.ttft_p99_pass else "❌", + "✅" if s.tpot_p99_pass else "❌", + "✅" if s.total_latency_p99_pass else "❌", + "[green]PASS[/green]" if s.meets_all else "[red]FAIL[/red]", + ) + + console.print(sla_table) + + # Rankings + rank_table = Table(title=f"\nRankings (by {result.rank_criteria})") + rank_table.add_column("Rank", justify="center", style="bold") + rank_table.add_column("Backend", style="cyan") + rank_table.add_column("Score", justify="right") + rank_table.add_column("Recommendation") + + for r in result.rankings: + rank_table.add_row( + str(r.rank), + r.backend_label, + f"{r.score:.1f}", + r.recommendation, + ) + + console.print(rank_table) + + console.print(f"\n[bold green]🏆 Best backend: {result.best_backend}[/bold green]") + + +def register_compare_backends(subparsers: argparse._SubParsersAction) -> None: + """Register the compare-backends subcommand.""" + + p = subparsers.add_parser( + "compare-backends", + help="Compare benchmark results across serving backends", + ) + p.add_argument( + "--benchmark", + dest="benchmarks", + action="append", + required=True, + help="Path to benchmark JSON file (repeat for each backend)", + ) + p.add_argument( + "--labels", + required=True, + help="Comma-separated backend labels (e.g., 'vllm,sglang,trtllm')", + ) + p.add_argument( + "--formats", + default=None, + help="Comma-separated format hints (auto,native,vllm,sglang,trtllm)", + ) + p.add_argument( + "--rank-by", + dest="rank_by", + default="ttft_p99", + choices=["ttft_p99", "tpot_p99", "total_latency_p99", "throughput"], + help="Ranking criteria (default: ttft_p99)", + ) + p.add_argument( + "--sla-ttft-p99", + dest="sla_ttft_p99", + type=float, + default=None, + help="SLA threshold for TTFT P99 (ms)", + ) + p.add_argument( + "--sla-tpot-p99", + dest="sla_tpot_p99", + type=float, + default=None, + help="SLA threshold for TPOT P99 (ms)", + ) + p.add_argument( + "--sla-total-latency-p99", + dest="sla_total_latency_p99", + type=float, + default=None, + help="SLA threshold for total latency P99 (ms)", + ) + p.add_argument( + "--format", + dest="output_format", + choices=["table", "json"], + default="table", + help="Output format (default: table)", + ) + p.set_defaults(func=_cmd_compare_backends) diff --git a/src/xpyd_plan/cli/_main.py b/src/xpyd_plan/cli/_main.py index ae72e78..0033a3e 100644 --- a/src/xpyd_plan/cli/_main.py +++ b/src/xpyd_plan/cli/_main.py @@ -19,6 +19,7 @@ from xpyd_plan.cli._cdf import _cmd_cdf, add_cdf_parser from xpyd_plan.cli._cold_start import add_cold_start_parser from xpyd_plan.cli._compare import _cmd_compare +from xpyd_plan.cli._compare_backends import register_compare_backends from xpyd_plan.cli._concurrency_util import _cmd_concurrency_util, add_concurrency_util_parser from xpyd_plan.cli._confidence import _cmd_confidence from xpyd_plan.cli._config import _add_config_flag, _apply_config_defaults, _cmd_config @@ -965,6 +966,7 @@ def main(argv: list[str] | None = None) -> None: register_vllm_commands(subparsers) register_sglang_commands(subparsers) register_trtllm_commands(subparsers) + register_compare_backends(subparsers) add_rate_limit_parser(subparsers) add_batch_analysis_parser(subparsers) add_stat_summary_parser(subparsers) @@ -1314,6 +1316,10 @@ def main(argv: list[str] | None = None) -> None: from xpyd_plan.cli._trtllm_commands import _cmd_trtllm_commands _cmd_trtllm_commands(args) + elif args.command == "compare-backends": + from xpyd_plan.cli._compare_backends import _cmd_compare_backends + + _cmd_compare_backends(args) else: parser.print_help() sys.exit(1) diff --git a/tests/test_backend_compare.py b/tests/test_backend_compare.py new file mode 100644 index 0000000..7f06938 --- /dev/null +++ b/tests/test_backend_compare.py @@ -0,0 +1,436 @@ +"""Tests for backend_compare module.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from xpyd_plan.backend_compare import ( + BackendComparator, + BackendComparisonConfig, + BackendComparisonReport, + BackendFormat, + BackendMetrics, + RankCriteria, + _compute_metrics, + _detect_format, + _load_benchmark, + _rank_backends, + compare_backends, +) +from xpyd_plan.benchmark_models import BenchmarkData + + +def _make_benchmark( + num_requests: int = 20, + ttft_base: float = 50.0, + tpot_base: float = 10.0, + total_base: float = 200.0, + num_prefill: int = 2, + num_decode: int = 2, + total_instances: int = 4, + measured_qps: float = 10.0, +) -> dict: + """Create a native benchmark data dict.""" + requests = [] + for i in range(num_requests): + requests.append( + { + "request_id": f"req-{i}", + "prompt_tokens": 100 + i, + "output_tokens": 50 + i, + "ttft_ms": ttft_base + i * 2.0, + "tpot_ms": tpot_base + i * 0.5, + "total_latency_ms": total_base + i * 5.0, + "timestamp": 1000.0 + i * 0.1, + } + ) + return { + "metadata": { + "num_prefill_instances": num_prefill, + "num_decode_instances": num_decode, + "total_instances": total_instances, + "measured_qps": measured_qps, + }, + "requests": requests, + } + + +def _write_json(data: dict, path: Path) -> str: + """Write JSON to a temp file and return path string.""" + path.write_text(json.dumps(data)) + return str(path) + + +class TestDetectFormat: + """Tests for format auto-detection.""" + + def test_detect_native(self): + data = {"metadata": {}, "requests": []} + assert _detect_format(data) == BackendFormat.NATIVE + + def test_detect_vllm(self): + data = [{"request_latency": 1.0, "prompt_len": 10, "output_len": 5}] + assert _detect_format(data) == BackendFormat.VLLM + + def test_detect_sglang(self): + data = [ + {"latency": 1.0, "itl": [0.01, 0.02], "success": True, "ttft": 0.1, "prompt_len": 10} + ] + assert _detect_format(data) == BackendFormat.SGLANG + + def test_detect_trtllm(self): + data = [ + { + "input_tokens": 10, + "output_tokens": 5, + "first_token_latency": 50.0, + "inter_token_latencies": [10.0, 12.0], + "end_to_end_latency": 200.0, + } + ] + assert _detect_format(data) == BackendFormat.TRTLLM + + def test_detect_empty_list_native(self): + data = [] + assert _detect_format(data) == BackendFormat.NATIVE + + +class TestLoadBenchmark: + """Tests for loading benchmark data.""" + + def test_load_native(self, tmp_path): + data = _make_benchmark() + path = _write_json(data, tmp_path / "bench.json") + bd, fmt = _load_benchmark(path, BackendFormat.NATIVE) + assert fmt == "native" + assert len(bd.requests) == 20 + + def test_load_auto_native(self, tmp_path): + data = _make_benchmark() + path = _write_json(data, tmp_path / "bench.json") + bd, fmt = _load_benchmark(path) + assert fmt == "native" + + def test_load_invalid_format(self, tmp_path): + data = _make_benchmark() + path = _write_json(data, tmp_path / "bench.json") + with pytest.raises(ValueError, match="Unsupported format"): + _load_benchmark(path, "invalid_format") + + +class TestComputeMetrics: + """Tests for metrics computation.""" + + def test_basic_metrics(self): + data = _make_benchmark(num_requests=10, ttft_base=100.0, tpot_base=20.0) + bd = BenchmarkData.model_validate(data) + m = _compute_metrics(bd, "test-backend", "native") + assert m.backend_label == "test-backend" + assert m.format_detected == "native" + assert m.request_count == 10 + assert m.ttft_p50_ms > 0 + assert m.tpot_p50_ms > 0 + assert m.total_latency_p50_ms > 0 + assert m.throughput_rps == 10.0 + + def test_metrics_percentiles_ordered(self): + data = _make_benchmark(num_requests=100) + bd = BenchmarkData.model_validate(data) + m = _compute_metrics(bd, "backend", "native") + assert m.ttft_p50_ms <= m.ttft_p95_ms <= m.ttft_p99_ms + assert m.tpot_p50_ms <= m.tpot_p95_ms <= m.tpot_p99_ms + assert m.total_latency_p50_ms <= m.total_latency_p95_ms <= m.total_latency_p99_ms + + +class TestSLACheck: + """Tests for SLA compliance checking.""" + + def test_sla_all_pass(self, tmp_path): + data = _make_benchmark(ttft_base=10.0, tpot_base=5.0, total_base=50.0) + path1 = _write_json(data, tmp_path / "b1.json") + path2 = _write_json(data, tmp_path / "b2.json") + result = compare_backends( + [path1, path2], + ["a", "b"], + sla_ttft_p99_ms=999.0, + sla_tpot_p99_ms=999.0, + sla_total_latency_p99_ms=9999.0, + ) + for s in result.sla_results: + assert s.meets_all is True + + def test_sla_fail(self, tmp_path): + data = _make_benchmark(ttft_base=100.0) + path1 = _write_json(data, tmp_path / "b1.json") + path2 = _write_json(data, tmp_path / "b2.json") + result = compare_backends( + [path1, path2], + ["a", "b"], + sla_ttft_p99_ms=1.0, # impossibly low + ) + for s in result.sla_results: + assert s.ttft_p99_pass is False + assert s.meets_all is False + + def test_sla_no_thresholds_all_pass(self, tmp_path): + data = _make_benchmark() + path1 = _write_json(data, tmp_path / "b1.json") + path2 = _write_json(data, tmp_path / "b2.json") + result = compare_backends([path1, path2], ["a", "b"]) + for s in result.sla_results: + assert s.meets_all is True + + +class TestRankBackends: + """Tests for backend ranking.""" + + def test_rank_by_ttft(self): + m1 = BackendMetrics( + backend_label="fast", + format_detected="native", + num_prefill_instances=1, + num_decode_instances=1, + total_instances=2, + measured_qps=10.0, + request_count=100, + ttft_p50_ms=10.0, + ttft_p95_ms=20.0, + ttft_p99_ms=30.0, + tpot_p50_ms=5.0, + tpot_p95_ms=10.0, + tpot_p99_ms=15.0, + total_latency_p50_ms=100.0, + total_latency_p95_ms=200.0, + total_latency_p99_ms=300.0, + throughput_rps=10.0, + avg_prompt_tokens=100.0, + avg_output_tokens=50.0, + ) + m2 = m1.model_copy( + update={"backend_label": "slow", "ttft_p99_ms": 100.0} + ) + rankings = _rank_backends([m1, m2], RankCriteria.TTFT_P99) + assert rankings[0].backend_label == "fast" + assert rankings[1].backend_label == "slow" + assert rankings[0].rank == 1 + assert rankings[1].rank == 2 + + def test_rank_by_throughput(self): + m1 = BackendMetrics( + backend_label="high-tp", + format_detected="native", + num_prefill_instances=1, + num_decode_instances=1, + total_instances=2, + measured_qps=100.0, + request_count=100, + ttft_p50_ms=10.0, + ttft_p95_ms=20.0, + ttft_p99_ms=30.0, + tpot_p50_ms=5.0, + tpot_p95_ms=10.0, + tpot_p99_ms=15.0, + total_latency_p50_ms=100.0, + total_latency_p95_ms=200.0, + total_latency_p99_ms=300.0, + throughput_rps=100.0, + avg_prompt_tokens=100.0, + avg_output_tokens=50.0, + ) + m2 = m1.model_copy( + update={"backend_label": "low-tp", "throughput_rps": 10.0, "measured_qps": 10.0} + ) + rankings = _rank_backends([m1, m2], RankCriteria.THROUGHPUT) + assert rankings[0].backend_label == "high-tp" + + def test_rank_by_total_latency(self): + m1 = BackendMetrics( + backend_label="a", + format_detected="native", + num_prefill_instances=1, + num_decode_instances=1, + total_instances=2, + measured_qps=10.0, + request_count=50, + ttft_p50_ms=10.0, + ttft_p95_ms=20.0, + ttft_p99_ms=30.0, + tpot_p50_ms=5.0, + tpot_p95_ms=10.0, + tpot_p99_ms=15.0, + total_latency_p50_ms=100.0, + total_latency_p95_ms=200.0, + total_latency_p99_ms=300.0, + throughput_rps=10.0, + avg_prompt_tokens=100.0, + avg_output_tokens=50.0, + ) + m2 = m1.model_copy( + update={"backend_label": "b", "total_latency_p99_ms": 150.0} + ) + rankings = _rank_backends([m2, m1], RankCriteria.TOTAL_LATENCY_P99) + assert rankings[0].backend_label == "b" + + +class TestBackendComparator: + """Tests for BackendComparator class.""" + + def test_compare_two_backends(self, tmp_path): + fast = _make_benchmark(ttft_base=20.0, tpot_base=5.0, total_base=100.0) + slow = _make_benchmark(ttft_base=80.0, tpot_base=15.0, total_base=400.0) + p1 = _write_json(fast, tmp_path / "fast.json") + p2 = _write_json(slow, tmp_path / "slow.json") + + comparator = BackendComparator() + report = comparator.compare([p1, p2], ["fast-backend", "slow-backend"]) + + assert isinstance(report, BackendComparisonReport) + assert len(report.metrics) == 2 + assert len(report.rankings) == 2 + assert report.best_backend == "fast-backend" + + def test_compare_three_backends(self, tmp_path): + b1 = _make_benchmark(ttft_base=10.0) + b2 = _make_benchmark(ttft_base=50.0) + b3 = _make_benchmark(ttft_base=30.0) + p1 = _write_json(b1, tmp_path / "b1.json") + p2 = _write_json(b2, tmp_path / "b2.json") + p3 = _write_json(b3, tmp_path / "b3.json") + + report = compare_backends([p1, p2, p3], ["a", "b", "c"]) + assert len(report.metrics) == 3 + assert report.best_backend == "a" + + def test_compare_with_explicit_formats(self, tmp_path): + b1 = _make_benchmark() + b2 = _make_benchmark() + p1 = _write_json(b1, tmp_path / "b1.json") + p2 = _write_json(b2, tmp_path / "b2.json") + + report = compare_backends( + [p1, p2], ["a", "b"], formats=["native", "native"] + ) + assert len(report.metrics) == 2 + for m in report.metrics: + assert m.format_detected == "native" + + def test_mismatched_counts_error(self, tmp_path): + b = _make_benchmark() + p = _write_json(b, tmp_path / "b.json") + with pytest.raises(ValueError, match="must match"): + compare_backends([p], ["a", "b"]) + + def test_single_benchmark_error(self, tmp_path): + b = _make_benchmark() + p = _write_json(b, tmp_path / "b.json") + with pytest.raises(ValueError, match="At least 2"): + compare_backends([p], ["a"]) + + def test_mismatched_formats_error(self, tmp_path): + b1 = _make_benchmark() + b2 = _make_benchmark() + p1 = _write_json(b1, tmp_path / "b1.json") + p2 = _write_json(b2, tmp_path / "b2.json") + with pytest.raises(ValueError, match="must match"): + compare_backends([p1, p2], ["a", "b"], formats=["native"]) + + def test_rank_by_tpot(self, tmp_path): + fast = _make_benchmark(tpot_base=2.0) + slow = _make_benchmark(tpot_base=30.0) + p1 = _write_json(fast, tmp_path / "fast.json") + p2 = _write_json(slow, tmp_path / "slow.json") + + report = compare_backends([p1, p2], ["fast", "slow"], rank_by="tpot_p99") + assert report.best_backend == "fast" + assert report.rank_criteria == "tpot_p99" + + def test_rank_by_throughput(self, tmp_path): + high = _make_benchmark(measured_qps=100.0) + low = _make_benchmark(measured_qps=5.0) + p1 = _write_json(high, tmp_path / "high.json") + p2 = _write_json(low, tmp_path / "low.json") + + report = compare_backends([p1, p2], ["high", "low"], rank_by="throughput") + assert report.best_backend == "high" + + +class TestCompareBackendsAPI: + """Tests for the compare_backends() programmatic API.""" + + def test_basic_api(self, tmp_path): + b1 = _make_benchmark() + b2 = _make_benchmark(ttft_base=200.0) + p1 = _write_json(b1, tmp_path / "b1.json") + p2 = _write_json(b2, tmp_path / "b2.json") + + result = compare_backends([p1, p2], ["backend-a", "backend-b"]) + assert isinstance(result, BackendComparisonReport) + assert result.best_backend == "backend-a" + + def test_api_with_sla(self, tmp_path): + b1 = _make_benchmark(ttft_base=10.0) + b2 = _make_benchmark(ttft_base=500.0) + p1 = _write_json(b1, tmp_path / "b1.json") + p2 = _write_json(b2, tmp_path / "b2.json") + + result = compare_backends( + [p1, p2], + ["a", "b"], + sla_ttft_p99_ms=100.0, + ) + # First should pass, second should fail + assert result.sla_results[0].ttft_p99_pass is True + assert result.sla_results[1].ttft_p99_pass is False + + def test_report_serializable(self, tmp_path): + b1 = _make_benchmark() + b2 = _make_benchmark() + p1 = _write_json(b1, tmp_path / "b1.json") + p2 = _write_json(b2, tmp_path / "b2.json") + + result = compare_backends([p1, p2], ["a", "b"]) + j = result.model_dump_json() + parsed = json.loads(j) + assert "metrics" in parsed + assert "rankings" in parsed + assert "best_backend" in parsed + + +class TestBackendComparisonConfig: + """Tests for BackendComparisonConfig model.""" + + def test_defaults(self): + config = BackendComparisonConfig() + assert config.rank_by == RankCriteria.TTFT_P99 + assert config.sla_ttft_p99_ms is None + + def test_custom_config(self): + config = BackendComparisonConfig( + rank_by=RankCriteria.THROUGHPUT, + sla_ttft_p99_ms=100.0, + sla_tpot_p99_ms=50.0, + ) + assert config.rank_by == RankCriteria.THROUGHPUT + assert config.sla_ttft_p99_ms == 100.0 + + +class TestBackendFormat: + """Tests for BackendFormat enum.""" + + def test_values(self): + assert BackendFormat.AUTO == "auto" + assert BackendFormat.NATIVE == "native" + assert BackendFormat.VLLM == "vllm" + assert BackendFormat.SGLANG == "sglang" + assert BackendFormat.TRTLLM == "trtllm" + + +class TestRankCriteria: + """Tests for RankCriteria enum.""" + + def test_values(self): + assert RankCriteria.TTFT_P99 == "ttft_p99" + assert RankCriteria.THROUGHPUT == "throughput"