diff --git a/CONFIGURATION.md b/CONFIGURATION.md index 11c6b36..dedad4a 100644 --- a/CONFIGURATION.md +++ b/CONFIGURATION.md @@ -243,3 +243,49 @@ When a request arrives, the proxy selects a model chain using the following orde 3. **Scenario routing** — fall back to the scenario chain (`default`, `background`, `think`, `complex`, `long_context`, `fast`). > **Trust model:** any client whose requests flow through the proxy can select from the configured `model_overrides` set without additional authentication. If you run the proxy as a shared service, treat `model_overrides` as a privileged allowlist. + +### Streaming Scenario Routing + +`enable_streaming_scenario_routing` controls whether streaming requests are evaluated by the full scenario router or routed directly to the `fast` scenario. + +> **Note for Claude Code `/review-code`, `/ultracode`, and multi-agent workflows** +> +> If you use Claude Code workflows that dispatch many subagents or produce many parallel tool calls, enable streaming scenario routing: +> +> ```json +> { +> "enable_streaming_scenario_routing": true +> } +> ``` +> +> Without this option, streaming requests are routed through the `fast` scenario even when the request is actually tool-heavy. This can route complex Claude Code workloads, such as `/review-code` with many `Agent` tool calls, to a fast model that may not handle parallel tool-call orchestration reliably. +> +> When enabled, streaming requests are evaluated by the same scenario router as non-streaming requests, allowing large or tool-heavy workloads to use `complex` or `long_context` models instead of always using the `fast` model. + +Recommended setup for Claude Code review workflows: + +```json +{ + "enable_streaming_scenario_routing": true, + "models": { + "fast": { + "provider": "opencode-go", + "model_id": "deepseek-v4-flash", + "max_tokens": 4096 + }, + "complex": { + "provider": "opencode-go", + "model_id": "minimax-m3", + "max_tokens": 8192 + }, + "long_context": { + "provider": "opencode-go", + "model_id": "minimax-m3", + "max_tokens": 16384, + "context_threshold": 80000 + } + } +} +``` + +Use the `fast` scenario for short/simple requests. Use `complex` or `long_context` for code review, multi-agent dispatch, large diffs, many tools, or long-context Claude Code sessions. diff --git a/README.md b/README.md index ac97298..d19786f 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ OpenCode Go gives you access to powerful open coding models for **$5/month** (th - **Transparent Proxy** — Claude Code sends Anthropic-format requests, proxy transforms to OpenAI/Responses/Gemini format and back - **Dual Provider Support** — Route models through OpenCode Go or OpenCode Zen based on your needs - **Model Routing** — Automatically routes to different models based on context (default, thinking, long context, background) +- **Streaming Scenario Routing** — Configurable routing for streaming requests; enables proper scenario selection for Claude Code multi-agent and review workflows (see [CONFIGURATION.md](CONFIGURATION.md#streaming-scenario-routing)) - **Fallback Chains** — If a model fails, automatically tries the next one in your configured chain - **Circuit Breaker** — Tracks model health and skips failing models to avoid latency spikes - **Real-time Streaming** — Full SSE streaming with live format transformation diff --git a/configs/config.example.json b/configs/config.example.json index ea3e02f..634bbe9 100644 --- a/configs/config.example.json +++ b/configs/config.example.json @@ -185,7 +185,8 @@ "opencode_go": { "base_url": "https://opencode.ai/zen/go/v1/chat/completions", "anthropic_base_url": "https://opencode.ai/zen/go/v1/messages", - "timeout_ms": 300000 + "timeout_ms": 300000, + "streaming_timeout_ms": 600000 }, "opencode_zen": { @@ -193,7 +194,8 @@ "anthropic_base_url": "https://opencode.ai/zen/v1/messages", "responses_base_url": "https://opencode.ai/zen/v1/responses", "gemini_base_url": "https://opencode.ai/zen/v1/models", - "timeout_ms": 300000 + "timeout_ms": 300000, + "streaming_timeout_ms": 600000 }, "logging": { diff --git a/internal/client/opencode.go b/internal/client/opencode.go index 756f1d0..d399efe 100644 --- a/internal/client/opencode.go +++ b/internal/client/opencode.go @@ -40,14 +40,8 @@ func (c *OpenCodeClient) nextAPIKey(keys []string) string { return keys[(old-1)%n] } -// NewOpenCodeClient creates a new OpenCode client. +// NewOpenCodeClient creates a client that relies on request contexts for timeouts. func NewOpenCodeClient(atomic *config.AtomicConfig) *OpenCodeClient { - cfg := atomic.Get() - timeout := time.Duration(cfg.OpenCodeGo.TimeoutMs) * time.Millisecond - if timeout == 0 { - timeout = 5 * time.Minute - } - transport := &http.Transport{ MaxIdleConns: 100, MaxIdleConnsPerHost: 20, @@ -60,12 +54,48 @@ func NewOpenCodeClient(atomic *config.AtomicConfig) *OpenCodeClient { return &OpenCodeClient{ atomic: atomic, httpClient: &http.Client{ - Timeout: timeout, + Timeout: 0, Transport: transport, }, } } +// RequestTimeout returns the provider timeout for a non-streaming attempt. +func (c *OpenCodeClient) RequestTimeout(model config.ModelConfig) time.Duration { + cfg := c.atomic.Get() + var timeoutMs int + if IsZen(model) { + timeoutMs = cfg.OpenCodeZen.TimeoutMs + } else { + timeoutMs = cfg.OpenCodeGo.TimeoutMs + } + if timeoutMs > 0 { + return time.Duration(timeoutMs) * time.Millisecond + } + return 5 * time.Minute +} + +// StreamingTimeout returns the provider timeout for a streaming attempt. +func (c *OpenCodeClient) StreamingTimeout(model config.ModelConfig) time.Duration { + cfg := c.atomic.Get() + var timeoutMs int + if IsZen(model) { + timeoutMs = cfg.OpenCodeZen.StreamingTimeoutMs + if timeoutMs <= 0 { + timeoutMs = cfg.OpenCodeZen.TimeoutMs + } + } else { + timeoutMs = cfg.OpenCodeGo.StreamingTimeoutMs + if timeoutMs <= 0 { + timeoutMs = cfg.OpenCodeGo.TimeoutMs + } + } + if timeoutMs > 0 { + return time.Duration(timeoutMs) * time.Millisecond + } + return 5 * time.Minute +} + // IsAnthropicModel returns true if the model requires the Anthropic endpoint. // This includes both Go models (minimax, all qwen) and Zen models (claude, qwen3.7-max). func IsAnthropicModel(modelID string) bool { diff --git a/internal/client/opencode_test.go b/internal/client/opencode_test.go index 003072f..bba1ffe 100644 --- a/internal/client/opencode_test.go +++ b/internal/client/opencode_test.go @@ -2,6 +2,7 @@ package client import ( "testing" + "time" "oc-go-cc/internal/config" ) @@ -461,3 +462,136 @@ func TestNextAPIKey_ConcurrentSafety(t *testing.T) { } } } + +func TestRequestTimeout_UsesConfiguredTimeout(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 120000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.RequestTimeout(model) + if timeout != 120*time.Second { + t.Errorf("RequestTimeout = %v, want 120s", timeout) + } +} + +func TestRequestTimeout_FallsBackToDefault(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 0, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.RequestTimeout(model) + if timeout != 5*time.Minute { + t.Errorf("RequestTimeout = %v, want 5m", timeout) + } +} + +func TestRequestTimeout_ZenProvider(t *testing.T) { + cfg := &config.Config{ + OpenCodeZen: config.OpenCodeZenConfig{ + TimeoutMs: 60000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeZen, ModelID: "claude-sonnet-4.5"} + timeout := c.RequestTimeout(model) + if timeout != 60*time.Second { + t.Errorf("RequestTimeout = %v, want 60s", timeout) + } +} + +func TestStreamingTimeout_UsesStreamingTimeoutMs(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 300000, + StreamingTimeoutMs: 600000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.StreamingTimeout(model) + if timeout != 600*time.Second { + t.Errorf("StreamingTimeout = %v, want 600s", timeout) + } +} + +func TestStreamingTimeout_FallsBackToTimeoutMs(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 300000, + StreamingTimeoutMs: 0, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.StreamingTimeout(model) + if timeout != 300*time.Second { + t.Errorf("StreamingTimeout = %v, want 300s (fallback to timeout_ms)", timeout) + } +} + +func TestStreamingTimeout_FallsBackToDefault(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 0, + StreamingTimeoutMs: 0, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.StreamingTimeout(model) + if timeout != 5*time.Minute { + t.Errorf("StreamingTimeout = %v, want 5m", timeout) + } +} + +func TestStreamingTimeout_ZenProvider(t *testing.T) { + cfg := &config.Config{ + OpenCodeZen: config.OpenCodeZenConfig{ + TimeoutMs: 300000, + StreamingTimeoutMs: 600000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeZen, ModelID: "claude-sonnet-4.5"} + timeout := c.StreamingTimeout(model) + if timeout != 600*time.Second { + t.Errorf("StreamingTimeout = %v, want 600s", timeout) + } +} + +func TestStreamingTimeout_SmallConfiguredValue(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 300000, + StreamingTimeoutMs: 100, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.StreamingTimeout(model) + if timeout != 100*time.Millisecond { + t.Errorf("StreamingTimeout = %v, want 100ms", timeout) + } +} diff --git a/internal/config/atomic.go b/internal/config/atomic.go index 2123fb1..4968f88 100644 --- a/internal/config/atomic.go +++ b/internal/config/atomic.go @@ -6,8 +6,7 @@ import ( "sync/atomic" ) -// AtomicConfig provides thread-safe access to the configuration with support -// for hot reloading. It uses atomic.Pointer for lock-free reads. +// AtomicConfig provides thread-safe config access with hot reload support. type AtomicConfig struct { ptr atomic.Pointer[Config] path string @@ -22,15 +21,12 @@ func NewAtomicConfig(cfg *Config, path string) *AtomicConfig { return a } -// Get returns the current configuration pointer. This is safe for concurrent use. -// Callers must not modify the returned Config. +// Get returns the current config pointer. Callers must treat it as read-only. func (a *AtomicConfig) Get() *Config { return a.ptr.Load() } -// Reload reloads the configuration from disk and atomically swaps it in. -// If the reload fails, the old configuration is preserved and an error is returned. -// On successful reload, all registered callbacks are invoked. +// Reload loads the config from disk and swaps it in atomically. func (a *AtomicConfig) Reload() error { old := a.Get() cfg, err := LoadFromPath(a.path) @@ -38,27 +34,33 @@ func (a *AtomicConfig) Reload() error { return err } - // Warn about changes that require a server restart before swapping. + // Warn about settings that take effect differently on reload. if old != nil { if old.Host != cfg.Host || old.Port != cfg.Port { slog.Warn("host/port changed but requires server restart to take effect", "old_host", old.Host, "new_host", cfg.Host, "old_port", old.Port, "new_port", cfg.Port) } - if old.OpenCodeGo.TimeoutMs != cfg.OpenCodeGo.TimeoutMs { - slog.Warn("timeout_ms changed but requires server restart to take effect", - "old_timeout", old.OpenCodeGo.TimeoutMs, - "new_timeout", cfg.OpenCodeGo.TimeoutMs) + // Timeout changes apply on the next request. + if old.OpenCodeGo.TimeoutMs != cfg.OpenCodeGo.TimeoutMs || + old.OpenCodeGo.StreamingTimeoutMs != cfg.OpenCodeGo.StreamingTimeoutMs || + old.OpenCodeZen.TimeoutMs != cfg.OpenCodeZen.TimeoutMs || + old.OpenCodeZen.StreamingTimeoutMs != cfg.OpenCodeZen.StreamingTimeoutMs { + slog.Info("timeout config updated, takes effect immediately", + "go_timeout_ms", cfg.OpenCodeGo.TimeoutMs, + "go_streaming_timeout_ms", cfg.OpenCodeGo.StreamingTimeoutMs, + "zen_timeout_ms", cfg.OpenCodeZen.TimeoutMs, + "zen_streaming_timeout_ms", cfg.OpenCodeZen.StreamingTimeoutMs) } } - // Copy callbacks to avoid holding lock during invocation + // Copy callbacks before invoking them. a.mu.Lock() callbacks := make([]func(*Config), len(a.onReload)) copy(callbacks, a.onReload) a.mu.Unlock() - // Invoke callbacks BEFORE swapping — they may mutate cfg (e.g., port override). + // Callbacks run before the swap so they can adjust cfg. for _, fn := range callbacks { func() { defer func() { @@ -70,7 +72,6 @@ func (a *AtomicConfig) Reload() error { }() } - // Now cfg is fully prepared — safe for concurrent readers. a.ptr.Store(cfg) return nil diff --git a/internal/config/config.go b/internal/config/config.go index c5ae557..01023d4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -34,18 +34,20 @@ type ModelConfig struct { // OpenCodeGoConfig holds the upstream OpenCode Go API settings. type OpenCodeGoConfig struct { - BaseURL string `json:"base_url"` - AnthropicBaseURL string `json:"anthropic_base_url"` - TimeoutMs int `json:"timeout_ms"` + BaseURL string `json:"base_url"` + AnthropicBaseURL string `json:"anthropic_base_url"` + TimeoutMs int `json:"timeout_ms"` + StreamingTimeoutMs int `json:"streaming_timeout_ms,omitempty"` } // OpenCodeZenConfig holds the upstream OpenCode Zen API settings. type OpenCodeZenConfig struct { - BaseURL string `json:"base_url"` - AnthropicBaseURL string `json:"anthropic_base_url"` - ResponsesBaseURL string `json:"responses_base_url"` - GeminiBaseURL string `json:"gemini_base_url"` - TimeoutMs int `json:"timeout_ms"` + BaseURL string `json:"base_url"` + AnthropicBaseURL string `json:"anthropic_base_url"` + ResponsesBaseURL string `json:"responses_base_url"` + GeminiBaseURL string `json:"gemini_base_url"` + TimeoutMs int `json:"timeout_ms"` + StreamingTimeoutMs int `json:"streaming_timeout_ms,omitempty"` } // LoggingConfig controls application logging behavior. diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index 207227a..7d87bcd 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -4,11 +4,12 @@ package handlers import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" "net/http" - "strings" + "sync" "sync/atomic" "time" @@ -38,13 +39,17 @@ type MessagesHandler struct { metrics *metrics.Metrics } -// responseWriter wraps http.ResponseWriter to track if headers were written. +// responseWriter wraps http.ResponseWriter to track if headers were written +// and serialize concurrent writes (heartbeat + stream body copy). type responseWriter struct { http.ResponseWriter + mu sync.Mutex wroteHeader bool } func (w *responseWriter) WriteHeader(code int) { + w.mu.Lock() + defer w.mu.Unlock() if !w.wroteHeader { w.wroteHeader = true w.ResponseWriter.WriteHeader(code) @@ -52,12 +57,22 @@ func (w *responseWriter) WriteHeader(code int) { } func (w *responseWriter) Write(b []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() if !w.wroteHeader { - w.WriteHeader(http.StatusOK) + w.wroteHeader = true + w.ResponseWriter.WriteHeader(http.StatusOK) } return w.ResponseWriter.Write(b) } +// HasWrittenHeader returns true if the response header has been written. +func (w *responseWriter) HasWrittenHeader() bool { + w.mu.Lock() + defer w.mu.Unlock() + return w.wroteHeader +} + // Flush implements http.Flusher for SSE streaming support. func (w *responseWriter) Flush() { if f, ok := w.ResponseWriter.(http.Flusher); ok { @@ -65,6 +80,23 @@ func (w *responseWriter) Flush() { } } +// flushWriter wraps an http.ResponseWriter and calls Flush after every Write, +// ensuring that SSE data from raw passthrough streams is not buffered in the +// net/http bufio.Writer where it would appear hung until the buffer fills. +type flushWriter struct { + http.ResponseWriter +} + +func (w *flushWriter) Write(b []byte) (int, error) { + n, err := w.ResponseWriter.Write(b) + if err == nil { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } + } + return n, err +} + // NewMessagesHandler creates a new messages handler. func NewMessagesHandler( openCodeClient *client.OpenCodeClient, @@ -96,14 +128,12 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) return } - // Generate or get request ID for correlation requestID := r.Header.Get("X-Request-ID") if requestID == "" { requestID = h.requestIDGen.Generate() } w.Header().Set("X-Request-ID", requestID) - // Rate limiting clientIP := middleware.GetClientIP(r) if !h.rateLimiter.Allow(clientIP) { h.metrics.RecordRateLimited() @@ -112,34 +142,29 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) return } - // Read the raw request body for debug logging var rawBody json.RawMessage if err := json.NewDecoder(r.Body).Decode(&rawBody); err != nil { h.sendError(w, http.StatusBadRequest, "invalid request body", err) return } - // Deduplicate - skip duplicate requests if _, ok := h.requestDedup.TryAcquire(rawBody); !ok { h.metrics.RecordDeduplicated() h.logger.Info("duplicate request skipped", "request_id", requestID) return } - // Parse into Anthropic request var anthropicReq types.MessageRequest if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { h.sendError(w, http.StatusBadRequest, "invalid request body", err) return } - // Validate request if err := anthropicReq.Validate(); err != nil { h.sendError(w, http.StatusBadRequest, err.Error(), nil) return } - // Record metrics isStreaming := anthropicReq.Stream != nil && *anthropicReq.Stream h.metrics.RecordRequest(isStreaming) @@ -151,7 +176,6 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) "max_tokens", anthropicReq.MaxTokens, ) - // Build message content for routing and token counting. var routerMessages []router.MessageContent var tokenMessages []token.MessageContent systemText := anthropicReq.SystemText() @@ -170,14 +194,12 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) }) } - // Count tokens. tokenCount, err := h.tokenCounter.CountMessages(systemText, tokenMessages) if err != nil { h.logger.Warn("failed to count tokens", "error", err) tokenCount = 0 } - // Route to appropriate model and build fallback chain. modelChain, routeResult, err := h.buildModelChain(anthropicReq.Model, routerMessages, tokenCount, isStreaming) if err != nil { h.sendError(w, http.StatusInternalServerError, "routing failed", err) @@ -192,22 +214,13 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) ) if isStreaming { - // Streaming: use ProxyStream for real-time SSE transformation h.handleStreaming(w, r, &anthropicReq, modelChain, rawBody) } else { - // Non-streaming: execute with fallback and return full response h.handleNonStreaming(w, r, &anthropicReq, modelChain, rawBody) } } -// buildModelChain resolves the request to a model chain (primary + fallbacks), -// honoring model_overrides (with a deduplicated scenario safety-net) and -// respecting the streaming-scenario-routing toggle. -// -// Precedence: -// 1. If requestedModel matches an entry in model_overrides, use that as the -// primary and append the scenario chain as a deduplicated safety net. -// 2. Otherwise, fall through to scenario-based routing via routeOnce. +// buildModelChain resolves the primary model and fallback chain. func (h *MessagesHandler) buildModelChain( requestedModel string, routerMessages []router.MessageContent, @@ -218,8 +231,6 @@ func (h *MessagesHandler) buildModelChain( if overrideResult, ok := h.modelRouter.RouteWithOverride(requestedModel); ok { scenarioResult, err := h.routeOnce(routerMessages, tokenCount, "", isStreaming) if err != nil { - // Override is valid; surface the scenario routing error rather - // than silently dropping the safety net. return overrideResult.GetModelChain(), overrideResult, err } chain := appendUniqueModels(overrideResult.GetModelChain(), scenarioResult.GetModelChain()) @@ -234,10 +245,7 @@ func (h *MessagesHandler) buildModelChain( return result.GetModelChain(), result, nil } -// routeOnce performs scenario-based routing, honoring the streaming-scenario-routing -// toggle. Pass requestedModel="" to force scenario routing (used for the override -// safety-net chain), or a non-empty value to let resolveRequestedModel kick in -// (only when respect_requested_model is enabled and no override matched). +// routeOnce runs the router once, respecting the streaming toggle. func (h *MessagesHandler) routeOnce( routerMessages []router.MessageContent, tokenCount int, @@ -245,15 +253,12 @@ func (h *MessagesHandler) routeOnce( isStreaming bool, ) (router.RouteResult, error) { if isStreaming && !h.modelRouter.IsStreamingScenarioRoutingEnabled() { - // Streaming: use faster models to minimize TTFT (time-to-first-token) return h.modelRouter.RouteForStreaming(routerMessages, tokenCount, requestedModel), nil } return h.modelRouter.Route(routerMessages, tokenCount, requestedModel) } -// appendUniqueModels appends models from extra to base, skipping any model_id -// already present in base. The first occurrence of a ModelID is kept; later -// duplicates are dropped. Order of the base chain is preserved. +// appendUniqueModels appends models from extra if their model_id is new. func appendUniqueModels(base, extra []config.ModelConfig) []config.ModelConfig { if len(extra) == 0 { return base @@ -272,7 +277,7 @@ func appendUniqueModels(base, extra []config.ModelConfig) []config.ModelConfig { return base } -// handleStreaming handles a streaming request with real-time SSE proxying. +// handleStreaming proxies a streaming request with per-model timeouts. func (h *MessagesHandler) handleStreaming( w http.ResponseWriter, r *http.Request, @@ -284,7 +289,6 @@ func (h *MessagesHandler) handleStreaming( rw := &responseWriter{ResponseWriter: w} - // Set SSE headers immediately w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") @@ -294,9 +298,9 @@ func (h *MessagesHandler) handleStreaming( f.Flush() } - // Start heartbeat var finished int32 heartbeatDone := make(chan struct{}) + heartbeatPaused := new(atomic.Int32) go func() { ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() @@ -307,6 +311,9 @@ func (h *MessagesHandler) handleStreaming( if atomic.LoadInt32(&finished) == 1 { return } + if heartbeatPaused.Load() == 1 { + continue + } _, _ = fmt.Fprintf(rw, ":keepalive\n\n") if f, ok := w.(http.Flusher); ok { f.Flush() @@ -326,27 +333,26 @@ func (h *MessagesHandler) handleStreaming( streamStart := time.Now() for _, model := range modelChain { - select { - case <-clientCtx.Done(): - h.logger.Info("client disconnected, stopping streaming fallbacks") + if err := clientCtx.Err(); err != nil { + h.logger.Info("client disconnected, stopping streaming fallbacks", "error", err) return - default: } h.logger.Info("attempting streaming model", "model", model.ModelID, "provider", model.Provider) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + timeout := h.client.StreamingTimeout(model) + attemptCtx, cancel := context.WithTimeout(clientCtx, timeout) - // Zen models use their own endpoint classification if client.IsZen(model) { endpointType := client.ClassifyEndpoint(model.ModelID) switch endpointType { case client.EndpointAnthropic: modelBody := replaceModelInRawBody(rawBody, model.ModelID) - if err := h.handleAnthropicStreaming(ctx, rw, modelBody, model.ModelID, model); err != nil { + err := h.handleAnthropicStreaming(attemptCtx, rw, modelBody, model.ModelID, model, heartbeatPaused) + if err != nil { cancel() - if clientCtx.Err() == context.Canceled { - h.logger.Info("client disconnected during anthropic stream") + if clientCtx.Err() != nil { + h.logger.Info("client disconnected during anthropic stream", "error", clientCtx.Err()) return } h.logger.Warn("anthropic streaming failed", "model", model.ModelID, "error", err) @@ -359,10 +365,10 @@ func (h *MessagesHandler) handleStreaming( return case client.EndpointResponses: - if err := h.handleResponsesStreaming(ctx, rw, anthropicReq, model, clientCtx); err != nil { + if err := h.handleResponsesStreaming(attemptCtx, rw, anthropicReq, model, clientCtx); err != nil { cancel() - if clientCtx.Err() == context.Canceled { - h.logger.Info("client disconnected during responses stream") + if clientCtx.Err() != nil { + h.logger.Info("client disconnected during responses stream", "error", clientCtx.Err()) return } h.logger.Warn("responses streaming failed", "model", model.ModelID, "error", err) @@ -375,10 +381,10 @@ func (h *MessagesHandler) handleStreaming( return case client.EndpointGemini: - if err := h.handleGeminiStreaming(ctx, rw, anthropicReq, model, clientCtx); err != nil { + if err := h.handleGeminiStreaming(attemptCtx, rw, anthropicReq, model, clientCtx); err != nil { cancel() - if clientCtx.Err() == context.Canceled { - h.logger.Info("client disconnected during gemini stream") + if clientCtx.Err() != nil { + h.logger.Info("client disconnected during gemini stream", "error", clientCtx.Err()) return } h.logger.Warn("gemini streaming failed", "model", model.ModelID, "error", err) @@ -391,11 +397,26 @@ func (h *MessagesHandler) handleStreaming( return default: - // Fall through to OpenAI-compatible handling } + } else if client.IsAnthropicModel(model.ModelID) { + modelBody := replaceModelInRawBody(rawBody, model.ModelID) + err := h.handleAnthropicStreaming(attemptCtx, rw, modelBody, model.ModelID, model, heartbeatPaused) + if err != nil { + cancel() + if clientCtx.Err() != nil { + h.logger.Info("client disconnected during anthropic stream", "error", clientCtx.Err()) + return + } + h.logger.Warn("anthropic streaming failed", "model", model.ModelID, "error", err) + continue + } + cancel() + latency := time.Since(streamStart) + h.metrics.RecordSuccess(model.ModelID, latency) + h.logger.Info("streaming completed", "model", model.ModelID, "latency", latency) + return } - // OpenAI-compatible models (both Go and Zen) openaiReq, err := h.requestTransformer.TransformRequest(anthropicReq, model) if err != nil { cancel() @@ -403,26 +424,29 @@ func (h *MessagesHandler) handleStreaming( continue } - streamBody, err := h.client.GetStreamingBody(ctx, model.ModelID, openaiReq, model) + streamBody, err := h.client.GetStreamingBody(attemptCtx, model.ModelID, openaiReq, model) if err != nil { cancel() - if clientCtx.Err() == context.Canceled { - h.logger.Info("client disconnected during upstream request") + if clientCtx.Err() != nil { + h.logger.Info("client disconnected during upstream request", "error", clientCtx.Err()) return } h.logger.Warn("streaming request failed", "model", model.ModelID, "error", err) continue } - if err := h.streamHandler.ProxyStream(rw, streamBody, model.ModelID, clientCtx); err != nil { + // Bind body read to attemptCtx so streaming_timeout_ms aborts mid-stream. + streamReader := transformer.NewCtxReadCloser(attemptCtx, streamBody) + + if err := h.streamHandler.ProxyStream(rw, streamReader, model.ModelID, attemptCtx); err != nil { _ = streamBody.Close() cancel() if err == transformer.ErrClientDisconnected { h.logger.Info("client disconnected during stream") return } - if clientCtx.Err() == context.Canceled { - h.logger.Info("client disconnected during stream (context canceled)") + if clientCtx.Err() != nil { + h.logger.Info("client disconnected during stream (context canceled)", "error", clientCtx.Err()) return } h.logger.Warn("stream proxy failed", "model", model.ModelID, "error", err) @@ -438,7 +462,7 @@ func (h *MessagesHandler) handleStreaming( } h.metrics.RecordFailure() - if !rw.wroteHeader { + if !rw.HasWrittenHeader() { h.sendError(w, http.StatusBadGateway, "all streaming models failed", nil) } else { h.sendStreamError(rw, "all upstream models failed") @@ -446,6 +470,8 @@ func (h *MessagesHandler) handleStreaming( } // handleResponsesStreaming handles streaming for OpenAI Responses endpoint. +// ctx is the per-attempt context (carries streaming_timeout_ms); clientCtx is the +// broader request context used only for client-disconnect signaling. func (h *MessagesHandler) handleResponsesStreaming( ctx context.Context, w http.ResponseWriter, @@ -463,7 +489,10 @@ func (h *MessagesHandler) handleResponsesStreaming( return err } - if err := h.streamHandler.ProxyResponsesStream(w, streamBody, model.ModelID, clientCtx); err != nil { + // Bind body read to ctx so streaming_timeout_ms aborts mid-stream. + streamReader := transformer.NewCtxReadCloser(ctx, streamBody) + + if err := h.streamHandler.ProxyResponsesStream(w, streamReader, model.ModelID, ctx); err != nil { _ = streamBody.Close() return err } @@ -473,6 +502,8 @@ func (h *MessagesHandler) handleResponsesStreaming( } // handleGeminiStreaming handles streaming for Gemini endpoint. +// ctx is the per-attempt context (carries streaming_timeout_ms); clientCtx is the +// broader request context used only for client-disconnect signaling. func (h *MessagesHandler) handleGeminiStreaming( ctx context.Context, w http.ResponseWriter, @@ -490,7 +521,10 @@ func (h *MessagesHandler) handleGeminiStreaming( return err } - if err := h.streamHandler.ProxyGeminiStream(w, streamBody, model.ModelID, clientCtx); err != nil { + // Bind body read to ctx so streaming_timeout_ms aborts mid-stream. + streamReader := transformer.NewCtxReadCloser(ctx, streamBody) + + if err := h.streamHandler.ProxyGeminiStream(w, streamReader, model.ModelID, ctx); err != nil { _ = streamBody.Close() return err } @@ -499,36 +533,66 @@ func (h *MessagesHandler) handleGeminiStreaming( return nil } -// replaceModelInRawBody replaces the model field in raw JSON body with the actual model ID. +// replaceModelInRawBody replaces the top-level model field in a raw JSON body. func replaceModelInRawBody(rawBody json.RawMessage, modelID string) json.RawMessage { - bodyStr := string(rawBody) - - if idx := strings.Index(bodyStr, `"model":"`); idx != -1 { - start := idx + len(`"model":"`) - if end := strings.Index(bodyStr[start:], `"`); end != -1 { - oldModel := bodyStr[start : start+end] - newBody := bodyStr[:start] + modelID + bodyStr[start+end:] - slog.Debug("replaced model in request body", - "old_model", oldModel, - "new_model", modelID, - "success", true) - return json.RawMessage(newBody) - } + var obj map[string]json.RawMessage + if err := json.Unmarshal(rawBody, &obj); err != nil { + slog.Warn("replaceModelInRawBody: unmarshal failed, using original body", + "error", err, + "body_preview", string(rawBody)[:min(len(rawBody), 200)]) + return rawBody + } + + oldModelRaw, ok := obj["model"] + if !ok { + slog.Warn("replaceModelInRawBody: model key not found, using original body", + "body_preview", string(rawBody)[:min(len(rawBody), 200)]) + return rawBody + } + + var oldModel string + if err := json.Unmarshal(oldModelRaw, &oldModel); err != nil { + slog.Warn("replaceModelInRawBody: model value not a string, using original body", + "error", err) + return rawBody + } + + modelJSON, err := json.Marshal(modelID) + if err != nil { + slog.Warn("replaceModelInRawBody: marshal modelID failed, using original body", + "error", err) + return rawBody } + obj["model"] = modelJSON - slog.Warn("could not find model field in request body, using original", - "body_preview", bodyStr[:min(len(bodyStr), 200)]) - return rawBody + newBody, err := json.Marshal(obj) + if err != nil { + slog.Warn("replaceModelInRawBody: remarshal failed, using original body", + "error", err) + return rawBody + } + + slog.Debug("replaced model in request body", + "old_model", oldModel, + "new_model", modelID, + "success", true) + return json.RawMessage(newBody) } -// handleAnthropicStreaming sends a raw Anthropic request to the Anthropic endpoint. +// handleAnthropicStreaming sends a raw Anthropic request to the Anthropic endpoint, +// pausing the heartbeat for the duration of the raw SSE copy so synthetic keepalive +// events are not interleaved with upstream events. func (h *MessagesHandler) handleAnthropicStreaming( ctx context.Context, w http.ResponseWriter, rawBody json.RawMessage, modelID string, model config.ModelConfig, + heartbeatPaused *atomic.Int32, ) error { + heartbeatPaused.Store(1) + defer heartbeatPaused.Store(0) + h.logger.Debug("sending anthropic streaming request", "model_id", modelID, "body_preview", string(rawBody)[:min(len(rawBody), 200)]) @@ -539,11 +603,18 @@ func (h *MessagesHandler) handleAnthropicStreaming( } defer func() { _ = resp.Body.Close() }() - _, err = io.Copy(w, resp.Body) + // Bind body read to ctx so streaming_timeout_ms aborts mid-stream. + // io.Copy does not honor ctx, so wrap the upstream body explicitly. + bodyReader := transformer.NewCtxReader(ctx, resp.Body) + fw := &flushWriter{ResponseWriter: w} + _, err = io.Copy(fw, bodyReader) if err != nil { if ctx.Err() == context.Canceled { return transformer.ErrClientDisconnected } + if errors.Is(err, transformer.ErrStreamReadCanceled) { + return transformer.ErrClientDisconnected + } return fmt.Errorf("failed to copy response: %w", err) } @@ -570,7 +641,7 @@ func (h *MessagesHandler) sendStreamError(w http.ResponseWriter, message string) } } -// handleNonStreaming handles a non-streaming request with fallback. +// handleNonStreaming runs a non-streaming request with per-model timeouts. func (h *MessagesHandler) handleNonStreaming( w http.ResponseWriter, r *http.Request, @@ -578,37 +649,43 @@ func (h *MessagesHandler) handleNonStreaming( modelChain []config.ModelConfig, rawBody json.RawMessage, ) { - ctx := r.Context() + parentCtx := r.Context() startTime := time.Now() result, responseBody, err := h.fallbackHandler.ExecuteWithFallback( - ctx, + parentCtx, modelChain, func(ctx context.Context, model config.ModelConfig) ([]byte, error) { - // Zen models use their own endpoint classification + timeout := h.client.RequestTimeout(model) + attemptCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + if client.IsZen(model) { endpointType := client.ClassifyEndpoint(model.ModelID) switch endpointType { case client.EndpointAnthropic: - return h.executeAnthropicRequest(ctx, rawBody, model) + modelBody := replaceModelInRawBody(rawBody, model.ModelID) + return h.executeAnthropicRequest(attemptCtx, modelBody, model) case client.EndpointResponses: - return h.executeResponsesRequest(ctx, anthropicReq, model) + return h.executeResponsesRequest(attemptCtx, anthropicReq, model) case client.EndpointGemini: - return h.executeGeminiRequest(ctx, anthropicReq, model) + return h.executeGeminiRequest(attemptCtx, anthropicReq, model) default: - // Fall through to OpenAI-compatible handling } } else if client.IsAnthropicModel(model.ModelID) { - // Go provider Anthropic-native models (MiniMax, Qwen) - return h.executeAnthropicRequest(ctx, rawBody, model) + modelBody := replaceModelInRawBody(rawBody, model.ModelID) + return h.executeAnthropicRequest(attemptCtx, modelBody, model) } - // OpenAI-compatible models (both Go and Zen) - return h.executeOpenAIRequest(ctx, anthropicReq, model) + return h.executeOpenAIRequest(attemptCtx, anthropicReq, model) }, ) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + h.logger.Info("request context canceled during non-streaming fallback", "error", err) + return + } h.metrics.RecordFailure() h.sendError(w, http.StatusBadGateway, "all models failed", err) return @@ -750,7 +827,7 @@ func (h *MessagesHandler) sendError(w http.ResponseWriter, statusCode int, messa "error", err, ) - if rw, ok := w.(*responseWriter); ok && rw.wroteHeader { + if rw, ok := w.(*responseWriter); ok && rw.HasWrittenHeader() { return } diff --git a/internal/handlers/messages_test.go b/internal/handlers/messages_test.go index f3dcac1..f48d9c6 100644 --- a/internal/handlers/messages_test.go +++ b/internal/handlers/messages_test.go @@ -1,11 +1,26 @@ package handlers import ( + "context" + "encoding/json" + "fmt" + "io" "log/slog" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" "testing" + "time" + "oc-go-cc/internal/client" "oc-go-cc/internal/config" + "oc-go-cc/internal/metrics" "oc-go-cc/internal/router" + "oc-go-cc/internal/token" + "oc-go-cc/internal/transformer" + "oc-go-cc/pkg/types" ) func TestAppendUniqueModels_DedupsByModelID(t *testing.T) { @@ -337,3 +352,1238 @@ func equalStrings(a, b []string) bool { } return true } + +// --------------------------------------------------------------------------- +// Phase 2 regression tests: replaceModelInRawBody (JSON-based replacement) +// --------------------------------------------------------------------------- + +func TestReplaceModelInRawBody_JSONBased(t *testing.T) { + raw := json.RawMessage(`{"model":"claude-opus-4-8","stream":true,"messages":[]}`) + got := string(replaceModelInRawBody(raw, "minimax-m3")) + + if !strings.Contains(got, `"minimax-m3"`) { + t.Fatalf("expected model replaced to minimax-m3, got: %s", got) + } + if !strings.Contains(got, `"stream":true`) { + t.Fatalf("expected other fields preserved, got: %s", got) + } + if strings.Contains(got, `"claude-opus-4-8"`) { + t.Fatalf("old model ID should be gone, got: %s", got) + } +} + +func TestReplaceModelInRawBody_HandlesWhitespace(t *testing.T) { + raw := json.RawMessage(`{ "model" : "claude-opus-4-8" , "stream" : true }`) + got := string(replaceModelInRawBody(raw, "minimax-m3")) + + if !strings.Contains(got, `"minimax-m3"`) { + t.Fatalf("expected model replaced despite whitespace, got: %s", got) + } +} + +func TestReplaceModelInRawBody_ReturnsOriginalWhenModelMissing(t *testing.T) { + raw := json.RawMessage(`{"stream":true,"messages":[]}`) + got := replaceModelInRawBody(raw, "minimax-m3") + + // Should return original raw bytes since there's no "model" key + var parsed map[string]interface{} + if err := json.Unmarshal(got, &parsed); err != nil { + t.Fatalf("result is invalid JSON: %v", err) + } + if _, ok := parsed["model"]; ok { + t.Fatalf("model key should not be present in result when absent from input") + } +} + +func TestReplaceModelInRawBody_ReturnsOriginalOnInvalidJSON(t *testing.T) { + raw := json.RawMessage(`{invalid}`) + got := replaceModelInRawBody(raw, "minimax-m3") + + if string(got) != `{invalid}` { + t.Fatalf("expected original body on invalid JSON, got: %s", got) + } +} + +func TestReplaceModelInRawBody_HandlesNestedObjects(t *testing.T) { + raw := json.RawMessage(`{ + "model": "claude-opus-4-8", + "messages": [{"role":"user","content":"hello"}], + "tools": [{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}], + "stream": true + }`) + got := string(replaceModelInRawBody(raw, "minimax-m3")) + + if !strings.Contains(got, `"minimax-m3"`) { + t.Fatalf("expected model replaced to minimax-m3 in complex body, got: %s", got) + } + if !strings.Contains(got, `"Bash"`) { + t.Fatalf("expected tool name Bash preserved, got: %s", got) + } + if !strings.Contains(got, `"input_schema"`) { + t.Fatalf("expected input_schema preserved, got: %s", got) + } +} + +// --------------------------------------------------------------------------- +// Phase 2 regression tests: handleStreaming Go Anthropic-native branch +// --------------------------------------------------------------------------- + +func TestHandleStreaming_GoAnthropicModel_SendsRawAnthropicBody(t *testing.T) { + // Spin up a fake upstream that records the request body + var capturedBody []byte + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + capturedBody, err = io.ReadAll(r.Body) + if err != nil { + t.Logf("upstream read body error: %v", err) + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {}\n\n") + _, _ = fmt.Fprintf(w, "event: message_stop\ndata: {}\n\n") + _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + })) + defer upstream.Close() + + handler := newStreamingTestHandler(t, upstream.URL) + + rawBody := json.RawMessage(`{ + "model": "claude-opus-4-8", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}], + "tools": [{"name":"Bash","description":"Run a command","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + // Call handleStreaming with minimax-m3 (Go Anthropic-native) + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "minimax-m3"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + // context is tied to the request lifetime; handleStreaming waits on it + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, chain, rawBody) + + // Verify the upstream received raw Anthropic format (not OpenAI-transformed) + if len(capturedBody) == 0 { + t.Fatal("upstream received no body") + } + + var captured map[string]interface{} + if err := json.Unmarshal(capturedBody, &captured); err != nil { + t.Fatalf("captured body is not valid JSON: %v\nbody: %s", err, capturedBody) + } + + // Must have model = minimax-m3 + if got, ok := captured["model"]; !ok || got != "minimax-m3" { + t.Fatalf("captured model = %v, want minimax-m3", got) + } + + // Must have tools with input_schema (Anthropic format), NOT function (OpenAI format) + toolsRaw, ok := captured["tools"] + if !ok { + t.Fatal("captured body missing tools field") + } + tools, ok := toolsRaw.([]interface{}) + if !ok || len(tools) == 0 { + t.Fatal("captured body tools is empty or not an array") + } + tool0, ok := tools[0].(map[string]interface{}) + if !ok { + t.Fatal("tool[0] is not an object") + } + if _, ok := tool0["function"]; ok { + t.Fatalf("captured tool has 'function' field (OpenAI format leak): %s", capturedBody) + } + if _, ok := tool0["input_schema"]; !ok { + t.Fatalf("captured tool missing 'input_schema' (Anthropic format): %s", capturedBody) + } + if got, ok := tool0["name"]; !ok || got != "Bash" { + t.Fatalf("captured tool name = %v, want Bash", got) + } +} + +// TestHandleStreaming_GoAnthropicModel_FallsThroughOnError verifies that +// when the Go Anthropic-native model fails, the streaming handler falls +// through to the next model in the chain. +func TestHandleStreaming_GoAnthropicModel_FallsThroughOnError(t *testing.T) { + // Single upstream: fails on first request, succeeds on second. + // Both models in the chain use the same base URL. + callCount := int32(0) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + // First call (minimax-m3) fails + w.WriteHeader(http.StatusInternalServerError) + return + } + // Second call (qwen3.5-plus) succeeds + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {}\n\n") + _, _ = fmt.Fprintf(w, "event: message_stop\ndata: {}\n\n") + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + OpenCodeGo: config.OpenCodeGoConfig{ + AnthropicBaseURL: upstream.URL, + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + + handler := &MessagesHandler{ + client: ocClient, + logger: slog.Default(), + metrics: metrics.New(), + streamHandler: transformer.NewStreamHandler(), + requestTransformer: transformer.NewRequestTransformer(), + responseTransformer: transformer.NewResponseTransformer(), + } + + rawBody := json.RawMessage(`{ + "model": "claude-opus-4-8", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + // Chain: minimax-m3 fails (first call → 500), qwen3.5-plus succeeds (second call) + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "minimax-m3"}, + {Provider: "opencode-go", ModelID: "qwen3.5-plus"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, chain, rawBody) + + // Both models tried: minimax got 500, qwen3.5-plus got 200 + finalCount := atomic.LoadInt32(&callCount) + if finalCount != 2 { + t.Fatalf("expected 2 upstream calls (1 fail + 1 success), got %d", finalCount) + } +} + +// newStreamingTestHandler creates a MessagesHandler for streaming tests, +// pointing both Go Anthropic and OpenAI endpoints at the given test server URL. +func newStreamingTestHandler(t *testing.T, upstreamURL string) *MessagesHandler { + t.Helper() + cfg := &config.Config{ + APIKey: "test-key", + OpenCodeGo: config.OpenCodeGoConfig{ + AnthropicBaseURL: upstreamURL, + BaseURL: upstreamURL, + TimeoutMs: 5000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + + return &MessagesHandler{ + client: ocClient, + logger: slog.Default(), + metrics: metrics.New(), + streamHandler: transformer.NewStreamHandler(), + requestTransformer: transformer.NewRequestTransformer(), + responseTransformer: transformer.NewResponseTransformer(), + } +} + +// --------------------------------------------------------------------------- +// End-to-end test: HandleMessages → routing → handleStreaming → upstream +// --------------------------------------------------------------------------- + +// TestHandleMessages_StreamingMinimaxM3_UsesAnthropicEndpoint verifies the +// full public API path: HandleMessages receives a streaming request for +// minimax-m3, routing selects it (via ModelOverrides), and the upstream +// receives the raw Anthropic body (NOT OpenAI-transformed). +func TestHandleMessages_StreamingMinimaxM3_UsesAnthropicEndpoint(t *testing.T) { + // 1. Set up fake upstream that records the request body. + var capturedBody []byte + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + capturedBody, err = io.ReadAll(r.Body) + if err != nil { + t.Logf("upstream read body error: %v", err) + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {}\n\n") + _, _ = fmt.Fprintf(w, "event: message_stop\ndata: {}\n\n") + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + })) + defer upstream.Close() + + // 2. Build config that forces routing to minimax-m3. + // ModelOverrides takes highest precedence in buildModelChain. + cfg := &config.Config{ + APIKey: "test-key", + Models: map[string]config.ModelConfig{ + "default": {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + "fast": {Provider: "opencode-go", ModelID: "qwen3.6-plus"}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "default": {{Provider: "opencode-go", ModelID: "glm-5"}}, + "fast": {{Provider: "opencode-go", ModelID: "qwen3.5-plus"}}, + }, + ModelOverrides: map[string]config.ModelConfig{ + "minimax-m3": { + Provider: "opencode-go", + ModelID: "minimax-m3", + }, + }, + OpenCodeGo: config.OpenCodeGoConfig{ + AnthropicBaseURL: upstream.URL, + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + + // 3. Build the full MessagesHandler with all real dependencies. + ocClient := client.NewOpenCodeClient(atomicCfg) + modelRouter := router.NewModelRouter(atomicCfg) + tokenCounter, err := token.NewCounter() + if err != nil { + t.Fatalf("NewCounter: %v", err) + } + + handler := NewMessagesHandler( + ocClient, + modelRouter, + nil, // fallbackHandler — not used in streaming path + tokenCounter, + metrics.New(), + ) + handler.logger = slog.Default() + + // 4. Build the streaming request body requesting minimax-m3 with tools. + requestBody := `{ + "model": "minimax-m3", + "stream": true, + "max_tokens": 256, + "messages": [{"role": "user", "content": "Say hello"}], + "tools": [{ + "name": "Bash", + "description": "Run a command", + "input_schema": {"type": "object", "properties": {"cmd": {"type": "string"}}} + }] + }` + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + // 5. Call HandleMessages — the full public entry point. + handler.HandleMessages(recorder, req) + + // 6. Verify upstream received raw Anthropic body. + if len(capturedBody) == 0 { + t.Fatal("upstream received no body — routing or streaming may have failed silently") + } + + var captured map[string]interface{} + if err := json.Unmarshal(capturedBody, &captured); err != nil { + t.Fatalf("captured body is not valid JSON: %v\nbody: %s", err, capturedBody) + } + + // Model must be minimax-m3 + if got, ok := captured["model"]; !ok || got != "minimax-m3" { + t.Fatalf("captured model = %v, want minimax-m3", got) + } + + // Tools must be Anthropic format (input_schema), NOT OpenAI format (function) + toolsRaw, ok := captured["tools"] + if !ok { + t.Fatal("captured body missing tools field") + } + tools, ok := toolsRaw.([]interface{}) + if !ok || len(tools) == 0 { + t.Fatal("captured body tools is empty or not an array") + } + tool0, ok := tools[0].(map[string]interface{}) + if !ok { + t.Fatal("tool[0] is not an object") + } + if _, ok := tool0["function"]; ok { + t.Fatalf("captured tool has 'function' field (OpenAI format leak — TransformRequest was called): %s", capturedBody) + } + if _, ok := tool0["input_schema"]; !ok { + t.Fatalf("captured tool missing 'input_schema' (Anthropic format): %s", capturedBody) + } + if got, ok := tool0["name"]; !ok || got != "Bash" { + t.Fatalf("captured tool name = %v, want Bash", got) + } + + t.Logf("end-to-end test PASSED: upstream received raw Anthropic body with model=minimax-m3 and input_schema") +} + +// --------------------------------------------------------------------------- +// Non-streaming regression tests: handleNonStreaming model replacement +// --------------------------------------------------------------------------- + +// TestHandleNonStreaming_GoAnthropicModel_ReplacesModelInBody verifies that +// the non-streaming path replaces the model in the request body for Go +// Anthropic-native models (minimax-m3) before forwarding to upstream. +// Without this fix, upstream would receive "claude-haiku-4-5-20251001" and +// reject it with "Model is not supported". +func TestHandleNonStreaming_GoAnthropicModel_ReplacesModelInBody(t *testing.T) { + var capturedBody []byte + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + capturedBody, err = io.ReadAll(r.Body) + if err != nil { + t.Logf("upstream read body error: %v", err) + } + // Non-streaming: return a valid JSON response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "minimax-m3", + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 5} + }`)) + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + Models: map[string]config.ModelConfig{ + "default": {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "default": {{Provider: "opencode-go", ModelID: "glm-5"}}, + }, + ModelOverrides: map[string]config.ModelConfig{ + "claude-haiku-4-5-20251001": { + Provider: "opencode-go", + ModelID: "minimax-m3", + }, + }, + OpenCodeGo: config.OpenCodeGoConfig{ + AnthropicBaseURL: upstream.URL, + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + modelRouter := router.NewModelRouter(atomicCfg) + tokenCounter, err := token.NewCounter() + if err != nil { + t.Fatalf("NewCounter: %v", err) + } + + handler := NewMessagesHandler( + ocClient, + modelRouter, + router.NewFallbackHandler(slog.Default(), 3, 30), + tokenCounter, + metrics.New(), + ) + handler.logger = slog.Default() + + // Use a different client model to verify the model is replaced to + // minimax-m3 before sending upstream. + requestBody := `{ + "model": "claude-haiku-4-5-20251001", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Say hello"}], + "tools": [{ + "name": "Bash", + "description": "Run a command", + "input_schema": {"type": "object", "properties": {"cmd": {"type": "string"}}} + }] + }` + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + handler.HandleMessages(recorder, req) + + // Verify upstream received the request body with model replaced + if len(capturedBody) == 0 { + t.Fatal("upstream received no body") + } + + var captured map[string]interface{} + if err := json.Unmarshal(capturedBody, &captured); err != nil { + t.Fatalf("captured body is not valid JSON: %v\nbody: %s", err, capturedBody) + } + + // Must have model = minimax-m3 + if got, ok := captured["model"]; !ok || got != "minimax-m3" { + t.Fatalf("captured model = %v, want minimax-m3", got) + } + + // Must have tools with input_schema (Anthropic format), NOT function (OpenAI) + toolsRaw, ok := captured["tools"] + if !ok { + t.Fatal("captured body missing tools field") + } + tools, ok := toolsRaw.([]interface{}) + if !ok || len(tools) == 0 { + t.Fatal("captured body tools is empty or not an array") + } + tool0, ok := tools[0].(map[string]interface{}) + if !ok { + t.Fatal("tool[0] is not an object") + } + if _, ok := tool0["function"]; ok { + t.Fatalf("captured tool has 'function' field (OpenAI format leak): %s", capturedBody) + } + if _, ok := tool0["input_schema"]; !ok { + t.Fatalf("captured tool missing 'input_schema' (Anthropic format): %s", capturedBody) + } + if got, ok := tool0["name"]; !ok || got != "Bash" { + t.Fatalf("captured tool name = %v, want Bash", got) + } + + t.Logf("non-streaming Go Anthropic-native test PASSED: upstream received model=minimax-m3 with Anthropic tool format") +} + +// TestHandleNonStreaming_ZenAnthropicModel_ReplacesModelInBody verifies that +// the non-streaming path replaces the model in the request body for Zen +// Anthropic-native models (claude-* via opencode-zen) before forwarding upstream. +func TestHandleNonStreaming_ZenAnthropicModel_ReplacesModelInBody(t *testing.T) { + var capturedBody []byte + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + capturedBody, err = io.ReadAll(r.Body) + if err != nil { + t.Logf("upstream read body error: %v", err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "claude-sonnet-4.5", + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 5} + }`)) + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + Models: map[string]config.ModelConfig{ + "default": {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "default": {{Provider: "opencode-go", ModelID: "glm-5"}}, + }, + ModelOverrides: map[string]config.ModelConfig{ + "claude-haiku-4-5-20251001": { + Provider: "opencode-zen", + ModelID: "claude-sonnet-4.5", + }, + }, + OpenCodeGo: config.OpenCodeGoConfig{ + AnthropicBaseURL: upstream.URL, + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + OpenCodeZen: config.OpenCodeZenConfig{ + AnthropicBaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + modelRouter := router.NewModelRouter(atomicCfg) + tokenCounter, err := token.NewCounter() + if err != nil { + t.Fatalf("NewCounter: %v", err) + } + + handler := NewMessagesHandler( + ocClient, + modelRouter, + router.NewFallbackHandler(slog.Default(), 3, 30), + tokenCounter, + metrics.New(), + ) + handler.logger = slog.Default() + + // Use a different client model to verify the model is replaced to + // claude-sonnet-4.5 before sending upstream. + requestBody := `{ + "model": "claude-haiku-4-5-20251001", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Say hello"}], + "tools": [{ + "name": "Bash", + "description": "Run a command", + "input_schema": {"type": "object", "properties": {"cmd": {"type": "string"}}} + }] + }` + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + handler.HandleMessages(recorder, req) + + if len(capturedBody) == 0 { + t.Fatal("upstream received no body") + } + + var captured map[string]interface{} + if err := json.Unmarshal(capturedBody, &captured); err != nil { + t.Fatalf("captured body is not valid JSON: %v\nbody: %s", err, capturedBody) + } + + // Must have model = claude-sonnet-4.5 (replaced from claude-haiku-4-5-20251001) + if got, ok := captured["model"]; !ok || got != "claude-sonnet-4.5" { + t.Fatalf("captured model = %v, want claude-sonnet-4.5", got) + } + + // Must have tools with input_schema (Anthropic format), NOT function (OpenAI) + toolsRaw, ok := captured["tools"] + if !ok { + t.Fatal("captured body missing tools field") + } + tools, ok := toolsRaw.([]interface{}) + if !ok || len(tools) == 0 { + t.Fatal("captured body tools is empty or not an array") + } + tool0, ok := tools[0].(map[string]interface{}) + if !ok { + t.Fatal("tool[0] is not an object") + } + if _, ok := tool0["function"]; ok { + t.Fatalf("captured tool has 'function' field (OpenAI format leak): %s", capturedBody) + } + if _, ok := tool0["input_schema"]; !ok { + t.Fatalf("captured tool missing 'input_schema' (Anthropic format): %s", capturedBody) + } + + t.Logf("non-streaming Zen Anthropic test PASSED: upstream received model=claude-sonnet-4.5 with Anthropic tool format") +} + +func TestHandleStreaming_ConfigurableTimeout(t *testing.T) { + callCount := int32(0) + handlerCtx, handlerCancel := context.WithCancel(context.Background()) + defer handlerCancel() + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&callCount, 1) + <-handlerCtx.Done() + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + OpenCodeGo: config.OpenCodeGoConfig{ + BaseURL: upstream.URL, + AnthropicBaseURL: upstream.URL, + TimeoutMs: 5000, + StreamingTimeoutMs: 100, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + + handler := &MessagesHandler{ + client: ocClient, + logger: slog.Default(), + metrics: metrics.New(), + streamHandler: transformer.NewStreamHandler(), + requestTransformer: transformer.NewRequestTransformer(), + responseTransformer: transformer.NewResponseTransformer(), + } + + rawBody := json.RawMessage(`{ + "model": "claude-opus-4-8", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + start := time.Now() + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, chain, rawBody) + elapsed := time.Since(start) + + handlerCancel() + + if elapsed > 10*time.Second { + t.Errorf("streaming attempt took %v, expected much less than 2 minutes", elapsed) + } + + finalCount := atomic.LoadInt32(&callCount) + if finalCount != 1 { + t.Errorf("expected 1 upstream call (single model in chain), got %d", finalCount) + } +} + +func TestHandleStreaming_ClientContextCanceled_StopsFallback(t *testing.T) { + callCount := int32(0) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&callCount, 1) + <-r.Context().Done() + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + OpenCodeGo: config.OpenCodeGoConfig{ + BaseURL: upstream.URL, + AnthropicBaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + + handler := &MessagesHandler{ + client: ocClient, + logger: slog.Default(), + metrics: metrics.New(), + streamHandler: transformer.NewStreamHandler(), + requestTransformer: transformer.NewRequestTransformer(), + responseTransformer: transformer.NewResponseTransformer(), + } + + rawBody := json.RawMessage(`{ + "model": "claude-opus-4-8", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + {Provider: "opencode-go", ModelID: "glm-5"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + + cancel() + + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, chain, rawBody) + + time.Sleep(50 * time.Millisecond) + + finalCount := atomic.LoadInt32(&callCount) + if finalCount != 0 { + t.Errorf("expected 0 upstream calls (client canceled), got %d", finalCount) + } + + body := recorder.Body.String() + if strings.Contains(body, "all upstream models failed") { + t.Errorf("should not send 'all upstream models failed' event for client disconnect, got: %s", body) + } +} + +func TestHandleStreaming_ClientDisconnectsDuringStream_StopsFallback(t *testing.T) { + callCount := int32(0) + handlerCtx, handlerCancel := context.WithCancel(context.Background()) + defer handlerCancel() + firstModelReady := make(chan struct{}) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + close(firstModelReady) + <-handlerCtx.Done() + return + } + t.Error("second model should not be attempted after client disconnect") + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + OpenCodeGo: config.OpenCodeGoConfig{ + BaseURL: upstream.URL, + AnthropicBaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + + handler := &MessagesHandler{ + client: ocClient, + logger: slog.Default(), + metrics: metrics.New(), + streamHandler: transformer.NewStreamHandler(), + requestTransformer: transformer.NewRequestTransformer(), + responseTransformer: transformer.NewResponseTransformer(), + } + + rawBody := json.RawMessage(`{ + "model": "claude-opus-4-8", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + {Provider: "opencode-go", ModelID: "glm-5"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + + done := make(chan struct{}) + go func() { + defer close(done) + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, chain, rawBody) + }() + + select { + case <-firstModelReady: + case <-time.After(5 * time.Second): + t.Fatal("first model did not start within 5s") + } + + cancel() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("handleStreaming did not return after client disconnect") + } + + handlerCancel() + + finalCount := atomic.LoadInt32(&callCount) + if finalCount != 1 { + t.Errorf("expected 1 upstream call, got %d", finalCount) + } + + body := recorder.Body.String() + if strings.Contains(body, "all upstream models failed") { + t.Errorf("should not send 'all upstream models failed' event for client disconnect, got: %s", body) + } +} + +func TestHandleStreaming_PerModelTimeoutFallback(t *testing.T) { + callCount := int32(0) + handlerCtx, handlerCancel := context.WithCancel(context.Background()) + defer handlerCancel() + firstModelReady := make(chan struct{}) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + close(firstModelReady) + <-handlerCtx.Done() + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {}\n\n") + _, _ = fmt.Fprintf(w, "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n") + _, _ = fmt.Fprintf(w, "event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n") + _, _ = fmt.Fprintf(w, "event: message_stop\ndata: {}\n\n") + _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + OpenCodeGo: config.OpenCodeGoConfig{ + BaseURL: upstream.URL, + AnthropicBaseURL: upstream.URL, + TimeoutMs: 5000, + StreamingTimeoutMs: 200, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + + handler := &MessagesHandler{ + client: ocClient, + logger: slog.Default(), + metrics: metrics.New(), + streamHandler: transformer.NewStreamHandler(), + requestTransformer: transformer.NewRequestTransformer(), + responseTransformer: transformer.NewResponseTransformer(), + } + + rawBody := json.RawMessage(`{ + "model": "claude-opus-4-8", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + {Provider: "opencode-go", ModelID: "glm-5"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + + done := make(chan struct{}) + go func() { + defer close(done) + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, chain, rawBody) + cancel() + }() + + select { + case <-firstModelReady: + case <-time.After(5 * time.Second): + t.Fatal("first model did not start within 5s") + } + + time.Sleep(500 * time.Millisecond) + + handlerCancel() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("handleStreaming did not complete within 5s") + } + + finalCount := atomic.LoadInt32(&callCount) + if finalCount != 2 { + t.Errorf("expected 2 upstream calls (1 timeout + 1 success), got %d", finalCount) + } +} + +func TestHandleNonStreaming_ParentContextCanceled_No502(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "kimi-k2.6", + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 5} + }`)) + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + Models: map[string]config.ModelConfig{ + "default": {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "default": {{Provider: "opencode-go", ModelID: "glm-5"}}, + }, + OpenCodeGo: config.OpenCodeGoConfig{ + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + modelRouter := router.NewModelRouter(atomicCfg) + tokenCounter, err := token.NewCounter() + if err != nil { + t.Fatalf("NewCounter: %v", err) + } + + m := metrics.New() + handler := NewMessagesHandler( + ocClient, + modelRouter, + router.NewFallbackHandler(slog.Default(), 3, 30*time.Second), + tokenCounter, + m, + ) + handler.logger = slog.Default() + + requestBody := `{ + "model": "claude-haiku-4-5-20251001", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Say hello"}] + }` + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + ctx, cancel := context.WithCancel(req.Context()) + cancel() + req = req.WithContext(ctx) + + handler.HandleMessages(recorder, req) + + if recorder.Code == http.StatusBadGateway { + t.Errorf("should not return 502 for canceled context, got status %d", recorder.Code) + } + + snap := m.GetSnapshot() + if snap.RequestsFailed > 0 { + t.Errorf("failure count should be 0 for canceled context, got %d", snap.RequestsFailed) + } + + body := recorder.Body.String() + if strings.Contains(body, "all models failed") { + t.Errorf("should not contain 'all models failed' for client cancellation, got: %s", body) + } +} + +func TestHandleNonStreaming_ParentDeadlineExceeded_No502(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "kimi-k2.6", + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 5} + }`)) + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + Models: map[string]config.ModelConfig{ + "default": {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "default": {{Provider: "opencode-go", ModelID: "glm-5"}}, + }, + OpenCodeGo: config.OpenCodeGoConfig{ + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + modelRouter := router.NewModelRouter(atomicCfg) + tokenCounter, err := token.NewCounter() + if err != nil { + t.Fatalf("NewCounter: %v", err) + } + + m := metrics.New() + handler := NewMessagesHandler( + ocClient, + modelRouter, + router.NewFallbackHandler(slog.Default(), 3, 30*time.Second), + tokenCounter, + m, + ) + handler.logger = slog.Default() + + requestBody := `{ + "model": "claude-haiku-4-5-20251001", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Say hello"}] + }` + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + ctx, cancel := context.WithDeadline(req.Context(), time.Now().Add(-1*time.Second)) + defer cancel() + req = req.WithContext(ctx) + + handler.HandleMessages(recorder, req) + + if recorder.Code == http.StatusBadGateway { + t.Errorf("should not return 502 for deadline exceeded, got status %d", recorder.Code) + } + snap := m.GetSnapshot() + if snap.RequestsFailed > 0 { + t.Errorf("failure count should be 0 for deadline exceeded, got %d", snap.RequestsFailed) + } + + body := recorder.Body.String() + if strings.Contains(body, "all models failed") { + t.Errorf("should not contain 'all models failed' for deadline exceeded, got: %s", body) + } +} + +// TestResponseWriter_ConcurrentWrites verifies the mutex serializes writes, +// preventing data races when heartbeat and stream copy write concurrently. +func TestResponseWriter_ConcurrentWrites(t *testing.T) { + recorder := httptest.NewRecorder() + rw := &responseWriter{ResponseWriter: recorder} + + var wg sync.WaitGroup + const goroutines = 10 + const writesPerGoroutine = 100 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < writesPerGoroutine; j++ { + rw.Write([]byte(fmt.Sprintf("goroutine-%d-write-%d\n", id, j))) + } + }(i) + } + wg.Wait() + + output := recorder.Body.String() + lines := strings.Split(strings.TrimSpace(output), "\n") + expectedLines := goroutines * writesPerGoroutine + if len(lines) != expectedLines { + t.Errorf("got %d lines, want %d (possible data loss from unsynchronized writes)", len(lines), expectedLines) + } +} + +// TestHandleStreaming_AnthropicRaw_NoKeepaliveInjection verifies that the +// heartbeat is disabled during Anthropic raw passthrough. The upstream sends +// SSE data slowly (blocking for > heartbeat interval) and the proxy must +// not inject keepalive comments into the raw stream. +func TestHandleStreaming_AnthropicRaw_NoKeepaliveInjection(t *testing.T) { + blockCh := make(chan struct{}) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {\"type\":\"message_start\"}\n\n") + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + select { + case <-blockCh: + case <-time.After(10 * time.Second): + } + _, _ = fmt.Fprintf(w, "event: content_block_delta\ndata: {\"type\":\"content_block_delta\"}\n\n") + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + })) + defer upstream.Close() + + handler := newStreamingTestHandler(t, upstream.URL) + + rawBody := json.RawMessage(`{ + "model": "claude-opus-4-8", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "minimax-m3"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + done := make(chan struct{}) + go func() { + defer close(done) + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, chain, rawBody) + }() + + time.Sleep(3500 * time.Millisecond) + close(blockCh) + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("handleStreaming did not return after unblocking upstream") + } + + body := recorder.Body.String() + + if !strings.Contains(body, "message_start") { + t.Error("output missing message_start event") + } + if !strings.Contains(body, "content_block_delta") { + t.Error("output missing content_block_delta event") + } + + if strings.Contains(body, ":keepalive") { + t.Errorf("keepalive comment leaked into Anthropic raw stream output (concurrent write bug):\n%s", body) + } +} diff --git a/internal/router/fallback.go b/internal/router/fallback.go index e741e41..58fa49e 100644 --- a/internal/router/fallback.go +++ b/internal/router/fallback.go @@ -165,8 +165,7 @@ func (h *FallbackHandler) getCircuitBreaker(modelID string) *CircuitBreaker { return cb } -// ExecuteWithFallback tries models in sequence until one succeeds. -// Respects circuit breaker state to skip models that are failing repeatedly. +// ExecuteWithFallback tries models in order until one succeeds. func (h *FallbackHandler) ExecuteWithFallback( ctx context.Context, models []config.ModelConfig, @@ -175,9 +174,15 @@ func (h *FallbackHandler) ExecuteWithFallback( totalModels := len(models) for i, model := range models { + if err := ctx.Err(); err != nil { + h.logger.Info("request context canceled, stopping fallback attempts", + "error", err, + ) + return nil, nil, err + } + cb := h.getCircuitBreaker(model.ModelID) - // Skip models with open circuit breakers if !cb.AllowRequest() { h.logger.Info("circuit breaker open, skipping model", "model", model.ModelID, @@ -208,6 +213,14 @@ func (h *FallbackHandler) ExecuteWithFallback( }, body, nil } + if errCtx := ctx.Err(); errCtx != nil { + h.logger.Info("request context canceled after model attempt, stopping fallback", + "model", model.ModelID, + "error", errCtx, + ) + return nil, nil, errCtx + } + cb.RecordFailure() h.logger.Warn("model failed, trying fallback", "model", model.ModelID, diff --git a/internal/router/fallback_test.go b/internal/router/fallback_test.go new file mode 100644 index 0000000..ff6159d --- /dev/null +++ b/internal/router/fallback_test.go @@ -0,0 +1,282 @@ +package router + +import ( + "context" + "errors" + "log/slog" + "testing" + "time" + + "oc-go-cc/internal/config" +) + +func TestExecuteWithFallback_StopsOnCanceledContext(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 3, 30*time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + {Provider: "opencode-go", ModelID: "model-b"}, + } + + callCount := 0 + _, _, err := handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + callCount++ + return []byte("ok"), nil + }, + ) + + if callCount != 0 { + t.Errorf("executor called %d times, want 0 (canceled context must stop immediately)", callCount) + } + if err == nil { + t.Error("expected error for canceled context, got nil") + } + + states := handler.GetCircuitStates() + if len(states) > 0 { + t.Errorf("expected no circuit breakers created, got %d", len(states)) + } +} + +func TestExecuteWithFallback_StopsOnCanceledAfterFirstModel(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 3, 30*time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + {Provider: "opencode-go", ModelID: "model-b"}, + } + + callCount := 0 + _, _, err := handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + callCount++ + if callCount == 1 { + cancel() + return nil, context.Canceled + } + return []byte("ok"), nil + }, + ) + + if callCount != 1 { + t.Errorf("executor called %d times, want 1 (should stop after parent context canceled)", callCount) + } + if err == nil { + t.Error("expected error for canceled context, got nil") + } + + states := handler.GetCircuitStates() + if _, ok := states["model-b"]; ok { + t.Error("model-b should not have a circuit breaker entry") + } +} + +func TestExecuteWithFallback_PerModelTimeoutFallback(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 3, 30*time.Second) + + parentCtx, parentCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer parentCancel() + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + {Provider: "opencode-go", ModelID: "model-b"}, + } + + callCount := 0 + result, body, err := handler.ExecuteWithFallback(parentCtx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + callCount++ + if callCount == 1 { + return nil, context.DeadlineExceeded + } + return []byte("success-" + model.ModelID), nil + }, + ) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if callCount != 2 { + t.Errorf("executor called %d times, want 2 (first timed out, second succeeds)", callCount) + } + if result.ModelID != "model-b" { + t.Errorf("result model = %s, want model-b", result.ModelID) + } + if string(body) != "success-model-b" { + t.Errorf("body = %s, want success-model-b", string(body)) + } +} + +func TestExecuteWithFallback_RealPerModelTimeout(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 3, 30*time.Second) + + parentCtx, parentCancel := context.WithCancel(context.Background()) + defer parentCancel() + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + {Provider: "opencode-go", ModelID: "model-b"}, + } + + callCount := 0 + result, body, err := handler.ExecuteWithFallback(parentCtx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + callCount++ + if callCount == 1 { + attemptCtx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + <-attemptCtx.Done() + return nil, attemptCtx.Err() + } + return []byte("fallback-success"), nil + }, + ) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if callCount != 2 { + t.Errorf("executor called %d times, want 2", callCount) + } + if result.ModelID != "model-b" { + t.Errorf("result model = %s, want model-b", result.ModelID) + } + if string(body) != "fallback-success" { + t.Errorf("body = %s, want fallback-success", string(body)) + } +} + +func TestExecuteWithFallback_CircuitBreakerDoesNotCountClientCancellation(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 1, 30*time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + } + + callCount := 0 + _, _, err := handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + callCount++ + cancel() + return nil, context.Canceled + }, + ) + + if callCount != 1 { + t.Errorf("executor called %d times, want 1", callCount) + } + if err == nil { + t.Error("expected error for canceled context") + } + + states := handler.GetCircuitStates() + if state, ok := states["model-a"]; ok { + if state == "open" { + t.Error("model-a circuit breaker should NOT be open for client cancellation") + } + } +} + +// TestExecuteWithFallback_RealModelFailurePenalizesCircuitBreaker verifies +// that a real upstream error (non-cancellation) DOES count as a model failure. +func TestExecuteWithFallback_RealModelFailurePenalizesCircuitBreaker(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 1, 30*time.Second) + + ctx := context.Background() + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + } + + _, _, _ = handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + return nil, errors.New("upstream 500 internal server error") + }, + ) + + // model-a's circuit breaker should be open because of real failure + states := handler.GetCircuitStates() + state, ok := states["model-a"] + if !ok { + t.Fatal("model-a should have circuit breaker entry") + } + if state != "open" { + t.Errorf("model-a circuit breaker state = %s, want open", state) + } +} + +// TestExecuteWithFallback_ParentDeadlineExceededNotPenalized verifies +// context.DeadlineExceeded from parent context does not count as failure. +func TestExecuteWithFallback_ParentDeadlineExceededNotPenalized(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 1, 30*time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + time.Sleep(10 * time.Millisecond) // let parent timeout expire + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + } + + _, _, err := handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + return nil, nil + }, + ) + + if err == nil { + t.Error("expected error for deadline exceeded context") + } + + states := handler.GetCircuitStates() + if state, ok := states["model-a"]; ok && state == "open" { + t.Error("model-a circuit breaker should NOT be open for parent deadline exceeded") + } +} + +// TestExecuteWithFallback_AllModelsFailRecordsFailures verifies +// that multiple real model failures all record circuit breaker failures. +func TestExecuteWithFallback_AllModelsFailRecordsFailures(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 2, 30*time.Second) + + ctx := context.Background() + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + {Provider: "opencode-go", ModelID: "model-b"}, + } + + _, _, err := handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + return nil, errors.New("upstream error") + }, + ) + + if err == nil { + t.Error("expected error for all models failed") + } + + states := handler.GetCircuitStates() + if _, ok := states["model-a"]; !ok { + t.Error("model-a should have circuit breaker entry") + } + if _, ok := states["model-b"]; !ok { + t.Error("model-b should have circuit breaker entry") + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 0c96909..464adad 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -69,13 +69,13 @@ func NewServer(atomic *config.AtomicConfig) (*Server, error) { mux.HandleFunc("/v1/messages/count_tokens", healthHandler.HandleCountTokens) mux.HandleFunc("/health", healthHandler.HandleHealth) - // Create HTTP server. + // WriteTimeout stays disabled so long SSE streams are not cut off locally. addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) httpSrv := &http.Server{ Addr: addr, Handler: mux, ReadTimeout: 30 * time.Second, - WriteTimeout: 5 * time.Minute, + WriteTimeout: 0, IdleTimeout: 120 * time.Second, } diff --git a/internal/transformer/ctxio.go b/internal/transformer/ctxio.go new file mode 100644 index 0000000..24f4a05 --- /dev/null +++ b/internal/transformer/ctxio.go @@ -0,0 +1,76 @@ +// Package transformer includes ctxio: a context-bound reader wrapper. +package transformer + +import ( + "context" + "errors" + "io" +) + +// ErrStreamReadCanceled is returned by ctxReader.Read when its context is canceled +// or its deadline expires. +var ErrStreamReadCanceled = errors.New("stream read canceled by context") + +// ctxReader wraps an io.Reader and aborts Read when ctx is done. +// +// http.Client.Timeout is checked only at request start; once headers arrive, the +// body is streamed and the net/http transport does not enforce any further deadline. +// Without this wrapper, a slow upstream mid-stream can stall a streaming proxy +// forever. The wrapper does not preempt a read that is already blocked in the +// transport — it surfaces a context-canceled error on the next call. +type ctxReader struct { + ctx context.Context + r io.Reader +} + +// NewCtxReader wraps r so that its next Read returns ErrStreamReadCanceled when +// ctx is canceled or its deadline expires. +func NewCtxReader(ctx context.Context, r io.Reader) io.Reader { + if r == nil { + return nil + } + return &ctxReader{ctx: ctx, r: r} +} + +// NewCtxReadCloser is like NewCtxReader but preserves the io.Closer on the +// returned value. If rc is nil, returns nil. +func NewCtxReadCloser(ctx context.Context, rc io.ReadCloser) io.ReadCloser { + if rc == nil { + return nil + } + return &ctxReadCloser{ + ctxReader: ctxReader{ctx: ctx, r: rc}, + closer: rc, + } +} + +type ctxReadCloser struct { + ctxReader + closer io.Closer +} + +func (c *ctxReadCloser) Close() error { + return c.closer.Close() +} + +func (c *ctxReader) Read(p []byte) (int, error) { + select { + case <-c.ctx.Done(): + return 0, ErrStreamReadCanceled + default: + } + + n, err := c.r.Read(p) + if n > 0 { + // Data still valid even if the deadline fired mid-read; the next + // Read will surface the cancellation. + return n, err + } + + select { + case <-c.ctx.Done(): + return 0, ErrStreamReadCanceled + default: + return n, err + } +} diff --git a/internal/transformer/ctxio_test.go b/internal/transformer/ctxio_test.go new file mode 100644 index 0000000..b2b395b --- /dev/null +++ b/internal/transformer/ctxio_test.go @@ -0,0 +1,101 @@ +package transformer + +import ( + "bytes" + "context" + "errors" + "io" + "strings" + "testing" + "time" +) + +func TestNewCtxReader_PassesThroughUncanceled(t *testing.T) { + ctx := context.Background() + in := strings.NewReader("hello world") + r := NewCtxReader(ctx, in) + + got, err := io.ReadAll(r.(io.Reader)) + if err != nil { + t.Fatalf("ReadAll err = %v, want nil", err) + } + if string(got) != "hello world" { + t.Fatalf("ReadAll = %q, want %q", got, "hello world") + } +} + +func TestNewCtxReader_AbortsOnCanceledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // already canceled before the first Read + + r := NewCtxReader(ctx, strings.NewReader("anything")) + buf := make([]byte, 16) + n, err := r.Read(buf) + if n != 0 { + t.Fatalf("Read returned n=%d, want 0", n) + } + if !errors.Is(err, ErrStreamReadCanceled) { + t.Fatalf("Read err = %v, want ErrStreamReadCanceled", err) + } +} + +func TestNewCtxReader_AbortsOnDeadlineExpiry(t *testing.T) { + // 1ns deadline fires almost immediately. + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + time.Sleep(2 * time.Millisecond) // ensure deadline has passed + + r := NewCtxReader(ctx, strings.NewReader("anything")) + buf := make([]byte, 16) + _, err := r.Read(buf) + if !errors.Is(err, ErrStreamReadCanceled) { + t.Fatalf("Read err = %v, want ErrStreamReadCanceled", err) + } +} + +func TestNewCtxReader_NilReaderReturnsNil(t *testing.T) { + if got := NewCtxReader(context.Background(), nil); got != nil { + t.Fatalf("NewCtxReader(nil) = %v, want nil", got) + } +} + +func TestNewCtxReadCloser_ClosesUnderlying(t *testing.T) { + ctx := context.Background() + br := &bufferReadCloser{Reader: bytes.NewReader([]byte("ok"))} + rc := NewCtxReadCloser(ctx, br) + if rc == nil { + t.Fatal("NewCtxReadCloser returned nil") + } + + // Underlying body is exposed via the underlying *bytes.Reader; close + // should still flip the closer flag. + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("ReadAll err = %v, want nil", err) + } + if string(got) != "ok" { + t.Fatalf("ReadAll = %q, want %q", got, "ok") + } + if err := rc.Close(); err != nil { + t.Fatalf("Close err = %v, want nil", err) + } + if !br.closed { + t.Fatal("underlying Close was not called") + } +} + +func TestNewCtxReadCloser_NilReturnsNil(t *testing.T) { + if got := NewCtxReadCloser(context.Background(), nil); got != nil { + t.Fatalf("NewCtxReadCloser(nil) = %v, want nil", got) + } +} + +type bufferReadCloser struct { + io.Reader + closed bool +} + +func (b *bufferReadCloser) Close() error { + b.closed = true + return nil +} diff --git a/internal/transformer/request.go b/internal/transformer/request.go index 404c552..cd634ab 100644 --- a/internal/transformer/request.go +++ b/internal/transformer/request.go @@ -569,21 +569,59 @@ func (t *RequestTransformer) transformAssistantMessage(blocks []types.ContentBlo return []types.ChatMessage{msg}, nil } -// transformTools converts Anthropic tools to OpenAI tools. +// transformTools converts Anthropic tools to OpenAI tool definitions. func (t *RequestTransformer) transformTools(tools []types.Tool) []types.ToolDef { var result []types.ToolDef for _, tool := range tools { - // InputSchema is already json.RawMessage, use it directly + name := strings.TrimSpace(tool.Name) + if name == "" { + continue + } + schema := tool.InputSchema - if len(schema) == 0 { - schema = []byte(`{"type":"object","properties":{}}`) + switch { + case len(schema) == 0, string(schema) == "null", string(schema) == "{}": + schema = []byte(`{"type":"object","properties":{},"additionalProperties":false}`) + default: + var schemaObj map[string]interface{} + if err := json.Unmarshal(schema, &schemaObj); err != nil { + schema = []byte(`{"type":"object","properties":{},"additionalProperties":false}`) + } else { + // Valid JSON " null " unmarshals to a nil map, which would panic + // on the field assignments below. + if schemaObj == nil { + schema = []byte(`{"type":"object","properties":{},"additionalProperties":false}`) + } else { + // Validate type field is "object" — otherwise OpenAI rejects the + // tool. A schema like {"type":"string"} passes unmarshal but + // produces a 400 from the upstream OpenAI-compatible endpoint. + schemaType, _ := schemaObj["type"].(string) + if schemaType != "object" { + schemaObj["type"] = "object" + } + + // Validate properties is an object — wrong shapes like arrays + // or primitives also produce 400 errors upstream. + if props, ok := schemaObj["properties"]; ok { + if _, valid := props.(map[string]interface{}); !valid { + schemaObj["properties"] = map[string]interface{}{} + } + } else { + schemaObj["properties"] = map[string]interface{}{} + } + + if fixed, err := json.Marshal(schemaObj); err == nil { + schema = fixed + } + } + } } result = append(result, types.ToolDef{ Type: "function", Function: types.FunctionDef{ - Name: tool.Name, + Name: name, Description: tool.Description, Parameters: json.RawMessage(schema), }, diff --git a/internal/transformer/request_test.go b/internal/transformer/request_test.go index 7b68818..899b606 100644 --- a/internal/transformer/request_test.go +++ b/internal/transformer/request_test.go @@ -1523,3 +1523,208 @@ func TestTransformRequestStandardModelIgnoresThinkingAndEffort(t *testing.T) { t.Fatalf("expected Thinking to be nil for standard model, got %s", string(openaiReq.Thinking)) } } + +func TestTransformTools_SkipsEmptyName(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "", Description: "empty name", InputSchema: json.RawMessage(`{"type":"object"}`)}, + {Name: "Bash", Description: "valid tool", InputSchema: json.RawMessage(`{"type":"object"}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d (empty-name tool should be skipped)", got, want) + } + if got, want := result[0].Function.Name, "Bash"; got != want { + t.Fatalf("result[0].Name = %q, want %q", got, want) + } +} + +func TestTransformTools_SkipsWhitespaceOnlyName(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: " ", Description: "whitespace name", InputSchema: json.RawMessage(`{"type":"object"}`)}, + {Name: "Bash", Description: "valid tool", InputSchema: json.RawMessage(`{"type":"object"}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d (whitespace-name tool should be skipped)", got, want) + } +} + +func TestTransformTools_FillsEmptySchema(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Bash", Description: "no schema", InputSchema: nil}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("parameters missing type=object: %s", params) + } + if !strings.Contains(params, `"additionalProperties":false`) { + t.Fatalf("parameters missing additionalProperties=false: %s", params) + } +} + +func TestTransformTools_FillsNullSchema(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Bash", Description: "null schema", InputSchema: json.RawMessage(`null`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("null schema should become type=object: %s", params) + } +} + +func TestTransformTools_FillsEmptyObjectSchema(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Bash", Description: "empty object schema", InputSchema: json.RawMessage(`{}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("empty object schema should get type=object: %s", params) + } + if !strings.Contains(params, `"additionalProperties":false`) { + t.Fatalf("empty object schema should get additionalProperties=false: %s", params) + } +} + +func TestTransformTools_FillsMissingType(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Search", Description: "schema without type", InputSchema: json.RawMessage(`{"properties":{"query":{"type":"string"}}}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("schema missing type should get type=object: %s", params) + } + if !strings.Contains(params, `"query"`) { + t.Fatalf("existing properties should be preserved: %s", params) + } +} + +func TestTransformTools_FillsMissingProperties(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "NoOp", Description: "schema without properties", InputSchema: json.RawMessage(`{"type":"object"}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"properties"`) { + t.Fatalf("schema missing properties should get properties={}: %s", params) + } +} + +func TestTransformTools_RecoversFromInvalidJSON(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Bash", Description: "malformed JSON", InputSchema: json.RawMessage(`{invalid`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d (malformed schema should be replaced, not skipped)", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("malformed schema should be replaced with valid schema: %s", params) + } +} + +func TestTransformTools_PreservesValidSchema(t *testing.T) { + transformer := NewRequestTransformer() + originalSchema := json.RawMessage(`{"type":"object","properties":{"cmd":{"type":"string","description":"The command"}},"required":["cmd"]}`) + tools := []types.Tool{ + {Name: "Bash", Description: "run a command", InputSchema: originalSchema}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"cmd"`) { + t.Fatalf("valid schema properties should be preserved: %s", params) + } + if !strings.Contains(params, `"required"`) { + t.Fatalf("valid schema required should be preserved: %s", params) + } + if !strings.Contains(params, `"type":"string"`) { + t.Fatalf("valid schema nested type should be preserved: %s", params) + } +} + +func TestTransformTools_PreservesAdditionalPropertiesWhenSet(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Flexible", Description: "allows extra props", InputSchema: json.RawMessage(`{"type":"object","properties":{"a":{"type":"string"}},"additionalProperties":true}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"additionalProperties":true`) { + t.Fatalf("existing additionalProperties should be preserved: %s", params) + } +} + +// TestTransformTools_HandlesWhitespaceNullSchema guards against a panic on +// valid JSON that unmarshals to a nil map (e.g. " null " with decorative +// whitespace). The fix is to fall back to the default schema when schemaObj +// is nil after Unmarshal. +func TestTransformTools_HandlesWhitespaceNullSchema(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Bash", Description: "decorative null", InputSchema: json.RawMessage(` null `)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d (whitespace-null schema should fall back, not panic)", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("whitespace-null schema should fall back to default object schema: %s", params) + } + if !strings.Contains(params, `"properties":{}`) { + t.Fatalf("whitespace-null schema should fall back to default properties: %s", params) + } +} diff --git a/internal/transformer/stream.go b/internal/transformer/stream.go index 646ec44..ed20b68 100644 --- a/internal/transformer/stream.go +++ b/internal/transformer/stream.go @@ -32,7 +32,8 @@ func NewStreamHandler() *StreamHandler { // ProxyStream takes an OpenAI streaming response and writes Anthropic-format SSE to the writer. // It reads OpenAI ChatCompletionChunk SSE events and transforms them into Anthropic MessageEvent SSE events. -// The clientCtx is used to detect client disconnection and abort early. +// The streamCtx is the per-model attempt context (carries streaming_timeout_ms); the caller +// should wrap openaiResp with NewCtxReadCloser so the body read also respects the deadline. // // CRITICAL: This function reads directly from resp.Body without buffering to minimize latency. // Per deep research: "Don't use bufio.Scanner or bufio.Reader on the response body - it adds buffering" @@ -40,7 +41,7 @@ func (h *StreamHandler) ProxyStream( w http.ResponseWriter, openaiResp io.ReadCloser, originalModel string, - clientCtx context.Context, + streamCtx context.Context, ) error { flusher, ok := w.(http.Flusher) if !ok { @@ -80,9 +81,11 @@ func (h *StreamHandler) ProxyStream( readBuf := make([]byte, 4096) for { - // Check if client disconnected + // Check if stream context is done (client disconnect or per-model + // streaming_timeout_ms deadline). Body read is also wrapped with + // NewCtxReadCloser at the call site so this returns promptly. select { - case <-clientCtx.Done(): + case <-streamCtx.Done(): return ErrClientDisconnected default: } @@ -406,11 +409,7 @@ func (h *StreamHandler) processSSELine( flusher.Flush() } - // Handle tool call deltas. - // OpenAI streams tool calls incrementally: the first chunk for a given - // tool call carries id + name (+ possibly empty arguments), subsequent - // chunks carry only incremental arguments. We must create exactly one - // content_block_start per tool call, then stream deltas for it. + // Handle streamed tool calls. if len(choice.Delta.ToolCalls) > 0 { for _, tc := range choice.Delta.ToolCalls { oi := tc.Index // OpenAI tool_calls array index @@ -418,9 +417,7 @@ func (h *StreamHandler) processSSELine( blockIdx, exists := startedToolCalls[oi] if !exists { if tc.Function.Name == "" { - // Ghost chunk: this index was closed and recycled, but - // has no name/id. Ignore — the real tool call was - // already fully processed. + // Ignore recycled chunks with no name/id. continue } if *contentStarted || *reasoningStarted { @@ -433,11 +430,11 @@ func (h *StreamHandler) processSSELine( } *contentStarted = false *reasoningStarted = false + *contentIndex++ } - // First time seeing this logical tool call — start a new block. + blockIdx = *contentIndex *contentIndex++ *toolUseCount++ - blockIdx = *contentIndex startedToolCalls[oi] = blockIdx toolID := tc.ID @@ -459,7 +456,6 @@ func (h *StreamHandler) processSSELine( } } - // Send argument delta (if any) — whether new or continuation. if tc.Function.Arguments != "" { delta := types.Delta{ Type: "input_json_delta", @@ -594,11 +590,13 @@ func generateID() string { } // ProxyResponsesStream takes an OpenAI Responses streaming response and writes Anthropic-format SSE. +// streamCtx is the per-model attempt context (carries streaming_timeout_ms); the caller should +// wrap responsesResp with NewCtxReadCloser so the body read also respects the deadline. func (h *StreamHandler) ProxyResponsesStream( w http.ResponseWriter, responsesResp io.ReadCloser, originalModel string, - clientCtx context.Context, + streamCtx context.Context, ) error { flusher, ok := w.(http.Flusher) if !ok { @@ -629,7 +627,7 @@ func (h *StreamHandler) ProxyResponsesStream( for { select { - case <-clientCtx.Done(): + case <-streamCtx.Done(): return ErrClientDisconnected default: } @@ -772,11 +770,13 @@ func (h *StreamHandler) processResponsesSSELine( } // ProxyGeminiStream takes a Gemini streaming response and writes Anthropic-format SSE. +// streamCtx is the per-model attempt context (carries streaming_timeout_ms); the caller should +// wrap geminiResp with NewCtxReadCloser so the body read also respects the deadline. func (h *StreamHandler) ProxyGeminiStream( w http.ResponseWriter, geminiResp io.ReadCloser, originalModel string, - clientCtx context.Context, + streamCtx context.Context, ) error { flusher, ok := w.(http.Flusher) if !ok { @@ -807,7 +807,7 @@ func (h *StreamHandler) ProxyGeminiStream( for { select { - case <-clientCtx.Done(): + case <-streamCtx.Done(): return ErrClientDisconnected default: } diff --git a/internal/transformer/stream_test.go b/internal/transformer/stream_test.go index fc4458d..7532ef0 100644 --- a/internal/transformer/stream_test.go +++ b/internal/transformer/stream_test.go @@ -593,9 +593,9 @@ func TestProxyStream_SingleToolCall(t *testing.T) { events := parseSSEEvents(t, w.buf.String()) - // Expected: message_start, tool_start(idx=1), 2x input_json_delta (3rd arg arrives + // Expected: message_start, tool_start(idx=0), 2x input_json_delta (3rd arg arrives // with finish_reason in same chunk, fast path returns before processing delta), - // tool_stop(idx=1), message_delta, message_stop = 7 + // tool_stop(idx=0), message_delta, message_stop = 7 if len(events) != 7 { t.Fatalf("expected 7 events, got %d: %+v", len(events), events) } @@ -941,9 +941,9 @@ func TestProxyStream_ToolCallAndFinishReasonInSameChunk(t *testing.T) { // Expected events: // 0: message_start - // 1: content_block_start (index 1, type tool_use) - // 2: content_block_delta (index 1, partial_json "{\"loc\":\"Beijing\"}") - // 3: content_block_stop (index 1) + // 1: content_block_start (index 0, type tool_use) + // 2: content_block_delta (index 0, partial_json "{\"loc\":\"Beijing\"}") + // 3: content_block_stop (index 0) // 4: message_delta (stop_reason: tool_use) // 5: message_stop if len(events) != 6 { @@ -956,8 +956,8 @@ func TestProxyStream_ToolCallAndFinishReasonInSameChunk(t *testing.T) { if events[2].Type != "content_block_delta" || events[2].Delta == nil || events[2].Delta.PartialJSON != `{"loc":"Beijing"}` { t.Errorf("event[2] = %+v, want content_block_delta", events[2]) } - if events[3].Type != "content_block_stop" || events[3].Index == nil || *events[3].Index != 1 { - t.Errorf("event[3] = %+v, want content_block_stop(1)", events[3]) + if events[3].Type != "content_block_stop" || events[3].Index == nil || *events[3].Index != 0 { + t.Errorf("event[3] = %+v, want content_block_stop(0)", events[3]) } if events[4].Type != "message_delta" || events[4].Delta == nil || events[4].Delta.StopReason != "tool_use" { t.Errorf("event[4] = %+v, want message_delta(tool_use)", events[4]) @@ -1067,6 +1067,55 @@ func TestProxyStream_EOFFallbackStopReasonToolUse(t *testing.T) { } } +// TestProxyStream_ToolUseFirstContentBlock verifies that when the first +// assistant output is a direct tool call (no preceding text or reasoning), +// the tool_use block is emitted at index 0 per Anthropic SSE spec. +func TestProxyStream_ToolUseFirstContentBlock(t *testing.T) { + handler := NewStreamHandler() + w := newMockResponseWriter() + body := sseLines( + `{"choices":[{"delta":{"tool_calls":[{"index":0,"id":"toolu_abc","type":"function","function":{"name":"read_file","arguments":""}}]}}]}`, + `{"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"/tmp/x\"}"}}]}}]}`, + `{"choices":[{"delta":{},"finish_reason":"tool_use"}]}`, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + t.Fatalf("ProxyStream error: %v", err) + } + + events := parseSSEEvents(t, w.buf.String()) + + // 0: message_start + // 1: content_block_start (index 0, type tool_use) — first content block + // 2: content_block_delta (index 0) + // 3: content_block_stop (index 0) + // 4: message_delta + // 5: message_stop + if len(events) != 6 { + t.Fatalf("expected 6 events, got %d: %+v", len(events), events) + } + + if events[1].Type != "content_block_start" { + t.Fatalf("event[1].Type = %q, want content_block_start", events[1].Type) + } + if events[1].ContentBlock == nil || events[1].ContentBlock.Type != "tool_use" { + t.Fatalf("event[1].ContentBlock = %+v, want tool_use", events[1].ContentBlock) + } + if events[1].Index == nil || *events[1].Index != 0 { + t.Fatalf("tool_use content_block_start index = %v, want 0", events[1].Index) + } + + if events[3].Type != "content_block_stop" || events[3].Index == nil || *events[3].Index != 0 { + t.Fatalf("tool_use content_block_stop index = %v, want 0", events[3].Index) + } + if events[4].Type != "message_delta" || events[4].Delta == nil || events[4].Delta.StopReason != "tool_use" { + t.Errorf("event[4] = %+v, want message_delta(tool_use)", events[4]) + } +} + // helpers func mustJSON(t *testing.T, v any) string {