diff --git a/docs/DESIGN.md b/docs/DESIGN.md index 5964010..27a4f02 100644 --- a/docs/DESIGN.md +++ b/docs/DESIGN.md @@ -131,10 +131,88 @@ Real models produce variable-length outputs. The simulator mimics this: - `ignore_eos: true` → always output full max_tokens - Works in both streaming and non-streaming modes -## OpenAI API Compliance +## API Compliance + +xPyD-sim targets two levels of API compatibility: + +### Level 1: OpenAI API Spec (Required) + +All endpoints must: +1. Accept ALL parameters defined in the OpenAI API spec without errors +2. Produce responses that match the spec format — content can be dummy, but structure must be correct +3. Validate parameter ranges per spec (e.g., temperature 0-2, top_p 0-1) and return 400 on invalid values +4. Support all parameter behaviors that affect response format (not just accept and ignore) + +Specific requirements: + +| Feature | Endpoint | Behavior | +|---|---|---| +| response_format (json_object) | /v1/chat/completions | Return valid JSON string as content | +| response_format (json_schema) | /v1/chat/completions | Return JSON conforming to provided schema | +| max_completion_tokens | /v1/chat/completions | Fallback when max_tokens not set (already implemented) | +| encoding_format: base64 | /v1/embeddings | Return base64-encoded float vector | +| Parameter range validation | All | temperature [0,2], top_p (0,1], frequency_penalty [-2,2], presence_penalty [-2,2] | + +### Level 2: vLLM Backend Extensions (Required) + +xPyD-sim must also accept vLLM-specific parameters so it can serve as a drop-in replacement when testing xPyD-proxy against vLLM backends. These parameters should be accepted without error; behavior can be simulated where practical. + +#### Sampling Parameters (accept, simulate where noted) + +| Parameter | Type | Behavior | +|---|---|---| +| best_of | int | Accept; generate n candidates and return best (or simulate: just return n=1) | +| use_beam_search | bool | Accept; ignore (sim doesn't do real search) | +| top_k | int | Accept; ignore | +| min_p | float | Accept; ignore | +| repetition_penalty | float | Accept; ignore | +| length_penalty | float | Accept; ignore | +| stop_token_ids | list[int] | Accept; ignore (sim uses stop strings) | +| include_stop_str_in_output | bool | Accept; ignore | +| min_tokens | int | Accept; ignore | +| skip_special_tokens | bool | Accept; ignore | +| spaces_between_special_tokens | bool | Accept; ignore | +| truncate_prompt_tokens | int | Accept; ignore | +| prompt_logprobs | int | Accept; return null (sim doesn't track prompt logprobs) | +| allowed_token_ids | list[int] | Accept; ignore | +| bad_words | list[str] | Accept; ignore | + +#### Extra Parameters (accept and ignore) + +| Parameter | Type | Notes | +|---|---|---| +| echo | bool | Already implemented for completions; accept for chat too | +| add_generation_prompt | bool | Accept; ignore | +| continue_final_message | bool | Accept; ignore | +| add_special_tokens | bool | Accept; ignore | +| documents | list[dict] | Accept; ignore (RAG) | +| chat_template | str | Accept; ignore | +| chat_template_kwargs | dict | Accept; ignore | +| mm_processor_kwargs | dict | Accept; ignore | +| structured_outputs | dict | Accept; ignore (use response_format instead) | +| priority | int | Accept; ignore | +| request_id | str | Accept; ignore | +| return_tokens_as_token_ids | bool | Accept; ignore | +| return_token_ids | bool | Accept; ignore | +| cache_salt | str | Accept; ignore | +| kv_transfer_params | dict | Accept; ignore | +| vllm_xargs | dict | Accept; ignore | +| repetition_detection | dict | Accept; ignore | +| reasoning_effort | str | Accept; ignore | +| thinking_token_budget | int | Accept; ignore | +| include_reasoning | bool | Accept; ignore | +| prompt_embeds | bytes | Accept; ignore | + +#### Response Extensions (include in responses) + +| Field | Where | Behavior | +|---|---|---| +| stop_reason | choices[].stop_reason | null (sim never stops on token IDs) | +| service_tier | response.service_tier | null | +| kv_transfer_params | response.kv_transfer_params | null | + +### Legacy Notes -- Accept ALL OpenAI API parameters without errors -- Response JSON format must exactly match OpenAI spec - Streaming: first chat chunk delta must include `role: "assistant"` - Streaming: final chunk includes `usage` when `stream_options.include_usage` is set - All responses include `system_fingerprint` @@ -560,3 +638,34 @@ scheduling: | TC13.10 | /debug/batch shows correct state | All fields accurate | | TC13.11 | Request log captures batch events | All events logged correctly | | TC13.12 | E2E with proxy: PD disaggregation | Full flow works, TTFT/TPOT correct | + +### 14. OpenAI Spec Compliance — Response Format +| ID | Test | Expected | +|---|---|---| +| TC14.1 | response_format: json_object | Content is valid JSON string | +| TC14.2 | response_format: json_schema with schema | Content conforms to provided JSON schema | +| TC14.3 | response_format in streaming | Streamed content assembles into valid JSON | + +### 15. OpenAI Spec Compliance — Parameter Validation +| ID | Test | Expected | +|---|---|---| +| TC15.1 | temperature=3.0 | HTTP 400, clear error message | +| TC15.2 | top_p=-0.5 | HTTP 400 | +| TC15.3 | frequency_penalty=5.0 | HTTP 400 | +| TC15.4 | presence_penalty=-3.0 | HTTP 400 | +| TC15.5 | n=0 or n=-1 | HTTP 400 | +| TC15.6 | best_of < n | HTTP 400 | + +### 16. OpenAI Spec Compliance — Embedding base64 +| ID | Test | Expected | +|---|---|---| +| TC16.1 | encoding_format=float | Returns list of floats (current behavior) | +| TC16.2 | encoding_format=base64 | Returns base64-encoded float vector | + +### 17. vLLM Backend Compatibility +| ID | Test | Expected | +|---|---|---| +| TC17.1 | All vLLM sampling params accepted | No 422/400 error | +| TC17.2 | All vLLM extra params accepted | No 422/400 error | +| TC17.3 | Response includes stop_reason field | null in choices | +| TC17.4 | Response includes service_tier field | null or absent | diff --git a/docs/GAP_ANALYSIS.md b/docs/GAP_ANALYSIS.md new file mode 100644 index 0000000..3cce73a --- /dev/null +++ b/docs/GAP_ANALYSIS.md @@ -0,0 +1,36 @@ +# xPyD-sim Gap Analysis + +Generated: 2026-04-06 + +## OpenAI Spec Gaps (Must Fix) + +| # | Feature | Current State | Required | Difficulty | +|---|---|---|---|---| +| 1 | Parameter range validation (temperature, top_p, frequency_penalty, presence_penalty) | No validation — any value accepted silently | Return HTTP 400 for out-of-range values (temperature [0,2], top_p (0,1], frequency_penalty [-2,2], presence_penalty [-2,2]) | Easy | +| 2 | `n` validation (n≤0) | No validation | Return HTTP 400 for n≤0 | Easy | +| 3 | `response_format: json_object` | Field accepted but ignored — content is plain dummy text | Return valid JSON string as content | Medium | +| 4 | `response_format: json_schema` | Field accepted but ignored | Return JSON conforming to provided schema | Complex | +| 5 | `response_format` in streaming | Not handled | Streamed content must assemble into valid JSON | Medium | +| 6 | `encoding_format: base64` for embeddings | Field accepted but always returns float array | Return base64-encoded float vector when `encoding_format=base64` | Easy | +| 7 | `best_of < n` validation | `best_of` exists on CompletionRequest but no cross-field validation | Return HTTP 400 when best_of < n | Easy | + +## vLLM Backend Gaps (Must Add) + +| # | Feature | Current State | Required | Difficulty | +|---|---|---|---|---| +| 1 | Accept vLLM sampling params on ChatCompletionRequest | `ChatCompletionRequest` has no `extra="allow"` — unknown fields cause 422 | Add `model_config = {"extra": "allow"}` or explicit Optional fields for all vLLM sampling params (top_k, min_p, repetition_penalty, use_beam_search, etc.) | Easy | +| 2 | Accept vLLM sampling params on CompletionRequest | `CompletionRequest` has no `extra="allow"` — unknown fields cause 422 | Add `model_config = {"extra": "allow"}` or explicit Optional fields | Easy | +| 3 | Accept vLLM extra params (chat_template, documents, add_generation_prompt, priority, request_id, etc.) | Not accepted — 422 error | Accept without error on all request models | Easy | +| 4 | `best_of` on ChatCompletionRequest | Only defined on CompletionRequest | Add `best_of` field to ChatCompletionRequest | Easy | +| 5 | `echo` on ChatCompletionRequest | Only defined on CompletionRequest | Accept on chat endpoint too | Easy | +| 6 | `stop_reason` in response choices | Not present in Choice/CompletionChoice models | Add `stop_reason: Optional[str] = None` to Choice, CompletionChoice, StreamChoice, CompletionStreamChoice | Easy | +| 7 | `service_tier` in response objects | Not present in response models | Add `service_tier: Optional[str] = None` to ChatCompletionResponse, CompletionResponse, ChatCompletionChunk, CompletionChunk | Easy | +| 8 | `kv_transfer_params` in response objects | Not present | Add `kv_transfer_params: Optional[dict] = None` to response models | Easy | +| 9 | `prompt_logprobs` support | Not present — would 422 | Accept and return null in response | Easy | + +## Summary + +- **OpenAI Spec Gaps**: 7 items (4 Easy, 2 Medium, 1 Complex) +- **vLLM Backend Gaps**: 9 items (all Easy) +- **Highest risk**: `response_format: json_schema` — requires parsing JSON Schema and generating conforming dummy data +- **Quick wins**: Parameter validation, `extra="allow"`, response field additions — can all be done in one PR diff --git a/tests/test_api_compliance.py b/tests/test_api_compliance.py new file mode 100644 index 0000000..4e94853 --- /dev/null +++ b/tests/test_api_compliance.py @@ -0,0 +1,366 @@ +"""Tests for API compliance — TC14 through TC17.""" + +from __future__ import annotations + +import base64 +import json +import struct + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient + +from xpyd_sim.server import ServerConfig, create_app + + +@pytest.fixture +def config(): + return ServerConfig( + mode="dual", + model_name="test-model", + prefill_delay_ms=0, + kv_transfer_delay_ms=0, + decode_delay_per_token_ms=0, + eos_min_ratio=0.5, + max_model_len=4096, + ) + + +@pytest.fixture +def client(config): + app = create_app(config) + transport = ASGITransport(app=app) + return AsyncClient( + transport=transport, base_url="http://test", + ) + + +# === TC14: response_format === + + +@pytest.mark.anyio +async def test_tc14_1_json_object(client): + """TC14.1: json_object returns valid JSON.""" + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "response_format": {"type": "json_object"}, + }, + ) + assert resp.status_code == 200 + data = resp.json() + content = data["choices"][0]["message"]["content"] + parsed = json.loads(content) + assert isinstance(parsed, dict) + + +@pytest.mark.anyio +async def test_tc14_2_json_schema(client): + """TC14.2: json_schema returns JSON conforming to schema.""" + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "response_format": { + "type": "json_schema", + "json_schema": {"schema": schema}, + }, + }, + ) + assert resp.status_code == 200 + data = resp.json() + content = data["choices"][0]["message"]["content"] + parsed = json.loads(content) + assert "name" in parsed + assert "age" in parsed + assert isinstance(parsed["age"], int) + + +@pytest.mark.anyio +async def test_tc14_3_streaming_json(client): + """TC14.3: streaming response_format assembles to valid JSON.""" + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "response_format": {"type": "json_object"}, + "stream": True, + }, + ) + assert resp.status_code == 200 + content_parts = [] + async for line in resp.aiter_lines(): + if line.startswith("data: ") and line != "data: [DONE]": + chunk = json.loads(line[6:]) + delta = chunk["choices"][0].get("delta", {}) + c = delta.get("content") + if c: + content_parts.append(c) + full = "".join(content_parts) + parsed = json.loads(full) + assert isinstance(parsed, dict) + + +# === TC15: Parameter validation === + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "field,value", + [ + ("temperature", 3.0), + ("top_p", -0.5), + ("frequency_penalty", 5.0), + ("presence_penalty", -3.0), + ("n", 0), + ], +) +async def test_tc15_param_validation(client, field, value): + """TC15.1-15.5: out-of-range params return 400.""" + payload = { + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + field: value, + } + resp = await client.post("/v1/chat/completions", json=payload) + assert resp.status_code == 400 + + +@pytest.mark.anyio +async def test_tc15_6_best_of_lt_n_completions(client): + """TC15.6: best_of < n on completions returns 400.""" + resp = await client.post( + "/v1/completions", + json={ + "model": "test-model", + "prompt": "hello", + "n": 3, + "best_of": 1, + }, + ) + assert resp.status_code == 400 + + +@pytest.mark.anyio +async def test_tc15_best_of_lt_n_chat(client): + """best_of < n on chat completions returns 400.""" + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "n": 3, + "best_of": 1, + }, + ) + assert resp.status_code == 400 + + +# === TC16: Embedding base64 === + + +@pytest.mark.anyio +async def test_tc16_1_float_format(client): + """TC16.1: encoding_format=float returns list[float].""" + resp = await client.post( + "/v1/embeddings", + json={ + "model": "test-model", + "input": "hello", + "encoding_format": "float", + }, + ) + assert resp.status_code == 200 + emb = resp.json()["data"][0]["embedding"] + assert isinstance(emb, list) + assert all(isinstance(x, float) for x in emb) + + +@pytest.mark.anyio +async def test_tc16_2_base64_format(config, client): + """TC16.2: encoding_format=base64 returns decodable floats.""" + resp = await client.post( + "/v1/embeddings", + json={ + "model": "test-model", + "input": "hello", + "encoding_format": "base64", + }, + ) + assert resp.status_code == 200 + emb = resp.json()["data"][0]["embedding"] + assert isinstance(emb, str) + raw = base64.b64decode(emb) + num_floats = len(raw) // 4 + assert num_floats == config.embedding_dim + values = struct.unpack(f"<{num_floats}f", raw) + assert len(values) == config.embedding_dim + + +# === TC17: vLLM compatibility === + + +@pytest.mark.anyio +async def test_tc17_1_vllm_sampling_params(client): + """TC17.1: vLLM sampling params accepted without error.""" + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "top_k": 50, + "min_p": 0.1, + "repetition_penalty": 1.2, + }, + ) + assert resp.status_code == 200 + + +@pytest.mark.anyio +async def test_tc17_2_vllm_extra_params(client): + """TC17.2: vLLM extra params accepted without error.""" + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "add_generation_prompt": True, + "priority": 1, + "request_id": "test-123", + }, + ) + assert resp.status_code == 200 + + +@pytest.mark.anyio +async def test_tc17_3_stop_reason_field(client): + """TC17.3: response includes stop_reason field.""" + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + }, + ) + assert resp.status_code == 200 + choice = resp.json()["choices"][0] + assert "stop_reason" in choice + + +@pytest.mark.anyio +async def test_tc17_4_service_tier_field(client): + """TC17.4: response includes service_tier field.""" + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert "service_tier" in data + + +# --- TC14 with scheduling enabled --- + + + +@pytest_asyncio.fixture +async def scheduled_client(): + config = ServerConfig( + mode="dual", + model_name="test-model", + prefill_delay_ms=1, + kv_transfer_delay_ms=0, + decode_delay_per_token_ms=1, + eos_min_ratio=1.0, + max_model_len=4096, + scheduling_enabled=True, + max_num_batched_tokens=4096, + max_num_seqs=64, + ) + app = create_app(config) + if config._scheduler: + await config._scheduler.start() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as c: + yield c + if config._scheduler: + await config._scheduler.stop() + + +@pytest.mark.anyio +async def test_tc14_1_json_object_scheduled(scheduled_client): + """TC14.1 with scheduling: json_object returns valid JSON.""" + resp = await scheduled_client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hello"}], + "response_format": {"type": "json_object"}, + }, + ) + assert resp.status_code == 200 + content = resp.json()["choices"][0]["message"]["content"] + parsed = json.loads(content) + assert isinstance(parsed, dict) + + +@pytest.mark.anyio +async def test_tc14_2_json_schema_scheduled(scheduled_client): + """TC14.2 with scheduling: json_schema returns conforming JSON.""" + resp = await scheduled_client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hello"}], + "response_format": { + "type": "json_schema", + "json_schema": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + }, + }, + }, + ) + assert resp.status_code == 200 + content = resp.json()["choices"][0]["message"]["content"] + parsed = json.loads(content) + assert "name" in parsed + assert "age" in parsed + + +@pytest.mark.anyio +async def test_response_format_json_skips_stop_seq(client): + """Stop sequences must not truncate JSON content.""" + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hello"}], + "response_format": {"type": "json_object"}, + "stop": ["}"], + }, + ) + assert resp.status_code == 200 + content = resp.json()["choices"][0]["message"]["content"] + # Content should be valid JSON despite stop=["}"] + parsed = json.loads(content) + assert isinstance(parsed, dict) diff --git a/xpyd_sim/common/models.py b/xpyd_sim/common/models.py index 9419acd..e3dcec0 100644 --- a/xpyd_sim/common/models.py +++ b/xpyd_sim/common/models.py @@ -34,6 +34,10 @@ class ChatCompletionRequest(BaseModel): parallel_tool_calls: Optional[bool] = None stream_options: Optional[dict] = None ignore_eos: Optional[bool] = None + best_of: Optional[int] = None + echo: Optional[bool] = False + + model_config = {"extra": "allow"} class CompletionRequest(BaseModel): @@ -56,6 +60,8 @@ class CompletionRequest(BaseModel): stream_options: Optional[dict] = None ignore_eos: Optional[bool] = None + model_config = {"extra": "allow"} + class UsageInfo(BaseModel): prompt_tokens: int = 0 @@ -85,6 +91,7 @@ class Choice(BaseModel): message: ChoiceMessage = Field(default_factory=ChoiceMessage) finish_reason: Optional[str] = "stop" logprobs: Optional[Any] = None + stop_reason: Optional[Any] = None class ChatCompletionResponse(BaseModel): @@ -95,6 +102,8 @@ class ChatCompletionResponse(BaseModel): choices: list[Choice] = [] usage: UsageInfo = Field(default_factory=UsageInfo) system_fingerprint: Optional[str] = None + service_tier: Optional[str] = None + kv_transfer_params: Optional[dict] = None class CompletionChoice(BaseModel): @@ -102,6 +111,7 @@ class CompletionChoice(BaseModel): text: str = "" finish_reason: Optional[str] = "stop" logprobs: Optional[Any] = None + stop_reason: Optional[Any] = None class CompletionResponse(BaseModel): @@ -112,6 +122,8 @@ class CompletionResponse(BaseModel): choices: list[CompletionChoice] = [] usage: UsageInfo = Field(default_factory=UsageInfo) system_fingerprint: Optional[str] = None + service_tier: Optional[str] = None + kv_transfer_params: Optional[dict] = None class DeltaMessage(BaseModel): @@ -125,6 +137,7 @@ class StreamChoice(BaseModel): delta: DeltaMessage = Field(default_factory=DeltaMessage) finish_reason: Optional[str] = None logprobs: Optional[Any] = None + stop_reason: Optional[Any] = None class ChatCompletionChunk(BaseModel): @@ -135,6 +148,7 @@ class ChatCompletionChunk(BaseModel): choices: list[StreamChoice] = [] system_fingerprint: Optional[str] = None usage: Optional[UsageInfo] = None + service_tier: Optional[str] = None class CompletionStreamChoice(BaseModel): @@ -142,6 +156,7 @@ class CompletionStreamChoice(BaseModel): text: str = "" finish_reason: Optional[str] = None logprobs: Optional[Any] = None + stop_reason: Optional[Any] = None class CompletionChunk(BaseModel): @@ -152,12 +167,13 @@ class CompletionChunk(BaseModel): choices: list[CompletionStreamChoice] = [] system_fingerprint: Optional[str] = None usage: Optional[UsageInfo] = None + service_tier: Optional[str] = None class EmbeddingData(BaseModel): object: str = "embedding" index: int = 0 - embedding: list[float] = [] + embedding: list[float] | str = [] class EmbeddingRequest(BaseModel): diff --git a/xpyd_sim/common/tools.py b/xpyd_sim/common/tools.py index 3370b26..171239e 100644 --- a/xpyd_sim/common/tools.py +++ b/xpyd_sim/common/tools.py @@ -6,28 +6,39 @@ from typing import Any -def generate_dummy_args(schema: dict) -> dict: - """Generate dummy arguments matching a JSON schema.""" - if not schema or schema.get("type") != "object": +def generate_dummy_from_schema(schema: dict): + """Generate dummy values matching a JSON schema. + + Handles: string, integer, number, boolean, array, object. + Supports enum (picks first value) and items (array elements). + Does not handle $ref / anyOf / oneOf. + """ + if not schema: return {} - props = schema.get("properties", {}) - result = {} - for key, prop in props.items(): - ptype = prop.get("type", "string") - if ptype == "string": - enum = prop.get("enum") - result[key] = enum[0] if enum else "dummy_value" - elif ptype == "integer": - result[key] = 42 - elif ptype == "number": - result[key] = 3.14 - elif ptype == "boolean": - result[key] = True - elif ptype == "array": - result[key] = [] - elif ptype == "object": - result[key] = generate_dummy_args(prop) - return result + schema_type = schema.get("type", "object") + enum = schema.get("enum") + if enum: + return enum[0] + if schema_type == "string": + return "dummy_value" + if schema_type == "integer": + return 42 + if schema_type == "number": + return 3.14 + if schema_type == "boolean": + return True + if schema_type == "array": + items = schema.get("items") + if items: + return [generate_dummy_from_schema(items)] + return [] + if schema_type == "object": + props = schema.get("properties", {}) + result = {} + for key, prop in props.items(): + result[key] = generate_dummy_from_schema(prop) + return result + return None def build_tool_calls( @@ -56,7 +67,7 @@ def build_tool_calls( func = t.get("function", {}) fname = func.get("name", "unknown") params = func.get("parameters", {}) - args = generate_dummy_args(params) + args = generate_dummy_from_schema(params) result.append({ "id": f"call_{uuid.uuid4().hex[:24]}", "type": "function", diff --git a/xpyd_sim/server.py b/xpyd_sim/server.py index bc21be7..3b81bed 100644 --- a/xpyd_sim/server.py +++ b/xpyd_sim/server.py @@ -3,9 +3,11 @@ from __future__ import annotations import asyncio +import base64 import json import math import random +import struct import time from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -48,7 +50,11 @@ ToolCallFunction, UsageInfo, ) -from xpyd_sim.common.tools import build_tool_calls, should_generate_tool_calls +from xpyd_sim.common.tools import ( + build_tool_calls, + generate_dummy_from_schema, + should_generate_tool_calls, +) from xpyd_sim.observability import Metrics, RequestLogger, WarmupTracker from xpyd_sim.profile import LatencyProfile from xpyd_sim.scheduler import InferenceRequest, Scheduler, SchedulingConfig @@ -56,6 +62,44 @@ SYSTEM_FINGERPRINT = "fp_xpyd_sim" +def _validate_common_params( + temperature: float | None, + top_p: float | None, + frequency_penalty: float | None, + presence_penalty: float | None, + n: int | None, +) -> str | None: + """Validate common OpenAI parameter ranges. Returns error message or None.""" + if temperature is not None and not (0 <= temperature <= 2): + return f"temperature must be between 0 and 2, got {temperature}" + if top_p is not None and not (0 < top_p <= 1): + return f"top_p must be between 0 (exclusive) and 1, got {top_p}" + if frequency_penalty is not None and not (-2 <= frequency_penalty <= 2): + return f"frequency_penalty must be between -2 and 2, got {frequency_penalty}" + if presence_penalty is not None and not (-2 <= presence_penalty <= 2): + return f"presence_penalty must be between -2 and 2, got {presence_penalty}" + if n is not None and n < 1: + return f"n must be a positive integer, got {n}" + return None + + +def _generate_response_content( + response_format: dict | None, + num_tokens: int, +) -> str | None: + """Generate content based on response_format. Returns None if no special format.""" + if not response_format: + return None + fmt_type = response_format.get("type") + if fmt_type == "json_object": + return json.dumps({"result": render_dummy_text(num_tokens)}) + if fmt_type == "json_schema": + json_schema = response_format.get("json_schema", {}) + schema = json_schema.get("schema", {}) + return json.dumps(generate_dummy_from_schema(schema)) + return None + + def _should_use_tool_calls(req: ChatCompletionRequest) -> bool: """Decide once per request whether to generate tool calls.""" tools = getattr(req, "tools", None) @@ -389,7 +433,13 @@ async def embeddings(request: Request): data = [] for i, text in enumerate(inputs): vec = [random.gauss(0, 1) for _ in range(config.embedding_dim)] - data.append(EmbeddingData(index=i, embedding=vec)) + # Support base64 encoding format + if req.encoding_format == "base64": + packed = struct.pack(f'<{len(vec)}f', *vec) + embedding_value = base64.b64encode(packed).decode('ascii') + else: + embedding_value = vec + data.append(EmbeddingData(index=i, embedding=embedding_value)) return EmbeddingResponse( data=data, @@ -437,6 +487,34 @@ async def chat_completions(request: Request): content={"error": {"message": str(e), "type": "invalid_request_error"}}, ) + # Validate parameter ranges (OpenAI spec) + param_err = _validate_common_params( + req.temperature, req.top_p, req.frequency_penalty, + req.presence_penalty, req.n, + ) + if param_err: + return JSONResponse( + status_code=400, + content={"error": {"message": param_err, "type": "invalid_request_error"}}, + ) + + best_of = getattr(req, 'best_of', None) + if best_of is not None: + n_val = req.n or 1 + if best_of < n_val: + return JSONResponse( + status_code=400, + content={ + "error": { + "message": ( + f"best_of must be >= n, got " + f"best_of={best_of} n={n_val}" + ), + "type": "invalid_request_error", + } + }, + ) + prompt_tokens = count_prompt_tokens(messages=req.messages) max_tokens = get_effective_max_tokens(req.max_completion_tokens, req.max_tokens) n = req.n or 1 @@ -520,11 +598,21 @@ async def chat_completions(request: Request): num_tokens, finish_reason = _compute_output_length( max_tokens, config.eos_min_ratio, ignore_eos ) - text = render_dummy_text(num_tokens) - text, stopped = _check_stop_sequences(text, req.stop) - if stopped: - finish_reason = "stop" - num_tokens = max(1, len(text.split())) + # Check for response_format (json_object / json_schema) + formatted_content = _generate_response_content( + req.response_format, num_tokens, + ) + if formatted_content is not None: + text = formatted_content + else: + text = render_dummy_text(num_tokens) + if formatted_content is None: + text, stopped = _check_stop_sequences( + text, req.stop, + ) + if stopped: + finish_reason = "stop" + num_tokens = max(1, len(text.split())) total_completion += num_tokens max_choice_tokens = max(max_choice_tokens, num_tokens) lp_data = None @@ -602,6 +690,32 @@ async def completions(request: Request): content={"error": {"message": str(e), "type": "invalid_request_error"}}, ) + # Validate parameter ranges (OpenAI spec) + param_err = _validate_common_params( + req.temperature, req.top_p, req.frequency_penalty, + req.presence_penalty, req.n, + ) + if param_err: + return JSONResponse( + status_code=400, + content={"error": {"message": param_err, "type": "invalid_request_error"}}, + ) + if req.best_of is not None: + n_val = req.n or 1 + if req.best_of < n_val: + return JSONResponse( + status_code=400, + content={ + "error": { + "message": ( + f"best_of must be >= n, got " + f"best_of={req.best_of} n={n_val}" + ), + "type": "invalid_request_error", + } + }, + ) + prompt_tokens = count_prompt_tokens(prompt=req.prompt) max_tokens = get_effective_max_tokens(req.max_tokens) n = req.n or 1 @@ -792,13 +906,23 @@ async def _non_stream_chat_scheduled( ) ) else: - text = render_dummy_text(inf_req.generated_tokens) - text, stopped = _check_stop_sequences(text, req.stop) + formatted = _generate_response_content( + req.response_format, + inf_req.generated_tokens, + ) + if formatted is not None: + text = formatted + else: + text = render_dummy_text(inf_req.generated_tokens) finish_reason = inf_req.finish_reason num_tokens = inf_req.generated_tokens - if stopped: - finish_reason = "stop" - num_tokens = max(1, len(text.split())) + if formatted is None: + text, stopped = _check_stop_sequences( + text, req.stop, + ) + if stopped: + finish_reason = "stop" + num_tokens = max(1, len(text.split())) total_completion += num_tokens lp_data = None @@ -954,10 +1078,19 @@ async def _stream_chat_scheduled( token_count += 1 # Build full text, apply stop truncation, then stream - text = render_dummy_text(token_count) + formatted = _generate_response_content( + req.response_format, token_count, + ) + if formatted is not None: + text = formatted + else: + text = render_dummy_text(token_count) finish_reason_override = None - if req.stop: - text, was_stopped = _check_stop_sequences(text, req.stop) + # Skip stop sequences when response_format is JSON + if req.stop and formatted is None: + text, was_stopped = _check_stop_sequences( + text, req.stop, + ) if was_stopped: finish_reason_override = "stop" @@ -965,8 +1098,14 @@ async def _stream_chat_scheduled( for i, token in enumerate(tokens): token_text = (" " + token) if i > 0 else token chunk_lp = None - if req.logprobs and req.top_logprobs and req.top_logprobs > 0: - chunk_lp = generate_chat_logprobs([token_text], req.top_logprobs) + if ( + req.logprobs + and req.top_logprobs + and req.top_logprobs > 0 + ): + chunk_lp = generate_chat_logprobs( + [token_text], req.top_logprobs, + ) chunk = ChatCompletionChunk( id=req_id, created=created, @@ -1336,7 +1475,14 @@ async def _stream_chat( num_tokens, finish_reason = _compute_output_length( max_tokens, config.eos_min_ratio, ignore_eos ) - text = render_dummy_text(num_tokens) + # Check for response_format (json_object / json_schema) + formatted_content = _generate_response_content( + req.response_format, num_tokens, + ) + if formatted_content is not None: + text = formatted_content + else: + text = render_dummy_text(num_tokens) # First chunk: role chunk = ChatCompletionChunk( @@ -1356,7 +1502,8 @@ async def _stream_chat( # Apply stop sequence truncation before streaming (ensures identical # output to non-streaming mode). - if req.stop: + # Skip stop sequences when response_format produces JSON content. + if req.stop and formatted_content is None: text, stopped = _check_stop_sequences(text, req.stop) if stopped: finish_reason = "stop"