diff --git a/ROADMAP.md b/ROADMAP.md index cf2e90b..d220f59 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1597,7 +1597,7 @@ Help users find the **optimal Prefill:Decode instance ratio** based on **real be - Programmatic `assess_sla_risk()` API - ~22 new tests -### M119 🔄 Deployment Readiness Report +### M119 ✅ Deployment Readiness Report *In progress* @@ -1610,3 +1610,16 @@ Help users find the **optimal Prefill:Decode instance ratio** based on **real be - CLI `readiness` subcommand with `--benchmark`, `--sla-ttft`, `--sla-tpot`, `--sla-total`, `--cost-model`, table + JSON output - Programmatic `assess_readiness()` API - ~24 new tests + +### M120 🔄 Benchmark Dataset Catalog + +- `DatasetCatalog` class in `catalog.py` +- `CatalogEntry`, `CatalogQuery`, `CatalogReport` Pydantic models +- SQLite-backed local catalog indexing benchmark files with extracted metadata +- Metadata extracted on `add`: GPU type, model name, P:D ratio, QPS, request count, total instances, file path, file hash (SHA-256), date added +- Duplicate detection via file hash +- Query API: filter by GPU type, QPS range, P:D ratio, date range, model name +- CLI `catalog` subcommand: `catalog add`, `catalog list`, `catalog search`, `catalog remove`, `catalog show` +- Auto-add from directory scan (integrate with `discover`) +- Programmatic `manage_catalog()` API +- ~22 new tests diff --git a/docs/iterations/current.md b/docs/iterations/current.md index fbc8b4d..51d8fb4 100644 --- a/docs/iterations/current.md +++ b/docs/iterations/current.md @@ -71,4 +71,5 @@ The project has completed **110 milestones**, covering the full feature chain fr | 10 | 2026-04-06 | M116 GPU Hour Calculator | ✅ merged | PR #256, both bots approved | | 11 | 2026-04-06 | M117 Benchmark Quality Gate | ✅ merged | PR #258 | | 12 | 2026-04-06 | M118 SLA Risk Score | ✅ merged | PR #260, both bots approved | -| 13 | 2026-04-06 | M119 Deployment Readiness Report | ⏳ pending review | PR TBD | +| 13 | 2026-04-06 | M119 Deployment Readiness Report | ✅ merged | PR #262, both bots approved | +| 14 | 2026-04-06 | M120 Benchmark Dataset Catalog | ⏳ pending review | PR TBD | diff --git a/src/xpyd_plan/__init__.py b/src/xpyd_plan/__init__.py index 4950045..ec50e7e 100644 --- a/src/xpyd_plan/__init__.py +++ b/src/xpyd_plan/__init__.py @@ -1591,3 +1591,19 @@ "SLARiskScorer", "assess_sla_risk", ] + +from xpyd_plan.catalog import ( # noqa: E402 + CatalogEntry, + CatalogQuery, + CatalogReport, + DatasetCatalog, + manage_catalog, +) + +__all__ += [ + "CatalogEntry", + "CatalogQuery", + "CatalogReport", + "DatasetCatalog", + "manage_catalog", +] diff --git a/src/xpyd_plan/catalog.py b/src/xpyd_plan/catalog.py new file mode 100644 index 0000000..0fcaa9a --- /dev/null +++ b/src/xpyd_plan/catalog.py @@ -0,0 +1,315 @@ +"""Benchmark Dataset Catalog — SQLite-backed local index for benchmark files.""" + +from __future__ import annotations + +import hashlib +import json +import sqlite3 +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel, Field + + +class CatalogEntry(BaseModel): + """A single indexed benchmark file.""" + + id: int = 0 + file_path: str + file_hash: str + gpu_type: str = "" + model_name: str = "" + prefill_instances: int = 0 + decode_instances: int = 0 + total_instances: int = 0 + pd_ratio: str = "" + measured_qps: float = 0.0 + request_count: int = 0 + date_added: str = "" + notes: str = "" + + +class CatalogQuery(BaseModel): + """Query filters for catalog search.""" + + gpu_type: Optional[str] = None + model_name: Optional[str] = None + min_qps: Optional[float] = None + max_qps: Optional[float] = None + pd_ratio: Optional[str] = None + min_instances: Optional[int] = None + max_instances: Optional[int] = None + added_after: Optional[str] = None + added_before: Optional[str] = None + + +class CatalogReport(BaseModel): + """Result of a catalog operation.""" + + entries: list[CatalogEntry] = Field(default_factory=list) + total_count: int = 0 + message: str = "" + + +class DatasetCatalog: + """SQLite-backed catalog for indexing benchmark files.""" + + def __init__(self, db_path: str = "~/.xpyd-plan/catalog.db") -> None: + self._db_path = Path(db_path).expanduser() + self._db_path.parent.mkdir(parents=True, exist_ok=True) + self._conn = sqlite3.connect(str(self._db_path)) + self._conn.row_factory = sqlite3.Row + self._create_tables() + + def _create_tables(self) -> None: + self._conn.execute(""" + CREATE TABLE IF NOT EXISTS catalog ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_path TEXT NOT NULL, + file_hash TEXT NOT NULL UNIQUE, + gpu_type TEXT DEFAULT '', + model_name TEXT DEFAULT '', + prefill_instances INTEGER DEFAULT 0, + decode_instances INTEGER DEFAULT 0, + total_instances INTEGER DEFAULT 0, + pd_ratio TEXT DEFAULT '', + measured_qps REAL DEFAULT 0.0, + request_count INTEGER DEFAULT 0, + date_added TEXT NOT NULL, + notes TEXT DEFAULT '' + ) + """) + # Indexes for common queries + self._conn.execute( + "CREATE INDEX IF NOT EXISTS idx_gpu_type ON catalog(gpu_type)" + ) + self._conn.execute( + "CREATE INDEX IF NOT EXISTS idx_pd_ratio ON catalog(pd_ratio)" + ) + self._conn.execute( + "CREATE INDEX IF NOT EXISTS idx_measured_qps ON catalog(measured_qps)" + ) + self._conn.execute( + "CREATE INDEX IF NOT EXISTS idx_date_added ON catalog(date_added)" + ) + self._conn.commit() + + @staticmethod + def _file_hash(path: Path) -> str: + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + h.update(chunk) + return h.hexdigest() + + @staticmethod + def _extract_metadata(path: Path) -> dict: + """Extract metadata from a benchmark JSON file.""" + with open(path) as f: + data = json.load(f) + + config = data.get("config", data.get("cluster_config", {})) + metadata = data.get("metadata", {}) + requests = data.get("requests", []) + + prefill = config.get("num_prefill_instances", 0) + decode = config.get("num_decode_instances", 0) + total = config.get("total_instances", prefill + decode) + + if prefill > 0 and decode > 0: + pd_ratio = f"{prefill}:{decode}" + else: + pd_ratio = "" + + return { + "gpu_type": metadata.get("gpu_type", config.get("gpu_type", "")), + "model_name": metadata.get("model_name", config.get("model_name", "")), + "prefill_instances": prefill, + "decode_instances": decode, + "total_instances": total, + "pd_ratio": pd_ratio, + "measured_qps": float(data.get("measured_qps", 0.0)), + "request_count": len(requests), + } + + def add(self, file_path: str, notes: str = "") -> CatalogEntry: + """Add a benchmark file to the catalog. Raises ValueError on duplicate.""" + path = Path(file_path).resolve() + if not path.exists(): + msg = f"File not found: {path}" + raise FileNotFoundError(msg) + + fhash = self._file_hash(path) + + # Check duplicate + row = self._conn.execute( + "SELECT id FROM catalog WHERE file_hash = ?", (fhash,) + ).fetchone() + if row: + msg = f"Duplicate file (hash={fhash[:12]}...) already in catalog as id={row['id']}" + raise ValueError(msg) + + meta = self._extract_metadata(path) + now = datetime.now(timezone.utc).isoformat() + + cursor = self._conn.execute( + """INSERT INTO catalog + (file_path, file_hash, gpu_type, model_name, + prefill_instances, decode_instances, total_instances, + pd_ratio, measured_qps, request_count, date_added, notes) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + str(path), + fhash, + meta["gpu_type"], + meta["model_name"], + meta["prefill_instances"], + meta["decode_instances"], + meta["total_instances"], + meta["pd_ratio"], + meta["measured_qps"], + meta["request_count"], + now, + notes, + ), + ) + self._conn.commit() + + return CatalogEntry( + id=cursor.lastrowid or 0, + file_path=str(path), + file_hash=fhash, + date_added=now, + notes=notes, + **meta, + ) + + def remove(self, entry_id: int) -> bool: + """Remove an entry by ID. Returns True if removed.""" + cursor = self._conn.execute("DELETE FROM catalog WHERE id = ?", (entry_id,)) + self._conn.commit() + return cursor.rowcount > 0 + + def get(self, entry_id: int) -> Optional[CatalogEntry]: + """Get a single entry by ID.""" + row = self._conn.execute( + "SELECT * FROM catalog WHERE id = ?", (entry_id,) + ).fetchone() + if not row: + return None + return self._row_to_entry(row) + + def list_all(self) -> CatalogReport: + """List all entries.""" + rows = self._conn.execute( + "SELECT * FROM catalog ORDER BY date_added DESC" + ).fetchall() + entries = [self._row_to_entry(r) for r in rows] + return CatalogReport( + entries=entries, total_count=len(entries), message="All catalog entries" + ) + + def search(self, query: CatalogQuery) -> CatalogReport: + """Search with filters.""" + conditions: list[str] = [] + params: list = [] + + if query.gpu_type: + conditions.append("gpu_type = ?") + params.append(query.gpu_type) + if query.model_name: + conditions.append("model_name = ?") + params.append(query.model_name) + if query.min_qps is not None: + conditions.append("measured_qps >= ?") + params.append(query.min_qps) + if query.max_qps is not None: + conditions.append("measured_qps <= ?") + params.append(query.max_qps) + if query.pd_ratio: + conditions.append("pd_ratio = ?") + params.append(query.pd_ratio) + if query.min_instances is not None: + conditions.append("total_instances >= ?") + params.append(query.min_instances) + if query.max_instances is not None: + conditions.append("total_instances <= ?") + params.append(query.max_instances) + if query.added_after: + conditions.append("date_added >= ?") + params.append(query.added_after) + if query.added_before: + conditions.append("date_added <= ?") + params.append(query.added_before) + + where = " AND ".join(conditions) if conditions else "1=1" + sql = f"SELECT * FROM catalog WHERE {where} ORDER BY date_added DESC" # noqa: S608 + rows = self._conn.execute(sql, params).fetchall() + entries = [self._row_to_entry(r) for r in rows] + return CatalogReport( + entries=entries, + total_count=len(entries), + message=f"Found {len(entries)} matching entries", + ) + + def close(self) -> None: + self._conn.close() + + @staticmethod + def _row_to_entry(row: sqlite3.Row) -> CatalogEntry: + return CatalogEntry( + id=row["id"], + file_path=row["file_path"], + file_hash=row["file_hash"], + gpu_type=row["gpu_type"], + model_name=row["model_name"], + prefill_instances=row["prefill_instances"], + decode_instances=row["decode_instances"], + total_instances=row["total_instances"], + pd_ratio=row["pd_ratio"], + measured_qps=row["measured_qps"], + request_count=row["request_count"], + date_added=row["date_added"], + notes=row["notes"], + ) + + +def manage_catalog( + action: str, + db_path: str = "~/.xpyd-plan/catalog.db", + file_path: str = "", + entry_id: int = 0, + query: Optional[CatalogQuery] = None, + notes: str = "", +) -> CatalogReport: + """Programmatic API for catalog management.""" + catalog = DatasetCatalog(db_path=db_path) + try: + if action == "add": + entry = catalog.add(file_path, notes=notes) + return CatalogReport( + entries=[entry], total_count=1, message="Added to catalog" + ) + elif action == "list": + return catalog.list_all() + elif action == "search": + return catalog.search(query or CatalogQuery()) + elif action == "remove": + removed = catalog.remove(entry_id) + msg = f"Removed entry {entry_id}" if removed else f"Entry {entry_id} not found" + return CatalogReport(entries=[], total_count=0, message=msg) + elif action == "show": + entry = catalog.get(entry_id) + if entry: + return CatalogReport( + entries=[entry], total_count=1, message=f"Entry {entry_id}" + ) + return CatalogReport( + entries=[], total_count=0, message=f"Entry {entry_id} not found" + ) + else: + msg = f"Unknown action: {action}" + raise ValueError(msg) + finally: + catalog.close() diff --git a/src/xpyd_plan/cli/_catalog.py b/src/xpyd_plan/cli/_catalog.py new file mode 100644 index 0000000..5f7698e --- /dev/null +++ b/src/xpyd_plan/cli/_catalog.py @@ -0,0 +1,172 @@ +"""CLI catalog command.""" + +from __future__ import annotations + +import argparse +import json +import sys + +from rich.console import Console +from rich.table import Table + +from xpyd_plan.catalog import CatalogQuery, DatasetCatalog + + +def _cmd_catalog(args: argparse.Namespace) -> None: + """Handle the 'catalog' subcommand.""" + console = Console() + db_path = getattr(args, "db_path", "~/.xpyd-plan/catalog.db") + output_format = getattr(args, "output_format", "table") + catalog = DatasetCatalog(db_path=db_path) + + try: + sub = args.catalog_action + + if sub == "add": + try: + entry = catalog.add(args.file, notes=getattr(args, "notes", "")) + except (FileNotFoundError, ValueError) as e: + console.print(f"[red]Error:[/red] {e}") + sys.exit(1) + if output_format == "json": + json.dump(entry.model_dump(), sys.stdout, indent=2) + sys.stdout.write("\n") + else: + console.print( + f"[green]Added[/green] {entry.file_path} " + f"(id={entry.id}, {entry.request_count} requests, " + f"P:D={entry.pd_ratio or 'N/A'}, QPS={entry.measured_qps:.1f})" + ) + + elif sub == "list": + report = catalog.list_all() + if output_format == "json": + json.dump(report.model_dump(), sys.stdout, indent=2) + sys.stdout.write("\n") + else: + title = f"Catalog ({report.total_count} entries)" + _print_entries_table(console, report.entries, title) + + elif sub == "search": + query = CatalogQuery( + gpu_type=getattr(args, "gpu_type", None), + model_name=getattr(args, "model_name", None), + min_qps=getattr(args, "min_qps", None), + max_qps=getattr(args, "max_qps", None), + pd_ratio=getattr(args, "pd_ratio", None), + min_instances=getattr(args, "min_instances", None), + max_instances=getattr(args, "max_instances", None), + ) + report = catalog.search(query) + if output_format == "json": + json.dump(report.model_dump(), sys.stdout, indent=2) + sys.stdout.write("\n") + else: + title = f"Search Results ({report.total_count})" + _print_entries_table(console, report.entries, title) + + elif sub == "show": + entry = catalog.get(args.id) + if not entry: + console.print(f"[red]Entry {args.id} not found[/red]") + sys.exit(1) + if output_format == "json": + json.dump(entry.model_dump(), sys.stdout, indent=2) + sys.stdout.write("\n") + else: + console.print(f"\n[bold]Entry {entry.id}[/bold]") + console.print(f" File: {entry.file_path}") + console.print(f" Hash: {entry.file_hash[:16]}...") + console.print(f" GPU: {entry.gpu_type or 'N/A'}") + console.print(f" Model: {entry.model_name or 'N/A'}") + console.print(f" P:D Ratio: {entry.pd_ratio or 'N/A'}") + p = entry.prefill_instances + d = entry.decode_instances + console.print( + f" Instances: {entry.total_instances}" + f" (P={p}, D={d})" + ) + console.print(f" QPS: {entry.measured_qps:.1f}") + console.print(f" Requests: {entry.request_count}") + console.print(f" Added: {entry.date_added}") + if entry.notes: + console.print(f" Notes: {entry.notes}") + + elif sub == "remove": + removed = catalog.remove(args.id) + if removed: + console.print(f"[green]Removed[/green] entry {args.id}") + else: + console.print(f"[red]Entry {args.id} not found[/red]") + sys.exit(1) + + else: + console.print(f"[red]Unknown catalog action: {sub}[/red]") + sys.exit(1) + + finally: + catalog.close() + + +def _print_entries_table(console: Console, entries: list, title: str) -> None: + """Print entries as a Rich table.""" + table = Table(title=title) + table.add_column("ID", justify="right") + table.add_column("P:D", justify="center") + table.add_column("QPS", justify="right") + table.add_column("Requests", justify="right") + table.add_column("Instances", justify="right") + table.add_column("GPU", justify="left") + table.add_column("File", justify="left", max_width=40) + + for e in entries: + table.add_row( + str(e.id), + e.pd_ratio or "N/A", + f"{e.measured_qps:.1f}", + str(e.request_count), + str(e.total_instances), + e.gpu_type or "N/A", + e.file_path.split("/")[-1] if "/" in e.file_path else e.file_path, + ) + + console.print(table) + + +def register(subparsers: argparse._SubParsersAction) -> None: + """Register catalog subcommand.""" + parser = subparsers.add_parser("catalog", help="Manage benchmark dataset catalog") + parser.add_argument( + "--db-path", + default="~/.xpyd-plan/catalog.db", + help="Catalog database path", + ) + parser.add_argument("--output-format", choices=["table", "json"], default="table") + + sub = parser.add_subparsers(dest="catalog_action") + + # add + add_p = sub.add_parser("add", help="Add a benchmark file to catalog") + add_p.add_argument("file", help="Path to benchmark JSON file") + add_p.add_argument("--notes", default="", help="Optional notes") + + # list + sub.add_parser("list", help="List all catalog entries") + + # search + search_p = sub.add_parser("search", help="Search catalog") + search_p.add_argument("--gpu-type", help="Filter by GPU type") + search_p.add_argument("--model-name", help="Filter by model name") + search_p.add_argument("--min-qps", type=float, help="Minimum QPS") + search_p.add_argument("--max-qps", type=float, help="Maximum QPS") + search_p.add_argument("--pd-ratio", help="Filter by P:D ratio (e.g. 2:6)") + search_p.add_argument("--min-instances", type=int, help="Minimum total instances") + search_p.add_argument("--max-instances", type=int, help="Maximum total instances") + + # show + show_p = sub.add_parser("show", help="Show details for a catalog entry") + show_p.add_argument("id", type=int, help="Entry ID") + + # remove + rm_p = sub.add_parser("remove", help="Remove a catalog entry") + rm_p.add_argument("id", type=int, help="Entry ID") diff --git a/src/xpyd_plan/cli/_main.py b/src/xpyd_plan/cli/_main.py index 4f72a75..26ed8a7 100644 --- a/src/xpyd_plan/cli/_main.py +++ b/src/xpyd_plan/cli/_main.py @@ -16,6 +16,7 @@ from xpyd_plan.cli._budget import _cmd_budget from xpyd_plan.cli._budget_tracker import _cmd_budget_tracker from xpyd_plan.cli._capacity import _cmd_plan_capacity +from xpyd_plan.cli._catalog import register as register_catalog from xpyd_plan.cli._cdf import _cmd_cdf, add_cdf_parser from xpyd_plan.cli._cold_start import add_cold_start_parser from xpyd_plan.cli._compare import _cmd_compare @@ -976,6 +977,7 @@ def main(argv: list[str] | None = None) -> None: register_quality_gate(subparsers) register_sla_risk(subparsers) register_readiness(subparsers) + register_catalog(subparsers) register_workload_mix(subparsers) add_rate_limit_parser(subparsers) add_batch_analysis_parser(subparsers) @@ -1350,6 +1352,10 @@ def main(argv: list[str] | None = None) -> None: from xpyd_plan.cli._readiness import _cmd_readiness _cmd_readiness(args) + elif args.command == "catalog": + from xpyd_plan.cli._catalog import _cmd_catalog + + _cmd_catalog(args) else: parser.print_help() sys.exit(1) diff --git a/tests/test_catalog.py b/tests/test_catalog.py new file mode 100644 index 0000000..9c1b818 --- /dev/null +++ b/tests/test_catalog.py @@ -0,0 +1,246 @@ +"""Tests for catalog module.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from xpyd_plan.catalog import CatalogQuery, DatasetCatalog, manage_catalog + + +def _make_benchmark(tmp_path: Path, name: str = "bench.json", **overrides) -> Path: + """Create a minimal benchmark JSON file.""" + data = { + "config": { + "num_prefill_instances": overrides.get("prefill", 2), + "num_decode_instances": overrides.get("decode", 6), + "total_instances": overrides.get("total", 8), + }, + "metadata": { + "gpu_type": overrides.get("gpu_type", "H100-80G"), + "model_name": overrides.get("model_name", "llama-70b"), + }, + "measured_qps": overrides.get("qps", 10.5), + "requests": [ + { + "request_id": f"r{i}", + "prompt_tokens": 100, + "output_tokens": 50, + "ttft_ms": 20.0, + "tpot_ms": 10.0, + "total_latency_ms": 520.0, + "timestamp": 1000.0 + i, + } + for i in range(overrides.get("n_requests", 5)) + ], + } + path = tmp_path / name + path.write_text(json.dumps(data)) + return path + + +@pytest.fixture() +def catalog(tmp_path): + db_path = str(tmp_path / "test_catalog.db") + cat = DatasetCatalog(db_path=db_path) + yield cat + cat.close() + + +class TestDatasetCatalog: + def test_add_and_get(self, catalog, tmp_path): + bench = _make_benchmark(tmp_path) + entry = catalog.add(str(bench)) + assert entry.id > 0 + assert entry.gpu_type == "H100-80G" + assert entry.model_name == "llama-70b" + assert entry.prefill_instances == 2 + assert entry.decode_instances == 6 + assert entry.pd_ratio == "2:6" + assert entry.measured_qps == 10.5 + assert entry.request_count == 5 + + got = catalog.get(entry.id) + assert got is not None + assert got.file_hash == entry.file_hash + + def test_duplicate_detection(self, catalog, tmp_path): + bench = _make_benchmark(tmp_path) + catalog.add(str(bench)) + with pytest.raises(ValueError, match="Duplicate"): + catalog.add(str(bench)) + + def test_remove(self, catalog, tmp_path): + bench = _make_benchmark(tmp_path) + entry = catalog.add(str(bench)) + assert catalog.remove(entry.id) is True + assert catalog.get(entry.id) is None + + def test_remove_nonexistent(self, catalog): + assert catalog.remove(999) is False + + def test_list_all(self, catalog, tmp_path): + b1 = _make_benchmark(tmp_path, "b1.json", qps=5.0) + b2 = _make_benchmark(tmp_path, "b2.json", qps=15.0) + catalog.add(str(b1)) + catalog.add(str(b2)) + report = catalog.list_all() + assert report.total_count == 2 + assert len(report.entries) == 2 + + def test_search_by_gpu(self, catalog, tmp_path): + b1 = _make_benchmark(tmp_path, "b1.json", gpu_type="H100-80G") + b2 = _make_benchmark(tmp_path, "b2.json", gpu_type="A100-80G") + catalog.add(str(b1)) + catalog.add(str(b2)) + report = catalog.search(CatalogQuery(gpu_type="H100-80G")) + assert report.total_count == 1 + assert report.entries[0].gpu_type == "H100-80G" + + def test_search_by_qps_range(self, catalog, tmp_path): + b1 = _make_benchmark(tmp_path, "b1.json", qps=5.0) + b2 = _make_benchmark(tmp_path, "b2.json", qps=15.0) + b3 = _make_benchmark(tmp_path, "b3.json", qps=25.0) + catalog.add(str(b1)) + catalog.add(str(b2)) + catalog.add(str(b3)) + report = catalog.search(CatalogQuery(min_qps=10.0, max_qps=20.0)) + assert report.total_count == 1 + assert report.entries[0].measured_qps == 15.0 + + def test_search_by_pd_ratio(self, catalog, tmp_path): + b1 = _make_benchmark(tmp_path, "b1.json", prefill=2, decode=6) + b2 = _make_benchmark(tmp_path, "b2.json", prefill=4, decode=4) + catalog.add(str(b1)) + catalog.add(str(b2)) + report = catalog.search(CatalogQuery(pd_ratio="4:4")) + assert report.total_count == 1 + + def test_search_by_model(self, catalog, tmp_path): + b1 = _make_benchmark(tmp_path, "b1.json", model_name="llama-70b") + b2 = _make_benchmark(tmp_path, "b2.json", model_name="mistral-7b") + catalog.add(str(b1)) + catalog.add(str(b2)) + report = catalog.search(CatalogQuery(model_name="mistral-7b")) + assert report.total_count == 1 + + def test_search_by_instances(self, catalog, tmp_path): + b1 = _make_benchmark(tmp_path, "b1.json", total=4, prefill=1, decode=3) + b2 = _make_benchmark(tmp_path, "b2.json", total=8, prefill=2, decode=6) + b3 = _make_benchmark(tmp_path, "b3.json", total=16, prefill=4, decode=12) + catalog.add(str(b1)) + catalog.add(str(b2)) + catalog.add(str(b3)) + report = catalog.search(CatalogQuery(min_instances=6, max_instances=10)) + assert report.total_count == 1 + assert report.entries[0].total_instances == 8 + + def test_search_empty_result(self, catalog, tmp_path): + bench = _make_benchmark(tmp_path) + catalog.add(str(bench)) + report = catalog.search(CatalogQuery(gpu_type="NONEXISTENT")) + assert report.total_count == 0 + + def test_search_no_filters(self, catalog, tmp_path): + b1 = _make_benchmark(tmp_path, "b1.json") + b2 = _make_benchmark(tmp_path, "b2.json", qps=20.0) + catalog.add(str(b1)) + catalog.add(str(b2)) + report = catalog.search(CatalogQuery()) + assert report.total_count == 2 + + def test_file_not_found(self, catalog): + with pytest.raises(FileNotFoundError): + catalog.add("/nonexistent/file.json") + + def test_get_nonexistent(self, catalog): + assert catalog.get(999) is None + + def test_add_with_notes(self, catalog, tmp_path): + bench = _make_benchmark(tmp_path) + entry = catalog.add(str(bench), notes="baseline run") + assert entry.notes == "baseline run" + got = catalog.get(entry.id) + assert got.notes == "baseline run" + + def test_no_pd_ratio_when_zero_instances(self, catalog, tmp_path): + bench = _make_benchmark(tmp_path, prefill=0, decode=0) + entry = catalog.add(str(bench)) + assert entry.pd_ratio == "" + + def test_metadata_from_cluster_config(self, tmp_path): + """Test extraction from cluster_config key (alternative format).""" + data = { + "cluster_config": { + "num_prefill_instances": 3, + "num_decode_instances": 5, + "total_instances": 8, + }, + "measured_qps": 8.0, + "requests": [ + { + "request_id": "r0", + "prompt_tokens": 100, + "output_tokens": 50, + "ttft_ms": 20, + "tpot_ms": 10, + "total_latency_ms": 500, + "timestamp": 1000, + } + ], + } + path = tmp_path / "alt.json" + path.write_text(json.dumps(data)) + + db_path = str(tmp_path / "cat.db") + cat = DatasetCatalog(db_path=db_path) + entry = cat.add(str(path)) + assert entry.pd_ratio == "3:5" + assert entry.request_count == 1 + cat.close() + + +class TestManageCatalogAPI: + def test_add_action(self, tmp_path): + bench = _make_benchmark(tmp_path) + db = str(tmp_path / "api.db") + report = manage_catalog("add", db_path=db, file_path=str(bench)) + assert report.total_count == 1 + assert report.entries[0].gpu_type == "H100-80G" + + def test_list_action(self, tmp_path): + bench = _make_benchmark(tmp_path) + db = str(tmp_path / "api.db") + manage_catalog("add", db_path=db, file_path=str(bench)) + report = manage_catalog("list", db_path=db) + assert report.total_count == 1 + + def test_search_action(self, tmp_path): + bench = _make_benchmark(tmp_path) + db = str(tmp_path / "api.db") + manage_catalog("add", db_path=db, file_path=str(bench)) + report = manage_catalog("search", db_path=db, query=CatalogQuery(gpu_type="H100-80G")) + assert report.total_count == 1 + + def test_show_action(self, tmp_path): + bench = _make_benchmark(tmp_path) + db = str(tmp_path / "api.db") + add_report = manage_catalog("add", db_path=db, file_path=str(bench)) + eid = add_report.entries[0].id + report = manage_catalog("show", db_path=db, entry_id=eid) + assert report.total_count == 1 + + def test_remove_action(self, tmp_path): + bench = _make_benchmark(tmp_path) + db = str(tmp_path / "api.db") + add_report = manage_catalog("add", db_path=db, file_path=str(bench)) + eid = add_report.entries[0].id + report = manage_catalog("remove", db_path=db, entry_id=eid) + assert "Removed" in report.message + + def test_unknown_action(self, tmp_path): + db = str(tmp_path / "api.db") + with pytest.raises(ValueError, match="Unknown action"): + manage_catalog("invalid", db_path=db)