From 2522c43ccb56c6dd4a141d3ae4c0fd467bcb11b0 Mon Sep 17 00:00:00 2001 From: hlin99 Date: Mon, 6 Apr 2026 20:20:27 +0800 Subject: [PATCH] refactor: remove integration/stress tests, sim_adapter, xpyd-sim dep Migrated to xPyD-integration repo: - tests/integration/ (~115 tests) - tests/stress/ (~10 tests) Also removed: - sim_adapter.py (no longer needed) - tests/conftest.py (only used by integration tests) - xpyd-sim dev dependency - Fixed build job CI to run unit tests --- .github/workflows/ci.yml | 12 +- pyproject.toml | 1 - sim_adapter.py | 26 -- tests/conftest.py | 90 ---- tests/integration/__init__.py | 0 tests/integration/test_cli_and_discovery.py | 316 ------------- .../integration/test_completions_endpoint.py | 44 -- tests/integration/test_concurrent_requests.py | 186 -------- tests/integration/test_dual_routing.py | 422 ------------------ tests/integration/test_large_payload.py | 176 -------- tests/integration/test_multi_model_routing.py | 349 --------------- tests/integration/test_proxy.py | 271 ----------- tests/integration/test_proxy_matrix.py | 224 ---------- .../test_resilience_integration.py | 365 --------------- .../test_scheduling_integration.py | 398 ----------------- tests/integration/test_sim_nodes.py | 131 ------ tests/integration/test_status_instances.py | 117 ----- tests/integration/test_streaming_edge.py | 145 ------ .../test_xpyd_start_proxy_integration.py | 202 --------- tests/integration/test_xpyd_start_proxy_sh.py | 335 -------------- tests/integration/test_yaml_config.py | 390 ---------------- tests/integration/test_yaml_integration.py | 222 --------- tests/stress/__init__.py | 0 tests/stress/test_benchmark_e2e.py | 395 ---------------- tests/stress/test_benchmark_integration.py | 338 -------------- tests/unit/test_metrics.py | 31 -- 26 files changed, 2 insertions(+), 5184 deletions(-) delete mode 100644 sim_adapter.py delete mode 100644 tests/conftest.py delete mode 100644 tests/integration/__init__.py delete mode 100644 tests/integration/test_cli_and_discovery.py delete mode 100644 tests/integration/test_completions_endpoint.py delete mode 100644 tests/integration/test_concurrent_requests.py delete mode 100644 tests/integration/test_dual_routing.py delete mode 100644 tests/integration/test_large_payload.py delete mode 100644 tests/integration/test_multi_model_routing.py delete mode 100644 tests/integration/test_proxy.py delete mode 100644 tests/integration/test_proxy_matrix.py delete mode 100644 tests/integration/test_resilience_integration.py delete mode 100644 tests/integration/test_scheduling_integration.py delete mode 100644 tests/integration/test_sim_nodes.py delete mode 100644 tests/integration/test_status_instances.py delete mode 100644 tests/integration/test_streaming_edge.py delete mode 100644 tests/integration/test_xpyd_start_proxy_integration.py delete mode 100644 tests/integration/test_xpyd_start_proxy_sh.py delete mode 100644 tests/integration/test_yaml_config.py delete mode 100644 tests/integration/test_yaml_integration.py delete mode 100644 tests/stress/__init__.py delete mode 100644 tests/stress/test_benchmark_e2e.py delete mode 100644 tests/stress/test_benchmark_integration.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0a2f497..b3922ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - run: pip install -e ".[dev]" - - run: python -m pytest tests/unit/ tests/integration/ -v --tb=short --timeout=120 + - run: python -m pytest tests/unit/ -v --tb=short --timeout=120 lint: runs-on: ubuntu-latest steps: @@ -36,12 +36,4 @@ jobs: - run: pip install build && python -m build - run: pip install dist/*.whl - run: xpyd --help && xpyd --version - - run: pip install pytest pytest-asyncio aiohttp requests xpyd-sim && python -m pytest tests/integration/test_cli_and_discovery.py -v --tb=short - benchmark: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: {python-version: "3.12"} - - run: pip install -e ".[dev]" - - run: python -m pytest tests/stress/ -v --tb=short -m benchmark -s + - run: pip install pytest pytest-asyncio && python -m pytest tests/unit/ -v --tb=short \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9aff729..0531086 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ dev = [ "isort>=5.13.0", "tiktoken>=0.7.0", "pytest-timeout>=2.3.0", - "xpyd-sim>=0.2.0", ] [project.scripts] diff --git a/sim_adapter.py b/sim_adapter.py deleted file mode 100644 index d011a72..0000000 --- a/sim_adapter.py +++ /dev/null @@ -1,26 +0,0 @@ -"""sim_adapter — drop-in for dummy_nodes using real xpyd-sim. - -The default model_name points to the test tokenizer bundled in this repo. -Override via make_sim_app(model_name=...) for custom models. -""" - -from pathlib import Path - -from xpyd_sim.server import ServerConfig, create_app - -_REPO_ROOT = Path(__file__).resolve().parent -_DEFAULT_MODEL = str(_REPO_ROOT / "tokenizers/DeepSeek-R1") - - -def make_sim_app(model_name=None, mode="dual"): - """Create a real xpyd-sim app. Defaults to test tokenizer model.""" - return create_app(ServerConfig( - mode=mode, model_name=model_name or _DEFAULT_MODEL, prefill_delay_ms=0, - kv_transfer_delay_ms=0, decode_delay_per_token_ms=0, - eos_min_ratio=1.0, max_model_len=131072, - )) - - -# Default apps for subprocess (uvicorn sim_adapter:prefill_app) -prefill_app = make_sim_app(mode="prefill") -decode_app = make_sim_app(mode="decode") diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 90ea2e2..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Shared test fixtures and utilities.""" - -import os -import socket -import threading -import time - -import pytest -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from httpx import ASGITransport, AsyncClient - -from sim_adapter import make_sim_app -from xpyd.proxy import Proxy, RoundRobinSchedulingPolicy - -_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -_TOKENIZER_PATH = os.path.join(_REPO_ROOT, "tokenizers", "DeepSeek-R1") - -# Create apps with the correct model name (must match proxy config) -_prefill_app = make_sim_app(mode="prefill") -_decode_app = make_sim_app(mode="decode") - - -def _free_port(): - with socket.socket() as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("127.0.0.1", 0)) - return sock.getsockname()[1] - - -_PREFILL_PORT = _free_port() -_DECODE_PORT = _free_port() - - -def _run_server(app, port): - config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - uvicorn.Server(config).run() - - -threading.Thread(target=_run_server, args=(_prefill_app, _PREFILL_PORT), daemon=True).start() -threading.Thread(target=_run_server, args=(_decode_app, _DECODE_PORT), daemon=True).start() - -# Wait for readiness -import httpx as _httpx # noqa: E402 - -for _port in (_PREFILL_PORT, _DECODE_PORT): - for _ in range(50): - try: - if _httpx.get(f"http://127.0.0.1:{_port}/health", timeout=1).status_code == 200: - break - except Exception: - time.sleep(0.2) - else: - raise RuntimeError(f"Server on port {_port} failed to start") - - -def _make_proxy_app(): - proxy = Proxy( - prefill_instances=[f"127.0.0.1:{_PREFILL_PORT}"], - decode_instances=[f"127.0.0.1:{_DECODE_PORT}"], - model=_TOKENIZER_PATH, - scheduling_policy=RoundRobinSchedulingPolicy(), - generator_on_p_node=False, - ) - app = FastAPI() - app.add_middleware( - CORSMiddleware, allow_origins=["*"], allow_credentials=False, - allow_methods=["*"], allow_headers=["*"], - ) - app.include_router(proxy.router) - return app - - -@pytest.fixture -def dummy_ports(): - return _PREFILL_PORT, _DECODE_PORT - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -@pytest.fixture -async def client(): - app = _make_proxy_app() - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - yield cli diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/integration/test_cli_and_discovery.py b/tests/integration/test_cli_and_discovery.py deleted file mode 100644 index e09db27..0000000 --- a/tests/integration/test_cli_and_discovery.py +++ /dev/null @@ -1,316 +0,0 @@ -"""Tests for Task 8/16: CLI packaging, subcommand parser, startup discovery.""" - -from __future__ import annotations - -import os -import socket -import textwrap -import threading -import time -from unittest.mock import patch - -import pytest -import uvicorn -from fastapi import FastAPI -from fastapi.responses import JSONResponse, PlainTextResponse -from httpx import ASGITransport, AsyncClient - -from xpyd.config import ProxyConfig -from xpyd.discovery import DiscoveryTimeout, NodeDiscovery -from xpyd.proxy import _build_parser, _resolve_config_path - -# ------------------------------------------------------------------ -# CLI argument parsing -# ------------------------------------------------------------------ - - -class TestCLIParsing: - def test_version_flag(self, capsys): - parser = _build_parser() - with pytest.raises(SystemExit) as exc_info: - parser.parse_args(["--version"]) - assert exc_info.value.code == 0 - - def test_config_arg(self): - parser = _build_parser() - args = parser.parse_args(["proxy", "--config", "test.yaml"]) - assert args.config == "test.yaml" - - def test_validate_config_arg(self): - parser = _build_parser() - args = parser.parse_args(["proxy", "--validate-config", "test.yaml"]) - assert args.validate_config == "test.yaml" - - def test_log_level_arg(self): - parser = _build_parser() - args = parser.parse_args(["proxy", "--log-level", "debug"]) - assert args.log_level == "debug" - - def test_proxy_subcommand_with_all_args(self): - parser = _build_parser() - args = parser.parse_args( - [ - "proxy", - "--config", - "c.yaml", - "--port", - "9000", - "--log-level", - "info", - ] - ) - assert args.command == "proxy" - assert args.config == "c.yaml" - assert args.port == 9000 - assert args.log_level == "info" - - def test_old_args_not_accepted(self): - """Old CLI args (--model, --prefill, etc.) must be rejected.""" - parser = _build_parser() - for flag in ( - "--model", - "--prefill", - "--decode", - "--roundrobin", - "--generator_on_p_node", - ): - with pytest.raises(SystemExit): - parser.parse_args(["proxy", flag, "value"]) - - -# ------------------------------------------------------------------ -# Config resolution: --config > XPYD_CONFIG > ./xpyd.yaml -# ------------------------------------------------------------------ - - -class TestConfigResolution: - def test_cli_config_wins(self): - parser = _build_parser() - args = parser.parse_args(["proxy", "--config", "cli.yaml"]) - assert _resolve_config_path(args) == "cli.yaml" - - def test_env_var_fallback(self): - parser = _build_parser() - args = parser.parse_args(["proxy"]) - with patch.dict(os.environ, {"XPYD_CONFIG": "env.yaml"}): - assert _resolve_config_path(args) == "env.yaml" - - def test_default_file_fallback(self, tmp_path, monkeypatch): - (tmp_path / "xpyd.yaml").write_text("model: test\n") - monkeypatch.chdir(tmp_path) - parser = _build_parser() - args = parser.parse_args(["proxy"]) - env = {k: v for k, v in os.environ.items() if k != "XPYD_CONFIG"} - with patch.dict(os.environ, env, clear=True): - result = _resolve_config_path(args) - assert result is not None - assert result.endswith("xpyd.yaml") - - def test_no_config_exits(self, tmp_path, monkeypatch): - monkeypatch.chdir(tmp_path) - parser = _build_parser() - args = parser.parse_args(["proxy"]) - env = {k: v for k, v in os.environ.items() if k != "XPYD_CONFIG"} - with patch.dict(os.environ, env, clear=True): - with pytest.raises(SystemExit) as exc_info: - _resolve_config_path(args) - assert exc_info.value.code == 1 - - -# ------------------------------------------------------------------ -# --validate-config -# ------------------------------------------------------------------ - - -class TestValidateConfig: - def test_valid_config(self, tmp_path): - p = tmp_path / "valid.yaml" - p.write_text( - textwrap.dedent( - """\ - model: /path/model - decode: - - "10.0.0.1:8000" - """ - ) - ) - config = ProxyConfig.from_yaml(str(p)) - assert config.model == "/path/model" - - def test_invalid_config(self, tmp_path): - p = tmp_path / "bad.yaml" - p.write_text("not_a_field: oops\n") - with pytest.raises((ValueError, Exception), match=".*"): - ProxyConfig.from_yaml(str(p)) - - -# ------------------------------------------------------------------ -# Startup config in YAML -# ------------------------------------------------------------------ - - -class TestStartupConfig: - def test_startup_section(self, tmp_path): - p = tmp_path / "config.yaml" - p.write_text( - textwrap.dedent( - """\ - model: /m - decode: - - "10.0.0.1:8000" - startup: - wait_timeout_seconds: 120 - probe_interval_seconds: 5 - """ - ) - ) - cfg = ProxyConfig.from_yaml(str(p)) - assert cfg.wait_timeout_seconds == 120 - assert cfg.probe_interval_seconds == 5 - - def test_startup_defaults(self, tmp_path): - p = tmp_path / "config.yaml" - p.write_text( - textwrap.dedent( - """\ - model: /m - decode: - - "10.0.0.1:8000" - """ - ) - ) - cfg = ProxyConfig.from_yaml(str(p)) - assert cfg.wait_timeout_seconds == 600 - assert cfg.probe_interval_seconds == 10 - - def test_startup_unknown_keys(self, tmp_path): - p = tmp_path / "config.yaml" - p.write_text( - textwrap.dedent( - """\ - model: /m - decode: - - "10.0.0.1:8000" - startup: - bad_key: 1 - """ - ) - ) - with pytest.raises(ValueError, match="Unknown keys in startup"): - ProxyConfig.from_yaml(str(p)) - - -# ------------------------------------------------------------------ -# Node discovery -# ------------------------------------------------------------------ - - -def _free_port(): - with socket.socket() as s: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -def _start_health_server(port: int): - """Start a tiny server that responds 200 on /health.""" - app = FastAPI() - - @app.get("/health") - async def health(): - return PlainTextResponse("ok") - - def _run(): - config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - uvicorn.Server(config).run() - - threading.Thread(target=_run, daemon=True).start() - time.sleep(1) - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -@pytest.mark.anyio -async def test_discovery_finds_healthy_nodes(): - """Discovery should detect healthy nodes and become ready.""" - p_port = _free_port() - d_port = _free_port() - _start_health_server(p_port) - _start_health_server(d_port) - - disc = NodeDiscovery( - prefill_instances=[f"127.0.0.1:{p_port}"], - decode_instances=[f"127.0.0.1:{d_port}"], - probe_interval=0.5, - wait_timeout=10, - ) - await disc.start() - ready = await disc.wait_until_ready() - await disc.stop() - - assert ready is True - assert disc.is_ready - assert f"127.0.0.1:{p_port}" in disc.healthy_prefill - assert f"127.0.0.1:{d_port}" in disc.healthy_decode - - -@pytest.mark.anyio -async def test_discovery_timeout_when_no_nodes(): - """Discovery should raise DiscoveryTimeout when nodes are unreachable.""" - disc = NodeDiscovery( - prefill_instances=["127.0.0.1:1"], - decode_instances=["127.0.0.1:2"], - probe_interval=0.2, - wait_timeout=1.0, - ) - await disc.start() - disc._task.remove_done_callback(disc._on_probe_done) - - ready = await disc.wait_until_ready() - assert ready is False - assert not disc.is_ready - - with pytest.raises(DiscoveryTimeout): - await disc._task - - -@pytest.mark.anyio -async def test_503_before_ready(): - """Proxy should return 503 before discovery reports ready.""" - disc = NodeDiscovery( - prefill_instances=["127.0.0.1:1"], - decode_instances=["127.0.0.1:2"], - probe_interval=60, - wait_timeout=600, - ) - - app = FastAPI() - - @app.middleware("http") - async def check_readiness(request, call_next): - path = request.url.path - if path in ("/health", "/ping", "/status", "/metrics"): - return await call_next(request) - if not disc.is_ready: - return JSONResponse({"error": "waiting for backend nodes"}, status_code=503) - return await call_next(request) - - @app.get("/health") - async def health(): - return PlainTextResponse("ok") - - @app.post("/v1/completions") - async def completions(): - return JSONResponse({"choices": []}) - - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as client: - resp = await client.get("/health") - assert resp.status_code == 200 - - resp = await client.post("/v1/completions", json={}) - assert resp.status_code == 503 - assert "waiting for backend nodes" in resp.json()["error"] diff --git a/tests/integration/test_completions_endpoint.py b/tests/integration/test_completions_endpoint.py deleted file mode 100644 index 2d2c1a8..0000000 --- a/tests/integration/test_completions_endpoint.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Tests for /v1/completions endpoint.""" - -import pytest -from httpx import AsyncClient - -COMPLETION_PAYLOAD = { - "model": "dummy", - "prompt": "Once upon a time", - "max_tokens": 5, - "stream": False, -} - - -@pytest.mark.anyio -async def test_non_streaming_completion(client: AsyncClient): - resp = await client.post("/v1/completions", json=COMPLETION_PAYLOAD) - assert resp.status_code == 200 - data = resp.json() - assert data["object"] == "text_completion" - assert len(data["choices"]) >= 1 - assert data["choices"][0]["finish_reason"] in ("stop", "length") - assert len(data["choices"][0]["text"]) > 0 - - -@pytest.mark.anyio -async def test_streaming_completion(client: AsyncClient): - payload = {**COMPLETION_PAYLOAD, "stream": True} - resp = await client.post("/v1/completions", json=payload) - assert resp.status_code == 200 - assert "text/event-stream" in resp.headers["content-type"] - - lines = resp.text.strip().split("\n") - data_lines = [line for line in lines if line.startswith("data: ")] - assert len(data_lines) >= 2 - assert data_lines[-1] == "data: [DONE]" - - -@pytest.mark.anyio -async def test_completion_max_tokens(client: AsyncClient): - payload = {**COMPLETION_PAYLOAD, "max_tokens": 3, "stream": False} - resp = await client.post("/v1/completions", json=payload) - assert resp.status_code == 200 - data = resp.json() - assert data["usage"]["completion_tokens"] == 3 diff --git a/tests/integration/test_concurrent_requests.py b/tests/integration/test_concurrent_requests.py deleted file mode 100644 index c9f47e1..0000000 --- a/tests/integration/test_concurrent_requests.py +++ /dev/null @@ -1,186 +0,0 @@ -"""Tests for concurrent requests.""" - -import asyncio -import json -import os -import socket -import threading -import time - -import pytest -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from httpx import ASGITransport, AsyncClient - -from sim_adapter import make_sim_app -from xpyd.proxy import Proxy, RoundRobinSchedulingPolicy - -prefill_app = make_sim_app(mode='prefill') -decode_app = make_sim_app(mode='decode') - -_REPO_ROOT = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) -_TOKENIZER_PATH = os.path.join(_REPO_ROOT, "tokenizers", "DeepSeek-R1") - - -def _free_port(): - with socket.socket() as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("127.0.0.1", 0)) - return sock.getsockname()[1] - - -_PREFILL_PORT = _free_port() -_DECODE_PORT = _free_port() - - -def _run_server(app, port): - config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - uvicorn.Server(config).run() - - -threading.Thread( - target=_run_server, args=(prefill_app, _PREFILL_PORT), daemon=True -).start() -threading.Thread( - target=_run_server, args=(decode_app, _DECODE_PORT), daemon=True -).start() -time.sleep(2) - - -def _make_proxy_app(): - proxy = Proxy( - prefill_instances=[f"127.0.0.1:{_PREFILL_PORT}"], - decode_instances=[f"127.0.0.1:{_DECODE_PORT}"], - model=_TOKENIZER_PATH, - scheduling_policy=RoundRobinSchedulingPolicy(), - generator_on_p_node=False, - ) - app = FastAPI() - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], - ) - app.include_router(proxy.router) - return app - - -CHAT_PAYLOAD = { - "model": "dummy", - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 5, - "stream": False, -} - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -@pytest.fixture -async def client(): - app = _make_proxy_app() - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - yield cli - - -@pytest.mark.anyio -async def test_concurrent_chat_completions(client: AsyncClient): - """15 concurrent non-streaming requests should all succeed with unique ids.""" - concurrency = 15 - payloads = [ - {**CHAT_PAYLOAD, "messages": [{"role": "user", "content": f"Hello {idx}"}]} - for idx in range(concurrency) - ] - - tasks = [client.post("/v1/chat/completions", json=p) for p in payloads] - responses = await asyncio.gather(*tasks) - - ids = set() - for resp in responses: - assert resp.status_code == 200 - data = resp.json() - assert data["object"] == "chat.completion" - assert len(data["choices"]) >= 1 - ids.add(data["id"]) - - # Every response must have a unique id - assert len(ids) == concurrency - - -@pytest.mark.anyio -async def test_concurrent_streaming(client: AsyncClient): - """10 concurrent streaming requests should all produce valid SSE.""" - concurrency = 10 - payload = {**CHAT_PAYLOAD, "stream": True} - - tasks = [ - client.post("/v1/chat/completions", json=payload) for _ in range(concurrency) - ] - responses = await asyncio.gather(*tasks) - - for resp in responses: - assert resp.status_code == 200 - assert "text/event-stream" in resp.headers["content-type"] - lines = resp.text.strip().split("\n") - data_lines = [line for line in lines if line.startswith("data: ")] - assert len(data_lines) >= 2 - assert data_lines[-1] == "data: [DONE]" - - -@pytest.mark.anyio -async def test_mixed_concurrent_streaming_and_non_streaming(client: AsyncClient): - """Streaming and non-streaming requests running concurrently.""" - non_stream_tasks = [ - client.post( - "/v1/chat/completions", - json={ - **CHAT_PAYLOAD, - "messages": [{"role": "user", "content": f"ns-{idx}"}], - }, - ) - for idx in range(8) - ] - stream_tasks = [ - client.post( - "/v1/chat/completions", - json={ - **CHAT_PAYLOAD, - "stream": True, - "messages": [{"role": "user", "content": f"s-{idx}"}], - }, - ) - for idx in range(8) - ] - - responses = await asyncio.gather(*non_stream_tasks, *stream_tasks) - non_stream_responses = responses[:8] - stream_responses = responses[8:] - - ns_ids = set() - for resp in non_stream_responses: - assert resp.status_code == 200 - data = resp.json() - assert data["object"] == "chat.completion" - assert data["usage"]["completion_tokens"] == CHAT_PAYLOAD["max_tokens"] - ns_ids.add(data["id"]) - - assert len(ns_ids) == 8, "Non-streaming response ids must be unique" - - for resp in stream_responses: - assert resp.status_code == 200 - assert "text/event-stream" in resp.headers["content-type"] - lines = resp.text.strip().split("\n") - data_lines = [line for line in lines if line.startswith("data: ")] - assert data_lines[-1] == "data: [DONE]" - # Verify each chunk is valid JSON - for line in data_lines[:-1]: - chunk = json.loads(line.removeprefix("data: ")) - assert chunk["object"] == "chat.completion.chunk" diff --git a/tests/integration/test_dual_routing.py b/tests/integration/test_dual_routing.py deleted file mode 100644 index 117eb80..0000000 --- a/tests/integration/test_dual_routing.py +++ /dev/null @@ -1,422 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Integration tests for dual-role instances with mock servers.""" - -import os -import random -import socket -import threading -import time - -import pytest -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from httpx import ASGITransport, AsyncClient - -from xpyd.proxy import Proxy -from xpyd.registry import InstanceRegistry -from xpyd.scheduler import RoundRobinSchedulingPolicy - -_REPO_ROOT = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) -_TOKENIZER_PATH = os.path.join(_REPO_ROOT, "tokenizers", "DeepSeek-R1") - - -def _free_port(): - with socket.socket() as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("127.0.0.1", 0)) - return sock.getsockname()[1] - - -def _make_dummy_app(model_id: str): - from sim_adapter import make_sim_app - return make_sim_app(model_id) - - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture(scope="module") -def dual_nodes(): - """Start 8 dummy nodes for dual/PD tests.""" - import httpx - - node_models = { - "dual1": "qwen-2", - "dual2": "qwen-2", - "p1": "llama-3", - "p2": "llama-3", - "d1": "llama-3", - "d2": "llama-3", - "dual3": "deepseek-r1", - "dual4": "deepseek-r1", - } - ports = {k: _free_port() for k in node_models} - servers = [] - - for name, port in ports.items(): - model = node_models[name] - app = _make_dummy_app(model) - config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - srv = uvicorn.Server(config) - servers.append(srv) - threading.Thread(target=srv.run, daemon=True).start() - - deadline = time.monotonic() + 10 - for _name, port in ports.items(): - url = f"http://127.0.0.1:{port}/health" - while time.monotonic() < deadline: - try: - r = httpx.get(url, timeout=1) - if r.status_code == 200: - break - except Exception: - pass - time.sleep(0.1) - - addrs = {name: f"127.0.0.1:{port}" for name, port in ports.items()} - yield addrs, node_models - - for srv in servers: - srv.should_exit = True - - -def _make_dual_proxy_app(addrs, node_models): - """Build a proxy with dual + P/D models.""" - reg = InstanceRegistry() - dual_instances = {} - - # Register dual instances - for name in ["dual1", "dual2", "dual3", "dual4"]: - model = node_models[name] - reg.add("dual", addrs[name], model=model) - dual_instances.setdefault(model, []).append(addrs[name]) - - # Register P/D instances - all_prefill = [] - all_decode = [] - for name in ["p1", "p2"]: - reg.add("prefill", addrs[name], model=node_models[name]) - all_prefill.append(addrs[name]) - for name in ["d1", "d2"]: - reg.add("decode", addrs[name], model=node_models[name]) - all_decode.append(addrs[name]) - - for addr in addrs.values(): - reg.mark_healthy(addr) - - sched = RoundRobinSchedulingPolicy(registry=reg) - proxy = Proxy( - prefill_instances=all_prefill, - decode_instances=all_decode, - model=_TOKENIZER_PATH, - scheduling_policy=sched, - generator_on_p_node=False, - registry=reg, - dual_instances=dual_instances, - ) - - app = FastAPI() - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], - ) - app.include_router(proxy.router) - return app, reg - - -@pytest.fixture -async def dual_client(dual_nodes): - addrs, node_models = dual_nodes - app, reg = _make_dual_proxy_app(addrs, node_models) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - yield cli - - -@pytest.fixture -async def dual_client_and_registry(dual_nodes): - addrs, node_models = dual_nodes - app, reg = _make_dual_proxy_app(addrs, node_models) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - yield cli, reg - - -# --------------------------------------------------------------------------- -# End-to-end dual tests -# --------------------------------------------------------------------------- - - -@pytest.mark.anyio -async def test_dual_chat_completion(dual_client: AsyncClient): - """Dual model request returns correct response.""" - resp = await dual_client.post( - "/v1/chat/completions", - json={ - "model": "qwen-2", - "messages": [{"role": "user", "content": "hello"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert data["model"] == "qwen-2" - - -@pytest.mark.anyio -async def test_dual_completions_endpoint(dual_client: AsyncClient): - """Dual model /v1/completions returns correct response.""" - resp = await dual_client.post( - "/v1/completions", - json={ - "model": "qwen-2", - "prompt": "hello world", - "max_tokens": 5, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert data["model"] == "qwen-2" - - -@pytest.mark.anyio -async def test_pd_model_still_works(dual_client: AsyncClient): - """P/D model request still goes through two-phase flow.""" - resp = await dual_client.post( - "/v1/chat/completions", - json={ - "model": "llama-3", - "messages": [{"role": "user", "content": "hello"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert data["model"] == "llama-3" - - -@pytest.mark.anyio -async def test_dual_unknown_model_error(dual_client: AsyncClient): - """Unknown model returns error.""" - resp = await dual_client.post( - "/v1/chat/completions", - json={ - "model": "nonexistent", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 5, - }, - ) - assert resp.status_code in (404, 503) - - -@pytest.mark.anyio -async def test_dual_all_down_503(dual_client_and_registry): - """All dual instances down returns 503.""" - cli, reg = dual_client_and_registry - # Mark deepseek-r1 dual instances as unhealthy - for info in reg.get_all_instances(): - if info.model == "deepseek-r1" and info.role == "dual": - reg.mark_unhealthy(info.address) - - resp = await cli.post( - "/v1/chat/completions", - json={ - "model": "deepseek-r1", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 503 - - # Restore - for info in reg.get_all_instances(): - if info.model == "deepseek-r1": - reg.mark_healthy(info.address) - - -@pytest.mark.anyio -async def test_models_lists_dual_models(dual_client: AsyncClient): - """GET /v1/models includes dual models.""" - resp = await dual_client.get("/v1/models") - assert resp.status_code == 200 - data = resp.json() - model_ids = sorted(m["id"] for m in data["data"]) - assert "qwen-2" in model_ids - assert "deepseek-r1" in model_ids - assert "llama-3" in model_ids - - -# --------------------------------------------------------------------------- -# Fixed boundary tests (8 instances) -# --------------------------------------------------------------------------- - - -@pytest.mark.anyio -async def test_all_dual_single_model(dual_nodes): - """8 dual instances serving one model.""" - addrs, _ = dual_nodes - all_addrs = list(addrs.values()) - - reg = InstanceRegistry() - dual_map = {"test-model": all_addrs} - for addr in all_addrs: - reg.add("dual", addr, model="test-model") - reg.mark_healthy(addr) - - sched = RoundRobinSchedulingPolicy(registry=reg) - proxy = Proxy( - prefill_instances=[], - decode_instances=[], - model=_TOKENIZER_PATH, - scheduling_policy=sched, - generator_on_p_node=False, - registry=reg, - dual_instances=dual_map, - ) - - app = FastAPI() - app.include_router(proxy.router) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - resp = await cli.post( - "/v1/chat/completions", - json={ - "model": "test-model", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 200 - - -@pytest.mark.anyio -async def test_three_models_mixed(dual_client: AsyncClient): - """3 models: qwen-2 (dual) + llama-3 (P/D) + deepseek-r1 (dual).""" - for model_name in ["qwen-2", "llama-3", "deepseek-r1"]: - resp = await dual_client.post( - "/v1/chat/completions", - json={ - "model": model_name, - "messages": [{"role": "user", "content": "test"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 200 - assert resp.json()["model"] == model_name - - -# --------------------------------------------------------------------------- -# Randomized mixed deployment (20 seeds) -# --------------------------------------------------------------------------- - - -def _generate_random_deployment(addrs, seed): - """Generate a random valid deployment from 8 addresses.""" - rng = random.Random(seed) - pool = list(addrs.values()) - rng.shuffle(pool) - - num_models = rng.randint(1, 3) - models = [] - - for i in range(num_models): - remaining = num_models - i - max_for_this = len(pool) - (remaining - 1) * 2 - if max_for_this < 2: - max_for_this = 2 - count = rng.randint(2, min(max_for_this, len(pool))) - assigned = pool[:count] - pool = pool[count:] - - mode = rng.choice(["dual", "pd"]) - name = f"model-{i}" - - if mode == "dual": - models.append({"name": name, "mode": "dual", "instances": assigned}) - else: - split = rng.randint(1, len(assigned) - 1) - models.append( - { - "name": name, - "mode": "pd", - "prefill": assigned[:split], - "decode": assigned[split:], - } - ) - - return models - - -@pytest.mark.parametrize("seed", range(20)) -@pytest.mark.anyio -async def test_randomized_deployment(dual_nodes, seed): - """Randomized deployment with seed-based reproducibility.""" - addrs, _ = dual_nodes - deployment = _generate_random_deployment(addrs, seed) - - reg = InstanceRegistry() - dual_map = {} - all_prefill = [] - all_decode = [] - - for model_cfg in deployment: - name = model_cfg["name"] - if model_cfg["mode"] == "dual": - dual_map[name] = model_cfg["instances"] - for addr in model_cfg["instances"]: - if addr not in [i.address for i in reg.get_all_instances()]: - reg.add("dual", addr, model=name) - else: - for addr in model_cfg["prefill"]: - if addr not in [i.address for i in reg.get_all_instances()]: - reg.add("prefill", addr, model=name) - all_prefill.append(addr) - for addr in model_cfg["decode"]: - if addr not in [i.address for i in reg.get_all_instances()]: - reg.add("decode", addr, model=name) - all_decode.append(addr) - - for info in reg.get_all_instances(): - reg.mark_healthy(info.address) - - sched = RoundRobinSchedulingPolicy(registry=reg) - proxy = Proxy( - prefill_instances=all_prefill, - decode_instances=all_decode, - model=_TOKENIZER_PATH, - scheduling_policy=sched, - generator_on_p_node=False, - registry=reg, - dual_instances=dual_map, - ) - - app = FastAPI() - app.include_router(proxy.router) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - for model_cfg in deployment: - name = model_cfg["name"] - # The dummy nodes return model IDs based on their actual model, - # not the name we assigned. For routing verification, just check - # we get a 200 (correct routing to a live server). - resp = await cli.post( - "/v1/chat/completions", - json={ - "model": name, - "messages": [{"role": "user", "content": "test"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 200, ( - f"seed={seed}, model={name}, mode={model_cfg['mode']}, " - f"status={resp.status_code}" - ) diff --git a/tests/integration/test_large_payload.py b/tests/integration/test_large_payload.py deleted file mode 100644 index 668d437..0000000 --- a/tests/integration/test_large_payload.py +++ /dev/null @@ -1,176 +0,0 @@ -"""Tests for large payloads and edge cases.""" - -import os -import socket -import threading -import time - -import pytest -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from httpx import ASGITransport, AsyncClient - -from sim_adapter import make_sim_app -from xpyd.proxy import Proxy, RoundRobinSchedulingPolicy - -prefill_app = make_sim_app(mode='prefill') -decode_app = make_sim_app(mode='decode') - -_REPO_ROOT = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) -_TOKENIZER_PATH = os.path.join(_REPO_ROOT, "tokenizers", "DeepSeek-R1") - - -def _free_port(): - with socket.socket() as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("127.0.0.1", 0)) - return sock.getsockname()[1] - - -_PREFILL_PORT = _free_port() -_DECODE_PORT = _free_port() - - -def _run_server(app, port): - config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - uvicorn.Server(config).run() - - -threading.Thread( - target=_run_server, args=(prefill_app, _PREFILL_PORT), daemon=True -).start() -threading.Thread( - target=_run_server, args=(decode_app, _DECODE_PORT), daemon=True -).start() -time.sleep(2) - - -def _make_proxy_app(): - proxy = Proxy( - prefill_instances=[f"127.0.0.1:{_PREFILL_PORT}"], - decode_instances=[f"127.0.0.1:{_DECODE_PORT}"], - model=_TOKENIZER_PATH, - scheduling_policy=RoundRobinSchedulingPolicy(), - generator_on_p_node=False, - ) - app = FastAPI() - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], - ) - app.include_router(proxy.router) - return app - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -@pytest.fixture -async def client(): - app = _make_proxy_app() - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - yield cli - - -@pytest.mark.anyio -async def test_large_prompt(client: AsyncClient): - """A 10K+ token prompt should still be handled without crashing.""" - # ~12K tokens: each "word_NNNN " is roughly 2 tokens - long_content = " ".join(f"word_{idx}" for idx in range(6000)) - payload = { - "model": "dummy", - "messages": [{"role": "user", "content": long_content}], - "max_tokens": 5, - "stream": False, - } - resp = await client.post("/v1/chat/completions", json=payload) - assert resp.status_code == 200 - data = resp.json() - assert "choices" in data - assert len(data["choices"]) >= 1 - - -@pytest.mark.anyio -async def test_max_tokens_zero(client: AsyncClient): - """max_tokens=0 → 200 with 0 completion tokens.""" - payload = { - "model": "dummy", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 0, - "stream": False, - } - resp = await client.post("/v1/chat/completions", json=payload) - assert resp.status_code == 200 - data = resp.json() - assert data["usage"]["completion_tokens"] == 0 - - -@pytest.mark.anyio -async def test_max_tokens_negative(client: AsyncClient): - """Negative max_tokens is passed through (proxy does not validate).""" - payload = { - "model": "dummy", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": -1, - "stream": False, - } - resp = await client.post("/v1/chat/completions", json=payload) - # Proxy does not validate max_tokens; backend handles it - assert resp.status_code == 200 - - -@pytest.mark.anyio -async def test_max_tokens_very_large(client: AsyncClient): - """Very large max_tokens should succeed (dummy backend caps output).""" - payload = { - "model": "dummy", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 1000, - "stream": False, - } - resp = await client.post("/v1/chat/completions", json=payload) - assert resp.status_code == 200 - data = resp.json() - assert "choices" in data - - -@pytest.mark.anyio -async def test_missing_messages_returns_400(client: AsyncClient): - """POST /v1/chat/completions without 'messages' should return 400.""" - resp = await client.post( - "/v1/chat/completions", - json={"model": "dummy", "max_tokens": 5}, - ) - assert resp.status_code == 400 - assert "messages" in resp.json()["error"]["message"].lower() - - -@pytest.mark.anyio -async def test_missing_prompt_returns_400(client: AsyncClient): - """POST /v1/completions without 'prompt' should return 400.""" - resp = await client.post( - "/v1/completions", - json={"model": "dummy", "max_tokens": 5}, - ) - assert resp.status_code == 400 - assert "prompt" in resp.json()["error"]["message"].lower() - - -@pytest.mark.anyio -async def test_invalid_messages_type_returns_400(client: AsyncClient): - """POST /v1/chat/completions with non-list 'messages' should return 400.""" - resp = await client.post( - "/v1/chat/completions", - json={"model": "dummy", "messages": "not a list", "max_tokens": 5}, - ) - assert resp.status_code == 400 - assert "list" in resp.json()["error"]["message"].lower() diff --git a/tests/integration/test_multi_model_routing.py b/tests/integration/test_multi_model_routing.py deleted file mode 100644 index cca71ff..0000000 --- a/tests/integration/test_multi_model_routing.py +++ /dev/null @@ -1,349 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Integration tests for multi-model routing with 4P+4D dummy nodes, 3 models.""" - -import os -import socket -import threading -import time - -import pytest -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from httpx import ASGITransport, AsyncClient - -from xpyd.proxy import Proxy, RoundRobinSchedulingPolicy -from xpyd.registry import InstanceRegistry - -_REPO_ROOT = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) -_TOKENIZER_PATH = os.path.join(_REPO_ROOT, "tokenizers", "DeepSeek-R1") - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _free_port(): - with socket.socket() as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("127.0.0.1", 0)) - return sock.getsockname()[1] - - -def _make_dummy_app(model_id: str): - from sim_adapter import make_sim_app - return make_sim_app(model_id) - - - -# --------------------------------------------------------------------------- -# Topology: 4P + 4D, 3 models -# llama-3: p1, p2, d1, d2 -# deepseek-r1: p3, d3 -# qwen-2: p4, d4 -# --------------------------------------------------------------------------- - -_MODEL_MAP = { - "p1": "llama-3", - "p2": "llama-3", - "p3": "deepseek-r1", - "p4": "qwen-2", - "d1": "llama-3", - "d2": "llama-3", - "d3": "deepseek-r1", - "d4": "qwen-2", -} - - -@pytest.fixture(scope="session") -def dummy_nodes(): - """Start all 8 dummy nodes once per session, poll for readiness, and - return a dict mapping node name to ``127.0.0.1:``.""" - import httpx - - ports: dict[str, int] = {k: _free_port() for k in _MODEL_MAP} - servers: list[uvicorn.Server] = [] - - for name, port in ports.items(): - model = _MODEL_MAP[name] - app = _make_dummy_app(model) - config = uvicorn.Config( - app, - host="127.0.0.1", - port=port, - log_level="error", - ) - srv = uvicorn.Server(config) - servers.append(srv) - threading.Thread(target=srv.run, daemon=True).start() - - # Poll for readiness instead of fixed sleep - deadline = time.monotonic() + 10 - for _name, port in ports.items(): - url = f"http://127.0.0.1:{port}/health" - while time.monotonic() < deadline: - try: - r = httpx.get(url, timeout=1) - if r.status_code == 200: - break - except Exception: - pass - time.sleep(0.1) - - addrs = {name: f"127.0.0.1:{port}" for name, port in ports.items()} - yield addrs - - # Teardown: signal servers to shut down - for srv in servers: - srv.should_exit = True - - -def _addr(name, addrs): - return addrs[name] - - -def _make_multi_model_proxy_app(addrs): - """Build a proxy with multi-model registry.""" - all_prefill = [addrs["p1"], addrs["p2"], addrs["p3"], addrs["p4"]] - all_decode = [addrs["d1"], addrs["d2"], addrs["d3"], addrs["d4"]] - - reg = InstanceRegistry() - for name in ["p1", "p2", "p3", "p4"]: - reg.add("prefill", addrs[name], model=_MODEL_MAP[name]) - for name in ["d1", "d2", "d3", "d4"]: - reg.add("decode", addrs[name], model=_MODEL_MAP[name]) - for name in addrs: - reg.mark_healthy(addrs[name]) - - sched = RoundRobinSchedulingPolicy(registry=reg) - - proxy = Proxy( - prefill_instances=all_prefill, - decode_instances=all_decode, - model=_TOKENIZER_PATH, - scheduling_policy=sched, - generator_on_p_node=False, - registry=reg, - ) - - app = FastAPI() - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], - ) - app.include_router(proxy.router) - return app, reg - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -@pytest.fixture -async def multi_model_client(dummy_nodes): - app, _ = _make_multi_model_proxy_app(dummy_nodes) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - yield cli - - -@pytest.fixture -async def multi_model_client_and_registry(dummy_nodes): - app, reg = _make_multi_model_proxy_app(dummy_nodes) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - yield cli, reg, dummy_nodes - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -@pytest.mark.anyio -async def test_multi_model_routing_correct(multi_model_client: AsyncClient): - """Send request with model=llama-3, verify response model matches.""" - resp = await multi_model_client.post( - "/v1/chat/completions", - json={ - "model": "llama-3", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert data["model"] == "llama-3" - - -@pytest.mark.anyio -async def test_multi_model_unknown_model_404(multi_model_client: AsyncClient): - """Request with unknown model should return 404.""" - resp = await multi_model_client.post( - "/v1/chat/completions", - json={ - "model": "nonexistent", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 404 - data = resp.json() - assert "error" in data - - -@pytest.mark.anyio -async def test_models_endpoint_lists_all( - multi_model_client_and_registry, -): - """GET /v1/models returns all 3 models.""" - cli, reg, addrs = multi_model_client_and_registry - resp = await cli.get("/v1/models") - assert resp.status_code == 200 - data = resp.json() - assert data["object"] == "list" - model_ids = sorted(m["id"] for m in data["data"]) - assert model_ids == ["deepseek-r1", "llama-3", "qwen-2"] - - -@pytest.mark.anyio -async def test_models_endpoint_format( - multi_model_client_and_registry, -): - """Response format matches OpenAI spec.""" - cli, reg, addrs = multi_model_client_and_registry - resp = await cli.get("/v1/models") - assert resp.status_code == 200 - data = resp.json() - assert "object" in data - assert "data" in data - for model in data["data"]: - assert "id" in model - assert model["object"] == "model" - assert "created" in model - assert "owned_by" in model - - -@pytest.mark.anyio -async def test_multi_model_routing_isolation(multi_model_client: AsyncClient): - """Request with model=deepseek-r1 must NOT hit llama-3 instances.""" - resp = await multi_model_client.post( - "/v1/chat/completions", - json={ - "model": "deepseek-r1", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert data["model"] == "deepseek-r1" - - -@pytest.mark.anyio -async def test_multi_model_one_model_down( - multi_model_client_and_registry, -): - """When all instances of model B are unhealthy, B returns error but A works.""" - cli, reg, addrs = multi_model_client_and_registry - # Mark deepseek-r1 instances as unhealthy - reg.mark_unhealthy(_addr("p3", addrs)) - reg.mark_unhealthy(_addr("d3", addrs)) - - # deepseek-r1 should fail (no available instances) - resp_b = await cli.post( - "/v1/chat/completions", - json={ - "model": "deepseek-r1", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 5, - }, - ) - assert resp_b.status_code in (404, 503) - - # llama-3 should still work - resp_a = await cli.post( - "/v1/chat/completions", - json={ - "model": "llama-3", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 5, - }, - ) - assert resp_a.status_code == 200 - assert resp_a.json()["model"] == "llama-3" - - # Restore health - reg.mark_healthy(_addr("p3", addrs)) - reg.mark_healthy(_addr("d3", addrs)) - - -@pytest.mark.anyio -async def test_multi_model_load_balance(multi_model_client: AsyncClient): - """N requests to llama-3 should distribute across d1 and d2.""" - models_seen = set() - for _ in range(10): - resp = await multi_model_client.post( - "/v1/chat/completions", - json={ - "model": "llama-3", - "messages": [{"role": "user", "content": "hi"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 200 - assert resp.json()["model"] == "llama-3" - models_seen.add(resp.json()["model"]) - # All should be llama-3 - assert models_seen == {"llama-3"} - - -@pytest.mark.anyio -async def test_multi_model_prefill_decode_match(multi_model_client: AsyncClient): - """Prefill and decode for same request go to instances of the same model.""" - # Send multiple requests to different models, verify model in response - for model_name in ["llama-3", "deepseek-r1", "qwen-2"]: - resp = await multi_model_client.post( - "/v1/chat/completions", - json={ - "model": model_name, - "messages": [{"role": "user", "content": "test"}], - "max_tokens": 5, - }, - ) - assert resp.status_code == 200 - assert resp.json()["model"] == model_name - - -@pytest.mark.anyio -async def test_models_endpoint_updates_on_instance_change( - multi_model_client_and_registry, -): - """After removing all instances of qwen-2, /v1/models no longer lists it.""" - cli, reg, addrs = multi_model_client_and_registry - - # Verify qwen-2 is listed initially - resp = await cli.get("/v1/models") - model_ids = [m["id"] for m in resp.json()["data"]] - assert "qwen-2" in model_ids - - # Remove all qwen-2 instances - reg.remove(_addr("p4", addrs)) - reg.remove(_addr("d4", addrs)) - - # qwen-2 should no longer be listed - resp = await cli.get("/v1/models") - model_ids = [m["id"] for m in resp.json()["data"]] - assert "qwen-2" not in model_ids - - # Re-add for cleanup (other tests might share fixtures) - reg.add("prefill", _addr("p4", addrs), model="qwen-2") - reg.add("decode", _addr("d4", addrs), model="qwen-2") - reg.mark_healthy(_addr("p4", addrs)) - reg.mark_healthy(_addr("d4", addrs)) diff --git a/tests/integration/test_proxy.py b/tests/integration/test_proxy.py deleted file mode 100644 index 8a19080..0000000 --- a/tests/integration/test_proxy.py +++ /dev/null @@ -1,271 +0,0 @@ -"""Tests for MicroPDProxyServer.""" - -import itertools -import json -import os -import socket -import threading -import time -from unittest.mock import patch - -import pytest -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from httpx import ASGITransport, AsyncClient - -from sim_adapter import make_sim_app -from xpyd.proxy import LoadBalancedScheduler, Proxy, RoundRobinSchedulingPolicy - -prefill_app = make_sim_app(mode='prefill') -decode_app = make_sim_app(mode='decode') - -# --------------------------------------------------------------------------- -# Use local tokenizer from repo to avoid network dependency in CI -# --------------------------------------------------------------------------- - -_REPO_ROOT = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) -_TOKENIZER_PATH = os.path.join(_REPO_ROOT, "tokenizers", "DeepSeek-R1") - - -def _make_proxy_app( - prefill_instances: list | None = None, - decode_instances: list | None = None, -) -> FastAPI: - """Create a FastAPI app with a Proxy router for testing.""" - proxy = Proxy( - prefill_instances=prefill_instances or [], - decode_instances=decode_instances or [], - model=_TOKENIZER_PATH, - scheduling_policy=RoundRobinSchedulingPolicy(), - generator_on_p_node=False, - ) - app = FastAPI() - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], - ) - app.include_router(proxy.router) - return app - - -# --------------------------------------------------------------------------- -# Helpers – start real dummy-node servers for the proxy to talk to -# --------------------------------------------------------------------------- - - -def _free_port(): - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -_PREFILL_PORT = _free_port() -_DECODE_PORT = _free_port() - - -def _run_server(app, port): - config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - uvicorn.Server(config).run() - - -threading.Thread( - target=_run_server, args=(prefill_app, _PREFILL_PORT), daemon=True -).start() -threading.Thread( - target=_run_server, args=(decode_app, _DECODE_PORT), daemon=True -).start() -time.sleep(1) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -@pytest.fixture -async def client(): - proxy_app = _make_proxy_app( - prefill_instances=[f"127.0.0.1:{_PREFILL_PORT}"], - decode_instances=[f"127.0.0.1:{_DECODE_PORT}"], - ) - transport = ASGITransport(app=proxy_app) - async with AsyncClient(transport=transport, base_url="http://test") as c: - yield c - - -CHAT_PAYLOAD = { - "model": "dummy", - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 5, - "stream": False, -} - - -# --------------------------------------------------------------------------- -# Proxy endpoint tests -# --------------------------------------------------------------------------- - - -@pytest.mark.anyio -async def test_health(client: AsyncClient): - resp = await client.get("/health") - assert resp.status_code == 200 - data = resp.json() - # Proxy /health returns per-instance results keyed by host:port - assert len(data) > 0 - for _inst, info in data.items(): - assert info["status"] == 200 - assert info["data"]["status"] == "ok" - - -@pytest.mark.anyio -async def test_status(client: AsyncClient): - resp = await client.get("/status") - assert resp.status_code == 200 - data = resp.json() - assert data["prefill_node_count"] == 1 - assert data["decode_node_count"] == 1 - assert len(data["prefill_nodes"]) == 1 - assert len(data["decode_nodes"]) == 1 - - -@pytest.mark.anyio -async def test_non_streaming(client: AsyncClient): - resp = await client.post("/v1/chat/completions", json=CHAT_PAYLOAD) - assert resp.status_code == 200 - data = resp.json() - - assert data["object"] == "chat.completion" - assert len(data["choices"]) == 1 - assert data["choices"][0]["finish_reason"] in ("stop", "length") - assert data["choices"][0]["message"]["role"] == "assistant" - assert len(data["choices"][0]["message"]["content"]) > 0 - - assert data["usage"]["completion_tokens"] == 5 - assert data["usage"]["total_tokens"] == data["usage"]["prompt_tokens"] + 5 - - -@pytest.mark.anyio -async def test_streaming(client: AsyncClient): - payload = {**CHAT_PAYLOAD, "stream": True} - resp = await client.post("/v1/chat/completions", json=payload) - assert resp.status_code == 200 - assert "text/event-stream" in resp.headers["content-type"] - - lines = resp.text.strip().split("\n") - data_lines = [line for line in lines if line.startswith("data: ")] - - # role + content + finish + [DONE] - assert len(data_lines) >= 4 - - assert data_lines[-1] == "data: [DONE]" - - first = json.loads(data_lines[0].removeprefix("data: ")) - assert first["choices"][0]["delta"]["role"] == "assistant" - - content = "" - for line in data_lines[1:-2]: - chunk = json.loads(line.removeprefix("data: ")) - content += chunk["choices"][0]["delta"]["content"] - assert len(content) > 0 - - -@pytest.mark.anyio -async def test_max_tokens_respected(client: AsyncClient): - payload = {**CHAT_PAYLOAD, "max_tokens": 3, "stream": False} - resp = await client.post("/v1/chat/completions", json=payload) - data = resp.json() - assert data["usage"]["completion_tokens"] == 3 - - -@pytest.mark.anyio -async def test_streaming_token_count(client: AsyncClient): - """Verify the number of content tokens in streaming matches max_tokens.""" - payload = {**CHAT_PAYLOAD, "max_tokens": 7, "stream": True} - resp = await client.post("/v1/chat/completions", json=payload) - - lines = resp.text.strip().split("\n") - data_lines = [ - line for line in lines if line.startswith("data: ") and line != "data: [DONE]" - ] - - content_chunks = 0 - for line in data_lines: - chunk = json.loads(line.removeprefix("data: ")) - delta = chunk["choices"][0]["delta"] - if delta.get("content") is not None: - content_chunks += 1 - - assert content_chunks >= 1 - - -# --------------------------------------------------------------------------- -# Unit tests – scheduling policies -# --------------------------------------------------------------------------- - - -def test_round_robin_scheduling(): - policy = RoundRobinSchedulingPolicy() - instances = ["a:1", "b:2", "c:3"] - cycler = itertools.cycle(instances) - results = [policy.schedule(cycler) for _ in range(6)] - assert results == ["a:1", "b:2", "c:3", "a:1", "b:2", "c:3"] - - -def test_round_robin_schedule_with_full_signature(): - """Verify RoundRobin.schedule() accepts the same args Proxy.schedule() passes.""" - policy = RoundRobinSchedulingPolicy() - instances = ["a:1", "b:2"] - cycler = itertools.cycle(instances) - - r1 = policy.schedule(cycler, True, 100, 1) - r2 = policy.schedule(cycler, False, 100, 50) - assert r1 == "a:1" - assert r2 == "b:2" - - -def test_round_robin_schedule_completion_exists(): - """Verify RoundRobin has schedule_completion() (no-op from base class).""" - policy = RoundRobinSchedulingPolicy() - policy.schedule_completion( - prefill_instance="a:1", decode_instance=None, req_len=100 - ) - policy.schedule_completion(prefill_instance=None, decode_instance="b:2", req_len=50) - - -@patch( - "xpyd.scheduler.load_balanced.query_instance_model_len", - return_value=[131072, 131072], -) -def test_load_balanced_scheduling(mock_query): - """Test LoadBalancedScheduler distributes requests across instances.""" - prefill = ["p1:1", "p2:2"] - decode = ["d1:1", "d2:2"] - policy = LoadBalancedScheduler(prefill, decode) - - p_cycler = itertools.cycle(prefill) - d_cycler = itertools.cycle(decode) - - r1 = policy.schedule(p_cycler, is_prompt=True, request_len=100, max_tokens=50) - assert r1 in prefill - - r2 = policy.schedule(p_cycler, is_prompt=True, request_len=100, max_tokens=50) - assert r2 in prefill - assert r2 != r1 - - d1 = policy.schedule(d_cycler, is_prompt=False, request_len=50, max_tokens=50) - assert d1 in decode - d2 = policy.schedule(d_cycler, is_prompt=False, request_len=50, max_tokens=50) - assert d2 in decode - assert d2 != d1 diff --git a/tests/integration/test_proxy_matrix.py b/tests/integration/test_proxy_matrix.py deleted file mode 100644 index b33c6e9..0000000 --- a/tests/integration/test_proxy_matrix.py +++ /dev/null @@ -1,224 +0,0 @@ -"""Integration tests for the proxy/dummy-node matrix from task_openclaw.md. - -These tests intentionally exercise the real ``xpyd/proxy.py`` -server with multiple dummy prefill/decode nodes and the requested proxy -configurations, without changing the core business logic. -""" - -from __future__ import annotations - -import os -import socket -import subprocess -import sys -import time -from contextlib import ExitStack -from pathlib import Path - -import pytest -import requests - -REPO_ROOT = Path(__file__).resolve().parents[2] -PYTHON = sys.executable -TOKENIZER_DIR = str(REPO_ROOT / "tokenizers" / "DeepSeek-R1") -ENV = { - **os.environ, - "PYTHONPATH": str(REPO_ROOT), - "PREFILL_DELAY_PER_TOKEN": "0", - "DECODE_DELAY_PER_TOKEN": "0", -} - -MATRIX = [ - (1, 2, 1), - (2, 2, 1), - (1, 2, 2), - (1, 2, 4), - (1, 2, 8), - (2, 2, 2), - (2, 4, 1), - (2, 4, 2), -] - - -_used_ports: set[int] = set() - - -def _free_port() -> int: - """Find a free TCP port, avoiding previously allocated ports.""" - for _ in range(100): - with socket.socket() as sock: - sock.bind(("127.0.0.1", 0)) - port = sock.getsockname()[1] - if port not in _used_ports: - _used_ports.add(port) - return port - raise RuntimeError("Unable to find a unique free port") - - -def _wait_http_ok(url: str, timeout: float = 40.0) -> None: - deadline = time.time() + timeout - last_error: Exception | None = None - while time.time() < deadline: - try: - response = requests.get(url, timeout=1.5) - if response.status_code == 200: - return - except Exception as exc: # pragma: no cover - best effort polling - last_error = exc - time.sleep(0.2) - raise AssertionError(f"Timed out waiting for {url}; last_error={last_error}") - - -def _spawn_node(mode, port): - app_ref = "sim_adapter:prefill_app" if mode == "prefill" else "sim_adapter:decode_app" - return subprocess.Popen( - [PYTHON, "-m", "uvicorn", app_ref, "--host", "127.0.0.1", "--port", str(port), "--log-level", "warning"], - cwd=REPO_ROOT, env={**os.environ, "PYTHONPATH": str(REPO_ROOT)}, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - ) - - -def _spawn_proxy( - prefill_instances: list[str], decode_instances: list[str], port: int -) -> subprocess.Popen: - # Generate a temporary YAML config for the proxy - import atexit - import tempfile - - import yaml - - config = { - "model": TOKENIZER_DIR, - "port": port, - "decode": decode_instances, - } - if prefill_instances: - config["prefill"] = prefill_instances - - config_file = tempfile.NamedTemporaryFile( - mode="w", suffix=".yaml", delete=False, dir=str(REPO_ROOT) - ) - yaml.dump(config, config_file) - config_file.close() - atexit.register(os.unlink, config_file.name) - - command = [ - PYTHON, - "-m", - "xpyd.proxy", - "proxy", - "--config", - config_file.name, - ] - return subprocess.Popen( - command, - cwd=REPO_ROOT, - env=ENV, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - - -def _stop_process(process: subprocess.Popen) -> None: - if process.poll() is not None: - return - process.terminate() - try: - process.wait(timeout=5) - except subprocess.TimeoutExpired: - process.kill() - process.wait(timeout=5) - - -def _drain_process_output(process: subprocess.Popen) -> str: - stdout = "" - stderr = "" - try: - if process.stdout: - stdout = process.stdout.read() or "" - if process.stderr: - stderr = process.stderr.read() or "" - except Exception: - pass - return f"STDOUT:\n{stdout}\nSTDERR:\n{stderr}" - - -@pytest.mark.parametrize("prefill_count,decode_count,tp_size", MATRIX) -def test_proxy_matrix(prefill_count: int, decode_count: int, tp_size: int): - num_decode_ports = 8 // tp_size - prefill_ports = [_free_port() for _ in range(prefill_count)] - decode_ports = [_free_port() for _ in range(decode_count * num_decode_ports)] - proxy_port = _free_port() - - with ExitStack() as stack: - prefill_processes = [] - decode_processes = [] - for port in prefill_ports: - process = _spawn_node("prefill", port) - prefill_processes.append(process) - stack.callback(_stop_process, process) - for port in decode_ports: - process = _spawn_node("decode", port) - decode_processes.append(process) - stack.callback(_stop_process, process) - - for port in prefill_ports: - _wait_http_ok(f"http://127.0.0.1:{port}/v1/models") - for port in decode_ports: - _wait_http_ok(f"http://127.0.0.1:{port}/v1/models") - - prefill_instances = [f"127.0.0.1:{port}" for port in prefill_ports] - decode_instances = [f"127.0.0.1:{port}" for port in decode_ports] - proxy = _spawn_proxy(prefill_instances, decode_instances, proxy_port) - stack.callback(_stop_process, proxy) - - try: - _wait_http_ok(f"http://127.0.0.1:{proxy_port}/status") - except AssertionError: - details = ["Proxy failed to start"] - details.append(_drain_process_output(proxy)) - for process in [*prefill_processes, *decode_processes]: - if process.poll() not in (None, 0): - details.append(_drain_process_output(process)) - pytest.fail("\n".join(details)) - - status = requests.get(f"http://127.0.0.1:{proxy_port}/status", timeout=5).json() - assert status["prefill_node_count"] == prefill_count - assert status["decode_node_count"] == decode_count * num_decode_ports - - payload = { - "messages": [ - { - "role": "user", - "content": f"matrix {prefill_count}-{decode_count}-{tp_size}", - } - ], - "max_tokens": 4, - "stream": False, - } - response = requests.post( - f"http://127.0.0.1:{proxy_port}/v1/chat/completions", - json=payload, - timeout=15, - ) - assert response.status_code == 200, response.text - data = response.json() - assert data["object"] == "chat.completion" - assert data["choices"][0]["message"]["role"] == "assistant" - assert data["usage"]["completion_tokens"] == 4 - - stream_payload = dict(payload) - stream_payload["stream"] = True - stream_response = requests.post( - f"http://127.0.0.1:{proxy_port}/v1/chat/completions", - json=stream_payload, - timeout=15, - ) - assert stream_response.status_code == 200, stream_response.text - assert "data: [DONE]" in stream_response.text - - # Fail fast if any child process crashed during the request. - for process in [*prefill_processes, *decode_processes, proxy]: - if process.poll() not in (None, 0): - pytest.fail(_drain_process_output(process)) diff --git a/tests/integration/test_resilience_integration.py b/tests/integration/test_resilience_integration.py deleted file mode 100644 index 62f90d9..0000000 --- a/tests/integration/test_resilience_integration.py +++ /dev/null @@ -1,365 +0,0 @@ -"""Integration tests for Task 9: registry + circuit breaker + health monitor.""" - -from __future__ import annotations - -import socket -import threading -import time -from pathlib import Path - -import pytest -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from httpx import ASGITransport, AsyncClient - -from sim_adapter import make_sim_app -from xpyd.config import ProxyConfig -from xpyd.proxy import Proxy, RoundRobinSchedulingPolicy - -prefill_app = make_sim_app(mode='prefill') -decode_app = make_sim_app(mode='decode') - -_REPO_ROOT = Path(__file__).resolve().parents[2] -_TOKENIZER_PATH = str(_REPO_ROOT / "tokenizers" / "DeepSeek-R1") - - -def _free_port(): - """Allocate a free port. Keep socket open until caller binds to avoid races.""" - s = socket.socket() - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(("127.0.0.1", 0)) - port = s.getsockname()[1] - s.close() - return port - - -def _run_server(app, port): - config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - uvicorn.Server(config).run() - - -@pytest.fixture(scope="session") -def dummy_nodes(): - """Start dummy prefill/decode servers once per test session.""" - prefill_port = _free_port() - decode_port_1 = _free_port() - decode_port_2 = _free_port() - - for app, port in [ - (prefill_app, prefill_port), - (decode_app, decode_port_1), - (decode_app, decode_port_2), - ]: - threading.Thread(target=_run_server, args=(app, port), daemon=True).start() - time.sleep(2) - - return { - "prefill_port": prefill_port, - "decode_port_1": decode_port_1, - "decode_port_2": decode_port_2, - } - - -def _make_config(dummy_nodes, **overrides): - """Build a ProxyConfig for testing.""" - defaults = { - "model": _TOKENIZER_PATH, - "prefill": [f"127.0.0.1:{dummy_nodes['prefill_port']}"], - "decode": [ - f"127.0.0.1:{dummy_nodes['decode_port_1']}", - f"127.0.0.1:{dummy_nodes['decode_port_2']}", - ], - "port": 8000, - } - defaults.update(overrides) - return ProxyConfig(**defaults) - - -def _make_proxy_app(config): - """Create a FastAPI app with Proxy from config.""" - proxy = Proxy( - prefill_instances=config.prefill, - decode_instances=config.decode, - model=config.model, - scheduling_policy=RoundRobinSchedulingPolicy(), - generator_on_p_node=config.generator_on_p_node, - ) - app = FastAPI() - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], - ) - app.include_router(proxy.router) - return app - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -# ------------------------------------------------------------------ -# Scenario A: Baseline — normal requests work -# ------------------------------------------------------------------ - - -@pytest.fixture -async def baseline_client(dummy_nodes): - config = _make_config(dummy_nodes) - app = _make_proxy_app(config) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - yield cli - - -@pytest.mark.anyio -async def test_baseline_health(baseline_client): - resp = await baseline_client.get("/health") - assert resp.status_code == 200 - - -@pytest.mark.anyio -async def test_baseline_completion(baseline_client): - resp = await baseline_client.post( - "/v1/completions", - json={"model": "test", "prompt": "Hello", "max_tokens": 5, "stream": False}, - ) - assert resp.status_code == 200 - assert "choices" in resp.json() - - -@pytest.mark.anyio -async def test_baseline_chat_completion(baseline_client): - resp = await baseline_client.post( - "/v1/chat/completions", - json={ - "model": "test", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 5, - "stream": False, - }, - ) - assert resp.status_code == 200 - assert "choices" in resp.json() - - -# ------------------------------------------------------------------ -# Scenario B: Registry correctly registers instances -# ------------------------------------------------------------------ - - -class TestRegistry: - def test_registry_registers_all_instances(self, dummy_nodes): - from xpyd.registry import InstanceRegistry - - registry = InstanceRegistry() - registry.add("prefill", f"127.0.0.1:{dummy_nodes['prefill_port']}") - registry.add("decode", f"127.0.0.1:{dummy_nodes['decode_port_1']}") - registry.add("decode", f"127.0.0.1:{dummy_nodes['decode_port_2']}") - - # New instances default to UNKNOWN — mark healthy first. - registry.mark_healthy(f"127.0.0.1:{dummy_nodes['prefill_port']}") - registry.mark_healthy(f"127.0.0.1:{dummy_nodes['decode_port_1']}") - registry.mark_healthy(f"127.0.0.1:{dummy_nodes['decode_port_2']}") - - prefill = registry.get_available_instances("prefill") - decode = registry.get_available_instances("decode") - assert len(prefill) == 1 - assert len(decode) == 2 - - def test_registry_mark_unhealthy_removes_from_available(self, dummy_nodes): - from xpyd.registry import InstanceRegistry - - registry = InstanceRegistry() - registry.add("decode", f"127.0.0.1:{dummy_nodes['decode_port_1']}") - registry.add("decode", f"127.0.0.1:{dummy_nodes['decode_port_2']}") - registry.mark_healthy(f"127.0.0.1:{dummy_nodes['decode_port_1']}") - registry.mark_healthy(f"127.0.0.1:{dummy_nodes['decode_port_2']}") - registry.mark_unhealthy(f"127.0.0.1:{dummy_nodes['decode_port_1']}") - available = registry.get_available_instances("decode") - assert len(available) == 1 - assert f"127.0.0.1:{dummy_nodes['decode_port_2']}" in available - - def test_registry_mark_healthy_restores(self, dummy_nodes): - from xpyd.registry import InstanceRegistry - - registry = InstanceRegistry() - registry.add("decode", f"127.0.0.1:{dummy_nodes['decode_port_1']}") - registry.mark_unhealthy(f"127.0.0.1:{dummy_nodes['decode_port_1']}") - assert len(registry.get_available_instances("decode")) == 0 - registry.mark_healthy(f"127.0.0.1:{dummy_nodes['decode_port_1']}") - assert len(registry.get_available_instances("decode")) == 1 - - -# ------------------------------------------------------------------ -# Scenario C: Circuit breaker integration -# ------------------------------------------------------------------ - - -class TestCircuitBreakerIntegration: - def test_circuit_opens_after_failures(self): - from xpyd.registry import InstanceRegistry - - registry = InstanceRegistry( - cb_enabled=True, failure_threshold=2, timeout_duration_seconds=30 - ) - registry.add("decode", "10.0.0.1:8200") - registry.add("decode", "10.0.0.2:8200") - registry.mark_healthy("10.0.0.1:8200") - registry.mark_healthy("10.0.0.2:8200") - - # Two failures → circuit opens - registry.record_failure("10.0.0.1:8200") - registry.record_failure("10.0.0.1:8200") - - available = registry.get_available_instances("decode") - assert "10.0.0.1:8200" not in available - assert "10.0.0.2:8200" in available - - def test_circuit_closes_after_recovery(self): - from xpyd.registry import InstanceRegistry - - t = [0.0] - registry = InstanceRegistry( - cb_enabled=True, - failure_threshold=2, - success_threshold=1, - timeout_duration_seconds=5, - clock=lambda: t[0], - ) - registry.add("decode", "10.0.0.1:8200") - registry.mark_healthy("10.0.0.1:8200") - - # Open circuit - registry.record_failure("10.0.0.1:8200") - registry.record_failure("10.0.0.1:8200") - assert len(registry.get_available_instances("decode")) == 0 - - # Advance time past timeout → half-open - t[0] = 6.0 - # Record success in half-open - registry.record_success("10.0.0.1:8200") - assert len(registry.get_available_instances("decode")) == 1 - - -# ------------------------------------------------------------------ -# Scenario D: Health monitor detects healthy nodes -# ------------------------------------------------------------------ - - -@pytest.mark.anyio -async def test_health_monitor_detects_healthy(dummy_nodes): - from xpyd.health_monitor import HealthMonitor - - results = [] - monitor = HealthMonitor( - nodes=[ - f"127.0.0.1:{dummy_nodes['decode_port_1']}", - f"127.0.0.1:{dummy_nodes['decode_port_2']}", - ], - interval_seconds=60, - timeout_seconds=2, - on_healthy=lambda addr: results.append(("healthy", addr)), - on_unhealthy=lambda addr: results.append(("unhealthy", addr)), - ) - await monitor.check_once() - - healthy = [r for r in results if r[0] == "healthy"] - assert len(healthy) == 2 - - -@pytest.mark.anyio -async def test_health_monitor_detects_unreachable(): - from xpyd.health_monitor import HealthMonitor - - results = [] - monitor = HealthMonitor( - nodes=["127.0.0.1:1"], # nothing listening - interval_seconds=60, - timeout_seconds=1, - on_healthy=lambda addr: results.append(("healthy", addr)), - on_unhealthy=lambda addr: results.append(("unhealthy", addr)), - ) - await monitor.check_once() - - assert len(results) == 1 - assert results[0] == ("unhealthy", "127.0.0.1:1") - - -# ------------------------------------------------------------------ -# Scenario E: Health monitor + registry integration -# ------------------------------------------------------------------ - - -@pytest.mark.anyio -async def test_health_monitor_updates_registry(dummy_nodes): - """Health monitor callbacks should update registry status.""" - from xpyd.health_monitor import HealthMonitor - from xpyd.registry import InstanceRegistry - - registry = InstanceRegistry() - registry.add("decode", f"127.0.0.1:{dummy_nodes['decode_port_1']}") - registry.add("decode", "127.0.0.1:1") # unreachable - - monitor = HealthMonitor( - nodes=[f"127.0.0.1:{dummy_nodes['decode_port_1']}", "127.0.0.1:1"], - interval_seconds=60, - timeout_seconds=1, - on_healthy=registry.mark_healthy, - on_unhealthy=registry.mark_unhealthy, - ) - await monitor.check_once() - - available = registry.get_available_instances("decode") - assert f"127.0.0.1:{dummy_nodes['decode_port_1']}" in available - assert "127.0.0.1:1" not in available - - -# ------------------------------------------------------------------ -# Scenario F: 4xx should not be retried (unit level check) -# ------------------------------------------------------------------ - - -class TestNoRetryOn4xx: - """Verify that 4xx errors are returned directly without retry.""" - - @pytest.mark.anyio - async def test_400_returned_directly(self, baseline_client): - resp = await baseline_client.post( - "/v1/completions", - json={"model": "test"}, # missing 'prompt' field - ) - assert resp.status_code == 400 - - -# ------------------------------------------------------------------ -# Scenario G: All features disabled = backward compatible -# ------------------------------------------------------------------ - - -class TestBackwardCompatibility: - def test_default_config_all_disabled(self, dummy_nodes): - config = _make_config(dummy_nodes) - assert config.circuit_breaker.enabled is False - assert config.health_check.enabled is False - - @pytest.mark.anyio - async def test_proxy_works_with_all_disabled(self, dummy_nodes): - config = _make_config(dummy_nodes) - app = _make_proxy_app(config) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - resp = await cli.post( - "/v1/completions", - json={ - "model": "test", - "prompt": "Hello", - "max_tokens": 5, - "stream": False, - }, - ) - assert resp.status_code == 200 diff --git a/tests/integration/test_scheduling_integration.py b/tests/integration/test_scheduling_integration.py deleted file mode 100644 index f2d0729..0000000 --- a/tests/integration/test_scheduling_integration.py +++ /dev/null @@ -1,398 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Tests for Task 10 integration — scheduling context extraction and -YAML-based policy selection.""" - -import itertools -from unittest.mock import MagicMock - -import pytest - -from xpyd.scheduler import ( - CacheAwarePolicy, - ConsistentHashPolicy, - PowerOfTwoPolicy, - RoundRobinSchedulingPolicy, - default_registry, -) -from xpyd.scheduler.cache_aware import CacheAwarePolicy as CacheAwareDirect - -# ------------------------------------------------------------------ # -# Cache-aware policy unit tests -# ------------------------------------------------------------------ # - - -class TestCacheAwarePolicy: - """Unit tests for CacheAwarePolicy.""" - - def test_same_prefix_same_worker(self): - policy = CacheAwarePolicy( - workers=["w1", "w2", "w3"], - prefix_length=5, - ) - # Both prompts share the same first 5 tokens (whitespace-split) - w1 = policy.select(prompt="alpha beta gamma delta epsilon zeta") - w2 = policy.select(prompt="alpha beta gamma delta epsilon different") - assert w1 == w2 - - def test_different_prefix_can_differ(self): - policy = CacheAwarePolicy( - workers=["w1", "w2", "w3"], - prefix_length=256, - ) - selected = set() - for i in range(50): - w = policy.select(prompt=f"Unique prompt {i} " * 100) - selected.add(w) - assert len(selected) > 1 - - def test_single_worker(self): - policy = CacheAwarePolicy(workers=["w1"], prefix_length=256) - assert policy.select(prompt="hello") == "w1" - - def test_no_workers(self): - policy = CacheAwarePolicy(workers=[], prefix_length=256) - assert policy.select(prompt="hello") is None - - def test_none_prompt_deterministic(self): - policy = CacheAwarePolicy(workers=["w1", "w2"], prefix_length=256) - w1 = policy.select(prompt=None) - w2 = policy.select(prompt=None) - assert w1 == w2 - - def test_add_remove_worker(self): - policy = CacheAwarePolicy(workers=["w1", "w2"]) - policy.add_worker("w3") - selected = {policy.select(prompt=f"p{i}") for i in range(100)} - assert "w3" in selected - policy.remove_worker("w3") - # After removal, w3 should never be selected - for i in range(50): - assert policy.select(prompt=f"p{i}") != "w3" - - def test_schedule_interface(self): - """schedule() passes prompt through to select().""" - policy = CacheAwarePolicy(workers=["w1", "w2", "w3"]) - cycler = itertools.cycle(["w1", "w2", "w3"]) - result = policy.schedule(cycler, prompt="hello world") - assert result in {"w1", "w2", "w3"} - - -# ------------------------------------------------------------------ # -# Policy registry — all built-in policies registered -# ------------------------------------------------------------------ # - - -class TestPolicyRegistryIntegration: - """Verify all Task 10 policies are registered in default_registry.""" - - @pytest.mark.parametrize( - "name", - [ - "roundrobin", - "loadbalanced", - "consistent_hash", - "power_of_two", - "cache_aware", - ], - ) - def test_builtin_policies_registered(self, name): - assert default_registry.has(name) - - def test_create_consistent_hash(self): - policy = default_registry.create( - "consistent_hash", - workers=["w1", "w2"], - ) - assert isinstance(policy, ConsistentHashPolicy) - - def test_create_power_of_two(self): - policy = default_registry.create( - "power_of_two", - workers=["w1", "w2"], - ) - assert isinstance(policy, PowerOfTwoPolicy) - - def test_create_cache_aware(self): - policy = default_registry.create( - "cache_aware", - workers=["w1", "w2"], - prefix_length=128, - ) - assert isinstance(policy, CacheAwareDirect) - - def test_create_cache_aware_default_prefix(self): - policy = default_registry.create( - "cache_aware", - workers=["w1"], - ) - assert policy._prefix_length == 256 - - def test_unknown_policy_raises(self): - with pytest.raises(ValueError, match="Unknown scheduling policy"): - default_registry.create("nonexistent", workers=["w1"]) - - -# ------------------------------------------------------------------ # -# Session ID / prompt extraction helpers -# ------------------------------------------------------------------ # - - -def _make_mock_request( - headers=None, - client_host="127.0.0.1", -): - """Create a mock Starlette Request.""" - req = MagicMock() - _headers = headers or {} - req.headers = _headers - if client_host: - req.client = MagicMock() - req.client.host = client_host - else: - req.client = None - return req - - -class TestSessionIdExtraction: - """Test the session_id extraction priority logic used in proxy.""" - - @staticmethod - def _extract_session_id(raw_request, body): - """Replicate the extraction logic from xpyd.proxy.""" - return ( - raw_request.headers.get("x-session-id") - or body.get("user") - or (raw_request.client.host if raw_request.client else None) - ) - - def test_header_takes_priority(self): - req = _make_mock_request( - headers={"x-session-id": "sess-1"}, - client_host="1.2.3.4", - ) - body = {"user": "user-abc"} - assert self._extract_session_id(req, body) == "sess-1" - - def test_user_field_fallback(self): - req = _make_mock_request(headers={}, client_host="1.2.3.4") - body = {"user": "user-abc"} - assert self._extract_session_id(req, body) == "user-abc" - - def test_client_ip_fallback(self): - req = _make_mock_request(headers={}, client_host="10.0.0.1") - body = {} - assert self._extract_session_id(req, body) == "10.0.0.1" - - def test_no_client(self): - req = _make_mock_request(headers={}, client_host=None) - body = {} - assert self._extract_session_id(req, body) is None - - -class TestPromptExtraction: - """Test prompt text extraction from request bodies.""" - - def test_completion_prompt(self): - body = {"prompt": "Hello, world!"} - prompt = body.get("prompt", "") - assert prompt == "Hello, world!" - - def test_chat_messages(self): - body = { - "messages": [ - {"role": "system", "content": "You are helpful."}, - {"role": "user", "content": "Hi there"}, - ], - } - prompt_text = " ".join( - msg.get("content", "") - for msg in body.get("messages", []) - if isinstance(msg.get("content"), str) - ) - assert "You are helpful." in prompt_text - assert "Hi there" in prompt_text - - -# ------------------------------------------------------------------ # -# YAML config → policy instantiation -# ------------------------------------------------------------------ # - - -class TestYamlConfigPolicySelection: - """Test that YAML scheduling config produces the correct policy.""" - - def test_default_is_loadbalanced(self): - """Without explicit scheduling key, default is loadbalanced.""" - from xpyd.config import ProxyConfig - - config = ProxyConfig( - model="/tmp/model", - decode=["127.0.0.1:8000"], - ) - assert config.scheduling == "loadbalanced" - - def test_scheduling_field_stored(self): - from xpyd.config import ProxyConfig - - config = ProxyConfig( - model="/tmp/model", - decode=["127.0.0.1:8000"], - scheduling="cache_aware", - scheduling_config={"cache_aware": {"prefix_length": 128}}, - ) - assert config.scheduling == "cache_aware" - assert config.scheduling_config["cache_aware"]["prefix_length"] == 128 - - -# ------------------------------------------------------------------ # -# Backward compatibility — kwargs ignored by legacy policies -# ------------------------------------------------------------------ # - - -class TestBackwardCompatibility: - """Verify roundrobin and loadbalanced accept and ignore extra kwargs.""" - - def test_roundrobin_ignores_kwargs(self): - policy = RoundRobinSchedulingPolicy() - cycler = itertools.cycle(["w1", "w2"]) - result = policy.schedule( - cycler, - is_prompt=True, - header="sess-1", - session_id="sess-1", - user="u", - client_ip="1.2.3.4", - prompt="hello", - ) - assert result in {"w1", "w2"} - - def test_consistent_hash_uses_kwargs(self): - policy = ConsistentHashPolicy(workers=["w1", "w2", "w3"]) - cycler = itertools.cycle(["w1", "w2", "w3"]) - r1 = policy.schedule( - cycler, - header="sess-1", - session_id="sess-1", - ) - r2 = policy.schedule( - cycler, - header="sess-1", - session_id="sess-1", - ) - assert r1 == r2 # same session → same worker - - def test_power_of_two_ignores_kwargs(self): - policy = PowerOfTwoPolicy(workers=["w1", "w2"]) - cycler = itertools.cycle(["w1", "w2"]) - result = policy.schedule( - cycler, - header="sess-1", - prompt="hello", - ) - assert result in {"w1", "w2"} - - def test_cache_aware_uses_prompt_kwarg(self): - policy = CacheAwarePolicy(workers=["w1", "w2", "w3"]) - cycler = itertools.cycle(["w1", "w2", "w3"]) - r1 = policy.schedule(cycler, prompt="same prefix " * 100) - r2 = policy.schedule(cycler, prompt="same prefix " * 100) - assert r1 == r2 - - -# ------------------------------------------------------------------ # -# select_from — role-filtered ring routing -# ------------------------------------------------------------------ # - - -class TestConsistentHashSelectFrom: - """Tests for ConsistentHashPolicy.select_from.""" - - def test_normal_selection(self): - policy = ConsistentHashPolicy(workers=["w1", "w2", "w3"]) - result = policy.select_from({"w1", "w2"}, header="sess-1") - assert result in {"w1", "w2"} - - def test_empty_candidates(self): - policy = ConsistentHashPolicy(workers=["w1", "w2", "w3"]) - assert policy.select_from(set(), header="sess-1") is None - - def test_deterministic(self): - policy = ConsistentHashPolicy(workers=["w1", "w2", "w3"]) - r1 = policy.select_from({"w1", "w2"}, header="sess-1") - r2 = policy.select_from({"w1", "w2"}, header="sess-1") - assert r1 == r2 - - def test_all_candidates(self): - policy = ConsistentHashPolicy(workers=["w1", "w2", "w3"]) - result = policy.select_from({"w1", "w2", "w3"}, header="sess-1") - assert result in {"w1", "w2", "w3"} - - def test_candidate_not_on_ring(self): - """Candidate not present on the ring returns None.""" - policy = ConsistentHashPolicy(workers=["w1", "w2"]) - assert policy.select_from({"w99"}, header="sess-1") is None - - def test_schedule_role_filtered(self): - """schedule(is_prompt=True/False) routes to different role pools.""" - from xpyd.registry import InstanceRegistry - - reg = InstanceRegistry() - for p in ["p1", "p2"]: - reg.add("prefill", p) - reg.mark_healthy(p) - for d in ["d1", "d2"]: - reg.add("decode", d) - reg.mark_healthy(d) - policy = ConsistentHashPolicy(workers=["p1", "p2", "d1", "d2"], registry=reg) - cycler = itertools.cycle(["p1", "p2", "d1", "d2"]) - prefill = policy.schedule(cycler, is_prompt=True, header="sess-1") - decode = policy.schedule(cycler, is_prompt=False, header="sess-1") - assert prefill in {"p1", "p2"} - assert decode in {"d1", "d2"} - - -class TestCacheAwareSelectFrom: - """Tests for CacheAwarePolicy.select_from.""" - - def test_normal_selection(self): - policy = CacheAwarePolicy(workers=["w1", "w2", "w3"]) - result = policy.select_from({"w1", "w2"}, prompt="hello world") - assert result in {"w1", "w2"} - - def test_empty_candidates(self): - policy = CacheAwarePolicy(workers=["w1", "w2", "w3"]) - assert policy.select_from(set(), prompt="hello") is None - - def test_deterministic(self): - policy = CacheAwarePolicy(workers=["w1", "w2", "w3"]) - r1 = policy.select_from({"w1", "w2"}, prompt="hello world") - r2 = policy.select_from({"w1", "w2"}, prompt="hello world") - assert r1 == r2 - - def test_all_candidates(self): - policy = CacheAwarePolicy(workers=["w1", "w2", "w3"]) - result = policy.select_from({"w1", "w2", "w3"}, prompt="test") - assert result in {"w1", "w2", "w3"} - - def test_candidate_not_on_ring(self): - """Candidate not present on the ring returns None.""" - policy = CacheAwarePolicy(workers=["w1", "w2"]) - assert policy.select_from({"w99"}, prompt="hello") is None - - def test_schedule_role_filtered(self): - """schedule(is_prompt=True/False) routes to different role pools.""" - from xpyd.registry import InstanceRegistry - - reg = InstanceRegistry() - for p in ["p1", "p2"]: - reg.add("prefill", p) - reg.mark_healthy(p) - for d in ["d1", "d2"]: - reg.add("decode", d) - reg.mark_healthy(d) - policy = CacheAwarePolicy(workers=["p1", "p2", "d1", "d2"], registry=reg) - cycler = itertools.cycle(["p1", "p2", "d1", "d2"]) - prefill = policy.schedule(cycler, is_prompt=True, prompt="hello") - decode = policy.schedule(cycler, is_prompt=False, prompt="hello") - assert prefill in {"p1", "p2"} - assert decode in {"d1", "d2"} diff --git a/tests/integration/test_sim_nodes.py b/tests/integration/test_sim_nodes.py deleted file mode 100644 index feb9ec7..0000000 --- a/tests/integration/test_sim_nodes.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Tests for sim_adapter node apps (replaces test_prefill_node.py + test_decode_node.py).""" - -import json - -import pytest -from httpx import ASGITransport, AsyncClient - -from sim_adapter import make_sim_app - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -@pytest.fixture -async def prefill_client(): - app = make_sim_app(mode="prefill") - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: - yield c - - -@pytest.fixture -async def decode_client(): - app = make_sim_app(mode="decode") - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: - yield c - - -CHAT_PAYLOAD = { - "model": "dummy", - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 5, - "stream": False, -} - - -# --- Prefill node tests --- - - -@pytest.mark.anyio -async def test_prefill_health(prefill_client): - resp = await prefill_client.get("/health") - assert resp.status_code == 200 - assert resp.json()["status"] == "ok" - - -@pytest.mark.anyio -async def test_prefill_non_streaming(prefill_client): - resp = await prefill_client.post("/v1/chat/completions", json=CHAT_PAYLOAD) - assert resp.status_code == 200 - data = resp.json() - assert data["object"] == "chat.completion" - assert len(data["choices"]) == 1 - assert data["choices"][0]["message"]["role"] == "assistant" - assert len(data["choices"][0]["message"]["content"]) > 0 - - -@pytest.mark.anyio -async def test_prefill_streaming(prefill_client): - payload = {**CHAT_PAYLOAD, "stream": True} - resp = await prefill_client.post("/v1/chat/completions", json=payload) - assert resp.status_code == 200 - lines = resp.text.strip().split("\n") - data_lines = [line for line in lines if line.startswith("data: ")] - assert data_lines[-1] == "data: [DONE]" - first = json.loads(data_lines[0].removeprefix("data: ")) - assert first["choices"][0]["delta"]["role"] == "assistant" - - -@pytest.mark.anyio -async def test_prefill_max_tokens(prefill_client): - payload = {**CHAT_PAYLOAD, "max_tokens": 3} - resp = await prefill_client.post("/v1/chat/completions", json=payload) - assert resp.json()["usage"]["completion_tokens"] == 3 - - -@pytest.mark.anyio -async def test_prefill_default_max_tokens(prefill_client): - payload = {"model": "dummy", "messages": [{"role": "user", "content": "Hi"}]} - resp = await prefill_client.post("/v1/chat/completions", json=payload) - assert resp.json()["usage"]["completion_tokens"] == 16 - - -# --- Decode node tests --- - - -@pytest.mark.anyio -async def test_decode_health(decode_client): - resp = await decode_client.get("/health") - assert resp.status_code == 200 - assert resp.json()["status"] == "ok" - - -@pytest.mark.anyio -async def test_decode_non_streaming(decode_client): - resp = await decode_client.post("/v1/chat/completions", json=CHAT_PAYLOAD) - assert resp.status_code == 200 - data = resp.json() - assert data["object"] == "chat.completion" - assert len(data["choices"]) == 1 - assert data["choices"][0]["message"]["role"] == "assistant" - - -@pytest.mark.anyio -async def test_decode_streaming(decode_client): - payload = {**CHAT_PAYLOAD, "stream": True} - resp = await decode_client.post("/v1/chat/completions", json=payload) - assert resp.status_code == 200 - assert "data: [DONE]" in resp.text - first_data = [line for line in resp.text.split("\n") if line.startswith("data: ") and line != "data: [DONE]"][0] - first = json.loads(first_data.removeprefix("data: ")) - assert first["choices"][0]["delta"]["role"] == "assistant" - - -@pytest.mark.anyio -async def test_decode_max_tokens(decode_client): - payload = {**CHAT_PAYLOAD, "max_tokens": 10} - resp = await decode_client.post("/v1/chat/completions", json=payload) - assert resp.json()["usage"]["completion_tokens"] == 10 - - -@pytest.mark.anyio -async def test_decode_streaming_has_content(decode_client): - payload = {**CHAT_PAYLOAD, "max_tokens": 7, "stream": True} - resp = await decode_client.post("/v1/chat/completions", json=payload) - data_lines = [line for line in resp.text.strip().split("\n") - if line.startswith("data: ") and line != "data: [DONE]"] - content_chunks = sum(1 for line in data_lines - if json.loads(line.removeprefix("data: "))["choices"][0]["delta"].get("content")) - assert content_chunks >= 1 diff --git a/tests/integration/test_status_instances.py b/tests/integration/test_status_instances.py deleted file mode 100644 index 57ad191..0000000 --- a/tests/integration/test_status_instances.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Tests for the /status/instances endpoint.""" - -import pytest -from fastapi import FastAPI -from fastapi.responses import JSONResponse -from httpx import ASGITransport, AsyncClient - -from xpyd.registry import InstanceRegistry - - -def _build_status_app(registry: InstanceRegistry) -> FastAPI: - """Build a minimal FastAPI app with only the /status/instances endpoint. - - This mirrors the endpoint defined in ``MicroPDProxyServer.run`` so we can - test it in isolation without needing a full server configuration. - """ - app = FastAPI() - - @app.get("/status/instances") - async def _instance_status(): - result: dict[str, list] = { - "prefill_instances": [], - "decode_instances": [], - } - for info in registry.get_all_instances(): - result[f"{info.role}_instances"].append( - { - "address": info.address, - "status": info.status.value, - "circuit": info.circuit_breaker_state.value, - "active_requests": info.active_request_count, - "last_check": info.last_health_check, - } - ) - return JSONResponse(result) - - return app - - -@pytest.fixture -def registry() -> InstanceRegistry: - """Create a registry with known prefill and decode instances.""" - reg = InstanceRegistry() - reg.add("prefill", "10.0.0.1:8000") - reg.add("prefill", "10.0.0.2:8000") - reg.add("decode", "10.0.0.3:8000") - # Mark one prefill healthy so we get a mix of statuses - reg.mark_healthy("10.0.0.1:8000") - return reg - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -@pytest.mark.anyio -async def test_status_instances_returns_all_instances(registry): - """GET /status/instances should list every registered instance.""" - app = _build_status_app(registry) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as client: - resp = await client.get("/status/instances") - - assert resp.status_code == 200 - data = resp.json() - - # Both keys must be present - assert "prefill_instances" in data - assert "decode_instances" in data - - assert len(data["prefill_instances"]) == 2 - assert len(data["decode_instances"]) == 1 - - # Verify required fields on every instance entry - required_fields = {"address", "status", "circuit", "active_requests", "last_check"} - for section in ("prefill_instances", "decode_instances"): - for entry in data[section]: - assert required_fields.issubset(entry.keys()), ( - f"Missing fields in {section} entry: " - f"{required_fields - entry.keys()}" - ) - - # Check specific addresses - prefill_addrs = {e["address"] for e in data["prefill_instances"]} - assert prefill_addrs == {"10.0.0.1:8000", "10.0.0.2:8000"} - - decode_addrs = {e["address"] for e in data["decode_instances"]} - assert decode_addrs == {"10.0.0.3:8000"} - - -@pytest.mark.anyio -async def test_status_instances_reflects_health(registry): - """Healthy instance should report status=healthy; others stay unknown.""" - app = _build_status_app(registry) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as client: - resp = await client.get("/status/instances") - - data = resp.json() - by_addr = {e["address"]: e for e in data["prefill_instances"]} - - assert by_addr["10.0.0.1:8000"]["status"] == "healthy" - assert by_addr["10.0.0.2:8000"]["status"] == "unknown" - - -@pytest.mark.anyio -async def test_status_instances_empty_registry(): - """An empty registry should return empty lists.""" - reg = InstanceRegistry() - app = _build_status_app(reg) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as client: - resp = await client.get("/status/instances") - - data = resp.json() - assert data == {"prefill_instances": [], "decode_instances": []} diff --git a/tests/integration/test_streaming_edge.py b/tests/integration/test_streaming_edge.py deleted file mode 100644 index b6a5339..0000000 --- a/tests/integration/test_streaming_edge.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Tests for streaming edge cases.""" - -import json -import os -import socket -import threading -import time - -import pytest -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from httpx import ASGITransport, AsyncClient - -from sim_adapter import make_sim_app -from xpyd.proxy import Proxy, RoundRobinSchedulingPolicy - -prefill_app = make_sim_app(mode='prefill') -decode_app = make_sim_app(mode='decode') - -_REPO_ROOT = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) -_TOKENIZER_PATH = os.path.join(_REPO_ROOT, "tokenizers", "DeepSeek-R1") - - -def _free_port(): - with socket.socket() as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("127.0.0.1", 0)) - return sock.getsockname()[1] - - -_PREFILL_PORT = _free_port() -_DECODE_PORT = _free_port() - - -def _run_server(app, port): - config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - uvicorn.Server(config).run() - - -threading.Thread( - target=_run_server, args=(prefill_app, _PREFILL_PORT), daemon=True -).start() -threading.Thread( - target=_run_server, args=(decode_app, _DECODE_PORT), daemon=True -).start() -time.sleep(2) - - -def _make_proxy_app(): - proxy = Proxy( - prefill_instances=[f"127.0.0.1:{_PREFILL_PORT}"], - decode_instances=[f"127.0.0.1:{_DECODE_PORT}"], - model=_TOKENIZER_PATH, - scheduling_policy=RoundRobinSchedulingPolicy(), - generator_on_p_node=False, - ) - app = FastAPI() - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], - ) - app.include_router(proxy.router) - return app - - -CHAT_PAYLOAD = { - "model": "dummy", - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 5, - "stream": True, -} - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -@pytest.fixture -async def client(): - app = _make_proxy_app() - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - yield cli - - -@pytest.mark.anyio -async def test_streaming_chunk_structure(client: AsyncClient): - """Each SSE chunk (except [DONE]) should be valid JSON with expected fields.""" - resp = await client.post("/v1/chat/completions", json=CHAT_PAYLOAD) - assert resp.status_code == 200 - assert "text/event-stream" in resp.headers["content-type"] - - lines = resp.text.strip().split("\n") - data_lines = [ - line for line in lines if line.startswith("data: ") and line != "data: [DONE]" - ] - assert len(data_lines) >= 1, "Expected at least one data chunk before [DONE]" - - chunk_ids = set() - for line in data_lines: - payload = line.removeprefix("data: ") - chunk = json.loads(payload) - assert "id" in chunk - assert chunk["object"] == "chat.completion.chunk" - assert "choices" in chunk - assert len(chunk["choices"]) >= 1 - assert "delta" in chunk["choices"][0] - chunk_ids.add(chunk["id"]) - - # All chunks in a single response should share the same id - assert len(chunk_ids) == 1, f"Expected one unique id across chunks, got {chunk_ids}" - - -@pytest.mark.anyio -async def test_streaming_max_tokens_one(client: AsyncClient): - """Streaming with max_tokens=1 should produce exactly one content chunk + [DONE].""" - payload = { - "model": "dummy", - "messages": [{"role": "user", "content": "Say something"}], - "max_tokens": 1, - "stream": True, - } - resp = await client.post("/v1/chat/completions", json=payload) - assert resp.status_code == 200 - assert "text/event-stream" in resp.headers["content-type"] - - lines = resp.text.strip().split("\n") - data_lines = [line for line in lines if line.startswith("data: ")] - assert data_lines[-1] == "data: [DONE]" - - content_chunks = [line for line in data_lines if line != "data: [DONE]"] - # With max_tokens=1, expect exactly 1 content chunk - assert len(content_chunks) >= 1 - - # Verify the chunk is valid - chunk = json.loads(content_chunks[0].removeprefix("data: ")) - assert chunk["object"] == "chat.completion.chunk" - assert "delta" in chunk["choices"][0] diff --git a/tests/integration/test_xpyd_start_proxy_integration.py b/tests/integration/test_xpyd_start_proxy_integration.py deleted file mode 100644 index 9aa0f3d..0000000 --- a/tests/integration/test_xpyd_start_proxy_integration.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Real integration tests for xpyd_start_proxy.sh with dummy nodes. - -These tests do not stop at validating the generated command string. Instead, -they start dummy prefill/decode nodes locally, launch the proxy through the -shell script itself, and then send real requests through the proxy. -""" - -from __future__ import annotations - -import os -import socket -import subprocess -import sys -import time -from contextlib import ExitStack -from pathlib import Path - -import pytest -import requests - -REPO_ROOT = Path(__file__).resolve().parents[2] -SCRIPT = REPO_ROOT / "xpyd" / "xpyd_start_proxy.sh" -PYTHON = sys.executable -TOKENIZER_DIR = str(REPO_ROOT / "tokenizers" / "DeepSeek-R1") -ENV_BASE = { - **os.environ, - "PYTHONPATH": str(REPO_ROOT), - "model_path": TOKENIZER_DIR, - "PREFILL_DELAY_PER_TOKEN": "0", - "DECODE_DELAY_PER_TOKEN": "0", - "NO_PROXY": "127.0.0.1,localhost", - "no_proxy": "127.0.0.1,localhost", -} - - -def _free_port() -> int: - with socket.socket() as sock: - sock.bind(("127.0.0.1", 0)) - return sock.getsockname()[1] - - -def _wait_http_ok(url: str, timeout: float = 30.0) -> None: - deadline = time.time() + timeout - last_error: Exception | None = None - while time.time() < deadline: - try: - response = requests.get(url, timeout=1.5) - if response.status_code == 200: - return - except Exception as exc: - last_error = exc - time.sleep(0.2) - raise AssertionError(f"Timed out waiting for {url}; last_error={last_error}") - - -def _spawn_node(mode, port): - app_ref = "sim_adapter:prefill_app" if mode == "prefill" else "sim_adapter:decode_app" - return subprocess.Popen( - [PYTHON, "-m", "uvicorn", app_ref, "--host", "127.0.0.1", "--port", str(port), "--log-level", "warning"], - cwd=REPO_ROOT, env={**os.environ, "PYTHONPATH": str(REPO_ROOT)}, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - ) - - -def _stop_process(process: subprocess.Popen) -> None: - if process.poll() is not None: - return - process.terminate() - try: - process.wait(timeout=5) - except subprocess.TimeoutExpired: - process.kill() - process.wait(timeout=5) - - -def _drain_process_output(process: subprocess.Popen) -> str: - stdout = "" - stderr = "" - try: - if process.stdout: - stdout = process.stdout.read() or "" - if process.stderr: - stderr = process.stderr.read() or "" - except Exception: - pass - return f"STDOUT:\n{stdout}\nSTDERR:\n{stderr}" - - -def _launch_proxy_via_script( - prefill_ports: list[int], decode_ports: list[int], proxy_port: int -): - env = { - **ENV_BASE, - "XPYD_PREFILL_IPS": " ".join(["127.0.0.1"] * len(prefill_ports)), - "XPYD_DECODE_IPS": " ".join(["127.0.0.1"] * len(decode_ports)), - "XPYD_PROXY_PORT": str(proxy_port), - "HTTP_PROXY": "", - "HTTPS_PROXY": "", - "http_proxy": "", - "https_proxy": "", - } - command = [ - "bash", - str(SCRIPT), - "-pn", - str(len(prefill_ports)), - "-pt", - "8", - "-pd", - str(len(prefill_ports)), - "-pw", - "8", - "-dn", - str(len(decode_ports)), - "-dt", - "8", - "-dd", - str(len(decode_ports)), - "-dw", - "8", - "--prefill-base-port", - str(prefill_ports[0]), - "--decode-base-port", - str(decode_ports[0]), - ] - return subprocess.Popen( - command, - cwd=REPO_ROOT / "xpyd", - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - - -@pytest.mark.parametrize("prefill_count,decode_count", [(1, 2), (2, 2)]) -def test_xpyd_start_proxy_launches_real_proxy_with_dummy_nodes( - prefill_count: int, decode_count: int -): - prefill_ports = [_free_port() for _ in range(prefill_count)] - decode_ports = [_free_port() for _ in range(decode_count)] - proxy_port = _free_port() - - with ExitStack() as stack: - prefill_processes = [] - decode_processes = [] - for port in prefill_ports: - process = _spawn_node("prefill", port) - prefill_processes.append(process) - stack.callback(_stop_process, process) - for port in decode_ports: - process = _spawn_node("decode", port) - decode_processes.append(process) - stack.callback(_stop_process, process) - - for port in prefill_ports: - _wait_http_ok(f"http://127.0.0.1:{port}/v1/models") - for port in decode_ports: - _wait_http_ok(f"http://127.0.0.1:{port}/v1/models") - - proxy = _launch_proxy_via_script(prefill_ports, decode_ports, proxy_port) - stack.callback(_stop_process, proxy) - - try: - _wait_http_ok(f"http://127.0.0.1:{proxy_port}/status") - except AssertionError: - pytest.fail(_drain_process_output(proxy)) - - status = requests.get(f"http://127.0.0.1:{proxy_port}/status", timeout=5).json() - assert status["prefill_node_count"] == prefill_count - assert status["decode_node_count"] == decode_count - - payload = { - "model": TOKENIZER_DIR, - "messages": [{"role": "user", "content": "integration via shell script"}], - "max_tokens": 4, - "stream": False, - } - response = requests.post( - f"http://127.0.0.1:{proxy_port}/v1/chat/completions", - json=payload, - timeout=15, - ) - assert response.status_code == 200, response.text - data = response.json() - assert data["object"] == "chat.completion" - assert data["choices"][0]["message"]["role"] == "assistant" - assert data["usage"]["completion_tokens"] == 4 - - stream_payload = dict(payload) - stream_payload["stream"] = True - stream_response = requests.post( - f"http://127.0.0.1:{proxy_port}/v1/chat/completions", - json=stream_payload, - timeout=15, - ) - assert stream_response.status_code == 200, stream_response.text - assert "data: [DONE]" in stream_response.text - - for process in [*prefill_processes, *decode_processes, proxy]: - if process.poll() not in (None, 0): - pytest.fail(_drain_process_output(process)) diff --git a/tests/integration/test_xpyd_start_proxy_sh.py b/tests/integration/test_xpyd_start_proxy_sh.py deleted file mode 100644 index 8f82a90..0000000 --- a/tests/integration/test_xpyd_start_proxy_sh.py +++ /dev/null @@ -1,335 +0,0 @@ -"""Tests for xpyd/xpyd_start_proxy.sh parameterization and validation.""" - -from __future__ import annotations - -import os -import subprocess -from pathlib import Path - -REPO_ROOT = Path(__file__).resolve().parents[2] -SCRIPT = REPO_ROOT / "xpyd" / "xpyd_start_proxy.sh" - -_MINIMAL_TOPO = [ - "-pn", - "1", - "-pt", - "8", - "-pd", - "1", - "-pw", - "8", - "-dn", - "1", - "-dt", - "8", - "-dd", - "1", - "-dw", - "8", -] - - -def run_script(*args: str, env_overrides: dict | None = None): - env = { - **os.environ, - "model_path": "dummy-model", - "XPYD_DRY_RUN": "1", - } - if env_overrides: - env.update(env_overrides) - return subprocess.run( - ["bash", str(SCRIPT), *args], - cwd=REPO_ROOT / "xpyd", - env=env, - capture_output=True, - text=True, - ) - - -def extract_running_line(stdout: str) -> str: - for line in stdout.splitlines(): - if line.startswith("Running: "): - return line - raise AssertionError(f"Running line not found in stdout:\n{stdout}") - - -def extract_config_from_cmd(stdout: str) -> dict: - """Extract the YAML config file path from the Running line and read it.""" - import re - - import yaml - - cmd = extract_running_line(stdout) - match = re.search(r"--config\s+(\S+)", cmd) - if not match: - raise AssertionError(f"No --config found in command: {cmd}") - config_path = match.group(1) - with open(config_path) as f: - return yaml.safe_load(f) - - -def test_valid_topology_simple_same_node_instances(): - result = run_script( - "-pn", - "2", - "-pt", - "4", - "-pd", - "4", - "-pw", - "8", - "-dn", - "2", - "-dt", - "2", - "-dd", - "8", - "-dw", - "8", - ) - assert result.returncode == 0, result.stderr - config = extract_config_from_cmd(result.stdout) - prefill = config.get("prefill", []) - assert "10.239.129.9:8100" in prefill - assert "10.239.129.9:8101" in prefill - assert "10.239.129.67:8100" in prefill - assert "10.239.129.67:8101" in prefill - decode = config.get("decode", []) - assert "10.239.129.81:8200" in decode - assert "10.239.129.165:8200" in decode - decode = [d for d in decode if d] # filter empty strings - assert len(decode) == 8 - - -def test_valid_topology_cross_node_instance_exposes_main_node_only(): - result = run_script( - "-pn", - "2", - "-pt", - "16", - "-pd", - "1", - "-pw", - "8", - "-dn", - "4", - "-dt", - "8", - "-dd", - "4", - "-dw", - "8", - ) - assert result.returncode == 0, result.stderr - config = extract_config_from_cmd(result.stdout) - prefill = [p for p in config.get("prefill", []) if p] - assert "10.239.129.9:8100" in prefill - assert "10.239.129.67:8100" not in prefill - decode = [d for d in config.get("decode", []) if d] - assert "10.239.129.81:8200" in decode - assert "10.239.129.165:8200" in decode - decode = [d for d in decode if d] # filter empty strings - assert len(decode) == 4 - - -def test_custom_base_ports(): - result = run_script( - "-pn", - "1", - "-pt", - "8", - "-pd", - "1", - "-pw", - "8", - "-dn", - "1", - "-dt", - "8", - "-dd", - "1", - "-dw", - "8", - "--prefill-base-port", - "9100", - "--decode-base-port", - "9200", - ) - assert result.returncode == 0, result.stderr - config = extract_config_from_cmd(result.stdout) - prefill = config.get("prefill", []) - decode = config.get("decode", []) - assert "10.239.129.9:9100" in prefill - assert "10.239.129.81:9200" in decode - - -def test_reject_non_power_of_two_tp(): - result = run_script( - "-pn", - "1", - "-pt", - "3", - "-pd", - "8", - "-pw", - "8", - "-dn", - "1", - "-dt", - "8", - "-dd", - "1", - "-dw", - "8", - ) - assert result.returncode != 0 - assert "prefill tp_size must be a power of two" in result.stderr - - -def test_reject_non_power_of_two_dp(): - result = run_script( - "-pn", - "1", - "-pt", - "8", - "-pd", - "3", - "-pw", - "8", - "-dn", - "1", - "-dt", - "8", - "-dd", - "1", - "-dw", - "8", - ) - assert result.returncode != 0 - assert "prefill dp_size must be a power of two" in result.stderr - - -def test_reject_invalid_topology_product_constraint(): - result = run_script( - "-pn", - "2", - "-pt", - "8", - "-pd", - "1", - "-pw", - "8", - "-dn", - "1", - "-dt", - "8", - "-dd", - "1", - "-dw", - "8", - ) - assert result.returncode != 0 - assert "prefill topology invalid" in result.stderr - - -def test_reject_nodes_exceeding_ip_list(): - result = run_script( - "-pn", - "8", - "-pt", - "8", - "-pd", - "8", - "-pw", - "8", - "-dn", - "1", - "-dt", - "8", - "-dd", - "1", - "-dw", - "8", - ) - assert result.returncode != 0 - assert "prefill nodes exceeds available IP list length" in result.stderr - - -def test_reject_non_integer_argument(): - result = run_script( - "-pn", - "two", - "-pt", - "8", - "-pd", - "1", - "-pw", - "8", - "-dn", - "1", - "-dt", - "8", - "-dd", - "1", - "-dw", - "8", - ) - assert result.returncode != 0 - assert "prefill nodes must be a positive integer" in result.stderr - - -def test_reject_zero_or_negative_argument(): - result = run_script( - "-pn", - "0", - "-pt", - "8", - "-pd", - "1", - "-pw", - "8", - "-dn", - "1", - "-dt", - "8", - "-dd", - "1", - "-dw", - "8", - ) - assert result.returncode != 0 - assert "prefill nodes must be a positive integer" in result.stderr - - -def test_model_cli_arg_overrides_env_var(): - """--model CLI arg should override model_path env var.""" - result = run_script( - *_MINIMAL_TOPO, - "--model", - "/cli/model/path", - env_overrides={"model_path": "/env/model/path"}, - ) - assert result.returncode == 0, result.stderr - config = extract_config_from_cmd(result.stdout) - assert config.get("model") == "/cli/model/path" - - -def test_model_env_var_fallback(): - """When --model is not given, script should use model_path env var.""" - result = run_script( - *_MINIMAL_TOPO, - env_overrides={"model_path": "/env/fallback/model"}, - ) - assert result.returncode == 0, result.stderr - config = extract_config_from_cmd(result.stdout) - assert config.get("model") == "/env/fallback/model" - - -def test_missing_model_errors(): - """When neither --model nor model_path env var is set, script should fail.""" - result = run_script( - *_MINIMAL_TOPO, - env_overrides={"model_path": ""}, - ) - assert result.returncode != 0 - assert "model path is not set" in result.stderr diff --git a/tests/integration/test_yaml_config.py b/tests/integration/test_yaml_config.py deleted file mode 100644 index 623a9b2..0000000 --- a/tests/integration/test_yaml_config.py +++ /dev/null @@ -1,390 +0,0 @@ -"""Tests for YAML config loading, topology expansion, and precedence.""" - -from __future__ import annotations - -import argparse -import os -import textwrap -from pathlib import Path -from unittest.mock import patch - -import pytest - -from xpyd.config import ProxyConfig - - -@pytest.fixture -def tmp_yaml(tmp_path): - """Helper that writes YAML content to a temp file and returns the path.""" - - def _write(content: str) -> Path: - p = tmp_path / "config.yaml" - p.write_text(textwrap.dedent(content)) - return p - - return _write - - -def _make_args(**overrides): - """Build a minimal argparse Namespace with defaults.""" - defaults = { - "config": None, - "model": None, - "prefill": None, - "decode": None, - "port": 8000, - "generator_on_p_node": False, - "roundrobin": False, - "log_level": "warning", - } - defaults.update(overrides) - return argparse.Namespace(**defaults) - - -# ------------------------------------------------------------------ -# YAML loading basics -# ------------------------------------------------------------------ - - -class TestLoadYaml: - def test_valid_yaml(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /path/to/model - decode: - nodes: - - "10.0.0.1:8200" - tp_size: 1 - dp_size: 1 - world_size_per_node: 1 - """ - ) - data = ProxyConfig.load_yaml(p) - assert data["model"] == "/path/to/model" - - def test_file_not_found(self): - with pytest.raises(FileNotFoundError, match="Config file not found"): - ProxyConfig.load_yaml("/nonexistent/path.yaml") - - def test_malformed_yaml(self, tmp_yaml): - p = tmp_yaml("{ bad yaml :::") - with pytest.raises(ValueError, match="Malformed YAML"): - ProxyConfig.load_yaml(p) - - def test_non_dict_yaml(self, tmp_yaml): - p = tmp_yaml("- just\n- a\n- list\n") - with pytest.raises(ValueError, match="must be a mapping"): - ProxyConfig.load_yaml(p) - - def test_unknown_keys_rejected(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /path/model - decode: - - "10.0.0.1:8000" - unknown_field: oops - """ - ) - args = _make_args(config=str(p)) - with pytest.raises(ValueError, match="Unknown keys"): - ProxyConfig.from_args(args) - - -# ------------------------------------------------------------------ -# Topology-style YAML -# ------------------------------------------------------------------ - - -class TestTopologyYaml: - def test_full_topology_expansion(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /path/model - prefill: - nodes: - - "10.0.0.1:8100" - tp_size: 8 - dp_size: 1 - world_size_per_node: 8 - decode: - nodes: - - "10.0.0.3:8200" - - "10.0.0.4:8200" - tp_size: 1 - dp_size: 16 - world_size_per_node: 8 - """ - ) - args = _make_args(config=str(p)) - cfg = ProxyConfig.from_args(args) - assert cfg.prefill == ["10.0.0.1:8100"] - assert len(cfg.decode) == 16 - assert cfg.decode[0] == "10.0.0.3:8200" - assert cfg.decode[8] == "10.0.0.4:8200" - - def test_flat_list_backward_compat(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /path/model - decode: - - "10.0.0.1:8000" - - "10.0.0.2:8000" - """ - ) - args = _make_args(config=str(p)) - cfg = ProxyConfig.from_args(args) - assert cfg.decode == ["10.0.0.1:8000", "10.0.0.2:8000"] - - def test_topology_missing_keys(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /path/model - decode: - nodes: - - "10.0.0.1:8200" - tp_size: 1 - """ - ) - args = _make_args(config=str(p)) - with pytest.raises(ValueError, match="missing keys"): - ProxyConfig.from_args(args) - - def test_topology_invalid_constraint(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /path/model - decode: - nodes: - - "10.0.0.1:8200" - tp_size: 8 - dp_size: 2 - world_size_per_node: 8 - """ - ) - args = _make_args(config=str(p)) - with pytest.raises(ValueError, match="topology invalid"): - ProxyConfig.from_args(args) - - def test_topology_unknown_keys_rejected(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /path/model - decode: - nodes: - - "10.0.0.1:8200" - tp_size: 1 - dp_size: 1 - world_size_per_node: 1 - extra_key: bad - """ - ) - args = _make_args(config=str(p)) - with pytest.raises(ValueError, match="unknown keys"): - ProxyConfig.from_args(args) - - -# ------------------------------------------------------------------ -# Precedence: CLI > env > YAML > defaults -# ------------------------------------------------------------------ - - -class TestPrecedence: - def test_yaml_only(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /yaml/model - decode: - - "10.0.0.1:8000" - port: 9000 - """ - ) - args = _make_args(config=str(p)) - cfg = ProxyConfig.from_args(args) - assert cfg.model == "/yaml/model" - assert cfg.port == 9000 - - def test_cli_overrides_yaml(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /yaml/model - decode: - - "10.0.0.1:8000" - port: 9000 - """ - ) - args = _make_args(config=str(p), port=7777, model="/cli/model") - cfg = ProxyConfig.from_args(args) - assert cfg.model == "/cli/model" - assert cfg.port == 7777 - - def test_cli_decode_overrides_yaml_topology(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /yaml/model - decode: - nodes: - - "10.0.0.1:8200" - tp_size: 1 - dp_size: 1 - world_size_per_node: 1 - """ - ) - args = _make_args(config=str(p), decode=["10.0.0.99:9999"]) - cfg = ProxyConfig.from_args(args) - assert cfg.decode == ["10.0.0.99:9999"] - - def test_env_overrides_yaml_for_api_keys(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /yaml/model - decode: - - "10.0.0.1:8000" - admin_api_key: yaml-key - """ - ) - args = _make_args(config=str(p)) - with patch.dict(os.environ, {"ADMIN_API_KEY": "env-key"}): - cfg = ProxyConfig.from_args(args) - assert cfg.admin_api_key == "env-key" - - def test_yaml_api_key_fallback(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /yaml/model - decode: - - "10.0.0.1:8000" - admin_api_key: yaml-key - """ - ) - args = _make_args(config=str(p)) - env = {k: v for k, v in os.environ.items() if k != "ADMIN_API_KEY"} - with patch.dict(os.environ, env, clear=True): - cfg = ProxyConfig.from_args(args) - assert cfg.admin_api_key == "yaml-key" - - def test_no_config_uses_cli_only(self): - args = _make_args(model="/cli/model", decode=["10.0.0.1:8000"], port=5555) - cfg = ProxyConfig.from_args(args) - assert cfg.model == "/cli/model" - assert cfg.port == 5555 - - -# ------------------------------------------------------------------ -# Scheduling -# ------------------------------------------------------------------ - - -class TestScheduling: - def test_roundrobin_from_yaml(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /yaml/model - decode: - - "10.0.0.1:8000" - scheduling: roundrobin - """ - ) - args = _make_args(config=str(p)) - cfg = ProxyConfig.from_args(args) - assert cfg.roundrobin is True - - def test_loadbalanced_from_yaml(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /yaml/model - decode: - - "10.0.0.1:8000" - scheduling: loadbalanced - """ - ) - args = _make_args(config=str(p)) - cfg = ProxyConfig.from_args(args) - assert cfg.roundrobin is False - - def test_cli_roundrobin_overrides_yaml(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /yaml/model - decode: - - "10.0.0.1:8000" - scheduling: loadbalanced - """ - ) - args = _make_args(config=str(p), roundrobin=True) - cfg = ProxyConfig.from_args(args) - assert cfg.roundrobin is True - - def test_invalid_scheduling_rejected(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /yaml/model - decode: - - "10.0.0.1:8000" - scheduling: typo - """ - ) - args = _make_args(config=str(p)) - with pytest.raises(ValueError, match="Invalid scheduling value"): - ProxyConfig.from_args(args) - - -# ------------------------------------------------------------------ -# log_level -# ------------------------------------------------------------------ - - -class TestLogLevel: - def test_log_level_from_yaml(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /path/model - decode: - - "10.0.0.1:8000" - log_level: debug - """ - ) - args = _make_args(config=str(p)) - cfg = ProxyConfig.from_args(args) - assert cfg.log_level == "debug" - - def test_invalid_log_level(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /path/model - decode: - - "10.0.0.1:8000" - log_level: verbose - """ - ) - args = _make_args(config=str(p)) - with pytest.raises(ValueError, match="log_level"): - ProxyConfig.from_args(args) - - def test_default_log_level(self): - args = _make_args(model="/m", decode=["10.0.0.1:8000"]) - cfg = ProxyConfig.from_args(args) - assert cfg.log_level == "warning" - - -# ------------------------------------------------------------------ -# Missing model -# ------------------------------------------------------------------ - - -class TestMissingModel: - def test_no_model_anywhere_raises(self): - args = _make_args(decode=["10.0.0.1:8000"]) - with pytest.raises(ValueError): - ProxyConfig.from_args(args) - - def test_model_in_yaml_only(self, tmp_yaml): - p = tmp_yaml( - """\ - model: /yaml/model - decode: - - "10.0.0.1:8000" - """ - ) - args = _make_args(config=str(p)) - cfg = ProxyConfig.from_args(args) - assert cfg.model == "/yaml/model" diff --git a/tests/integration/test_yaml_integration.py b/tests/integration/test_yaml_integration.py deleted file mode 100644 index b00686d..0000000 --- a/tests/integration/test_yaml_integration.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Integration test: start proxy from a YAML config with dummy nodes.""" - -from __future__ import annotations - -import argparse -import socket -import textwrap -import threading -import time -from pathlib import Path - -import pytest -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from httpx import ASGITransport, AsyncClient - -from sim_adapter import make_sim_app -from xpyd.config import ProxyConfig -from xpyd.proxy import Proxy, RoundRobinSchedulingPolicy - -prefill_app = make_sim_app(mode='prefill') -decode_app = make_sim_app(mode='decode') - -_REPO_ROOT = Path(__file__).resolve().parents[2] -_TOKENIZER_PATH = str(_REPO_ROOT / "tokenizers" / "DeepSeek-R1") - - -def _free_port(): - with socket.socket() as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("127.0.0.1", 0)) - return sock.getsockname()[1] - - -# Start dedicated dummy nodes for YAML integration tests. -_YAML_PREFILL_PORT = _free_port() -_YAML_DECODE_PORT = _free_port() - - -def _run_server(app, port): - config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") - uvicorn.Server(config).run() - - -threading.Thread( - target=_run_server, args=(prefill_app, _YAML_PREFILL_PORT), daemon=True -).start() -threading.Thread( - target=_run_server, args=(decode_app, _YAML_DECODE_PORT), daemon=True -).start() -time.sleep(2) - - -def _make_proxy_from_yaml(yaml_content: str, tmp_path: Path) -> Proxy: - """Write YAML to a temp file, parse it, and build a Proxy instance.""" - config_file = tmp_path / "test_config.yaml" - config_file.write_text(textwrap.dedent(yaml_content)) - - args = argparse.Namespace( - config=str(config_file), - model=None, - prefill=None, - decode=None, - port=8000, - generator_on_p_node=False, - roundrobin=False, - log_level="warning", - ) - config = ProxyConfig.from_args(args) - - return Proxy( - prefill_instances=config.prefill, - decode_instances=config.decode, - model=config.model, - scheduling_policy=RoundRobinSchedulingPolicy(), - generator_on_p_node=config.generator_on_p_node, - ) - - -def _make_app(proxy: Proxy) -> FastAPI: - app = FastAPI() - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], - ) - app.include_router(proxy.router) - return app - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -@pytest.fixture -async def yaml_client(tmp_path): - """Async HTTP client wired to a proxy started from YAML config.""" - yaml_content = f"""\ - model: {_TOKENIZER_PATH} - prefill: - - "127.0.0.1:{_YAML_PREFILL_PORT}" - decode: - - "127.0.0.1:{_YAML_DECODE_PORT}" - scheduling: roundrobin - """ - proxy = _make_proxy_from_yaml(yaml_content, tmp_path) - app = _make_app(proxy) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - yield cli - - -@pytest.fixture -async def yaml_topology_client(tmp_path): - """Client from YAML config using topology-style node definition.""" - yaml_content = f"""\ - model: {_TOKENIZER_PATH} - prefill: - nodes: - - "127.0.0.1:{_YAML_PREFILL_PORT}" - tp_size: 1 - dp_size: 1 - world_size_per_node: 1 - decode: - nodes: - - "127.0.0.1:{_YAML_DECODE_PORT}" - tp_size: 1 - dp_size: 1 - world_size_per_node: 1 - scheduling: roundrobin - """ - proxy = _make_proxy_from_yaml(yaml_content, tmp_path) - app = _make_app(proxy) - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as cli: - yield cli - - -@pytest.mark.anyio -async def test_yaml_config_health(yaml_client): - """Proxy started from YAML config should respond to /health.""" - resp = await yaml_client.get("/health") - assert resp.status_code == 200 - - -@pytest.mark.anyio -async def test_yaml_config_non_streaming(yaml_client): - """Non-streaming completion through YAML-configured proxy.""" - resp = await yaml_client.post( - "/v1/completions", - json={ - "model": "test", - "prompt": "Hello", - "max_tokens": 5, - "stream": False, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert "choices" in data - - -@pytest.mark.anyio -async def test_yaml_config_streaming(yaml_client): - """Streaming completion through YAML-configured proxy.""" - resp = await yaml_client.post( - "/v1/completions", - json={ - "model": "test", - "prompt": "Hello", - "max_tokens": 5, - "stream": True, - }, - ) - assert resp.status_code == 200 - body = resp.text - assert "data:" in body - - -@pytest.mark.anyio -async def test_yaml_config_chat_completion(yaml_client): - """Chat completion through YAML-configured proxy.""" - resp = await yaml_client.post( - "/v1/chat/completions", - json={ - "model": "test", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 5, - "stream": False, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert "choices" in data - - -@pytest.mark.anyio -async def test_yaml_topology_health(yaml_topology_client): - """Proxy from topology-style YAML config responds to /health.""" - resp = await yaml_topology_client.get("/health") - assert resp.status_code == 200 - - -@pytest.mark.anyio -async def test_yaml_topology_completion(yaml_topology_client): - """Completion through topology-style YAML-configured proxy.""" - resp = await yaml_topology_client.post( - "/v1/completions", - json={ - "model": "test", - "prompt": "Hello", - "max_tokens": 5, - "stream": False, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert "choices" in data diff --git a/tests/stress/__init__.py b/tests/stress/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/stress/test_benchmark_e2e.py b/tests/stress/test_benchmark_e2e.py deleted file mode 100644 index 4a419de..0000000 --- a/tests/stress/test_benchmark_e2e.py +++ /dev/null @@ -1,395 +0,0 @@ -"""End-to-end benchmark: 1000 concurrent clients, 10000 requests, mixed lengths. - -Topology: 2 prefill + 16 decode + 1 proxy (same as test_benchmark_integration). -Excluded from CI via --ignore. Run manually: - - pytest tests/test_benchmark_e2e.py -v -s - -Uses pytest.mark.benchmark so it can also be collected via: - - pytest -m benchmark tests/test_benchmark_e2e.py -""" - -from __future__ import annotations - -import os -import random -import socket -import subprocess -import sys -import tempfile -import time -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any - -import httpx -import pytest -import yaml - -_REPO_ROOT = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) -MODEL_PATH = os.path.join(_REPO_ROOT, "tokenizers", "DeepSeek-R1") - -NUM_PREFILL = 2 -NUM_DECODE = 16 -TOTAL_REQUESTS = 10_000 -MAX_CONCURRENCY = 1_000 - - -def _free_port() -> int: - with socket.socket() as s: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -def _write_bench_config(model, prefill, decode, port): - """Write a temporary YAML config for benchmark proxy launch.""" - f = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) - yaml.dump({"model": model, "prefill": prefill, "decode": decode, "port": port}, f) - f.close() - return f.name - - -def _wait_port(port: int, timeout: float = 20.0) -> bool: - deadline = time.time() + timeout - while time.time() < deadline: - try: - with socket.create_connection(("127.0.0.1", port), timeout=1): - return True - except OSError: - time.sleep(0.5) - return False - - -def _random_content(length: int) -> str: - """Generate a random string of approximately *length* characters.""" - if length <= 0: - return "" - # Use words to produce roughly natural-looking content. - words = ["hello", "world", "bench", "test", "proxy", "stream", "token", "data"] - pieces: list[str] = [] - cur = 0 - while cur < length: - w = random.choice(words) - pieces.append(w) - cur += len(w) + 1 # +1 for the space - return " ".join(pieces)[:length] - - -def _build_payload(model: str, stream: bool) -> dict[str, Any]: - """Build a chat completion payload with random prompt length 0-10k chars.""" - prompt_len = random.randint(0, 10_000) - content = _random_content(prompt_len) - return { - "model": model, - "messages": [{"role": "user", "content": content}], - "max_tokens": random.randint(1, 64), - "stream": stream, - } - - -# --------------------------------------------------------------------------- -# Cluster fixture (module-scoped — start once, reuse across all tests) -# --------------------------------------------------------------------------- - - -@pytest.fixture(scope="module") -def cluster(): - """Spin up dummy nodes + proxy, yield connection info, tear down.""" - env = os.environ.copy() - # Speed up dummy nodes for benchmarking - env["PREFILL_DELAY_PER_TOKEN"] = "0" - env["DECODE_DELAY_PER_TOKEN"] = "0" - procs: list[subprocess.Popen] = [] - - prefill_ports = [_free_port() for _ in range(NUM_PREFILL)] - decode_ports = [_free_port() for _ in range(NUM_DECODE)] - proxy_port = _free_port() - - try: - for port in prefill_ports: - procs.append( - subprocess.Popen( - [ - sys.executable, - "-m", - "uvicorn", - "sim_adapter:prefill_app", - "--host", - "127.0.0.1", - "--port", - str(port), - "--log-level", - "error", - ], - env=env, - cwd=_REPO_ROOT, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - ) - - for port in decode_ports: - procs.append( - subprocess.Popen( - [ - sys.executable, - "-m", - "uvicorn", - "sim_adapter:decode_app", - "--host", - "127.0.0.1", - "--port", - str(port), - "--log-level", - "error", - ], - env=env, - cwd=_REPO_ROOT, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - ) - - for port in prefill_ports + decode_ports: - assert _wait_port(port), f"Node on port {port} failed to start" - - procs.append( - subprocess.Popen( - [ - sys.executable, - "-m", - "xpyd.proxy", - "proxy", - "--config", - _write_bench_config( - MODEL_PATH, - [f"127.0.0.1:{p}" for p in prefill_ports], - [f"127.0.0.1:{p}" for p in decode_ports], - proxy_port, - ), - ], - env=env, - cwd=_REPO_ROOT, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - ) - assert _wait_port(proxy_port, timeout=30), "Proxy failed to start" - - yield { - "proxy_port": proxy_port, - "model": MODEL_PATH, - } - - finally: - for p in procs: - p.terminate() - for p in procs: - try: - p.wait(timeout=5) - except subprocess.TimeoutExpired: - p.kill() - p.wait(timeout=5) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _send_non_streaming(base_url: str, payload: dict) -> dict: - """Send a non-streaming request, return summary dict.""" - t0 = time.monotonic() - with httpx.Client(base_url=base_url, timeout=60, trust_env=False) as c: - r = c.post("/v1/chat/completions", json=payload) - elapsed = time.monotonic() - t0 - return {"status": r.status_code, "elapsed": elapsed, "stream": False} - - -def _send_streaming(base_url: str, payload: dict) -> dict: - """Send a streaming request, consume all SSE chunks, return summary.""" - t0 = time.monotonic() - chunks = 0 - status = 0 - with httpx.Client(base_url=base_url, timeout=60, trust_env=False) as c: - with c.stream("POST", "/v1/chat/completions", json=payload) as r: - status = r.status_code - for line in r.iter_lines(): - if line.startswith("data: "): - chunks += 1 - elapsed = time.monotonic() - t0 - return {"status": status, "elapsed": elapsed, "stream": True, "chunks": chunks} - - -def _send_request(base_url: str, model: str, idx: int) -> dict: - """Build and send one request (randomly streaming or not).""" - stream = random.choice([True, False]) - payload = _build_payload(model, stream=stream) - try: - if stream: - return _send_streaming(base_url, payload) - return _send_non_streaming(base_url, payload) - except Exception as exc: - return {"status": -1, "error": str(exc), "stream": stream, "elapsed": 0} - - -# --------------------------------------------------------------------------- -# Benchmark tests -# --------------------------------------------------------------------------- - - -@pytest.mark.benchmark -@pytest.mark.benchmark -def test_benchmark_10k_mixed(cluster): - """Fire 10000 mixed (streaming + non-streaming) requests at 1000 concurrency. - - Validates that every request succeeds (HTTP 200) and prints a summary - with throughput and latency percentiles. - """ - base_url = f"http://127.0.0.1:{cluster['proxy_port']}" - results: list[dict] = [] - - with ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as pool: - futures = [ - pool.submit(_send_request, base_url, cluster["model"], i) - for i in range(TOTAL_REQUESTS) - ] - for f in as_completed(futures): - results.append(f.result()) - - # ---- Assertions ---- - statuses = [r["status"] for r in results] - success = statuses.count(200) - failed = len(statuses) - success - errors = [r for r in results if r["status"] != 200] - - # Print summary before asserting so we always see stats - elapsed_all = sorted(r["elapsed"] for r in results if r["status"] == 200) - stream_count = sum(1 for r in results if r.get("stream")) - non_stream_count = len(results) - stream_count - - print("\n" + "=" * 60) - print("BENCHMARK SUMMARY") - print("=" * 60) - print(f"Total requests : {TOTAL_REQUESTS}") - print(f"Concurrency : {MAX_CONCURRENCY}") - print(f"Streaming : {stream_count}") - print(f"Non-streaming : {non_stream_count}") - print(f"Successful : {success}") - print(f"Failed : {failed}") - if elapsed_all: - print(f"Latency p50 : {elapsed_all[len(elapsed_all) // 2]:.3f}s") - print(f"Latency p90 : {elapsed_all[int(len(elapsed_all) * 0.9)]:.3f}s") - print(f"Latency p99 : {elapsed_all[int(len(elapsed_all) * 0.99)]:.3f}s") - print(f"Latency max : {elapsed_all[-1]:.3f}s") - print("=" * 60) - - if errors: - sample = errors[:5] - print(f"First {len(sample)} errors: {sample}") - - assert failed == 0, f"{failed}/{TOTAL_REQUESTS} requests failed" - - -@pytest.mark.benchmark -@pytest.mark.benchmark -def test_benchmark_streaming_only(cluster): - """1000 concurrent streaming requests to verify SSE under load.""" - base_url = f"http://127.0.0.1:{cluster['proxy_port']}" - count = 2000 - - def send(idx: int) -> dict: - payload = _build_payload(cluster["model"], stream=True) - return _send_streaming(base_url, payload) - - results: list[dict] = [] - with ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as pool: - futures = [pool.submit(send, i) for i in range(count)] - for f in as_completed(futures): - results.append(f.result()) - - success = sum(1 for r in results if r["status"] == 200) - has_chunks = sum(1 for r in results if r.get("chunks", 0) >= 2) - - print(f"\nStreaming-only: {success}/{count} OK, {has_chunks} with >=2 chunks") - assert success == count, f"{count - success} streaming requests failed" - assert has_chunks == count, "Some streaming responses had fewer than 2 chunks" - - -@pytest.mark.benchmark -@pytest.mark.benchmark -def test_benchmark_burst_short_prompts(cluster): - """Burst of 5000 short-prompt requests (< 100 chars) at full concurrency.""" - base_url = f"http://127.0.0.1:{cluster['proxy_port']}" - count = 5000 - - def send(idx: int) -> dict: - payload = { - "model": cluster["model"], - "messages": [ - {"role": "user", "content": _random_content(random.randint(0, 100))} - ], - "max_tokens": 5, - "stream": False, - } - return _send_non_streaming(base_url, payload) - - results: list[dict] = [] - with ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as pool: - futures = [pool.submit(send, i) for i in range(count)] - for f in as_completed(futures): - results.append(f.result()) - - success = sum(1 for r in results if r["status"] == 200) - elapsed = sorted(r["elapsed"] for r in results if r["status"] == 200) - if elapsed: - print( - f"\nShort burst: {success}/{count} OK, " - f"p50={elapsed[len(elapsed) // 2]:.3f}s, " - f"p99={elapsed[int(len(elapsed) * 0.99)]:.3f}s" - ) - assert success == count, f"{count - success} short-burst requests failed" - - -@pytest.mark.benchmark -@pytest.mark.benchmark -def test_benchmark_long_prompts(cluster): - """500 requests with long prompts (5k-10k chars) at moderate concurrency.""" - base_url = f"http://127.0.0.1:{cluster['proxy_port']}" - count = 500 - concurrency = 200 - - def send(idx: int) -> dict: - payload = { - "model": cluster["model"], - "messages": [ - { - "role": "user", - "content": _random_content(random.randint(5000, 10000)), - } - ], - "max_tokens": 32, - "stream": random.choice([True, False]), - } - if payload["stream"]: - return _send_streaming(base_url, payload) - return _send_non_streaming(base_url, payload) - - results: list[dict] = [] - with ThreadPoolExecutor(max_workers=concurrency) as pool: - futures = [pool.submit(send, i) for i in range(count)] - for f in as_completed(futures): - results.append(f.result()) - - success = sum(1 for r in results if r["status"] == 200) - elapsed = sorted(r["elapsed"] for r in results if r["status"] == 200) - if elapsed: - print( - f"\nLong prompts: {success}/{count} OK, " - f"p50={elapsed[len(elapsed) // 2]:.3f}s, " - f"p99={elapsed[int(len(elapsed) * 0.99)]:.3f}s" - ) - assert success == count, f"{count - success} long-prompt requests failed" diff --git a/tests/stress/test_benchmark_integration.py b/tests/stress/test_benchmark_integration.py deleted file mode 100644 index 2ba46c2..0000000 --- a/tests/stress/test_benchmark_integration.py +++ /dev/null @@ -1,338 +0,0 @@ -"""Integration test: proxy + dummy nodes end-to-end. - -Topology (matches benchmarks/run_benchmark.sh): - - 2 prefill nodes (dynamically allocated ports) - - 16 decode nodes (dynamically allocated ports) - - 1 proxy (dynamically allocated port) - -This test file is excluded from CI via --ignore in the workflow. -Run manually: pytest tests/test_benchmark_integration.py -v -""" - -from __future__ import annotations - -import os -import socket -import subprocess -import sys -import tempfile -import time - -import httpx -import pytest -import yaml - -_REPO_ROOT = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) -MODEL_PATH = os.path.join(_REPO_ROOT, "tokenizers", "DeepSeek-R1") - -NUM_PREFILL = 2 -NUM_DECODE = 16 - - -def _free_port(): - """Allocate an ephemeral port.""" - with socket.socket() as s: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -def _wait_port(port: int, timeout: float = 20.0) -> bool: - """Wait until a port is accepting connections.""" - deadline = time.time() + timeout - while time.time() < deadline: - try: - with socket.create_connection(("127.0.0.1", port), timeout=1): - return True - except OSError: - time.sleep(0.5) - return False - - -@pytest.fixture(scope="module") -def cluster(): - """Start dummy nodes + proxy, yield, then tear down.""" - env = os.environ.copy() - procs = [] - - prefill_ports = [_free_port() for _ in range(NUM_PREFILL)] - decode_ports = [_free_port() for _ in range(NUM_DECODE)] - proxy_port = _free_port() - - try: - # Start prefill nodes - for port in prefill_ports: - p = subprocess.Popen( - [ - sys.executable, - "-m", - "uvicorn", - "sim_adapter:prefill_app", - "--host", - "127.0.0.1", - "--port", - str(port), - "--log-level", - "error", - ], - env=env, - cwd=_REPO_ROOT, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - procs.append(p) - - # Start decode nodes - for port in decode_ports: - p = subprocess.Popen( - [ - sys.executable, - "-m", - "uvicorn", - "sim_adapter:decode_app", - "--host", - "127.0.0.1", - "--port", - str(port), - "--log-level", - "error", - ], - env=env, - cwd=_REPO_ROOT, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - procs.append(p) - - # Wait for all nodes - for port in prefill_ports: - assert _wait_port(port), f"Prefill {port} didn't start" - for port in decode_ports: - assert _wait_port(port), f"Decode {port} didn't start" - - # Start proxy - prefill_args = [f"127.0.0.1:{p}" for p in prefill_ports] - decode_args = [f"127.0.0.1:{p}" for p in decode_ports] - - _cfg = { - "model": MODEL_PATH, - "prefill": prefill_args, - "decode": decode_args, - "port": proxy_port, - } - _cf = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) - yaml.dump(_cfg, _cf) - _cf.close() - proxy = subprocess.Popen( - [ - sys.executable, - "-m", - "xpyd.proxy", - "proxy", - "--config", - _cf.name, - ], - env=env, - cwd=_REPO_ROOT, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - procs.append(proxy) - assert _wait_port(proxy_port, timeout=30), "Proxy didn't start" - - yield { - "proxy_port": proxy_port, - "model": MODEL_PATH, - "prefill_ports": prefill_ports, - "decode_ports": decode_ports, - } - - finally: - # Teardown — always clean up, even if setup fails - for p in procs: - p.terminate() - for p in procs: - try: - p.wait(timeout=5) - except subprocess.TimeoutExpired: - p.kill() - p.wait(timeout=5) - - -CHAT_PAYLOAD = { - "model": "", - "messages": [{"role": "user", "content": "Hello world"}], - "max_tokens": 5, - "stream": False, -} - - -def test_models_endpoint(cluster): - """Proxy /v1/models returns OpenAI-compatible model listing.""" - with httpx.Client( - base_url=f"http://127.0.0.1:{cluster['proxy_port']}", - timeout=10, - trust_env=False, - ) as c: - r = c.get("/v1/models") - assert r.status_code == 200 - data = r.json() - assert data["object"] == "list" - assert len(data["data"]) > 0, "No models in /v1/models response" - for model in data["data"]: - assert "id" in model - assert model["object"] == "model" - - -def test_chat_completions(cluster): - """Non-streaming chat completions through proxy.""" - payload = {**CHAT_PAYLOAD, "model": cluster["model"]} - with httpx.Client( - base_url=f"http://127.0.0.1:{cluster['proxy_port']}", - timeout=30, - trust_env=False, - ) as c: - r = c.post("/v1/chat/completions", json=payload) - assert r.status_code == 200 - data = r.json() - assert "choices" in data - assert len(data["choices"]) > 0 - assert data["choices"][0]["message"]["content"] - - -def test_chat_completions_streaming(cluster): - """Streaming chat completions through proxy.""" - payload = {**CHAT_PAYLOAD, "model": cluster["model"], "stream": True} - with httpx.Client( - base_url=f"http://127.0.0.1:{cluster['proxy_port']}", - timeout=30, - trust_env=False, - ) as c: - r = c.post("/v1/chat/completions", json=payload) - assert r.status_code == 200 - assert "text/event-stream" in r.headers.get("content-type", "") - lines = r.text.strip().split("\n") - data_lines = [ln for ln in lines if ln.startswith("data: ")] - assert len(data_lines) >= 2 - assert data_lines[-1] == "data: [DONE]" - - -def test_status_topology(cluster): - """Proxy status should reflect correct topology.""" - with httpx.Client( - base_url=f"http://127.0.0.1:{cluster['proxy_port']}", - timeout=10, - trust_env=False, - ) as c: - r = c.get("/status") - assert r.status_code == 200 - data = r.json() - assert data["prefill_node_count"] == NUM_PREFILL - assert data["decode_node_count"] == NUM_DECODE - - -def test_concurrent_requests(cluster): - """Multiple concurrent requests should all succeed.""" - import concurrent.futures - - payload = {**CHAT_PAYLOAD, "model": cluster["model"]} - - def send_request(idx): - with httpx.Client( - base_url=f"http://127.0.0.1:{cluster['proxy_port']}", - timeout=30, - trust_env=False, - ) as c: - r = c.post("/v1/chat/completions", json=payload) - return r.status_code - - with concurrent.futures.ThreadPoolExecutor(max_workers=10) as pool: - futures = [pool.submit(send_request, i) for i in range(20)] - results = [f.result() for f in concurrent.futures.as_completed(futures)] - - assert all(code == 200 for code in results), f"Some requests failed: {results}" - - -@pytest.mark.skipif( - os.environ.get("RUN_VLLM_BENCH") != "1", - reason="Set RUN_VLLM_BENCH=1 to run this heavy benchmark test", -) -def test_vllm_bench_serve(cluster): - """Run vllm bench serve with 1000 prompts through the proxy. - - This test is heavy (~5-10 min) and requires vllm. It is gated behind - the RUN_VLLM_BENCH=1 env var and skipped by default. - - Run manually: - RUN_VLLM_BENCH=1 \\ - pytest tests/test_benchmark_integration.py::test_vllm_bench_serve -v - - Note: Uses --tokenizer gpt2 because the local DeepSeek-R1 tokenizer - uses a custom class (TokenizersBackend) that vllm cannot load. - gpt2 is lightweight and sufficient for random-data benchmarks. - """ - import shutil - - vllm_bin = shutil.which("vllm") - if vllm_bin is None: - pytest.skip("vllm not installed") - - result = subprocess.run( - [ - vllm_bin, - "bench", - "serve", - "--host", - "127.0.0.1", - "--port", - str(cluster["proxy_port"]), - "--model", - cluster["model"], - "--tokenizer", - "gpt2", - "--dataset-name", - "random", - "--random-input-len", - "3000", - "--random-output-len", - "200", - "--num-prompts", - "1000", - "--burstiness", - "100", - "--request-rate", - "3.6", - "--endpoint", - "/v1/completions", - ], - capture_output=True, - text=True, - timeout=600, - ) - - print(result.stdout[-2000:] if len(result.stdout) > 2000 else result.stdout) - if result.stderr: - important = [ - line - for line in result.stderr.split("\n") - if "error" in line.lower() and "triton" not in line.lower() - ] - if important: - print("STDERR:", "\n".join(important[-5:])) - - assert result.returncode == 0, f"vllm bench serve failed: {result.stderr[-500:]}" - - successful = None - failed = None - for line in result.stdout.strip().split("\n"): - if "Successful requests:" in line: - successful = int(line.split(":")[1].strip()) - if "Failed requests:" in line: - failed = int(line.split(":")[1].strip()) - - assert successful is not None, "Could not parse 'Successful requests' from output" - assert failed is not None, "Could not parse 'Failed requests' from output" - assert successful == 1000, f"Expected 1000 successful, got {successful}" - assert failed == 0, f"Expected 0 failed, got {failed}" diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index 9d70258..72c1264 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -1,6 +1,5 @@ """Tests for the /metrics Prometheus endpoint.""" -import pytest from xpyd.metrics import ( get_metrics, @@ -59,33 +58,3 @@ def test_track_request_lifecycle(self): output = get_metrics().decode() assert "proxy_active_requests 0.0" in output assert "proxy_request_duration_seconds" in output - - -@pytest.mark.anyio -async def test_metrics_endpoint_returns_prometheus_format(client): - """Integration test: GET /metrics returns valid Prometheus text.""" - resp = await client.get("/metrics") - assert resp.status_code == 200 - assert "text/plain" in resp.headers["content-type"] - body = resp.text - assert "proxy_requests_total" in body or "proxy_active_requests" in body - - -@pytest.mark.anyio -async def test_metrics_endpoint_after_completion_request(client): - """After a /v1/completions request, metrics should reflect it.""" - # Fire a completion request (will go through the proxy) - await client.post( - "/v1/completions", - json={ - "model": "test", - "prompt": "Hello", - "max_tokens": 5, - }, - ) - resp = await client.get("/metrics") - assert resp.status_code == 200 - body = resp.text - assert "proxy_requests_total" in body - # The /v1/completions counter should have been incremented - assert "/v1/completions" in body