diff --git a/.gitignore b/.gitignore index b401ea4..3f9af01 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ configs/config.json .tmp/ oc-go-cc routatic-proxy +.trunk diff --git a/CLAUDE.md b/CLAUDE.md index f35858a..372702f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -52,3 +52,22 @@ For streaming, the router downgrades to fast models (Qwen3.6 Plus) for better TT - `internal/transformer/` — Request/response format conversion (Anthropic ↔ OpenAI). - `internal/router/fallback.go` — Circuit breaker per model (3 failures = 30s skip). - `configs/config.example.json` — Reference config with all options documented. + +## Skill routing + +When the user's request matches an available skill, invoke it via the Skill tool. When in doubt, invoke the skill. + +Key routing rules: +- Product ideas/brainstorming → invoke /office-hours +- Strategy/scope → invoke /plan-ceo-review +- Architecture → invoke /plan-eng-review +- Design system/plan review → invoke /design-consultation or /plan-design-review +- Full review pipeline → invoke /autoplan +- Bugs/errors → invoke /investigate +- QA/testing site behavior → invoke /qa or /qa-only +- Code review/diff check → invoke /review +- Visual polish → invoke /design-review +- Ship/deploy/PR → invoke /ship or /land-and-deploy +- Save progress → invoke /context-save +- Resume context → invoke /context-restore +- Author a backlog-ready spec/issue → invoke /spec diff --git a/README.md b/README.md index 0681b44..cea0678 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,11 @@ routatic-proxy --version Show version | [MODELS.md](MODELS.md) | Model capabilities, costs, and routing recommendations | | [CONTRIBUTING.md](CONTRIBUTING.md) | Development setup, architecture, how it works | | [TROUBLESHOOTING.md](TROUBLESHOOTING.md) | Common issues and debug mode | +| [docs/architecture.md](docs/architecture.md) | System design, request flow, module overview | +| [docs/reference-api.md](docs/reference-api.md) | HTTP API reference (endpoints, streaming, errors) | +| [docs/howto-add-model.md](docs/howto-add-model.md) | Adding new models (zero code changes) | +| [docs/howto-custom-routing.md](docs/howto-custom-routing.md) | Customizing scenario detection and model selection | +| [docs/howto-debug-routing.md](docs/howto-debug-routing.md) | Debugging routing issues and common problems | ## License diff --git a/cmd/routatic-proxy/main.go b/cmd/routatic-proxy/main.go index b2fc137..a53e421 100644 --- a/cmd/routatic-proxy/main.go +++ b/cmd/routatic-proxy/main.go @@ -243,11 +243,11 @@ func initCmd() *cobra.Command { return nil } - if err := os.MkdirAll(configDir, 0755); err != nil { + if err := os.MkdirAll(configDir, 0700); err != nil { return fmt.Errorf("failed to create config directory: %w", err) } - if err := os.WriteFile(configPath, []byte(getDefaultConfig()), 0644); err != nil { + if err := os.WriteFile(configPath, []byte(getDefaultConfig()), 0600); err != nil { return fmt.Errorf("failed to write config file: %w", err) } diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..98eb841 --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,118 @@ +# Architecture + +routatic-proxy sits between Claude Code and upstream model providers, intercepting Anthropic API requests and routing them to the optimal model. The design prioritizes cost efficiency, reliability, and zero-config operation. + +## Request Flow + +``` +Claude Code + │ + ▼ +┌─────────────────────────────────────────────┐ +│ routatic-proxy │ +│ │ +│ 1. Parse Anthropic MessageRequest │ +│ 2. Count tokens (tiktoken cl100k_base) │ +│ 3. Analyze request facts │ +│ - Has images? New images? │ +│ - Complex patterns? Thinking patterns? │ +│ 4. Detect scenario │ +│ 5. Select model + fallback chain │ +│ 6. Transform to provider format │ +│ 7. Forward to upstream │ +│ 8. Transform response back to Anthropic │ +│ 9. Stream SSE to Claude Code │ +└─────────────────────────────────────────────┘ + │ + ▼ +OpenCode Go / OpenCode Zen +``` + +## Core Modules + +| Module | Purpose | +|--------|---------| +| `internal/handlers` | HTTP request handling, message parsing, orchestration | +| `internal/router` | Scenario detection, model selection, fallback chains | +| `internal/transformer` | Anthropic ↔ OpenAI/Responses/Gemini format conversion | +| `internal/client` | Upstream HTTP client, endpoint classification | +| `internal/config` | JSON config loading, hot reload, atomic access | +| `internal/provider` | Provider-based dispatch (new path, replaces legacy client) | +| `internal/token` | Token counting via tiktoken | +| `internal/daemon` | Background mode, PID management, autostart | + +## Scenario-Based Routing + +The router analyzes each request and assigns it a scenario. Each scenario maps to a primary model and a fallback chain, all config-driven. + +**Detection priority** (highest to lowest): + +1. **Long Context** — token count exceeds `context_threshold` (default 100K) → MiniMax (1M context) +2. **Vision** — request contains images → vision-capable model +3. **Complex** — architectural patterns, tool operations → GLM-5.1 +4. **Think** — reasoning keywords in system prompt → GLM-5 +5. **Background** — simple read-only ops, no tools → Qwen3.5 Plus +6. **Default** → Kimi K2.6 + +**Streaming override**: when `enable_streaming_scenario_routing` is false (default), streaming requests always route to the `fast` model (Qwen3.6 Plus) for better TTFT. + +## Request Transformation + +Claude Code sends Anthropic Messages API format. The proxy transforms to the provider's native format: + +| Provider | Format | Endpoint | +|----------|--------|----------| +| OpenCode Go (most models) | OpenAI Chat Completions | `/v1/chat/completions` | +| OpenCode Go (MiniMax, Qwen) | Anthropic Messages | `/v1/messages` | +| OpenCode Zen (Claude, Qwen) | Anthropic Messages | `/v1/messages` | +| OpenCode Zen (GPT models) | OpenAI Responses | `/v1/responses` | +| OpenCode Zen (Gemini) | Gemini | `/v1/models/{id}` | + +**Key transformation details:** + +- Anthropic `tool_use` ↔ OpenAI `function_calling` bidirectional translation +- Anthropic `thinking` blocks ↔ OpenAI `reasoning_content` preservation +- DeepSeek system message rewriting to prevent prefix cache invalidation +- Image blocks: base64 data URL for vision models, `[Image]` placeholder for non-vision +- `cache_control` stripping for non-DeepSeek models + +## Fallback & Circuit Breaker + +When a model fails, the proxy tries the next model in the chain. The circuit breaker prevents repeated calls to failing models: + +``` +Closed (normal) → 3 failures → Open (skip) → 30s timeout → Half-Open (test) → success → Closed + → failure → Open +``` + +- Only 5xx errors and network failures trigger the circuit breaker +- 4xx errors (bad request, rate limit) skip the breaker — retrying won't help +- Per-model tracking: each model has its own circuit breaker + +## Streaming Architecture + +Streaming uses a per-byte idle watchdog instead of a server-level write timeout: + +1. Server `WriteTimeout` is 0 (disabled) — long SSE streams must not be killed mid-flight +2. Each upstream read uses `http.ResponseController.SetReadDeadline` that resets on every successful byte +3. If no byte arrives within `stream_timeout_ms`, the connection is treated as stuck +4. Heartbeat comments (`:keepalive\n\n`) are sent every 3s to keep the connection alive +5. Client disconnects during stream are logged at Debug level — normal during tool execution + +## Configuration Hot Reload + +When `hot_reload: true`, a filesystem watcher monitors the config file. Changes are applied atomically via `AtomicConfig` — all readers see a consistent snapshot without locks. The HTTP server, model router, and provider registry all read from the same atomic pointer. + +## Polymorphic Field Handling + +Anthropic's `system` and `content` fields accept both strings and arrays. The `pkg/types` package uses `json.RawMessage` with accessor methods (`SystemText()`, `ContentBlocks()`) to handle both formats transparently. + +## Design Decisions + +**Why config-driven routing?** Adding a new model requires zero code changes — just add an entry to `config.json`. The scenario detector, fallback chains, and model metadata are all declarative. + +**Why not use Anthropic format everywhere?** Most upstream models only support OpenAI Chat Completions format. The proxy handles the translation so Claude Code doesn't need to know which provider it's talking to. + +**Why per-read idle timeout instead of WriteTimeout?** Claude Code's tool execution can pause streams for minutes. A server-level timeout would kill active streams. The per-byte watchdog only triggers when the upstream is truly stuck. + +**Why rewrite DeepSeek system messages?** DeepSeek internally reorders all system-role messages to the front. Claude Code injects system reminders mid-conversation, which would shift the message history and invalidate the prefix cache. Wrapping them in `` tags prevents reordering. diff --git a/docs/howto-add-model.md b/docs/howto-add-model.md new file mode 100644 index 0000000..bfb4981 --- /dev/null +++ b/docs/howto-add-model.md @@ -0,0 +1,187 @@ +# How to Add a New Model + +Adding a new model requires zero code changes. Everything is config-driven. + +## Step 1: Identify the Provider and Endpoint + +Determine which upstream provider the model uses and which endpoint format it accepts: + +| Provider | Endpoint | Format | +|----------|----------|--------| +| `opencode-go` | `/v1/chat/completions` | OpenAI Chat Completions (default) | +| `opencode-go` | `/v1/messages` | Anthropic Messages (MiniMax, Qwen) | +| `opencode-zen` | `/v1/chat/completions` | OpenAI Chat Completions | +| `opencode-zen` | `/v1/messages` | Anthropic Messages (Claude, Qwen) | +| `opencode-zen` | `/v1/responses` | OpenAI Responses (GPT models) | +| `opencode-zen` | `/v1/models/{id}` | Gemini | + +## Step 2: Add Model Metadata + +Edit `internal/config/model_registry.go` and add the model to `modelMetadata`: + +```go +"my-new-model": { + ContextWindow: 256000, + MaxOutputTokens: 8192, + Vision: false, + SupportsTools: true, +}, +``` + +This metadata is used by `ResolveModelConfig` to fill in defaults when the model is referenced in config. + +## Step 3: Add Endpoint Classification (Zen only) + +If the model uses Zen, add it to the appropriate classifier in `internal/client/opencode.go`: + +```go +// For Anthropic endpoint: +func isZenAnthropicModel(modelID string) bool { + // ... + if modelID == "my-new-model" { + return true + } + // ... +} + +// For Responses endpoint: +func isResponsesModel(modelID string) bool { + // ... + if modelID == "my-new-model" { + return true + } + // ... +} + +// For Gemini endpoint: +func isGeminiModel(modelID string) bool { + // ... + if modelID == "my-new-model" { + return true + } + // ... +} +``` + +If the model uses Go provider and requires the Anthropic endpoint (not Chat Completions), add it to `IsAnthropicModel`: + +```go +func IsAnthropicModel(modelID string) bool { + switch modelID { + // ... + case "my-new-model": + return true + // ... + } +} +``` + +## Step 4: Add to Config + +Add the model to your `config.json`: + +**As a scenario model:** + +```json +{ + "models": { + "default": { + "provider": "opencode-go", + "model_id": "my-new-model", + "temperature": 0.7, + "max_tokens": 4096 + } + } +} +``` + +**As a model override (for direct requests):** + +```json +{ + "model_overrides": { + "my-new-model": { + "provider": "opencode-go", + "model_id": "my-new-model", + "temperature": 0.7, + "max_tokens": 8192 + } + } +} +``` + +**As a fallback:** + +```json +{ + "fallbacks": { + "default": [ + { "provider": "opencode-go", "model_id": "my-new-model" } + ] + } +} +``` + +## Step 5: Test + +```bash +# Validate config +routatic-proxy validate + +# Test with a request +curl -X POST http://localhost:3456/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "my-new-model", + "max_tokens": 100, + "messages": [{"role": "user", "content": "Hello"}] + }' +``` + +## Model-Specific Considerations + +### Models requiring Anthropic endpoint + +Some models (MiniMax, Qwen on Go provider) only accept Anthropic Messages format, not OpenAI Chat Completions. These need `IsAnthropicModel` to return true. + +### Models with thinking/reasoning + +If the model supports thinking mode (DeepSeek, OpenAI o-series), configure: + +```json +{ + "thinking": { "type": "enabled" }, + "reasoning_effort": "high" +} +``` + +The proxy handles the Anthropic `thinking` ↔ OpenAI `reasoning_content` translation automatically. + +### Models with tool format issues + +If the model doesn't support Anthropic's `type: "custom"` tool shorthands, set: + +```json +{ + "anthropic_tools_disabled": true +} +``` + +This forces the request through the Chat Completions transform path. + +### Models with vision support + +Set `"vision": true` in the model metadata to enable image routing: + +```go +"my-vision-model": { + ContextWindow: 256000, + MaxOutputTokens: 8192, + Vision: true, + SupportsTools: true, +}, +``` + +### Temperature constraints + +Some models have hard temperature requirements (e.g., kimi-k2.7-code requires temperature=1). Add constraints in `constrainTemperature` in `internal/transformer/request.go`. diff --git a/docs/howto-custom-routing.md b/docs/howto-custom-routing.md new file mode 100644 index 0000000..0d3cee6 --- /dev/null +++ b/docs/howto-custom-routing.md @@ -0,0 +1,146 @@ +# How to Customize Model Routing + +routatic-proxy routes requests to different models based on request content. You can customize this behavior through configuration. + +## Understanding Scenarios + +Each request is classified into a scenario, which maps to a model: + +| Scenario | Trigger | Default Model | +|----------|---------|---------------| +| `default` | No special patterns detected | Kimi K2.6 | +| `complex` | Architectural keywords, tool operations | GLM-5.1 | +| `think` | Reasoning keywords in system prompt | GLM-5 | +| `background` | Simple read-only ops (ls, cat, "what is") | Qwen3.5 Plus | +| `long_context` | Token count > threshold (default 100K) | MiniMax M2.5 | +| `vision` | Request contains images | (must configure) | +| `fast` | Streaming requests (when scenario routing disabled) | Qwen3.6 Plus | + +## Override Scenario Models + +Change which model handles each scenario: + +```json +{ + "models": { + "default": { + "provider": "opencode-go", + "model_id": "kimi-k2.6", + "temperature": 0.7, + "max_tokens": 4096 + }, + "complex": { + "provider": "opencode-go", + "model_id": "glm-5.1", + "temperature": 0.7, + "max_tokens": 4096 + } + } +} +``` + +## Add Model Overrides + +Model overrides let specific model names bypass scenario routing: + +```json +{ + "model_overrides": { + "deepseek-v4-pro": { + "provider": "opencode-zen", + "model_id": "deepseek-v4-pro", + "temperature": 0.7, + "max_tokens": 8192, + "reasoning_effort": "max", + "thinking": { "type": "enabled" } + } + } +} +``` + +When Claude Code requests `deepseek-v4-pro`, it goes directly to that model regardless of scenario. + +## Customize Fallback Chains + +Define per-scenario fallback chains: + +```json +{ + "fallbacks": { + "default": [ + { "provider": "opencode-go", "model_id": "mimo-v2-pro" }, + { "provider": "opencode-go", "model_id": "qwen3.6-plus" } + ], + "complex": [ + { "provider": "opencode-go", "model_id": "glm-5" }, + { "provider": "opencode-go", "model_id": "kimi-k2.6" } + ], + "long_context": [ + { "provider": "opencode-go", "model_id": "minimax-m2.7" }, + { "provider": "opencode-go", "model_id": "kimi-k2.6" } + ] + } +} +``` + +If a model in the chain fails (5xx error, timeout), the next model is tried automatically. + +## Adjust Context Threshold + +The long-context threshold determines when the proxy switches to a 1M-context model: + +```json +{ + "models": { + "long_context": { + "provider": "opencode-go", + "model_id": "minimax-m2.5", + "context_threshold": 80000 + } + } +} +``` + +## Enable Streaming Scenario Routing + +By default, streaming requests bypass scenario routing and use the `fast` model. Enable scenario-based routing for streaming: + +```json +{ + "enable_streaming_scenario_routing": true +} +``` + +This is useful for multi-agent and review workflows where streaming requests need capability, not just speed. + +## Disable Requested Model Routing + +By default, the proxy respects the `model` field from Claude Code. Disable this to force scenario routing: + +```json +{ + "respect_requested_model": false} +``` + +## Custom Scenario Detection + +Scenario detection is keyword-based. To add custom patterns, edit `internal/router/scenarios.go`: + +- `hasComplexPattern()` — keywords that trigger the `complex` scenario +- `hasThinkingPattern()` — keywords that trigger the `think` scenario +- `hasBackgroundPattern()` — keywords that trigger the `background` scenario +- `hasVisualIntent()` — keywords that suggest image-related requests + +## Verify Routing + +Check which scenario was selected in the logs: + +``` +INFO routing request scenario=complex model=glm-5.1 provider=opencode-go tokens=1500 +``` + +Or use the validate command to check config: + +```bash +routatic-proxy validate +``` diff --git a/docs/howto-debug-routing.md b/docs/howto-debug-routing.md new file mode 100644 index 0000000..9c62665 --- /dev/null +++ b/docs/howto-debug-routing.md @@ -0,0 +1,117 @@ +# How to Debug Routing Issues + +When requests route to unexpected models or fail, here's how to diagnose. + +## Enable Debug Logging + +Set log level to debug in config: + +```json +{ + "logging": { + "level": "debug" + } +} +``` + +Or temporarily via environment: + +```bash +ROUTATIC_PROXY_LOG_LEVEL=debug routatic-proxy serve +``` + +Debug logs show: +- Request parsing and token counting +- Scenario detection with reasons +- Model selection and fallback attempts +- Upstream request/response details + +## Check Scenario Detection + +The log line `INFO routing request` shows the selected scenario and model: + +``` +INFO routing request scenario=complex model=glm-5.1 provider=opencode-go tokens=1500 +``` + +If the scenario is wrong, check the keyword patterns in `internal/router/scenarios.go`. + +## Check Circuit Breakers + +If a model is being skipped, the circuit breaker may be open: + +``` +INFO circuit breaker open, skipping model model=kimi-k2.6 attempt=2 total=3 +``` + +Circuit breakers open after 3 consecutive failures and recover after 30 seconds. Wait or restart the proxy to reset. + +## Check Model Configuration + +Validate your config: + +```bash +routatic-proxy validate +``` + +Common issues: +- Model ID typo in `models` or `model_overrides` +- Missing provider field (defaults to `opencode-go`) +- Wrong endpoint format for the model + +## Check Upstream Errors + +5xx errors from upstream trigger fallback: + +``` +WARN model failed, trying fallback model=kimi-k2.6 error="API error 502: ..." remaining=2 +``` + +4xx errors skip the circuit breaker (retrying won't help): + +``` +WARN non-retryable error (skipping circuit breaker), trying fallback model=kimi-k2.6 error="API error 400: ..." +``` + +## Check Token Counting + +If requests route to `long_context` unexpectedly, check the token count: + +```bash +curl -X POST http://localhost:3456/v1/messages/count_tokens \ + -H "Content-Type: application/json" \ + -d '{"system":"...","messages":[...]}' +``` + +## Check Streaming + +Streaming issues show as idle timeouts or client disconnects: + +``` +WARN upstream openai stream idle, trying next model model=qwen3.6-plus idle_timeout=5m0s +``` + +``` +DEBUG client disconnected during stream +``` + +The second is normal during Claude Code tool execution — the client pauses the stream while processing tool results. + +## Common Routing Scenarios + +**Request routes to default instead of complex:** +- Check if the keyword is in `hasComplexPattern()` — it only checks system and user messages +- Check if a tool keyword in `hasBackgroundPattern()` is blocking it + +**Request routes to fast instead of complex (streaming):** +- Streaming routes to `fast` by default when `enable_streaming_scenario_routing` is false +- Enable scenario routing: `"enable_streaming_scenario_routing": true` + +**Request routes to long_context unexpectedly:** +- Check token count — the default threshold is 100K tokens +- Image tokens add ~1500 per image — large images can push over the threshold +- Adjust threshold: `"context_threshold": 80000` in the long_context model config + +**Vision request routes to non-vision model:** +- Check `"vision": true` in the model metadata (`internal/config/model_registry.go`) +- Check that the model is configured in the `vision` scenario in config diff --git a/docs/reference-api.md b/docs/reference-api.md new file mode 100644 index 0000000..6a13dae --- /dev/null +++ b/docs/reference-api.md @@ -0,0 +1,153 @@ +# HTTP API Reference + +routatic-proxy exposes an Anthropic-compatible API. Claude Code connects to it as if it were the Anthropic API. + +## Endpoints + +### `POST /v1/messages` + +The primary endpoint. Accepts Anthropic Messages API requests and returns responses in the same format. + +**Request body** — standard Anthropic `MessageRequest`: + +```json +{ + "model": "claude-sonnet-4-20250514", + "max_tokens": 4096, + "system": "You are a helpful assistant.", + "messages": [ + { + "role": "user", + "content": "Hello, world!" + } + ], + "stream": true, + "tools": [] +} +``` + +**Response** — Anthropic `MessageResponse` (non-streaming) or SSE stream (streaming). + +**Routing behavior:** + +- If `model` matches an entry in `model_overrides`, that model is used as primary with a scenario-derived safety net +- Otherwise, scenario-based routing selects the model based on request content and token count +- Set `respect_requested_model: false` in config to force scenario routing regardless of the `model` field + +**Headers:** + +| Header | Value | +|--------|-------| +| `X-Request-ID` | Unique request identifier (generated or forwarded from client) | +| `Content-Type` | `application/json` (non-streaming) or `text/event-stream` (streaming) | + +### `POST /v1/messages/count_tokens` + +Counts tokens for a message array without generating a response. + +**Request body:** + +```json +{ + "system": "System prompt text", + "messages": [ + { "role": "user", "content": "Hello" } + ] +} +``` + +**Response:** + +```json +{ + "input_tokens": 42 +} +``` + +### `GET /health` + +Returns server health status. + +**Response:** + +```json +{ + "status": "ok", + "version": "1.2.3", + "models_configured": 6, + "uptime": "2h30m" +} +``` + +### `GET /statusline` + +Returns compact status for TUI integration (statusline, tmux bar). + +**Response:** + +```json +{ + "status": "running", + "version": "1.2.3", + "uptime": "2h30m" +} +``` + +## Error Responses + +Errors follow Anthropic's error format: + +```json +{ + "type": "error", + "error": { + "type": "api_error", + "message": "description of what went wrong" + } +} +``` + +**HTTP status codes:** + +| Code | Meaning | +|------|---------| +| 400 | Invalid request body | +| 405 | Method not allowed (non-POST on /v1/messages) | +| 413 | Request body too large (>100MB) | +| 429 | Rate limited | +| 500 | Internal error (routing failed, transform error) | +| 502 | All upstream models failed | + +## Streaming + +Streaming responses use Server-Sent Events (SSE) with Anthropic's event format: + +``` +event: message_start +data: {"type":"message_start","message":{"id":"msg_...","type":"message","role":"assistant","content":[],"model":"...","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":42,"output_tokens":0}}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":0} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":15}} + +event: message_stop +data: {"type":"message_stop"} +``` + +**Heartbeat**: keepalive comments (`:keepalive\n\n`) are sent every 3 seconds during streaming. + +## Rate Limiting + +The proxy applies per-IP rate limiting (default: 100 requests/minute). Rate-limited requests receive HTTP 429. + +## Request Deduplication + +Optional request deduplication (`request_dedup` in config) prevents processing identical concurrent requests. Deduplicated requests receive HTTP 200 with no body. diff --git a/go.mod b/go.mod index cf4d175..9504095 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,12 @@ module github.com/routatic/proxy -go 1.25 +go 1.25.0 require ( github.com/fsnotify/fsnotify v1.10.1 github.com/pkoukk/tiktoken-go v0.1.8 github.com/spf13/cobra v1.8.1 - golang.org/x/sys v0.13.0 + golang.org/x/sys v0.46.0 ) require ( diff --git a/go.sum b/go.sum index d063323..c99c83f 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw= +golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/buildinfo/buildinfo.go b/internal/buildinfo/buildinfo.go new file mode 100644 index 0000000..f83074a --- /dev/null +++ b/internal/buildinfo/buildinfo.go @@ -0,0 +1,18 @@ +package buildinfo + +import "os" + +var Version = "dev" +var BuildTime = "unknown" + +func BinaryPath() string { + path, err := os.Executable() + if err != nil { + return "unknown" + } + return path +} + +func PID() int { + return os.Getpid() +} diff --git a/internal/config/config.go b/internal/config/config.go index 983f064..0bde30b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -27,7 +27,11 @@ type ModelConfig struct { WireFormat string `json:"wire_format,omitempty"` // "auto" (default), "openai", "anthropic", "responses", "gemini" Temperature float64 `json:"temperature"` MaxTokens int `json:"max_tokens"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + ContextWindow int `json:"context_window,omitempty"` + ContextMargin int `json:"context_margin,omitempty"` ContextThreshold int `json:"context_threshold"` + SupportsTools *bool `json:"supports_tools,omitempty"` ReasoningEffort string `json:"reasoning_effort"` Thinking json.RawMessage `json:"thinking,omitempty"` Vision bool `json:"vision"` diff --git a/internal/config/loader.go b/internal/config/loader.go index c8641d7..48534fe 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -232,6 +232,10 @@ func validate(cfg *Config) error { return err } + if err := validateSingleAPIKey(cfg.APIKey); err != nil { + return err + } + if err := validateModelOverrides(cfg.ModelOverrides); err != nil { return err } @@ -240,6 +244,25 @@ func validate(cfg *Config) error { return err } + if err := validateVisionModels(cfg); err != nil { + return err + } + + return nil +} + +// validateVisionModels checks that when a vision scenario is configured, +// the primary model supports vision. Vision scenarios are optional — +// only validate them when they appear in the models map. +func validateVisionModels(cfg *Config) error { + for _, scenario := range []string{"vision", "vision_complex", "vision_long_context"} { + if model, ok := cfg.Models[scenario]; ok && !model.Vision { + resolved := ResolveModelConfig(model) + if !resolved.Vision { + return fmt.Errorf("models[%q] does not support vision but is configured for vision scenario", scenario) + } + } + } return nil } @@ -268,6 +291,16 @@ func validateAnthropicToolsDisabled(cfg *Config) error { // validateAPIKeys ensures no api_keys entries contain unresolved ${VAR} placeholders. // Unresolved placeholders indicate the user did not set the corresponding env vars, // and the literal placeholder string would be sent as a bearer token. +func validateSingleAPIKey(key string) error { + if key == "" { + return nil + } + if envVarPattern.MatchString(key) { + return fmt.Errorf("api_key contains unresolved env var %q — set the corresponding environment variable or use api_keys", key) + } + return nil +} + func validateAPIKeys(keys []string) error { for i, key := range keys { if key == "" { diff --git a/internal/config/model_registry.go b/internal/config/model_registry.go new file mode 100644 index 0000000..8626448 --- /dev/null +++ b/internal/config/model_registry.go @@ -0,0 +1,56 @@ +package config + +const DefaultContextMargin = 8192 + +type ModelMetadata struct { + ContextWindow int + MaxOutputTokens int + Vision bool + SupportsTools bool +} + +var modelMetadata = map[string]ModelMetadata{ + "deepseek-v4-pro": {ContextWindow: 1000000, MaxOutputTokens: 8192, Vision: false, SupportsTools: true}, + "deepseek-v4-flash": {ContextWindow: 1000000, MaxOutputTokens: 4096, Vision: false, SupportsTools: true}, + "kimi-k2.6": {ContextWindow: 256000, MaxOutputTokens: 8192, Vision: true, SupportsTools: true}, + "kimi-k2.5": {ContextWindow: 256000, MaxOutputTokens: 8192, Vision: true, SupportsTools: true}, + "mimo-v2.5-pro": {ContextWindow: 1000000, MaxOutputTokens: 16384, Vision: false, SupportsTools: true}, + "mimo-v2.5": {ContextWindow: 1000000, MaxOutputTokens: 8192, Vision: false, SupportsTools: true}, + "minimax-m2.7": {ContextWindow: 200000, MaxOutputTokens: 8192, Vision: false, SupportsTools: true}, + "minimax-m2.5": {ContextWindow: 200000, MaxOutputTokens: 4096, Vision: false, SupportsTools: true}, + "qwen3.6-plus": {ContextWindow: 1000000, MaxOutputTokens: 8192, Vision: true, SupportsTools: true}, + "qwen3.5-plus": {ContextWindow: 1000000, MaxOutputTokens: 8192, Vision: true, SupportsTools: true}, + "glm-5.1": {ContextWindow: 200000, MaxOutputTokens: 8192, Vision: false, SupportsTools: true}, + "glm-5": {ContextWindow: 200000, MaxOutputTokens: 8192, Vision: false, SupportsTools: true}, +} + +func ResolveModelConfig(model ModelConfig) ModelConfig { + if meta, ok := modelMetadata[model.ModelID]; ok { + if model.ContextWindow == 0 { + model.ContextWindow = meta.ContextWindow + } + if model.MaxOutputTokens == 0 { + model.MaxOutputTokens = meta.MaxOutputTokens + } + if !model.Vision { + model.Vision = meta.Vision + } + if model.SupportsTools == nil { + v := meta.SupportsTools + model.SupportsTools = &v + } + } + if model.ContextMargin == 0 { + model.ContextMargin = DefaultContextMargin + } + if model.SupportsTools == nil { + v := true + model.SupportsTools = &v + } + return model +} + +func SupportsTools(model ModelConfig) bool { + model = ResolveModelConfig(model) + return model.SupportsTools == nil || *model.SupportsTools +} diff --git a/internal/daemon/background.go b/internal/daemon/background.go index d429438..1cd977c 100644 --- a/internal/daemon/background.go +++ b/internal/daemon/background.go @@ -24,7 +24,7 @@ func ForkIntoBackground(opts BackgroundOpts) error { return fmt.Errorf("cannot create config directory: %w", err) } if pid, err := GetPID(paths.PIDFile); err == nil { - if IsProcessRunning(pid) { + if IsProcessRunning(pid) && IsAppProcess(pid, AppName) { return fmt.Errorf("server is already running (PID %d)", pid) } _ = os.Remove(paths.PIDFile) diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 8376c25..bde7a9b 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -69,6 +69,12 @@ func TestIsProcessRunning_CurrentProcess(t *testing.T) { } } +func TestIsAppProcess_CurrentTestProcessIsNotOCGoCC(t *testing.T) { + if IsAppProcess(os.Getpid(), AppName) { + t.Errorf("current test process should not be reported as %s", AppName) + } +} + func TestIsProcessRunning_NonexistentPID(t *testing.T) { // PID 1 is typically init — but on some systems it may not exist. // Use an almost-certainly-invalid PID instead. diff --git a/internal/daemon/paths.go b/internal/daemon/paths.go index 6079cb8..7cfda6b 100644 --- a/internal/daemon/paths.go +++ b/internal/daemon/paths.go @@ -7,6 +7,7 @@ import ( "os/exec" "path/filepath" "runtime" + "syscall" ) const ( @@ -71,8 +72,20 @@ func GetPID(pidPath string) (int, error) { } // WritePID writes the given PID to a file. -func WritePID(pidPath string, pid int) error { - return os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", pid)), 0644) +// Uses O_NOFOLLOW to atomically reject symlinks at open time, +// preventing symlink-traversal attacks (CWE-59). +func WritePID(pidPath string, pid int) (err error) { + f, err := os.OpenFile(pidPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC|syscall.O_NOFOLLOW, 0644) + if err != nil { + return fmt.Errorf("refusing to write PID file: %w", err) + } + defer func() { + if cerr := f.Close(); cerr != nil && err == nil { + err = cerr + } + }() + _, err = fmt.Fprintf(f, "%d", pid) + return err } // FindBinary returns the absolute path to the routatic-proxy binary. diff --git a/internal/daemon/process_unix.go b/internal/daemon/process_unix.go index 6f02400..29c48b9 100644 --- a/internal/daemon/process_unix.go +++ b/internal/daemon/process_unix.go @@ -5,6 +5,8 @@ package daemon import ( "fmt" "os" + "path/filepath" + "strings" "syscall" ) @@ -19,6 +21,16 @@ func IsProcessRunning(pid int) bool { return err == nil } +func IsAppProcess(pid int, appName string) bool { + exe, err := os.Readlink(filepath.Join("/proc", fmt.Sprintf("%d", pid), "exe")) + if err != nil { + return false + } + base := strings.ToLower(filepath.Base(exe)) + app := strings.ToLower(appName) + return base == app || strings.HasPrefix(base, app+"-") +} + // StopProcess sends SIGTERM to a process and waits for it to exit. func StopProcess(pid int) error { process, err := os.FindProcess(pid) diff --git a/internal/daemon/process_windows.go b/internal/daemon/process_windows.go index 8500a79..18e23b2 100644 --- a/internal/daemon/process_windows.go +++ b/internal/daemon/process_windows.go @@ -5,7 +5,11 @@ package daemon import ( "fmt" "os" + "path/filepath" + "strings" "syscall" + + "golang.org/x/sys/windows" ) const windowsSynchronize = 0x00100000 @@ -22,6 +26,24 @@ func IsProcessRunning(pid int) bool { return err == nil && event == syscall.WAIT_TIMEOUT } +func IsAppProcess(pid int, appName string) bool { + handle, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, uint32(pid)) + if err != nil { + return false + } + defer func() { _ = windows.CloseHandle(handle) }() + + buf := make([]uint16, windows.MAX_PATH) + size := uint32(len(buf)) + if err := windows.QueryFullProcessImageName(handle, 0, &buf[0], &size); err != nil { + return false + } + + base := strings.TrimSuffix(strings.ToLower(filepath.Base(windows.UTF16ToString(buf[:size]))), ".exe") + app := strings.TrimSuffix(strings.ToLower(appName), ".exe") + return base == app || strings.HasPrefix(base, app+"-") +} + // StopProcess terminates a process on Windows. // Unlike the Unix implementation which sends SIGTERM for graceful shutdown, // this uses process.Kill() (TerminateProcess) which immediately terminates the diff --git a/internal/handlers/health.go b/internal/handlers/health.go index 7205fac..80a38ec 100644 --- a/internal/handlers/health.go +++ b/internal/handlers/health.go @@ -4,8 +4,10 @@ import ( "encoding/json" "net/http" + "github.com/routatic/proxy/internal/buildinfo" "github.com/routatic/proxy/internal/metrics" "github.com/routatic/proxy/internal/router" + "github.com/routatic/proxy/internal/status" "github.com/routatic/proxy/internal/token" "github.com/routatic/proxy/pkg/types" ) @@ -15,14 +17,16 @@ type HealthHandler struct { tokenCounter *token.Counter fallbackHandler *router.FallbackHandler metrics *metrics.Metrics + statusStore *status.Store } // NewHealthHandler creates a new health handler. -func NewHealthHandler(tokenCounter *token.Counter, fallbackHandler *router.FallbackHandler, metrics *metrics.Metrics) *HealthHandler { +func NewHealthHandler(tokenCounter *token.Counter, fallbackHandler *router.FallbackHandler, metrics *metrics.Metrics, statusStore *status.Store) *HealthHandler { return &HealthHandler{ tokenCounter: tokenCounter, fallbackHandler: fallbackHandler, metrics: metrics, + statusStore: statusStore, } } @@ -38,8 +42,12 @@ func (h *HealthHandler) HandleHealth(w http.ResponseWriter, r *http.Request) { } response := map[string]interface{}{ - "status": "ok", - "service": "routatic-proxy", + "status": "ok", + "service": "routatic-proxy", + "version": buildinfo.Version, + "build_time": buildinfo.BuildTime, + "pid": buildinfo.PID(), + "binary": buildinfo.BinaryPath(), "metrics": map[string]interface{}{ "requests_received": snapshot.RequestsReceived, "requests_success": snapshot.RequestsSuccess, @@ -61,6 +69,21 @@ func (h *HealthHandler) HandleHealth(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(response) } +func (h *HealthHandler) HandleStatusline(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(http.StatusOK) + if h.statusStore == nil { + _ = json.NewEncoder(w).Encode(status.Snapshot{SchemaVersion: 1, Source: "empty", Stale: true}) + return + } + _ = json.NewEncoder(w).Encode(h.statusStore.Snapshot()) +} + // HandleCountTokens handles POST /v1/messages/count_tokens. func (h *HealthHandler) HandleCountTokens(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { diff --git a/internal/handlers/health_test.go b/internal/handlers/health_test.go index efad4e4..99bb4dc 100644 --- a/internal/handlers/health_test.go +++ b/internal/handlers/health_test.go @@ -39,6 +39,27 @@ func TestHandleCountTokensSupportsAnthropicContentBlocks(t *testing.T) { } } +func TestHandleHealthIncludesBuildInfo(t *testing.T) { + handler := newTestHealthHandler(t) + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/health", nil) + + handler.HandleHealth(recorder, req) + + if got, want := recorder.Code, http.StatusOK; got != want { + t.Fatalf("status = %d, want %d; body: %s", got, want, recorder.Body.String()) + } + var response map[string]any + if err := json.Unmarshal(recorder.Body.Bytes(), &response); err != nil { + t.Fatalf("response is invalid JSON: %v", err) + } + for _, key := range []string{"version", "build_time", "pid", "binary"} { + if _, ok := response[key]; !ok { + t.Fatalf("health response missing %s: %s", key, recorder.Body.String()) + } + } +} + func TestHandleCountTokensIncludesSystemToolsAndThinking(t *testing.T) { handler := newTestHealthHandler(t) @@ -87,5 +108,5 @@ func newTestHealthHandler(t *testing.T) *HealthHandler { if err != nil { t.Fatalf("NewCounter() error = %v", err) } - return NewHealthHandler(counter, nil, metrics.New()) + return NewHealthHandler(counter, nil, metrics.New(), nil) } diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index c39740d..eebd9e2 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -3,6 +3,8 @@ package handlers import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -141,8 +143,12 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) return } - // Generate or get request ID for correlation + // Generate or get request ID for correlation. + // Cap externally-provided IDs at 256 bytes to prevent header abuse. requestID := r.Header.Get("X-Request-ID") + if len(requestID) > 256 { + requestID = requestID[:256] + } if requestID == "" { requestID = h.requestIDGen.Generate() } @@ -215,13 +221,16 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) blocks := msg.ContentBlocks() content := extractTextFromBlocks(blocks) mc := router.MessageContent{ - Role: msg.Role, - Content: content, + Role: msg.Role, + Content: content, + HasImage: blocksHaveImage(blocks), + ImageHashes: imageHashesFromBlocks(blocks), } routerMessages = append(routerMessages, mc) tokenMessages = append(tokenMessages, token.MessageContent{ - Role: msg.Role, - Content: content, + Role: msg.Role, + Content: content, + ExtraTokens: imageTokenEstimate(blocks), }) } @@ -233,7 +242,9 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) } // Route to appropriate model and build fallback chain. - modelChain, routeResult, err := h.buildModelChain(anthropicReq.Model, routerMessages, tokenCount, isStreaming) + facts := router.AnalyzeRequestFacts(routerMessages) + needsTools := len(anthropicReq.Tools) > 0 + modelChain, routeResult, err := h.buildModelChain(anthropicReq.Model, routerMessages, tokenCount, isStreaming, anthropicReq.MaxTokens, facts.NeedsVision, needsTools) if err != nil { h.sendError(w, http.StatusInternalServerError, "routing failed", err) return @@ -269,25 +280,43 @@ func (h *MessagesHandler) buildModelChain( routerMessages []router.MessageContent, tokenCount int, isStreaming bool, + requestedMaxTokens int, + needsVision bool, + needsTools bool, ) ([]config.ModelConfig, router.RouteResult, error) { + var chain []config.ModelConfig + var result router.RouteResult + if requestedModel != "" { if overrideResult, ok := h.modelRouter.RouteWithOverride(requestedModel); ok { scenarioResult, err := h.routeOnce(routerMessages, tokenCount, "", isStreaming) if err != nil { - // Override is valid; surface the scenario routing error rather - // than silently dropping the safety net. return overrideResult.GetModelChain(), overrideResult, err } - chain := appendUniqueModels(overrideResult.GetModelChain(), scenarioResult.GetModelChain()) - return chain, overrideResult, nil + chain = appendUniqueModels(overrideResult.GetModelChain(), scenarioResult.GetModelChain()) + result = overrideResult + } + } + + if chain == nil { + var err error + result, err = h.routeOnce(routerMessages, tokenCount, requestedModel, isStreaming) + if err != nil { + return nil, result, err } + chain = result.GetModelChain() } - result, err := h.routeOnce(routerMessages, tokenCount, requestedModel, isStreaming) + decision, err := router.FilterByCapacity(chain, tokenCount, requestedMaxTokens, needsVision, needsTools) if err != nil { return nil, result, err } - return result.GetModelChain(), result, nil + + for _, s := range decision.Skipped { + h.logger.Info("model skipped by capacity filter", "model", s.ModelID, "reason", s.Reason) + } + + return decision.Models, result, nil } // routeOnce performs scenario-based routing, honoring the streaming-scenario-routing @@ -1059,6 +1088,28 @@ func extractTextFromBlocks(blocks []types.ContentBlock) string { return content } +func blocksHaveImage(blocks []types.ContentBlock) bool { + for _, block := range blocks { + if block.Type == "image" && block.Source != nil { + return true + } + } + return false +} + +func imageHashesFromBlocks(blocks []types.ContentBlock) []string { + var hashes []string + for _, block := range blocks { + if block.Type != "image" || block.Source == nil { + continue + } + source := block.Source.Type + "\x00" + block.Source.MediaType + "\x00" + block.Source.Data + "\x00" + block.Source.URL + sum := sha256.Sum256([]byte(source)) + hashes = append(hashes, hex.EncodeToString(sum[:])) + } + return hashes +} + // sendError sends an error response in Anthropic format. func (h *MessagesHandler) sendError(w http.ResponseWriter, statusCode int, message string, err error) { h.logger.Error("request error", diff --git a/internal/handlers/messages_test.go b/internal/handlers/messages_test.go index 74fff39..2e7135c 100644 --- a/internal/handlers/messages_test.go +++ b/internal/handlers/messages_test.go @@ -147,7 +147,7 @@ func TestBuildModelChain_NoOverride_UsesScenarioRoute(t *testing.T) { } h := newTestMessagesHandler(t, cfg) - chain, result, err := h.buildModelChain("", nil, 100, false) + chain, result, err := h.buildModelChain("", nil, 100, false, 4096, false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -184,7 +184,7 @@ func TestBuildModelChain_Override_AppendsScenarioChainDeduped(t *testing.T) { } h := newTestMessagesHandler(t, cfg) - chain, result, err := h.buildModelChain("kimi-k2.6", nil, 100, false) + chain, result, err := h.buildModelChain("kimi-k2.6", nil, 100, false, 4096, false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -227,7 +227,7 @@ func TestBuildModelChain_Override_AppendsUniqueScenarioModels(t *testing.T) { } h := newTestMessagesHandler(t, cfg) - chain, result, err := h.buildModelChain("claude-sonnet-4.5", nil, 100, false) + chain, result, err := h.buildModelChain("claude-sonnet-4.5", nil, 100, false, 4096, false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -265,7 +265,7 @@ func TestBuildModelChain_Override_NoMatchingFallbacksKey(t *testing.T) { } h := newTestMessagesHandler(t, cfg) - chain, _, err := h.buildModelChain("claude-sonnet-4.5", nil, 100, false) + chain, _, err := h.buildModelChain("claude-sonnet-4.5", nil, 100, false, 4096, false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -299,7 +299,7 @@ func TestBuildModelChain_StreamingFlag_UsesStreamingRoute(t *testing.T) { h := newTestMessagesHandler(t, cfg) // Non-streaming: scenario is default - _, resultNonStream, _ := h.buildModelChain("claude-sonnet-4.5", nil, 100, false) + _, resultNonStream, _ := h.buildModelChain("claude-sonnet-4.5", nil, 100, false, 4096, false, false) if resultNonStream.Scenario != router.ScenarioOverride { t.Errorf("non-streaming scenario = %s, want %s", resultNonStream.Scenario, router.ScenarioOverride) } @@ -307,7 +307,7 @@ func TestBuildModelChain_StreamingFlag_UsesStreamingRoute(t *testing.T) { // Streaming: override still wins, but the safety-net uses fast route. // Chain: [claude-sonnet-4.5 (override), mimo-v2-pro (default fallback), // qwen3.6-plus (fast scenario primary), qwen3.5-plus (fast scenario fallback)] - chain, _, _ := h.buildModelChain("claude-sonnet-4.5", nil, 100, true) + chain, _, _ := h.buildModelChain("claude-sonnet-4.5", nil, 100, true, 4096, false, false) want := []string{"claude-sonnet-4.5", "mimo-v2-pro", "qwen3.6-plus", "qwen3.5-plus"} if got := chainIDs(chain); !equalStrings(got, want) { t.Errorf("streaming chain = %v, want %v (safety-net should use RouteForStreaming)", got, want) @@ -331,7 +331,7 @@ func TestBuildModelChain_UnknownModel_FallsThroughToScenarioRoute(t *testing.T) } h := newTestMessagesHandler(t, cfg) - chain, result, err := h.buildModelChain("completely-unknown", nil, 100, false) + chain, result, err := h.buildModelChain("completely-unknown", nil, 100, false, 4096, false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/internal/handlers/token_count.go b/internal/handlers/token_count.go index da6141e..8f2e26c 100644 --- a/internal/handlers/token_count.go +++ b/internal/handlers/token_count.go @@ -12,9 +12,11 @@ import ( func tokenMessagesFromAnthropic(messages []types.Message) []token.MessageContent { tokenMessages := make([]token.MessageContent, 0, len(messages)) for _, msg := range messages { + blocks := msg.ContentBlocks() tokenMessages = append(tokenMessages, token.MessageContent{ - Role: msg.Role, - Content: extractTokenTextFromBlocks(msg.ContentBlocks()), + Role: msg.Role, + Content: extractTokenTextFromBlocks(blocks), + ExtraTokens: imageTokenEstimate(blocks), }) } return tokenMessages @@ -72,3 +74,41 @@ func extractTokenTextFromBlocks(blocks []types.ContentBlock) string { } return content.String() } + +// imageTokenEstimate estimates extra tokens for image blocks. +// For base64 images, derives an estimate from encoded data length (file-size +// heuristic, not pixel-dimension based). For URL images where no data is +// available, returns a default estimate. This is used for routing decisions +// (scenario detection), not billing — Anthropic's actual token cost depends +// on image dimensions after resize. +func imageTokenEstimate(blocks []types.ContentBlock) int { + total := 0 + for _, block := range blocks { + if block.Type != "image" || block.Source == nil { + continue + } + if len(block.Source.Data) > 0 { + total += imageTokenEstimateFromBase64(len(block.Source.Data)) + } else { + total += 1500 + } + } + return total +} + +// imageTokenEstimateFromBase64 estimates token count from base64 image data length. +// Base64 encoding inflates size by ~4/3; raw bytes / 75 ≈ Anthropic image tokens. +func imageTokenEstimateFromBase64(base64Len int) int { + if base64Len == 0 { + return 1500 + } + rawBytes := base64Len * 3 / 4 + tokens := rawBytes / 75 + if tokens < 300 { + return 300 + } + if tokens > 4000 { + return 4000 + } + return tokens +} diff --git a/internal/handlers/token_count_test.go b/internal/handlers/token_count_test.go new file mode 100644 index 0000000..0d4359c --- /dev/null +++ b/internal/handlers/token_count_test.go @@ -0,0 +1,99 @@ +package handlers + +import ( + "testing" + + "github.com/routatic/proxy/pkg/types" +) + +func TestImageTokenEstimateFromBase64_Zero(t *testing.T) { + got := imageTokenEstimateFromBase64(0) + if got != 1500 { + t.Errorf("got %d, want 1500", got) + } +} + +func TestImageTokenEstimateFromBase64_Small(t *testing.T) { + // ~5KB base64 → ~3.7KB raw → ~50 tokens → clamped to 300 + got := imageTokenEstimateFromBase64(5000) + if got != 300 { + t.Errorf("got %d, want 300", got) + } +} + +func TestImageTokenEstimateFromBase64_Medium(t *testing.T) { + // ~150KB base64 → ~112KB raw → ~1500 tokens + got := imageTokenEstimateFromBase64(150000) + if got != 1500 { + t.Errorf("got %d, want 1500", got) + } +} + +func TestImageTokenEstimateFromBase64_Large(t *testing.T) { + // ~400KB base64 → ~300KB raw → ~4000 tokens (at clamp boundary) + got := imageTokenEstimateFromBase64(400000) + if got != 4000 { + t.Errorf("got %d, want 4000", got) + } +} + +func TestImageTokenEstimateFromBase64_Overflow(t *testing.T) { + // ~1MB base64 → clamped to 4000 + got := imageTokenEstimateFromBase64(1000000) + if got != 4000 { + t.Errorf("got %d, want 4000", got) + } +} + +func TestImageTokenEstimate_NoImageBlocks(t *testing.T) { + blocks := []types.ContentBlock{ + {Type: "text", Text: "hello"}, + {Type: "tool_use", Name: "test"}, + } + got := imageTokenEstimate(blocks) + if got != 0 { + t.Errorf("got %d, want 0", got) + } +} + +func TestImageTokenEstimate_Base64Image(t *testing.T) { + blocks := []types.ContentBlock{ + {Type: "text", Text: "hello"}, + {Type: "image", Source: &types.ImageSource{Data: "AAAA", Type: "base64", MediaType: "image/png"}}, + } + got := imageTokenEstimate(blocks) + if got != 300 { + t.Errorf("got %d, want 300", got) + } +} + +func TestImageTokenEstimate_URLImage(t *testing.T) { + blocks := []types.ContentBlock{ + {Type: "image", Source: &types.ImageSource{URL: "https://example.com/img.png"}}, + } + got := imageTokenEstimate(blocks) + if got != 1500 { + t.Errorf("got %d, want 1500", got) + } +} + +func TestImageTokenEstimate_MultipleImages(t *testing.T) { + blocks := []types.ContentBlock{ + {Type: "image", Source: &types.ImageSource{URL: "https://example.com/a.png"}}, + {Type: "image", Source: &types.ImageSource{URL: "https://example.com/b.png"}}, + } + got := imageTokenEstimate(blocks) + if got != 3000 { + t.Errorf("got %d, want 3000", got) + } +} + +func TestImageTokenEstimate_NilSource(t *testing.T) { + blocks := []types.ContentBlock{ + {Type: "image"}, + } + got := imageTokenEstimate(blocks) + if got != 0 { + t.Errorf("got %d, want 0", got) + } +} diff --git a/internal/router/capacity.go b/internal/router/capacity.go new file mode 100644 index 0000000..4492082 --- /dev/null +++ b/internal/router/capacity.go @@ -0,0 +1,89 @@ +package router + +import ( + "fmt" + + "github.com/routatic/proxy/internal/config" +) + +const minimumOutputTokens = 256 + +type SkippedModel struct { + ModelID string `json:"model_id"` + Reason string `json:"reason"` +} + +type CapacityDecision struct { + Models []config.ModelConfig + Skipped []SkippedModel + InputTokens int + RequestedMaxTokens int + SelectedMaxTokens int + ContextWindow int + ContextMargin int + NeedsVision bool + NeedsTools bool +} + +func FilterByCapacity(chain []config.ModelConfig, inputTokens int, requestedMaxTokens int, needsVision bool, needsTools bool) (CapacityDecision, error) { + decision := CapacityDecision{ + InputTokens: inputTokens, + RequestedMaxTokens: requestedMaxTokens, + NeedsVision: needsVision, + NeedsTools: needsTools, + } + + for _, raw := range chain { + model := config.ResolveModelConfig(raw) + if needsVision && !model.Vision { + decision.Skipped = append(decision.Skipped, SkippedModel{ModelID: model.ModelID, Reason: "vision_not_supported"}) + continue + } + if needsTools && !config.SupportsTools(model) { + decision.Skipped = append(decision.Skipped, SkippedModel{ModelID: model.ModelID, Reason: "tools_not_supported"}) + continue + } + + sentMax := clampOutputTokens(model, inputTokens, requestedMaxTokens) + if sentMax < minimumOutputTokens { + decision.Skipped = append(decision.Skipped, SkippedModel{ModelID: model.ModelID, Reason: "context_window_exceeded"}) + continue + } + model.MaxTokens = sentMax + if len(decision.Models) == 0 { + decision.SelectedMaxTokens = sentMax + decision.ContextWindow = model.ContextWindow + decision.ContextMargin = model.ContextMargin + } + decision.Models = append(decision.Models, model) + } + + if len(decision.Models) == 0 { + return decision, fmt.Errorf("no eligible model for request capacity") + } + return decision, nil +} + +func clampOutputTokens(model config.ModelConfig, inputTokens int, requestedMaxTokens int) int { + if inputTokens < 0 { + inputTokens = 0 + } + limit := model.MaxTokens + if requestedMaxTokens > 0 && (limit == 0 || requestedMaxTokens < limit) { + limit = requestedMaxTokens + } + if model.MaxOutputTokens > 0 && (limit == 0 || model.MaxOutputTokens < limit) { + limit = model.MaxOutputTokens + } + if model.ContextWindow <= 0 { + return limit + } + remaining := model.ContextWindow - inputTokens - model.ContextMargin + if limit == 0 || remaining < limit { + if remaining < 0 { + return 0 + } + limit = remaining + } + return limit +} diff --git a/internal/router/capacity_test.go b/internal/router/capacity_test.go new file mode 100644 index 0000000..6dca402 --- /dev/null +++ b/internal/router/capacity_test.go @@ -0,0 +1,56 @@ +package router + +import ( + "testing" + + "github.com/routatic/proxy/internal/config" +) + +func TestFilterByCapacitySkipsPrimaryAndUsesEligibleFallback(t *testing.T) { + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "glm-5.1", MaxTokens: 8192}, + {Provider: "opencode-go", ModelID: "deepseek-v4-pro", MaxTokens: 8192}, + } + + decision, err := FilterByCapacity(chain, 250000, 8192, false, false) + if err != nil { + t.Fatalf("FilterByCapacity() error = %v", err) + } + if got, want := decision.Models[0].ModelID, "deepseek-v4-pro"; got != want { + t.Fatalf("selected model = %s, want %s", got, want) + } + if len(decision.Skipped) != 1 || decision.Skipped[0].Reason != "context_window_exceeded" { + t.Fatalf("skipped = %+v, want context skip", decision.Skipped) + } +} + +func TestFilterByCapacityRejectsVisionFallbackToTextModel(t *testing.T) { + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "deepseek-v4-pro", MaxTokens: 8192}, + } + + decision, err := FilterByCapacity(chain, 1000, 8192, true, false) + if err == nil { + t.Fatal("FilterByCapacity() error = nil, want error") + } + if len(decision.Models) != 0 { + t.Fatalf("eligible models = %+v, want none", decision.Models) + } + if len(decision.Skipped) != 1 || decision.Skipped[0].Reason != "vision_not_supported" { + t.Fatalf("skipped = %+v, want vision skip", decision.Skipped) + } +} + +func TestFilterByCapacityClampsMaxTokens(t *testing.T) { + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "kimi-k2.6", MaxTokens: 16384}, + } + + decision, err := FilterByCapacity(chain, 240000, 16384, true, false) + if err != nil { + t.Fatalf("FilterByCapacity() error = %v", err) + } + if got, want := decision.Models[0].MaxTokens, 256000-240000-config.DefaultContextMargin; got != want { + t.Fatalf("max_tokens = %d, want %d", got, want) + } +} diff --git a/internal/router/model_router.go b/internal/router/model_router.go index de81f9c..a5122d7 100644 --- a/internal/router/model_router.go +++ b/internal/router/model_router.go @@ -38,9 +38,9 @@ type RouteResult struct { // resolveRequestedModel checks if the user-specified model should override // scenario-based routing. Returns the route result and true if it matched, // or zero value and false if scenario routing should proceed normally. -func (r *ModelRouter) resolveRequestedModel(cfg *config.Config, requestedModel string) (RouteResult, bool) { +func (r *ModelRouter) resolveRequestedModel(cfg *config.Config, requestedModel string, needsVision bool) (RouteResult, bool, error) { if !isRespectRequestedModel(cfg) || requestedModel == "" { - return RouteResult{}, false + return RouteResult{}, false, nil } // Look up the requested model in config to inherit its settings @@ -56,6 +56,10 @@ func (r *ModelRouter) resolveRequestedModel(cfg *config.Config, requestedModel s primary.MaxTokens = def.MaxTokens } } + primary = config.ResolveModelConfig(primary) + if needsVision && !primary.Vision { + return RouteResult{}, false, fmt.Errorf("requested model %s does not support vision", primary.ModelID) + } fallbacks := cfg.Fallbacks["default"] @@ -63,15 +67,18 @@ func (r *ModelRouter) resolveRequestedModel(cfg *config.Config, requestedModel s Primary: primary, Fallbacks: fallbacks, Scenario: ScenarioDefault, - }, true + }, true, nil } // Route determines which model to use for a request. // If respect_requested_model is enabled and requestedModel is provided, it overrides scenario-based routing. func (r *ModelRouter) Route(messages []MessageContent, tokenCount int, requestedModel string) (RouteResult, error) { cfg := r.atomic.Get() + facts := AnalyzeRequestFacts(messages) - if result, ok := r.resolveRequestedModel(cfg, requestedModel); ok { + if result, ok, err := r.resolveRequestedModel(cfg, requestedModel, facts.NeedsVision); err != nil { + return RouteResult{}, err + } else if ok { return result, nil } @@ -81,6 +88,9 @@ func (r *ModelRouter) Route(messages []MessageContent, tokenCount int, requested // Get primary model for scenario primary, ok := cfg.Models[string(result.Scenario)] if !ok { + if isVisionScenario(result.Scenario) { + return RouteResult{}, fmt.Errorf("vision scenario %s is not configured", result.Scenario) + } // Fall back to default if scenario model not configured primary, ok = cfg.Models["default"] if !ok { @@ -91,6 +101,9 @@ func (r *ModelRouter) Route(messages []MessageContent, tokenCount int, requested // Get fallbacks for scenario fallbacks := cfg.Fallbacks[string(result.Scenario)] if len(fallbacks) == 0 { + if isVisionScenario(result.Scenario) { + return RouteResult{}, fmt.Errorf("vision scenario %s has no configured vision fallbacks", result.Scenario) + } // Fall back to default fallbacks fallbacks = cfg.Fallbacks["default"] } @@ -151,7 +164,7 @@ func (rr *RouteResult) GetModelChain() []config.ModelConfig { func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount int, requestedModel string) (RouteResult, error) { cfg := r.atomic.Get() - if result, ok := r.resolveRequestedModel(cfg, requestedModel); ok { + if result, ok, err := r.resolveRequestedModel(cfg, requestedModel, false); err == nil && ok { return result, nil } @@ -161,6 +174,9 @@ func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount in // Get primary model for scenario primary, ok := cfg.Models[string(result.Scenario)] if !ok { + if isVisionScenario(result.Scenario) { + return RouteResult{Scenario: result.Scenario}, fmt.Errorf("vision scenario %s is not configured", result.Scenario) + } // Fall back to fast scenario if not configured primary, ok = cfg.Models["fast"] if !ok { @@ -175,8 +191,12 @@ func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount in // Get fallbacks for scenario fallbacks := cfg.Fallbacks[string(result.Scenario)] if len(fallbacks) == 0 { - // Fall back to fast fallbacks - fallbacks = cfg.Fallbacks["fast"] + if isVisionScenario(result.Scenario) { + fallbacks = nil + } else { + // Fall back to fast fallbacks + fallbacks = cfg.Fallbacks["fast"] + } } return RouteResult{ @@ -185,3 +205,7 @@ func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount in Scenario: result.Scenario, }, nil } + +func isVisionScenario(s Scenario) bool { + return s == ScenarioVision || s == ScenarioVisionComplex || s == ScenarioVisionLongContext +} diff --git a/internal/router/model_router_test.go b/internal/router/model_router_test.go index 18bb170..fbfa981 100644 --- a/internal/router/model_router_test.go +++ b/internal/router/model_router_test.go @@ -229,7 +229,10 @@ func TestResolveRequestedModel_UsesFallbacks(t *testing.T) { router := NewModelRouter(newTestAtomicConfig(cfg)) - result, ok := router.resolveRequestedModel(cfg, "kimi-k2.6") + result, ok, err := router.resolveRequestedModel(cfg, "kimi-k2.6", false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if !ok { t.Fatal("expected resolveRequestedModel to match") } diff --git a/internal/router/scenarios.go b/internal/router/scenarios.go index 987f7f7..52beb78 100644 --- a/internal/router/scenarios.go +++ b/internal/router/scenarios.go @@ -11,13 +11,16 @@ import ( type Scenario string const ( - ScenarioDefault Scenario = "default" - ScenarioBackground Scenario = "background" - ScenarioThink Scenario = "think" - ScenarioComplex Scenario = "complex" - ScenarioLongContext Scenario = "long_context" - ScenarioFast Scenario = "fast" - ScenarioOverride Scenario = "override" + ScenarioDefault Scenario = "default" + ScenarioBackground Scenario = "background" + ScenarioThink Scenario = "think" + ScenarioComplex Scenario = "complex" + ScenarioLongContext Scenario = "long_context" + ScenarioFast Scenario = "fast" + ScenarioOverride Scenario = "override" + ScenarioVision Scenario = "vision" + ScenarioVisionComplex Scenario = "vision_complex" + ScenarioVisionLongContext Scenario = "vision_long_context" ) // ScenarioResult contains the detected scenario and token count. @@ -29,8 +32,19 @@ type ScenarioResult struct { // MessageContent represents a single message in a conversation. type MessageContent struct { - Role string - Content string + Role string + Content string + HasImage bool + ImageHashes []string +} + +type RequestFacts struct { + LatestUserText string + LatestUserHasImage bool + AnyHistoricalImage bool + LatestTextVisualIntent bool + LatestTextComplexIntent bool + NeedsVision bool } // DetectScenario analyzes a request to determine which model to use. @@ -43,9 +57,17 @@ type MessageContent struct { // // For streaming requests, consider using RouteForStreaming() to prefer faster models. func DetectScenario(messages []MessageContent, tokenCount int, cfg *config.Config) ScenarioResult { + facts := AnalyzeRequestFacts(messages) // 1. Check for long context first (most important) threshold := getLongContextThreshold(cfg) if tokenCount > threshold { + if facts.NeedsVision { + return ScenarioResult{ + Scenario: ScenarioVisionLongContext, + TokenCount: tokenCount, + Reason: fmt.Sprintf("image request token count %d exceeds threshold %d", tokenCount, threshold), + } + } return ScenarioResult{ Scenario: ScenarioLongContext, TokenCount: tokenCount, @@ -53,8 +75,24 @@ func DetectScenario(messages []MessageContent, tokenCount int, cfg *config.Confi } } + if facts.NeedsVision { + if facts.LatestTextComplexIntent { + return ScenarioResult{ + Scenario: ScenarioVisionComplex, + TokenCount: tokenCount, + Reason: "complex image request detected", + } + } + return ScenarioResult{ + Scenario: ScenarioVision, + TokenCount: tokenCount, + Reason: "simple image request detected", + } + } + // 2. Check for complex tasks (architectural OR tool-related) - if hasComplexPattern(messages) { + latestUser := latestUserMessages(messages) + if hasComplexPattern(latestUser) { return ScenarioResult{ Scenario: ScenarioComplex, TokenCount: tokenCount, @@ -63,7 +101,7 @@ func DetectScenario(messages []MessageContent, tokenCount int, cfg *config.Confi } // 3. Check for thinking/reasoning patterns - if hasThinkingPattern(messages) { + if hasThinkingPattern(latestUser) { return ScenarioResult{ Scenario: ScenarioThink, TokenCount: tokenCount, @@ -88,6 +126,78 @@ func DetectScenario(messages []MessageContent, tokenCount int, cfg *config.Confi } } +func AnalyzeRequestFacts(messages []MessageContent) RequestFacts { + facts := RequestFacts{} + latestIdx := -1 + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + latestIdx = i + break + } + } + if latestIdx == -1 { + return facts + } + + latest := messages[latestIdx] + facts.LatestUserText = latest.Content + facts.LatestUserHasImage = latest.HasImage && imageHashesAreNewForLatest(messages, latestIdx) + facts.LatestTextVisualIntent = hasVisualIntent(latest.Content) + facts.LatestTextComplexIntent = hasComplexPattern([]MessageContent{latest}) || hasThinkingPattern([]MessageContent{latest}) + + for i, msg := range messages { + if i != latestIdx && msg.Role == "user" && msg.HasImage { + facts.AnyHistoricalImage = true + break + } + } + + facts.NeedsVision = facts.LatestUserHasImage || (facts.AnyHistoricalImage && facts.LatestTextVisualIntent) + return facts +} + +func imageHashesAreNewForLatest(messages []MessageContent, latestIdx int) bool { + latest := messages[latestIdx] + if len(latest.ImageHashes) == 0 { + return latest.HasImage + } + seen := map[string]bool{} + for i := 0; i < latestIdx; i++ { + for _, hash := range messages[i].ImageHashes { + seen[hash] = true + } + } + for _, hash := range latest.ImageHashes { + if !seen[hash] { + return true + } + } + return false +} + +func latestUserMessages(messages []MessageContent) []MessageContent { + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + return []MessageContent{messages[i]} + } + } + return nil +} + +func hasVisualIntent(content string) bool { + visualKeywords := []string{ + "image", "screenshot", "screen", "schermata", "immagine", "foto", + "allegato", "[image", "vedi", "visual", "ui", "layout", + } + lower := strings.ToLower(content) + for _, kw := range visualKeywords { + if strings.Contains(lower, kw) { + return true + } + } + return false +} + // hasComplexPattern looks for complex operations that need more capable models. // This includes tool-based operations (executing functions, writing/editing files, etc.) func hasComplexPattern(messages []MessageContent) bool { @@ -97,6 +207,7 @@ func hasComplexPattern(messages []MessageContent) bool { "complex", "difficult", "challenging", "optimize", "performance", "efficiency", "design pattern", "best practice", + "bug", "debug", "error", "exception", "stack trace", // Tool-related keywords indicate complex operations "execute", "run command", "bash", "shell", "implement", "build", "create", "add feature", @@ -196,11 +307,19 @@ func getLongContextThreshold(cfg *config.Config) int { // For streaming, we prioritize fast TTFT (time-to-first-token) over capability. // This may return a less capable model but one that streams faster. func RouteForStreaming(messages []MessageContent, tokenCount int, cfg *config.Config) ScenarioResult { + facts := AnalyzeRequestFacts(messages) // For streaming, use simpler models that have better TTFT // Complex models (GLM, Kimi) are too slow for streaming with many tools threshold := getLongContextThreshold(cfg) if tokenCount > threshold { + if facts.NeedsVision { + return ScenarioResult{ + Scenario: ScenarioVisionLongContext, + TokenCount: tokenCount, + Reason: fmt.Sprintf("high token count image request (%d > %d)", tokenCount, threshold), + } + } model := "long_context" if cfg != nil { if lc, ok := cfg.Models["long_context"]; ok && lc.ModelID != "" { @@ -214,7 +333,23 @@ func RouteForStreaming(messages []MessageContent, tokenCount int, cfg *config.Co } } - if hasComplexPattern(messages) || hasThinkingPattern(messages) { + if facts.NeedsVision { + if facts.LatestTextComplexIntent { + return ScenarioResult{ + Scenario: ScenarioVisionComplex, + TokenCount: tokenCount, + Reason: "complex image request detected", + } + } + return ScenarioResult{ + Scenario: ScenarioVision, + TokenCount: tokenCount, + Reason: "simple image request detected", + } + } + + latestUser := latestUserMessages(messages) + if hasComplexPattern(latestUser) || hasThinkingPattern(latestUser) { // Complex request but streaming - downgrade to faster model // GLM-5 and Kimi are too slow for streaming with complex prompts return ScenarioResult{ diff --git a/internal/router/scenarios_test.go b/internal/router/scenarios_test.go index 2edb6b2..e76f4f7 100644 --- a/internal/router/scenarios_test.go +++ b/internal/router/scenarios_test.go @@ -113,6 +113,118 @@ func TestDetectScenario_LongContextTakesPriority(t *testing.T) { } } +func TestDetectScenario_VisionSimpleRequest(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Describe this screen", HasImage: true}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioVision { + t.Errorf("Expected ScenarioVision, got %s", result.Scenario) + } +} + +func TestDetectScenario_VisionComplexRequest(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Analyze this screenshot and find the bug", HasImage: true}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioVisionComplex { + t.Errorf("Expected ScenarioVisionComplex, got %s", result.Scenario) + } +} + +func TestDetectScenario_VisionUsesLatestImageRequestComplexity(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Analyze this previous architecture"}, + {Role: "assistant", Content: "Done"}, + {Role: "user", Content: "Cosa vedi?", HasImage: true}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioVision { + t.Errorf("Expected ScenarioVision, got %s", result.Scenario) + } +} + +func TestDetectScenario_ReturnsToTextRoutingAfterImageTurn(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Analyze this screenshot and find the bug", HasImage: true}, + {Role: "assistant", Content: "Done"}, + {Role: "user", Content: "Refactor this code"}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioComplex { + t.Errorf("Expected ScenarioComplex, got %s", result.Scenario) + } +} + +func TestDetectScenario_ReturnsToTextRoutingWhenLatestTurnHasHistoricalImageOnly(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Cosa vedi?", HasImage: true, ImageHashes: []string{"img1"}}, + {Role: "assistant", Content: "Vedo una schermata."}, + {Role: "user", Content: "ci sei?", HasImage: true, ImageHashes: []string{"img1"}}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioDefault { + t.Errorf("Expected ScenarioDefault, got %s", result.Scenario) + } +} + +func TestDetectScenario_UsesHistoricalImageWhenLatestTextHasVisualIntent(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Cosa vedi?", HasImage: true, ImageHashes: []string{"img1"}}, + {Role: "assistant", Content: "Vedo una schermata."}, + {Role: "user", Content: "cosa vedi nello screenshot?", HasImage: true, ImageHashes: []string{"img1"}}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioVision { + t.Errorf("Expected ScenarioVision, got %s", result.Scenario) + } +} + +func TestDetectScenario_DebugWithoutVisualIntentStaysTextComplex(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Cosa vedi?", HasImage: true, ImageHashes: []string{"img1"}}, + {Role: "assistant", Content: "Vedo una schermata."}, + {Role: "user", Content: "debug questo codice", HasImage: false}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioComplex { + t.Errorf("Expected ScenarioComplex, got %s", result.Scenario) + } +} + +func TestRouteForStreaming_ReturnsToFastAfterImageTurn(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Cosa vedi?", HasImage: true}, + {Role: "assistant", Content: "Done"}, + {Role: "user", Content: "Hello"}, + } + result := RouteForStreaming(messages, 100, mockConfig()) + if result.Scenario != ScenarioFast { + t.Errorf("Expected ScenarioFast, got %s", result.Scenario) + } +} + +func TestDetectScenario_VisionLongContextTakesPriorityOverVisionComplex(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Analyze this screenshot and refactor the code", HasImage: true}, + } + result := DetectScenario(messages, 70000, mockConfig()) + if result.Scenario != ScenarioVisionLongContext { + t.Errorf("Expected ScenarioVisionLongContext, got %s", result.Scenario) + } +} + +func TestRouteForStreaming_VisionComplexKeepsVisionComplexScenario(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Find the bug in this screenshot", HasImage: true}, + } + result := RouteForStreaming(messages, 100, mockConfig()) + if result.Scenario != ScenarioVisionComplex { + t.Errorf("Expected ScenarioVisionComplex, got %s", result.Scenario) + } +} + func TestRouteForStreaming_RespectsConfiguredThreshold(t *testing.T) { messages := []MessageContent{ {Role: "user", Content: "Hello"}, diff --git a/internal/server/server.go b/internal/server/server.go index 2b96b53..f4cc039 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -18,6 +18,7 @@ import ( "github.com/routatic/proxy/internal/metrics" "github.com/routatic/proxy/internal/provider" "github.com/routatic/proxy/internal/router" + "github.com/routatic/proxy/internal/status" "github.com/routatic/proxy/internal/token" ) @@ -58,6 +59,9 @@ func NewServer(atomic *config.AtomicConfig) (*Server, error) { _ = providerRegistry.Register(provider.NewOpenCodeGoProvider(atomic)) _ = providerRegistry.Register(provider.NewOpenCodeZenProvider(atomic)) + // Create status store for the statusline endpoint. + statusStore := status.NewStore(0) + // Create handlers. messagesHandler := handlers.NewMessagesHandler( openCodeClient, @@ -67,7 +71,7 @@ func NewServer(atomic *config.AtomicConfig) (*Server, error) { tokenCounter, metrics, ) - healthHandler := handlers.NewHealthHandler(tokenCounter, fallbackHandler, metrics) + healthHandler := handlers.NewHealthHandler(tokenCounter, fallbackHandler, metrics, statusStore) // Setup router. mux := http.NewServeMux() @@ -76,6 +80,7 @@ func NewServer(atomic *config.AtomicConfig) (*Server, error) { mux.HandleFunc("/v1/messages", messagesHandler.HandleMessages) mux.HandleFunc("/v1/messages/count_tokens", healthHandler.HandleCountTokens) mux.HandleFunc("/health", healthHandler.HandleHealth) + mux.HandleFunc("/statusline", healthHandler.HandleStatusline) // Create HTTP server. addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) diff --git a/internal/status/store.go b/internal/status/store.go new file mode 100644 index 0000000..e58c174 --- /dev/null +++ b/internal/status/store.go @@ -0,0 +1,114 @@ +package status + +import ( + "sync" + "time" + + "github.com/routatic/proxy/internal/buildinfo" + "github.com/routatic/proxy/internal/router" +) + +type Snapshot struct { + SchemaVersion int `json:"schema_version"` + UpdatedAt string `json:"updated_at"` + AgeMS int64 `json:"age_ms"` + Source string `json:"source"` + Stale bool `json:"stale"` + Proxy ProxySnapshot `json:"proxy"` + Request RequestSnapshot `json:"request"` + Routing RoutingSnapshot `json:"routing"` + Context ContextSnapshot `json:"context"` + Models ModelsSnapshot `json:"models"` +} + +type ProxySnapshot struct { + Version string `json:"version"` + PID int `json:"pid"` + Binary string `json:"binary"` +} + +type RequestSnapshot struct { + RequestID string `json:"request_id"` + Streaming bool `json:"streaming"` +} + +type RoutingSnapshot struct { + Scenario string `json:"scenario"` + ModelID string `json:"model_id"` +} + +type ContextSnapshot struct { + InputTokens int `json:"input_tokens"` + MaxTokens int `json:"max_tokens"` + Percent int `json:"pct"` +} + +type ModelsSnapshot struct { + SkippedFallbacks []router.SkippedModel `json:"skipped_fallbacks,omitempty"` +} + +type Store struct { + mu sync.RWMutex + seq uint64 + updated time.Time + snapshot Snapshot + ttl time.Duration +} + +func NewStore(ttl time.Duration) *Store { + if ttl <= 0 { + ttl = 10 * time.Second + } + return &Store{ttl: ttl} +} + +func (s *Store) Update(seq uint64, snap Snapshot) { + s.mu.Lock() + defer s.mu.Unlock() + if seq < s.seq { + return + } + now := time.Now().UTC() + s.seq = seq + s.updated = now + snap.SchemaVersion = 1 + snap.UpdatedAt = now.Format(time.RFC3339Nano) + snap.AgeMS = 0 + snap.Source = "proxy" + snap.Proxy = ProxySnapshot{ + Version: buildinfo.Version, + PID: buildinfo.PID(), + Binary: buildinfo.BinaryPath(), + } + s.snapshot = snap +} + +func (s *Store) Snapshot() Snapshot { + s.mu.RLock() + defer s.mu.RUnlock() + snap := s.snapshot + + // Deep-copy the SkippedFallbacks slice to avoid sharing the backing + // array with a concurrent Update(). + if len(snap.Models.SkippedFallbacks) > 0 { + skipped := make([]router.SkippedModel, len(snap.Models.SkippedFallbacks)) + copy(skipped, snap.Models.SkippedFallbacks) + snap.Models.SkippedFallbacks = skipped + } + + if s.updated.IsZero() { + snap.SchemaVersion = 1 + snap.Source = "empty" + snap.Stale = true + snap.Proxy = ProxySnapshot{ + Version: buildinfo.Version, + PID: buildinfo.PID(), + Binary: buildinfo.BinaryPath(), + } + return snap + } + age := time.Since(s.updated) + snap.AgeMS = age.Milliseconds() + snap.Stale = age > s.ttl + return snap +} diff --git a/internal/token/counter.go b/internal/token/counter.go index db90472..4c49900 100644 --- a/internal/token/counter.go +++ b/internal/token/counter.go @@ -52,8 +52,9 @@ func (c *Counter) CountTokens(text string) (int, error) { // MessageContent represents a single message in a conversation. type MessageContent struct { - Role string - Content string + Role string + Content string + ExtraTokens int } // CountMessages counts tokens in a message array. @@ -76,6 +77,7 @@ func (c *Counter) CountMessages(system string, messages []MessageContent) (int, return 0, err } total += msgTokens + 5 // Per-message overhead + total += msg.ExtraTokens } return total, nil diff --git a/internal/token/counter_test.go b/internal/token/counter_test.go index cce9337..609c2ca 100644 --- a/internal/token/counter_test.go +++ b/internal/token/counter_test.go @@ -3,6 +3,7 @@ package token import ( "os" "path/filepath" + "runtime" "testing" ) @@ -52,6 +53,9 @@ func TestDefaultCacheDir(t *testing.T) { } func TestDefaultCacheDir_HomeDirFallback(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Windows UserHomeDir does not depend only on HOME") + } // When UserHomeDir fails (HOME unset), fall back to temp dir. t.Setenv("TIKTOKEN_CACHE_DIR", "") t.Setenv("DATA_GYM_CACHE_DIR", "") diff --git a/internal/transformer/response_test.go b/internal/transformer/response_test.go index ae629fb..acbe5f2 100644 --- a/internal/transformer/response_test.go +++ b/internal/transformer/response_test.go @@ -1,6 +1,7 @@ package transformer import ( + "encoding/json" "testing" "github.com/routatic/proxy/pkg/types" @@ -202,6 +203,43 @@ func TestTransformResponseNoReasoningContent(t *testing.T) { } } +func TestTransformResponseExtractsTextFromContentParts(t *testing.T) { + transformer := NewResponseTransformer() + + resp := &types.ChatCompletionResponse{ + ID: "chatcmpl_parts", + Object: "chat.completion", + Created: 1234567890, + Model: "qwen3.6-plus", + Choices: []types.Choice{ + { + Index: 0, + Message: types.ChatMessage{ + Role: "assistant", + Content: json.RawMessage(`[{"type":"text","text":"Vedo uno screenshot."}]`), + }, + FinishReason: "stop", + }, + }, + Usage: types.UsageInfo{ + PromptTokens: 10, + CompletionTokens: 5, + TotalTokens: 15, + }, + } + + anthropicResp, err := transformer.TransformResponse(resp, "qwen3.6-plus") + if err != nil { + t.Fatalf("TransformResponse() error = %v", err) + } + if got, want := len(anthropicResp.Content), 1; got != want { + t.Fatalf("len(Content) = %d, want %d", got, want) + } + if got, want := anthropicResp.Content[0].Text, "Vedo uno screenshot."; got != want { + t.Fatalf("Content[0].Text = %q, want %q", got, want) + } +} + func TestTransformResponseWithCacheTokens(t *testing.T) { transformer := NewResponseTransformer() diff --git a/internal/transformer/stream.go b/internal/transformer/stream.go index 9f756ba..b581459 100644 --- a/internal/transformer/stream.go +++ b/internal/transformer/stream.go @@ -53,6 +53,111 @@ func NewStreamHandler() *StreamHandler { } } +// EmitMessageResponse synthesizes an Anthropic-format SSE stream from a non-streaming +// MessageResponse. This is used for vision scenarios where the upstream model does not +// support streaming — the proxy fetches the full response, then emits it as SSE events +// so the client's streaming contract is preserved. +func (h *StreamHandler) EmitMessageResponse(w http.ResponseWriter, resp *types.MessageResponse) error { + flusher, ok := w.(http.Flusher) + if !ok { + return fmt.Errorf("streaming not supported by response writer") + } + if resp == nil { + return fmt.Errorf("nil message response") + } + msgStart := types.MessageEvent{ + Type: "message_start", + Message: resp, + } + if err := writeSSEEvent(w, msgStart); err != nil { + return ErrClientDisconnected + } + flusher.Flush() + + for i, block := range resp.Content { + idx := i + startBlock := block + switch block.Type { + case "text": + startBlock.Text = "" + case "thinking": + startBlock.Thinking = "" + case "tool_use": + startBlock.Input = json.RawMessage(`{}`) + } + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &startBlock, + }); err != nil { + return ErrClientDisconnected + } + switch block.Type { + case "text": + if block.Text != "" { + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "content_block_delta", + Index: &idx, + Delta: &types.Delta{Type: "text_delta", Text: block.Text}, + }); err != nil { + return ErrClientDisconnected + } + } + case "thinking": + if block.Thinking != "" { + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "content_block_delta", + Index: &idx, + Delta: &types.Delta{Type: "thinking_delta", Thinking: block.Thinking}, + }); err != nil { + return ErrClientDisconnected + } + } + case "tool_use": + if len(block.Input) > 0 { + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "content_block_delta", + Index: &idx, + Delta: &types.Delta{Type: "input_json_delta", PartialJSON: string(block.Input)}, + }); err != nil { + return ErrClientDisconnected + } + } + } + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "content_block_stop", + Index: &idx, + }); err != nil { + return ErrClientDisconnected + } + flusher.Flush() + } + + stopReason := resp.StopReason + if stopReason == "" { + stopReason = "end_turn" + } + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "message_delta", + Delta: &types.Delta{ + StopReason: stopReason, + }, + Usage: &types.Usage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + CacheCreationInputTokens: resp.Usage.CacheCreationInputTokens, + CacheReadInputTokens: resp.Usage.CacheReadInputTokens, + }, + }); err != nil { + return ErrClientDisconnected + } + if err := writeSSEEvent(w, types.MessageEvent{Type: "message_stop"}); err != nil { + return ErrClientDisconnected + } + flusher.Flush() + return nil +} + // ProxyStream takes an OpenAI streaming response and writes Anthropic-format SSE to the writer. // It reads OpenAI ChatCompletionChunk SSE events and transforms them into Anthropic MessageEvent SSE events. // The streamCtx is the per-model attempt context (carries streaming_timeout_ms); the caller diff --git a/internal/transformer/stream_test.go b/internal/transformer/stream_test.go index f0a78bf..fd143b2 100644 --- a/internal/transformer/stream_test.go +++ b/internal/transformer/stream_test.go @@ -9,6 +9,7 @@ import ( "net/http" "strings" "testing" + "time" "github.com/routatic/proxy/pkg/types" ) @@ -63,6 +64,42 @@ func parseSSEEvents(t *testing.T, raw string) []types.MessageEvent { return events } +func TestEmitMessageResponse_SynthesizesAnthropicSSE(t *testing.T) { + handler := NewStreamHandler() + w := newMockResponseWriter() + resp := &types.MessageResponse{ + ID: "msg_test", + Type: "message", + Role: "assistant", + Model: "qwen3.6-plus", + StopReason: "end_turn", + Content: []types.ContentBlock{ + {Type: "text", Text: "Vedo uno screenshot."}, + }, + Usage: types.Usage{InputTokens: 10, OutputTokens: 4}, + } + + if err := handler.EmitMessageResponse(w, resp); err != nil { + t.Fatalf("EmitMessageResponse error: %v", err) + } + events := parseSSEEvents(t, w.buf.String()) + if len(events) != 6 { + t.Fatalf("events = %d, want 6: %+v", len(events), events) + } + if events[0].Type != "message_start" { + t.Fatalf("event[0] = %s, want message_start", events[0].Type) + } + if events[2].Type != "content_block_delta" || events[2].Delta.Type != "text_delta" { + t.Fatalf("event[2] = %+v, want text_delta", events[2]) + } + if got, want := events[2].Delta.Text, "Vedo uno screenshot."; got != want { + t.Fatalf("text delta = %q, want %q", got, want) + } + if events[4].Type != "message_delta" || events[5].Type != "message_stop" { + t.Fatalf("tail events = %+v %+v, want message_delta/message_stop", events[4], events[5]) + } +} + func TestProxyStream_ReasoningContentFastPath(t *testing.T) { handler := NewStreamHandler() w := newMockResponseWriter() @@ -209,6 +246,36 @@ func TestProxyStream_TextOnlyStillWorks(t *testing.T) { } } +func TestProxyStream_ContentArrayTextDelta(t *testing.T) { + handler := NewStreamHandler() + w := newMockResponseWriter() + body := sseLines( + `{"choices":[{"delta":{"content":[{"type":"text","text":"Vedo uno screenshot."}]}}]}`, + `{"choices":[{"delta":{},"finish_reason":"stop"}]}`, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := handler.ProxyStream(w, body, "qwen3.6-plus", ctx, 5*time.Second, cancel); err != nil { + t.Fatalf("ProxyStream error: %v", err) + } + + events := parseSSEEvents(t, w.buf.String()) + if len(events) != 6 { + t.Fatalf("expected 6 events, got %d: %+v", len(events), events) + } + if events[1].Type != "content_block_start" || events[1].ContentBlock == nil || events[1].ContentBlock.Type != "text" { + t.Errorf("event[1] = %+v, want content_block_start(text)", events[1]) + } + if events[2].Type != "content_block_delta" || events[2].Delta.Type != "text_delta" { + t.Errorf("event[2] = %+v, want content_block_delta(text_delta)", events[2]) + } + if got, want := events[2].Delta.Text, "Vedo uno screenshot."; got != want { + t.Errorf("event[2].Delta.Text = %q, want %q", got, want) + } +} + func TestProxyStream_UsageOnlyChunk(t *testing.T) { handler := NewStreamHandler() w := newMockResponseWriter() diff --git a/pkg/types/anthropic.go b/pkg/types/anthropic.go index fed8404..4d72e35 100644 --- a/pkg/types/anthropic.go +++ b/pkg/types/anthropic.go @@ -236,6 +236,7 @@ type ImageSource struct { Type string `json:"type"` MediaType string `json:"media_type"` Data string `json:"data"` + URL string `json:"url,omitempty"` } // Tool represents a tool definition for function calling.