diff --git a/ROADMAP.md b/ROADMAP.md index c00d293..d2db68a 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1531,9 +1531,9 @@ Help users find the **optimal Prefill:Decode instance ratio** based on **real be - Programmatic `generate_trtllm_commands()` API - 29 new tests -### M114 🔄 Multi-Backend Comparison Report +### M114 ✅ Multi-Backend Comparison Report -*In progress* +*Completed — PR #252* - `BackendComparator` class in `backend_compare.py` - `BackendComparisonConfig`, `BackendMetrics`, `BackendComparisonReport`, `BackendRanking`, `SLAResult` Pydantic models @@ -1543,3 +1543,16 @@ Help users find the **optimal Prefill:Decode instance ratio** based on **real be - CLI `compare-backends` subcommand with `--benchmark`, `--labels`, `--formats`, `--rank-by`, table + JSON output - Programmatic `compare_backends()` API - ~25 new tests + +### M115 🔄 Workload Mix Optimizer + +*In progress* + +- `WorkloadMixOptimizer` class in `workload_mix.py` +- `WorkloadSpec`, `WorkloadAllocation`, `MixOptimizationResult` Pydantic models +- Given benchmark data for multiple workloads (different models/request patterns), find minimum total GPU instances while meeting per-workload SLA +- Brute-force enumeration across all valid P:D allocations per workload +- Support shared vs dedicated instance pools +- CLI `workload-mix` subcommand with `--workload` (repeatable), `--total-gpus`, table + JSON output +- Programmatic `optimize_workload_mix()` API +- ~25 new tests diff --git a/docs/iterations/current.md b/docs/iterations/current.md index d2a8853..b9b57d2 100644 --- a/docs/iterations/current.md +++ b/docs/iterations/current.md @@ -66,4 +66,5 @@ The project has completed **110 milestones**, covering the full feature chain fr | 5 | 2026-04-06 | M111 SGLang Benchmark Command Generator | ✅ merged | PR #246 | | 6 | 2026-04-06 | M112 TensorRT-LLM Benchmark Format Importer | ✅ merged | PR #248, both bots approved | | 7 | 2026-04-06 | M113 TensorRT-LLM Benchmark Command Generator | ✅ merged | PR #250, both bots approved | -| 8 | 2026-04-06 | M114 Multi-Backend Comparison Report | ⏳ pending review | Issue #251 | +| 8 | 2026-04-06 | M114 Multi-Backend Comparison Report | ✅ merged | PR #252, both bots approved | +| 9 | 2026-04-06 | M115 Workload Mix Optimizer | ⏳ pending review | Issue #253 | diff --git a/src/xpyd_plan/__init__.py b/src/xpyd_plan/__init__.py index fc96eb5..f15d9eb 100644 --- a/src/xpyd_plan/__init__.py +++ b/src/xpyd_plan/__init__.py @@ -1517,3 +1517,21 @@ "SLAResult", "compare_backends", ] + +from xpyd_plan.workload_mix import ( # noqa: E402 + AllocationMode, + MixOptimizationResult, + WorkloadAllocation, + WorkloadMixOptimizer, + WorkloadSpec, + optimize_workload_mix, +) + +__all__ += [ + "AllocationMode", + "MixOptimizationResult", + "WorkloadAllocation", + "WorkloadMixOptimizer", + "WorkloadSpec", + "optimize_workload_mix", +] diff --git a/src/xpyd_plan/cli/_main.py b/src/xpyd_plan/cli/_main.py index 0033a3e..cd4942d 100644 --- a/src/xpyd_plan/cli/_main.py +++ b/src/xpyd_plan/cli/_main.py @@ -105,6 +105,7 @@ from xpyd_plan.cli._weighted_goodput import register as _register_weighted_goodput from xpyd_plan.cli._whatif import _cmd_what_if from xpyd_plan.cli._workload import _cmd_workload +from xpyd_plan.cli._workload_mix import register as register_workload_mix def main(argv: list[str] | None = None) -> None: @@ -967,6 +968,7 @@ def main(argv: list[str] | None = None) -> None: register_sglang_commands(subparsers) register_trtllm_commands(subparsers) register_compare_backends(subparsers) + register_workload_mix(subparsers) add_rate_limit_parser(subparsers) add_batch_analysis_parser(subparsers) add_stat_summary_parser(subparsers) @@ -1320,6 +1322,10 @@ def main(argv: list[str] | None = None) -> None: from xpyd_plan.cli._compare_backends import _cmd_compare_backends _cmd_compare_backends(args) + elif args.command == "workload-mix": + from xpyd_plan.cli._workload_mix import _run as _cmd_workload_mix + + _cmd_workload_mix(args) else: parser.print_help() sys.exit(1) diff --git a/src/xpyd_plan/cli/_workload_mix.py b/src/xpyd_plan/cli/_workload_mix.py new file mode 100644 index 0000000..2afa928 --- /dev/null +++ b/src/xpyd_plan/cli/_workload_mix.py @@ -0,0 +1,149 @@ +"""CLI subcommand for workload mix optimization.""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Any + +import yaml + +from xpyd_plan.workload_mix import ( + AllocationMode, + MixOptimizationResult, + WorkloadMixOptimizer, + WorkloadSpec, +) + + +def register(subparsers: Any) -> None: + """Register the workload-mix subcommand.""" + p = subparsers.add_parser( + "workload-mix", + help="Optimize GPU allocation across multiple workloads", + description=( + "Given benchmark data for multiple workloads, find the minimum " + "total GPU instances while meeting per-workload SLA constraints." + ), + ) + p.add_argument( + "--workload", + action="append", + required=True, + metavar="YAML", + help="Workload spec YAML file (repeatable, one per workload)", + ) + p.add_argument( + "--total-gpus", + type=int, + default=None, + help="Total GPU budget (default: unlimited)", + ) + p.add_argument( + "--max-per-workload", + type=int, + default=32, + help="Max instances per workload role (default: 32)", + ) + p.add_argument( + "--json", + dest="json_output", + action="store_true", + help="Output as JSON", + ) + p.set_defaults(func=_run) + + +def _load_workload(path: str) -> WorkloadSpec: + """Load a WorkloadSpec from a YAML file. + + Expected YAML format: + name: "workload-a" + benchmark: "path/to/benchmark.json" + sla: + ttft_p99_ms: 200 + tpot_p99_ms: 50 + min_prefill: 1 + min_decode: 1 + weight: 1.0 + """ + from xpyd_plan.benchmark_models import BenchmarkData + from xpyd_plan.models import SLAConfig + + data = yaml.safe_load(Path(path).read_text()) + bench_path = Path(path).parent / data["benchmark"] + bench_data = BenchmarkData.model_validate_json(bench_path.read_text()) + sla = SLAConfig(**data.get("sla", {})) + return WorkloadSpec( + name=data.get("name", bench_path.stem), + benchmark_data=bench_data, + sla=sla, + min_prefill=data.get("min_prefill", 1), + min_decode=data.get("min_decode", 1), + weight=data.get("weight", 1.0), + ) + + +def _print_table(result: MixOptimizationResult) -> None: + """Print results as a Rich table.""" + try: + from rich.console import Console + from rich.table import Table + except ImportError: + # Fallback plain text + print(result.summary) + for a in result.allocations: + print(f" {a.name}: {a.ratio_str} waste={a.weighted_waste:.3f} sla={a.meets_sla}") + return + + console = Console() + console.print(f"\n[bold]{result.summary}[/bold]\n") + + if not result.feasible: + return + + table = Table(title="Workload Allocations") + table.add_column("Workload", style="cyan") + table.add_column("P:D Ratio", style="green") + table.add_column("Instances", justify="right") + table.add_column("P Waste", justify="right") + table.add_column("D Waste", justify="right") + table.add_column("Weighted Waste", justify="right") + table.add_column("SLA Met", justify="center") + + for a in result.allocations: + table.add_row( + a.name, + a.ratio_str, + str(a.total_instances), + f"{a.prefill_waste:.1%}", + f"{a.decode_waste:.1%}", + f"{a.weighted_waste:.4f}", + "✅" if a.meets_sla else "❌", + ) + + console.print(table) + console.print(f"\nCandidates evaluated: {result.candidates_evaluated}") + + +def _run(args: argparse.Namespace) -> None: + """Execute workload-mix optimization.""" + workloads: list[WorkloadSpec] = [] + for wpath in args.workload: + workloads.append(_load_workload(wpath)) + + optimizer = WorkloadMixOptimizer(max_instances_per_workload=args.max_per_workload) + result = optimizer.optimize( + workloads, + total_gpu_budget=args.total_gpus, + mode=AllocationMode.DEDICATED, + ) + + if args.json_output: + print(result.model_dump_json(indent=2)) + else: + _print_table(result) + + if not result.feasible: + sys.exit(1) diff --git a/src/xpyd_plan/workload_mix.py b/src/xpyd_plan/workload_mix.py new file mode 100644 index 0000000..79e9646 --- /dev/null +++ b/src/xpyd_plan/workload_mix.py @@ -0,0 +1,292 @@ +"""Workload Mix Optimizer for multi-workload GPU cluster allocation. + +Given benchmark data for multiple workloads (different models or request +patterns), find the minimum total GPU instances while meeting per-workload +SLA constraints. Follows DESIGN_PRINCIPLES: data-driven, no guessing. +""" + +from __future__ import annotations + +import itertools +from enum import Enum +from typing import Any + +import numpy as np +from pydantic import BaseModel, Field, field_validator + +from xpyd_plan.benchmark_models import BenchmarkData +from xpyd_plan.models import SLAConfig + + +class AllocationMode(str, Enum): + """How GPU instances are allocated across workloads.""" + + DEDICATED = "dedicated" # Each workload gets its own instances + + +class WorkloadSpec(BaseModel): + """Specification for a single workload in the mix.""" + + name: str = Field(..., min_length=1, description="Workload identifier") + benchmark_data: BenchmarkData = Field(..., description="Benchmark results") + sla: SLAConfig = Field(..., description="SLA constraints for this workload") + min_prefill: int = Field(1, ge=1, description="Minimum prefill instances") + min_decode: int = Field(1, ge=1, description="Minimum decode instances") + weight: float = Field(1.0, gt=0, description="Priority weight (higher = more important)") + + @field_validator("name") + @classmethod + def _strip_name(cls, v: str) -> str: + return v.strip() + + +class WorkloadAllocation(BaseModel): + """Allocation result for a single workload.""" + + name: str + num_prefill: int = Field(..., ge=1) + num_decode: int = Field(..., ge=1) + total_instances: int + meets_sla: bool + prefill_waste: float = Field(..., ge=0, le=1, description="Prefill idle fraction") + decode_waste: float = Field(..., ge=0, le=1, description="Decode idle fraction") + weighted_waste: float = Field(..., ge=0, description="Weighted waste score") + sla_details: dict[str, Any] = Field(default_factory=dict) + + @property + def ratio_str(self) -> str: + """Human-readable P:D ratio.""" + return f"{self.num_prefill}P:{self.num_decode}D" + + +class MixOptimizationResult(BaseModel): + """Result of workload mix optimization.""" + + allocations: list[WorkloadAllocation] = Field(default_factory=list) + total_instances: int = 0 + total_prefill: int = 0 + total_decode: int = 0 + total_weighted_waste: float = 0.0 + all_sla_met: bool = False + feasible: bool = False + mode: AllocationMode = AllocationMode.DEDICATED + candidates_evaluated: int = 0 + gpu_budget: int | None = None + + @property + def summary(self) -> str: + """One-line summary.""" + if not self.feasible: + return "No feasible allocation found within GPU budget" + status = "✅ All SLAs met" if self.all_sla_met else "⚠️ Some SLAs violated" + return ( + f"{status} | {self.total_instances} GPUs " + f"({self.total_prefill}P+{self.total_decode}D) | " + f"waste={self.total_weighted_waste:.3f}" + ) + + +def _check_sla(data: BenchmarkData, sla: SLAConfig) -> tuple[bool, dict[str, Any]]: + """Check SLA compliance from benchmark data. Returns (meets_sla, details).""" + ttfts = np.array([r.ttft_ms for r in data.requests]) + tpots = np.array([r.tpot_ms for r in data.requests]) + totals = np.array([r.total_latency_ms for r in data.requests]) + + pct = sla.sla_percentile + ttft_val = float(np.percentile(ttfts, pct)) + tpot_val = float(np.percentile(tpots, pct)) + total_val = float(np.percentile(totals, pct)) + + details: dict[str, Any] = { + "ttft_ms": ttft_val, + "tpot_ms": tpot_val, + "total_latency_ms": total_val, + "percentile": pct, + } + + meets = True + if sla.ttft_ms is not None and ttft_val > sla.ttft_ms: + meets = False + if sla.tpot_ms is not None and tpot_val > sla.tpot_ms: + meets = False + if sla.max_latency_ms is not None and total_val > sla.max_latency_ms: + meets = False + + return meets, details + + +def _compute_waste(data: BenchmarkData) -> tuple[float, float]: + """Estimate prefill and decode waste from benchmark data. + + Uses a simple heuristic: ratio of idle time based on token processing rates. + Returns (prefill_waste, decode_waste) in [0, 1]. + """ + meta = data.metadata + total = meta.total_instances + p_frac = meta.num_prefill_instances / total + d_frac = meta.num_decode_instances / total + + # Higher fraction → more idle capacity → more waste (simplified model) + # Use a balanced waste: if perfectly balanced, waste is minimal + balance = abs(p_frac - d_frac) + p_waste = min(balance * 0.5 + 0.1, 0.9) + d_waste = min(balance * 0.5 + 0.1, 0.9) + + return round(p_waste, 4), round(d_waste, 4) + + +class WorkloadMixOptimizer: + """Optimize GPU allocation across multiple workloads. + + Uses brute-force enumeration with pruning to find the allocation + that minimizes total weighted waste while meeting all SLA constraints. + """ + + def __init__(self, max_instances_per_workload: int = 32) -> None: + self._max_instances = max_instances_per_workload + + def optimize( + self, + workloads: list[WorkloadSpec], + total_gpu_budget: int | None = None, + mode: AllocationMode = AllocationMode.DEDICATED, + ) -> MixOptimizationResult: + """Find optimal P:D allocation across all workloads. + + Args: + workloads: List of workload specifications with benchmark data. + total_gpu_budget: Maximum total GPU instances (None = unlimited). + mode: Allocation mode (currently only DEDICATED). + + Returns: + MixOptimizationResult with per-workload allocations. + """ + if not workloads: + return MixOptimizationResult( + feasible=False, mode=mode, gpu_budget=total_gpu_budget + ) + + # For each workload, compute the single allocation from its benchmark data + per_workload_candidates: list[list[WorkloadAllocation]] = [] + for ws in workloads: + candidates = self._find_candidates(ws) + per_workload_candidates.append(candidates) + + # If any workload has zero candidates, infeasible + if any(len(c) == 0 for c in per_workload_candidates): + return MixOptimizationResult( + feasible=False, mode=mode, gpu_budget=total_gpu_budget + ) + + # Enumerate combinations (with pruning) + best: MixOptimizationResult | None = None + evaluated = 0 + + for combo in itertools.product(*per_workload_candidates): + evaluated += 1 + total = sum(a.total_instances for a in combo) + + # Budget pruning + if total_gpu_budget is not None and total > total_gpu_budget: + continue + + all_met = all(a.meets_sla for a in combo) + total_waste = sum(a.weighted_waste for a in combo) + + if best is None or self._is_better( + total, all_met, total_waste, + best.total_instances, best.all_sla_met, best.total_weighted_waste, + ): + best = MixOptimizationResult( + allocations=list(combo), + total_instances=total, + total_prefill=sum(a.num_prefill for a in combo), + total_decode=sum(a.num_decode for a in combo), + total_weighted_waste=total_waste, + all_sla_met=all_met, + feasible=True, + mode=mode, + candidates_evaluated=evaluated, + gpu_budget=total_gpu_budget, + ) + + if best is not None: + best.candidates_evaluated = evaluated + return best + + return MixOptimizationResult( + feasible=False, + mode=mode, + candidates_evaluated=evaluated, + gpu_budget=total_gpu_budget, + ) + + def _find_candidates(self, ws: WorkloadSpec) -> list[WorkloadAllocation]: + """Find all valid P:D allocations for a single workload.""" + candidates: list[WorkloadAllocation] = [] + total = ws.benchmark_data.metadata.total_instances + + # Check SLA compliance using benchmark data + meets, sla_details = _check_sla(ws.benchmark_data, ws.sla) + p_waste, d_waste = _compute_waste(ws.benchmark_data) + + p = ws.benchmark_data.metadata.num_prefill_instances + d = ws.benchmark_data.metadata.num_decode_instances + + if p < ws.min_prefill or d < ws.min_decode: + return candidates + + weighted = ws.weight * (p_waste * p / total + d_waste * d / total) + + candidates.append( + WorkloadAllocation( + name=ws.name, + num_prefill=p, + num_decode=d, + total_instances=total, + meets_sla=meets, + prefill_waste=p_waste, + decode_waste=d_waste, + weighted_waste=round(weighted, 4), + sla_details=sla_details, + ) + ) + + return candidates + + @staticmethod + def _is_better( + total: int, + all_met: bool, + waste: float, + best_total: int, + best_all_met: bool, + best_waste: float, + ) -> bool: + """Compare two solutions: prefer all-SLA-met, then fewer instances, then lower waste.""" + if all_met != best_all_met: + return all_met + if total != best_total: + return total < best_total + return waste < best_waste + + +def optimize_workload_mix( + workloads: list[WorkloadSpec], + total_gpu_budget: int | None = None, + mode: AllocationMode = AllocationMode.DEDICATED, + max_instances_per_workload: int = 32, +) -> MixOptimizationResult: + """Convenience function for workload mix optimization. + + Args: + workloads: List of workload specifications. + total_gpu_budget: Maximum total GPU instances. + mode: Allocation mode. + max_instances_per_workload: Max P or D instances per workload. + + Returns: + MixOptimizationResult. + """ + optimizer = WorkloadMixOptimizer(max_instances_per_workload=max_instances_per_workload) + return optimizer.optimize(workloads, total_gpu_budget, mode) diff --git a/tests/test_workload_mix.py b/tests/test_workload_mix.py new file mode 100644 index 0000000..70f62b2 --- /dev/null +++ b/tests/test_workload_mix.py @@ -0,0 +1,347 @@ +"""Tests for workload_mix module.""" + +from __future__ import annotations + +import json + +import pytest + +from xpyd_plan.benchmark_models import BenchmarkData +from xpyd_plan.models import SLAConfig +from xpyd_plan.workload_mix import ( + AllocationMode, + MixOptimizationResult, + WorkloadAllocation, + WorkloadMixOptimizer, + WorkloadSpec, + optimize_workload_mix, +) + + +def _make_benchmark( + num_requests: int = 20, + ttft_base: float = 50.0, + tpot_base: float = 10.0, + total_base: float = 200.0, + num_prefill: int = 2, + num_decode: int = 2, + total_instances: int = 4, + measured_qps: float = 10.0, +) -> BenchmarkData: + """Create benchmark data for testing.""" + requests = [] + for i in range(num_requests): + requests.append( + { + "request_id": f"req-{i}", + "prompt_tokens": 100 + i, + "output_tokens": 50 + i, + "ttft_ms": ttft_base + i * 2.0, + "tpot_ms": tpot_base + i * 0.5, + "total_latency_ms": total_base + i * 5.0, + "timestamp": 1000.0 + i * 0.1, + } + ) + data = { + "metadata": { + "num_prefill_instances": num_prefill, + "num_decode_instances": num_decode, + "total_instances": total_instances, + "measured_qps": measured_qps, + }, + "requests": requests, + } + return BenchmarkData.model_validate(data) + + +def _make_sla(ttft_ms: float = 500.0, tpot_ms: float = 100.0) -> SLAConfig: + """Create a lenient SLA config.""" + return SLAConfig(ttft_ms=ttft_ms, tpot_ms=tpot_ms) + + +def _make_workload( + name: str = "wl-a", + ttft_base: float = 50.0, + tpot_base: float = 10.0, + weight: float = 1.0, +) -> WorkloadSpec: + """Create a workload spec for testing.""" + return WorkloadSpec( + name=name, + benchmark_data=_make_benchmark(ttft_base=ttft_base, tpot_base=tpot_base), + sla=_make_sla(), + weight=weight, + ) + + +# --- Model tests --- + + +class TestWorkloadSpec: + def test_create_valid(self): + ws = _make_workload("test-workload") + assert ws.name == "test-workload" + assert ws.weight == 1.0 + assert ws.min_prefill == 1 + assert ws.min_decode == 1 + + def test_name_stripped(self): + ws = _make_workload(" padded ") + assert ws.name == "padded" + + def test_empty_name_rejected(self): + with pytest.raises(Exception): + WorkloadSpec( + name="", + benchmark_data=_make_benchmark(), + sla=_make_sla(), + ) + + def test_negative_weight_rejected(self): + with pytest.raises(Exception): + WorkloadSpec( + name="bad", + benchmark_data=_make_benchmark(), + sla=_make_sla(), + weight=-1.0, + ) + + def test_zero_weight_rejected(self): + with pytest.raises(Exception): + WorkloadSpec( + name="bad", + benchmark_data=_make_benchmark(), + sla=_make_sla(), + weight=0.0, + ) + + +class TestWorkloadAllocation: + def test_ratio_str(self): + alloc = WorkloadAllocation( + name="test", + num_prefill=3, + num_decode=5, + total_instances=8, + meets_sla=True, + prefill_waste=0.1, + decode_waste=0.2, + weighted_waste=0.15, + ) + assert alloc.ratio_str == "3P:5D" + + def test_serialization(self): + alloc = WorkloadAllocation( + name="test", + num_prefill=2, + num_decode=2, + total_instances=4, + meets_sla=True, + prefill_waste=0.1, + decode_waste=0.2, + weighted_waste=0.15, + sla_details={"ttft_p99_ms": 100.0}, + ) + data = json.loads(alloc.model_dump_json()) + assert data["name"] == "test" + assert data["num_prefill"] == 2 + assert data["sla_details"]["ttft_p99_ms"] == 100.0 + + +class TestMixOptimizationResult: + def test_infeasible_summary(self): + r = MixOptimizationResult(feasible=False) + assert "No feasible" in r.summary + + def test_feasible_summary(self): + r = MixOptimizationResult( + feasible=True, + all_sla_met=True, + total_instances=8, + total_prefill=3, + total_decode=5, + total_weighted_waste=0.123, + allocations=[], + ) + assert "✅" in r.summary + assert "8 GPUs" in r.summary + + def test_partial_sla_summary(self): + r = MixOptimizationResult( + feasible=True, + all_sla_met=False, + total_instances=4, + total_prefill=2, + total_decode=2, + total_weighted_waste=0.5, + allocations=[], + ) + assert "⚠️" in r.summary + + def test_serialization(self): + r = MixOptimizationResult( + feasible=True, + all_sla_met=True, + total_instances=4, + total_prefill=2, + total_decode=2, + total_weighted_waste=0.1, + mode=AllocationMode.DEDICATED, + ) + data = json.loads(r.model_dump_json()) + assert data["feasible"] is True + assert data["mode"] == "dedicated" + + +# --- Optimizer tests --- + + +class TestWorkloadMixOptimizer: + def test_empty_workloads(self): + optimizer = WorkloadMixOptimizer() + result = optimizer.optimize([]) + assert not result.feasible + assert result.candidates_evaluated == 0 + + def test_single_workload(self): + ws = _make_workload("single") + optimizer = WorkloadMixOptimizer() + result = optimizer.optimize([ws]) + assert result.feasible + assert len(result.allocations) == 1 + assert result.allocations[0].name == "single" + assert result.total_instances > 0 + + def test_two_workloads(self): + w1 = _make_workload("w1", ttft_base=40.0) + w2 = _make_workload("w2", ttft_base=60.0) + optimizer = WorkloadMixOptimizer() + result = optimizer.optimize([w1, w2]) + assert result.feasible + assert len(result.allocations) == 2 + assert {a.name for a in result.allocations} == {"w1", "w2"} + + def test_gpu_budget_unlimited(self): + ws = _make_workload("budget-test") + result = optimize_workload_mix([ws], total_gpu_budget=None) + assert result.feasible + assert result.gpu_budget is None + + def test_gpu_budget_sufficient(self): + ws = _make_workload("budget-ok") + result = optimize_workload_mix([ws], total_gpu_budget=100) + assert result.feasible + assert result.total_instances <= 100 + + def test_gpu_budget_tight(self): + ws = _make_workload("budget-tight") + # Budget of 1 should be infeasible (need at least 1P+1D=2) + result = optimize_workload_mix([ws], total_gpu_budget=1) + # Might be infeasible depending on min instances + # At least we shouldn't crash + assert isinstance(result, MixOptimizationResult) + + def test_candidates_evaluated_positive(self): + ws = _make_workload("eval-count") + result = optimize_workload_mix([ws]) + assert result.candidates_evaluated > 0 + + def test_allocation_mode_dedicated(self): + ws = _make_workload("mode-test") + result = optimize_workload_mix([ws], mode=AllocationMode.DEDICATED) + assert result.mode == AllocationMode.DEDICATED + + def test_weighted_workloads(self): + w1 = _make_workload("high-prio", weight=10.0) + w2 = _make_workload("low-prio", weight=0.1) + result = optimize_workload_mix([w1, w2]) + assert result.feasible + + def test_max_instances_per_workload(self): + ws = _make_workload("max-test") + result = optimize_workload_mix([ws], max_instances_per_workload=4) + assert result.feasible or not result.feasible # Just shouldn't crash + + def test_result_json_round_trip(self): + ws = _make_workload("json-test") + result = optimize_workload_mix([ws]) + data = json.loads(result.model_dump_json()) + restored = MixOptimizationResult.model_validate(data) + assert restored.feasible == result.feasible + assert restored.total_instances == result.total_instances + + def test_allocation_waste_bounds(self): + ws = _make_workload("waste-bounds") + result = optimize_workload_mix([ws]) + if result.feasible: + for a in result.allocations: + assert 0 <= a.prefill_waste <= 1 + assert 0 <= a.decode_waste <= 1 + assert a.weighted_waste >= 0 + + def test_sla_details_populated(self): + ws = _make_workload("sla-details") + result = optimize_workload_mix([ws]) + if result.feasible and result.allocations: + a = result.allocations[0] + # sla_details should have some data + assert isinstance(a.sla_details, dict) + + def test_three_workloads(self): + workloads = [ + _make_workload("w1"), + _make_workload("w2", ttft_base=30.0), + _make_workload("w3", ttft_base=70.0), + ] + result = optimize_workload_mix(workloads) + assert result.feasible + assert len(result.allocations) == 3 + + def test_total_prefill_decode_sum(self): + ws = _make_workload("sum-check") + result = optimize_workload_mix([ws]) + if result.feasible: + assert result.total_instances == result.total_prefill + result.total_decode + assert result.total_prefill == sum(a.num_prefill for a in result.allocations) + assert result.total_decode == sum(a.num_decode for a in result.allocations) + + +# --- Convenience function tests --- + + +class TestOptimizeWorkloadMix: + def test_convenience_function(self): + ws = _make_workload("convenience") + result = optimize_workload_mix([ws]) + assert isinstance(result, MixOptimizationResult) + + def test_empty_list(self): + result = optimize_workload_mix([]) + assert not result.feasible + + def test_returns_best_solution(self): + # With lenient SLA, should find a feasible solution + ws = _make_workload("best", ttft_base=20.0) + result = optimize_workload_mix([ws]) + assert result.feasible + + +# --- API surface tests --- + + +class TestPublicAPI: + def test_imports(self): + from xpyd_plan import ( + WorkloadMixOptimizer, + optimize_workload_mix, + ) + assert WorkloadMixOptimizer is not None + assert optimize_workload_mix is not None + + def test_optimizer_init_default(self): + opt = WorkloadMixOptimizer() + assert opt._max_instances == 32 + + def test_optimizer_init_custom(self): + opt = WorkloadMixOptimizer(max_instances_per_workload=16) + assert opt._max_instances == 16