diff --git a/ROADMAP.md b/ROADMAP.md index 337dbf1..898ef4f 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1505,9 +1505,9 @@ Help users find the **optimal Prefill:Decode instance ratio** based on **real be - Programmatic `generate_sglang_commands()` API - 25 new tests -### M112 🔄 TensorRT-LLM Benchmark Format Importer +### M112 ✅ TensorRT-LLM Benchmark Format Importer -*In progress — PR #TBD* +*Completed — PR #248* - `TRTLLMImporter` module in `trtllm_import.py` - `TRTLLMRequest`, `TRTLLMBenchmarkData`, `TRTLLMImportConfig`, `TRTLLMImportResult` Pydantic models @@ -1517,3 +1517,16 @@ Help users find the **optimal Prefill:Decode instance ratio** based on **real be - CLI `import --format trtllm` support - Programmatic `import_trtllm()` and `import_trtllm_data()` API - 25+ new tests + +### M113 🔄 TensorRT-LLM Benchmark Command Generator + +*In progress* + +- `TRTLLMCommandGenerator` class in `trtllm_commands.py` +- `TRTLLMCommandConfig`, `TRTLLMServerCommand`, `TRTLLMBenchmarkCommand`, `TRTLLMCommandSet` Pydantic models +- Generate TRT-LLM engine build (`trtllm-build`) and server launch commands for each P:D ratio +- TRT-LLM specific options: max_batch_size, kv_cache_free_gpu_mem_fraction, pp_size, dtype, engine_dir +- Shell script output with engine build + server + benchmark lifecycle +- CLI `trtllm-commands` subcommand with table + JSON output +- Programmatic `generate_trtllm_commands()` API +- 29 new tests diff --git a/docs/iterations/current.md b/docs/iterations/current.md index a948599..8eccdd2 100644 --- a/docs/iterations/current.md +++ b/docs/iterations/current.md @@ -53,7 +53,7 @@ The project has completed **110 milestones**, covering the full feature chain fr - Closed-loop integration with xPyD-proxy auto-tuning - Web UI dashboard (replacing TUI) - Richer visualizations (interactive charts) -- Support additional benchmark tool formats (TensorRT-LLM) +- Support additional benchmark tool formats (TensorRT-LLM) ✅ ## Iteration History @@ -64,4 +64,5 @@ The project has completed **110 milestones**, covering the full feature chain fr | 3 | 2026-04-06 | M109 vLLM Benchmark Command Generator | ✅ merged | PR #242 | | 4 | 2026-04-06 | M110 SGLang Benchmark Format Importer | ✅ merged | PR #244 | | 5 | 2026-04-06 | M111 SGLang Benchmark Command Generator | ✅ merged | PR #246 | -| 6 | 2026-04-06 | M112 TensorRT-LLM Benchmark Format Importer | ⏳ pending review | Issue #247 | +| 6 | 2026-04-06 | M112 TensorRT-LLM Benchmark Format Importer | ✅ merged | PR #248, both bots approved | +| 7 | 2026-04-06 | M113 TensorRT-LLM Benchmark Command Generator | ⏳ pending review | Issue #249 | diff --git a/src/xpyd_plan/__init__.py b/src/xpyd_plan/__init__.py index b70ab1e..c0f8a6d 100644 --- a/src/xpyd_plan/__init__.py +++ b/src/xpyd_plan/__init__.py @@ -1412,6 +1412,10 @@ from xpyd_plan.sglang_commands import SGLangCommandGenerator # noqa: E402 from xpyd_plan.sglang_commands import SGLangCommandSet # noqa: E402 from xpyd_plan.sglang_commands import generate_sglang_commands # noqa: E402 +from xpyd_plan.trtllm_commands import TRTLLMCommandConfig # noqa: E402, I001 +from xpyd_plan.trtllm_commands import TRTLLMCommandGenerator # noqa: E402 +from xpyd_plan.trtllm_commands import TRTLLMCommandSet # noqa: E402 +from xpyd_plan.trtllm_commands import generate_trtllm_commands # noqa: E402 from xpyd_plan.vllm_commands import BenchmarkCommand # noqa: E402 from xpyd_plan.vllm_commands import CommandGenerator # noqa: E402 from xpyd_plan.vllm_commands import CommandSet # noqa: E402 @@ -1430,6 +1434,10 @@ "SGLangCommandConfig", "SGLangCommandGenerator", "SGLangCommandSet", + "generate_trtllm_commands", + "TRTLLMCommandConfig", + "TRTLLMCommandGenerator", + "TRTLLMCommandSet", ] from xpyd_plan.vllm_import import ( # noqa: E402 diff --git a/src/xpyd_plan/cli/_main.py b/src/xpyd_plan/cli/_main.py index 758fec3..ae72e78 100644 --- a/src/xpyd_plan/cli/_main.py +++ b/src/xpyd_plan/cli/_main.py @@ -96,6 +96,7 @@ from xpyd_plan.cli._token_budget import add_token_budget_parser from xpyd_plan.cli._token_efficiency import add_token_efficiency_parser, handle_token_efficiency from xpyd_plan.cli._trend import _cmd_trend +from xpyd_plan.cli._trtllm_commands import register_trtllm_commands from xpyd_plan.cli._validate import _cmd_validate from xpyd_plan.cli._variance import _cmd_variance, add_variance_parser from xpyd_plan.cli._vllm_commands import register_vllm_commands @@ -963,6 +964,7 @@ def main(argv: list[str] | None = None) -> None: add_import_parser(subparsers) register_vllm_commands(subparsers) register_sglang_commands(subparsers) + register_trtllm_commands(subparsers) add_rate_limit_parser(subparsers) add_batch_analysis_parser(subparsers) add_stat_summary_parser(subparsers) @@ -1308,6 +1310,10 @@ def main(argv: list[str] | None = None) -> None: from xpyd_plan.cli._sglang_commands import _cmd_sglang_commands _cmd_sglang_commands(args) + elif args.command == "trtllm-commands": + from xpyd_plan.cli._trtllm_commands import _cmd_trtllm_commands + + _cmd_trtllm_commands(args) else: parser.print_help() sys.exit(1) diff --git a/src/xpyd_plan/cli/_trtllm_commands.py b/src/xpyd_plan/cli/_trtllm_commands.py new file mode 100644 index 0000000..0fd3967 --- /dev/null +++ b/src/xpyd_plan/cli/_trtllm_commands.py @@ -0,0 +1,162 @@ +"""CLI trtllm-commands subcommand.""" + +from __future__ import annotations + +import argparse +import json + +from rich.console import Console +from rich.table import Table + +from xpyd_plan.trtllm_commands import TRTLLMCommandConfig, TRTLLMCommandGenerator + + +def _cmd_trtllm_commands(args: argparse.Namespace) -> None: + """Handle the 'trtllm-commands' subcommand.""" + console = Console() + + qps_levels = [float(q) for q in args.qps.split(",")] + + config = TRTLLMCommandConfig( + model=args.model, + total_instances=args.total_instances, + qps_levels=qps_levels, + tp_size=getattr(args, "tp_size", 1), + pp_size=getattr(args, "pp_size", 1), + max_batch_size=getattr(args, "max_batch_size", 256), + max_input_len=getattr(args, "max_input_len", 2048), + max_output_len=getattr(args, "max_output_len", 2048), + kv_cache_free_gpu_mem_fraction=getattr( + args, "kv_cache_free_gpu_mem_fraction", 0.9 + ), + dtype=getattr(args, "dtype", "float16"), + dataset=getattr(args, "dataset", None), + num_prompts=getattr(args, "num_prompts", 1000), + host=getattr(args, "host", "localhost"), + port=getattr(args, "port", 8000), + engine_dir=getattr(args, "engine_dir", "./engines"), + ) + + gen = TRTLLMCommandGenerator(config) + result = gen.generate() + + if args.output_script: + script = _to_shell_script(result, config) + with open(args.output_script, "w") as f: + f.write(script) + console.print( + f"[green]Shell script written to {args.output_script} " + f"({len(result)} ratios)[/green]" + ) + return + + output_format = getattr(args, "output_format", "table") + + if output_format == "json": + print(json.dumps([cs.model_dump() for cs in result], indent=2, default=str)) + return + + # Table output + console.print("\n[bold]TensorRT-LLM Benchmark Commands[/bold]") + console.print( + f"Model: {config.model} | Instances: {config.total_instances} | " + f"Ratios: {len(result)}\n" + ) + + table = Table(title="Benchmark Runs") + table.add_column("P:D Ratio") + table.add_column("Prefill", justify="right") + table.add_column("Decode", justify="right") + table.add_column("QPS Levels") + + for cs in result: + table.add_row( + cs.server.ratio, + str(cs.server.prefill_instances), + str(cs.server.decode_instances), + ", ".join(f"{b.qps}" for b in cs.benchmarks), + ) + + console.print(table) + console.print( + "\n[dim]Use --output-script to generate an executable shell script[/dim]" + ) + + +def _to_shell_script( + command_sets: list, + config: TRTLLMCommandConfig, +) -> str: + """Build a complete shell script from command sets.""" + lines = [ + "#!/usr/bin/env bash", + "set -euo pipefail", + f"# TensorRT-LLM Benchmark Script — {config.model}", + f"# Total instances: {config.total_instances}", + f"# QPS levels: {', '.join(str(q) for q in config.qps_levels)}", + "", + ] + for cs in command_sets: + lines.append(cs.script_snippet) + lines.append("echo 'All benchmarks complete!'") + return "\n".join(lines) + "\n" + + +def register_trtllm_commands(subparsers: argparse._SubParsersAction) -> None: + """Register the trtllm-commands subcommand.""" + p = subparsers.add_parser( + "trtllm-commands", + help="Generate TensorRT-LLM benchmark commands for P:D ratio exploration", + ) + p.add_argument("--model", type=str, required=True, help="HuggingFace model name") + p.add_argument( + "--total-instances", + type=int, + required=True, + help="Total instances (prefill + decode)", + ) + p.add_argument( + "--qps", + type=str, + required=True, + help="Comma-separated QPS levels (e.g. 1,2,4)", + ) + p.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size") + p.add_argument("--pp-size", type=int, default=1, help="Pipeline parallel size") + p.add_argument("--max-batch-size", type=int, default=256, help="Max batch size") + p.add_argument( + "--max-input-len", type=int, default=2048, help="Max input length" + ) + p.add_argument( + "--max-output-len", type=int, default=2048, help="Max output length" + ) + p.add_argument( + "--kv-cache-free-gpu-mem-fraction", + type=float, + default=0.9, + help="KV cache GPU memory fraction", + ) + p.add_argument( + "--dtype", + type=str, + default="float16", + choices=["float16", "bfloat16", "float32"], + help="Data type", + ) + p.add_argument("--dataset", type=str, default=None, help="Dataset path") + p.add_argument("--num-prompts", type=int, default=1000, help="Prompts per run") + p.add_argument("--host", type=str, default="localhost", help="Server host") + p.add_argument("--port", type=int, default=8000, help="Server port") + p.add_argument( + "--engine-dir", type=str, default="./engines", help="Engine output directory" + ) + p.add_argument( + "--output-script", type=str, default=None, help="Write shell script" + ) + p.add_argument( + "--output-format", + choices=["table", "json"], + default="table", + help="Output format", + ) + p.set_defaults(func=_cmd_trtllm_commands) diff --git a/src/xpyd_plan/trtllm_commands.py b/src/xpyd_plan/trtllm_commands.py new file mode 100644 index 0000000..70d9924 --- /dev/null +++ b/src/xpyd_plan/trtllm_commands.py @@ -0,0 +1,209 @@ +"""TensorRT-LLM Benchmark Command Generator — generate ready-to-run TRT-LLM commands. + +Given total instances, model name, and QPS levels, generate TensorRT-LLM +engine build and trtllm-bench commands for each planned P:D ratio configuration. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class TRTLLMCommandConfig(BaseModel): + """Configuration for TensorRT-LLM command generation.""" + + model: str = Field(..., min_length=1, description="HuggingFace model name") + total_instances: int = Field(..., ge=2, description="Total instances (P+D)") + qps_levels: list[float] = Field( + ..., min_length=1, description="QPS levels to benchmark" + ) + tp_size: int = Field(1, ge=1, description="Tensor parallel size") + pp_size: int = Field(1, ge=1, description="Pipeline parallel size") + max_batch_size: int = Field(256, ge=1, description="Max batch size") + max_input_len: int = Field(2048, ge=1, description="Max input sequence length") + max_output_len: int = Field(2048, ge=1, description="Max output sequence length") + kv_cache_free_gpu_mem_fraction: float = Field( + 0.9, gt=0.0, le=1.0, description="KV cache GPU memory fraction" + ) + dtype: str = Field("float16", description="Data type (float16, bfloat16, float32)") + dataset: str | None = Field(None, description="Dataset path") + num_prompts: int = Field(1000, ge=1, description="Number of prompts per run") + host: str = Field("localhost", description="Server host") + port: int = Field(8000, ge=1, le=65535, description="Server port") + engine_dir: str = Field("./engines", description="Engine output directory") + + +class TRTLLMServerCommand(BaseModel): + """A TensorRT-LLM server launch command for one P:D ratio.""" + + ratio: str = Field(..., description="e.g. '2P:3D'") + prefill_instances: int = Field(..., ge=1) + decode_instances: int = Field(..., ge=1) + engine_build_command: str = Field(..., description="Engine build command") + server_command: str = Field(..., description="Server launch command") + + +class TRTLLMBenchmarkCommand(BaseModel): + """A trtllm-bench invocation command.""" + + ratio: str = Field(..., description="P:D ratio string") + qps: float = Field(..., gt=0) + command: str = Field(..., description="Shell command") + + +class TRTLLMCommandSet(BaseModel): + """Complete command set for one P:D ratio.""" + + server: TRTLLMServerCommand + benchmarks: list[TRTLLMBenchmarkCommand] + script_snippet: str = Field("", description="Combined shell script snippet") + + +class TRTLLMCommandGenerator: + """Generate TensorRT-LLM engine build, server, and benchmark commands.""" + + def __init__(self, config: TRTLLMCommandConfig) -> None: + self._config = config + + def generate(self) -> list[TRTLLMCommandSet]: + """Generate command sets for all valid P:D ratios.""" + total = self._config.total_instances + results: list[TRTLLMCommandSet] = [] + + for p in range(1, total): + d = total - p + if d < 1: + continue + ratio_str = f"{p}P:{d}D" + server_cmd = self._build_server_command(p, d, ratio_str) + bench_cmds = [ + self._build_benchmark_command(ratio_str, qps) + for qps in self._config.qps_levels + ] + snippet = self._build_script_snippet(server_cmd, bench_cmds) + results.append( + TRTLLMCommandSet( + server=server_cmd, + benchmarks=bench_cmds, + script_snippet=snippet, + ) + ) + + return results + + def _build_server_command( + self, prefill: int, decode: int, ratio: str + ) -> TRTLLMServerCommand: + cfg = self._config + engine_path = f"{cfg.engine_dir}/{ratio.replace(':', '_')}" + + build_parts = [ + "trtllm-build", + f"--model_dir {cfg.model}", + f"--output_dir {engine_path}", + f"--tp_size {cfg.tp_size}", + f"--pp_size {cfg.pp_size}", + f"--max_batch_size {cfg.max_batch_size}", + f"--max_input_len {cfg.max_input_len}", + f"--max_output_len {cfg.max_output_len}", + f"--dtype {cfg.dtype}", + ] + + server_parts = [ + "python3 -m tensorrt_llm.serve", + f"--engine_dir {engine_path}", + f"--host {cfg.host}", + f"--port {cfg.port}", + f"--kv_cache_free_gpu_mem_fraction {cfg.kv_cache_free_gpu_mem_fraction}", + ] + + return TRTLLMServerCommand( + ratio=ratio, + prefill_instances=prefill, + decode_instances=decode, + engine_build_command=" \\\n ".join(build_parts), + server_command=" \\\n ".join(server_parts), + ) + + def _build_benchmark_command( + self, ratio: str, qps: float + ) -> TRTLLMBenchmarkCommand: + cfg = self._config + parts = [ + "trtllm-bench", + f"--host {cfg.host}", + f"--port {cfg.port}", + f"--num-prompts {cfg.num_prompts}", + f"--request-rate {qps}", + ] + if cfg.dataset is not None: + parts.append(f"--dataset {cfg.dataset}") + output_file = f"bench_{ratio.replace(':', '_')}_qps{qps}.json" + parts.append(f"--output-file {output_file}") + return TRTLLMBenchmarkCommand( + ratio=ratio, + qps=qps, + command=" \\\n ".join(parts), + ) + + def _build_script_snippet( + self, + server: TRTLLMServerCommand, + benchmarks: list[TRTLLMBenchmarkCommand], + ) -> str: + lines = [ + f"# --- {server.ratio} ---", + f"echo 'Building engine for {server.ratio}...'", + f"{server.engine_build_command}", + "", + f"{server.server_command} &", + "SERVER_PID=$!", + "sleep 60 # wait for TRT-LLM server to load engine", + "", + ] + for bench in benchmarks: + lines.append(f"{bench.command}") + lines.append("") + lines.append("kill $SERVER_PID") + lines.append(f"echo 'Done with {server.ratio}'") + lines.append("") + return "\n".join(lines) + + +def generate_trtllm_commands( + model: str, + total_instances: int, + qps_levels: list[float], + *, + tp_size: int = 1, + pp_size: int = 1, + max_batch_size: int = 256, + max_input_len: int = 2048, + max_output_len: int = 2048, + kv_cache_free_gpu_mem_fraction: float = 0.9, + dtype: str = "float16", + dataset: str | None = None, + num_prompts: int = 1000, + host: str = "localhost", + port: int = 8000, + engine_dir: str = "./engines", +) -> list[TRTLLMCommandSet]: + """Programmatic API: generate TensorRT-LLM benchmark commands.""" + config = TRTLLMCommandConfig( + model=model, + total_instances=total_instances, + qps_levels=qps_levels, + tp_size=tp_size, + pp_size=pp_size, + max_batch_size=max_batch_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + kv_cache_free_gpu_mem_fraction=kv_cache_free_gpu_mem_fraction, + dtype=dtype, + dataset=dataset, + num_prompts=num_prompts, + host=host, + port=port, + engine_dir=engine_dir, + ) + return TRTLLMCommandGenerator(config).generate() diff --git a/tests/test_trtllm_commands.py b/tests/test_trtllm_commands.py new file mode 100644 index 0000000..d52bbd6 --- /dev/null +++ b/tests/test_trtllm_commands.py @@ -0,0 +1,322 @@ +"""Tests for TensorRT-LLM Benchmark Command Generator (M113).""" + +from __future__ import annotations + +import json +import subprocess +import tempfile +from pathlib import Path + +import pytest + +from xpyd_plan.trtllm_commands import ( + TRTLLMBenchmarkCommand, + TRTLLMCommandConfig, + TRTLLMCommandGenerator, + TRTLLMCommandSet, + TRTLLMServerCommand, + generate_trtllm_commands, +) + +# --- Config model tests --- + + +class TestTRTLLMCommandConfig: + def test_valid_config(self): + cfg = TRTLLMCommandConfig( + model="meta-llama/Llama-2-7b", + total_instances=4, + qps_levels=[1.0, 2.0], + ) + assert cfg.model == "meta-llama/Llama-2-7b" + assert cfg.total_instances == 4 + assert cfg.qps_levels == [1.0, 2.0] + + def test_defaults(self): + cfg = TRTLLMCommandConfig( + model="m", total_instances=2, qps_levels=[1.0] + ) + assert cfg.tp_size == 1 + assert cfg.pp_size == 1 + assert cfg.max_batch_size == 256 + assert cfg.max_input_len == 2048 + assert cfg.max_output_len == 2048 + assert cfg.kv_cache_free_gpu_mem_fraction == 0.9 + assert cfg.dtype == "float16" + assert cfg.host == "localhost" + assert cfg.port == 8000 + assert cfg.engine_dir == "./engines" + assert cfg.num_prompts == 1000 + + def test_invalid_total_instances(self): + with pytest.raises(Exception): + TRTLLMCommandConfig( + model="m", total_instances=1, qps_levels=[1.0] + ) + + def test_empty_model(self): + with pytest.raises(Exception): + TRTLLMCommandConfig( + model="", total_instances=2, qps_levels=[1.0] + ) + + def test_empty_qps(self): + with pytest.raises(Exception): + TRTLLMCommandConfig( + model="m", total_instances=2, qps_levels=[] + ) + + def test_invalid_kv_cache_fraction(self): + with pytest.raises(Exception): + TRTLLMCommandConfig( + model="m", + total_instances=2, + qps_levels=[1.0], + kv_cache_free_gpu_mem_fraction=0.0, + ) + + +# --- Generator tests --- + + +class TestTRTLLMCommandGenerator: + def _make_config(self, **kwargs): + defaults = dict( + model="meta-llama/Llama-2-7b", + total_instances=4, + qps_levels=[1.0, 2.0], + ) + defaults.update(kwargs) + return TRTLLMCommandConfig(**defaults) + + def test_generates_correct_ratio_count(self): + gen = TRTLLMCommandGenerator(self._make_config(total_instances=4)) + result = gen.generate() + assert len(result) == 3 # 1P:3D, 2P:2D, 3P:1D + + def test_ratio_strings(self): + gen = TRTLLMCommandGenerator(self._make_config(total_instances=3)) + result = gen.generate() + ratios = [cs.server.ratio for cs in result] + assert ratios == ["1P:2D", "2P:1D"] + + def test_server_command_contains_model(self): + gen = TRTLLMCommandGenerator(self._make_config()) + result = gen.generate() + assert "meta-llama/Llama-2-7b" in result[0].server.engine_build_command + + def test_server_command_contains_engine_dir(self): + gen = TRTLLMCommandGenerator(self._make_config(engine_dir="/my/engines")) + result = gen.generate() + assert "/my/engines" in result[0].server.engine_build_command + assert "/my/engines" in result[0].server.server_command + + def test_benchmark_commands_per_qps(self): + gen = TRTLLMCommandGenerator( + self._make_config(qps_levels=[1.0, 2.0, 4.0]) + ) + result = gen.generate() + for cs in result: + assert len(cs.benchmarks) == 3 + + def test_benchmark_command_contains_qps(self): + gen = TRTLLMCommandGenerator(self._make_config(qps_levels=[5.5])) + result = gen.generate() + assert "5.5" in result[0].benchmarks[0].command + + def test_benchmark_output_file(self): + gen = TRTLLMCommandGenerator(self._make_config(qps_levels=[1.0])) + result = gen.generate() + assert "bench_1P_3D_qps1.0.json" in result[0].benchmarks[0].command + + def test_dataset_in_command(self): + gen = TRTLLMCommandGenerator(self._make_config(dataset="/data/test.json")) + result = gen.generate() + assert "/data/test.json" in result[0].benchmarks[0].command + + def test_no_dataset_by_default(self): + gen = TRTLLMCommandGenerator(self._make_config()) + result = gen.generate() + assert "--dataset" not in result[0].benchmarks[0].command + + def test_script_snippet_contains_engine_build(self): + gen = TRTLLMCommandGenerator(self._make_config()) + result = gen.generate() + assert "trtllm-build" in result[0].script_snippet + + def test_script_snippet_contains_kill(self): + gen = TRTLLMCommandGenerator(self._make_config()) + result = gen.generate() + assert "kill $SERVER_PID" in result[0].script_snippet + + def test_custom_tp_pp(self): + gen = TRTLLMCommandGenerator(self._make_config(tp_size=2, pp_size=4)) + result = gen.generate() + assert "--tp_size 2" in result[0].server.engine_build_command + assert "--pp_size 4" in result[0].server.engine_build_command + + def test_custom_dtype(self): + gen = TRTLLMCommandGenerator(self._make_config(dtype="bfloat16")) + result = gen.generate() + assert "--dtype bfloat16" in result[0].server.engine_build_command + + def test_kv_cache_fraction_in_server(self): + gen = TRTLLMCommandGenerator( + self._make_config(kv_cache_free_gpu_mem_fraction=0.85) + ) + result = gen.generate() + assert "0.85" in result[0].server.server_command + + def test_two_instances_gives_one_ratio(self): + gen = TRTLLMCommandGenerator(self._make_config(total_instances=2)) + result = gen.generate() + assert len(result) == 1 + assert result[0].server.ratio == "1P:1D" + + +# --- Programmatic API tests --- + + +class TestGenerateTRTLLMCommands: + def test_basic_api(self): + result = generate_trtllm_commands( + model="test-model", + total_instances=3, + qps_levels=[1.0], + ) + assert len(result) == 2 + assert isinstance(result[0], TRTLLMCommandSet) + + def test_api_with_options(self): + result = generate_trtllm_commands( + model="test-model", + total_instances=4, + qps_levels=[1.0, 2.0], + tp_size=2, + pp_size=2, + max_batch_size=128, + dtype="bfloat16", + engine_dir="/tmp/engines", + ) + assert len(result) == 3 + assert "--tp_size 2" in result[0].server.engine_build_command + assert "--max_batch_size 128" in result[0].server.engine_build_command + + +# --- Model tests --- + + +class TestModels: + def test_server_command_fields(self): + cmd = TRTLLMServerCommand( + ratio="1P:2D", + prefill_instances=1, + decode_instances=2, + engine_build_command="trtllm-build --model test", + server_command="python3 -m tensorrt_llm.serve --engine_dir test", + ) + assert cmd.ratio == "1P:2D" + assert cmd.prefill_instances == 1 + assert cmd.decode_instances == 2 + + def test_benchmark_command_fields(self): + cmd = TRTLLMBenchmarkCommand( + ratio="1P:2D", qps=5.0, command="trtllm-bench --qps 5" + ) + assert cmd.qps == 5.0 + + def test_command_set_fields(self): + server = TRTLLMServerCommand( + ratio="1P:1D", + prefill_instances=1, + decode_instances=1, + engine_build_command="build", + server_command="serve", + ) + cs = TRTLLMCommandSet( + server=server, benchmarks=[], script_snippet="echo hi" + ) + assert cs.script_snippet == "echo hi" + + +# --- CLI integration test --- + + +class TestCLIIntegration: + def test_trtllm_commands_json(self): + """Test CLI trtllm-commands subcommand with JSON output.""" + + result = subprocess.run( + [ + + + "xpyd-plan", + "trtllm-commands", + "--model", + "test-model", + "--total-instances", + "3", + "--qps", + "1.0,2.0", + "--output-format", + "json", + ], + capture_output=True, + text=True, + ) + assert result.returncode == 0 + data = json.loads(result.stdout) + assert len(data) == 2 # 1P:2D, 2P:1D + + def test_trtllm_commands_table(self): + + result = subprocess.run( + [ + + + "xpyd-plan", + "trtllm-commands", + "--model", + "test-model", + "--total-instances", + "3", + "--qps", + "1.0", + "--output-format", + "table", + ], + capture_output=True, + text=True, + ) + assert result.returncode == 0 + assert "P:D Ratio" in result.stdout + + def test_trtllm_commands_output_script(self): + + with tempfile.NamedTemporaryFile(suffix=".sh", delete=False) as f: + script_path = f.name + + result = subprocess.run( + [ + + + "xpyd-plan", + "trtllm-commands", + "--model", + "test-model", + "--total-instances", + "3", + "--qps", + "1.0", + "--output-script", + script_path, + ], + capture_output=True, + text=True, + ) + assert result.returncode == 0 + content = Path(script_path).read_text() + assert "#!/usr/bin/env bash" in content + assert "trtllm-build" in content + assert "tensorrt_llm.serve" in content + Path(script_path).unlink()