diff --git a/CLAUDE.md b/CLAUDE.md index b3e62c6..59a1130 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -20,14 +20,16 @@ Run a single test: `go test ./internal/router/ -v` **Purpose:** oc-go-cc is a proxy server that sits between Claude Code and OpenCode Go. It intercepts Anthropic API requests, transforms them to OpenAI Chat Completions format, forwards them to OpenCode Go, and transforms responses back to Anthropic SSE. -**Model routing is config-driven, not code-driven.** Models are defined in `~/.config/oc-go-cc/config.json` — adding a new model does not require code changes (except for `IsAnthropicModel()` if the new model uses the Anthropic endpoint). The router in `internal/router/` selects models by matching request content against scenario patterns defined in `scenarios.go`. +**Model routing is config-driven, not code-driven.** All models are defined in `~/.config/oc-go-cc/config.json` — adding a new model requires no code changes. Go provider models are transformed to OpenAI Chat Completions format automatically. Zen models use endpoint classification via `ClassifyEndpoint()`. The router in `internal/router/` selects models by matching request content against scenario patterns defined in `scenarios.go`. + +If a model's upstream doesn't support Anthropic tool format (`type: "custom"` server-tool shorthands), set `"anthropic_tools_disabled": true` in the model config to force it through the Chat Completions transform path instead of the raw Anthropic endpoint. **Two API endpoints:** - OpenAI endpoint (`/v1/chat/completions`) — used by most models (GLM, Kimi, MiMo, Qwen) - Anthropic endpoint (`/v1/messages`) — used only by MiniMax models -`internal/client/opencode.go` routes by model ID via `IsAnthropicModel()`. +`internal/client/opencode.go` routes Go provider models to Chat Completions; Zen models are classified by `ClassifyEndpoint()`. If a model's upstream doesn't support Anthropic tool format, set `anthropic_tools_disabled: true` in config. **Scenario detection priority** (`internal/router/scenarios.go`): @@ -41,6 +43,8 @@ For streaming, the router downgrades to fast models (Qwen3.6 Plus) for better TT **Polymorphic field handling:** Anthropic's `system` and `content` fields accept both strings and arrays. `pkg/types/` uses `json.RawMessage` with accessor methods (`SystemText()`, `ContentBlocks()`) to handle both formats. +**Long-running stream policy:** The proxy never kills a stream that is actively producing bytes. The server-level `WriteTimeout` is set to 0; instead each upstream read uses a per-`Read` deadline via `http.ResponseController.SetReadDeadline` that is renewed on every successful byte. If the gap between bytes exceeds `OpenCodeGo.stream_timeout_ms` (or `OpenCodeZen.stream_timeout_ms`), the connection is treated as stuck and the request is routed to the next fallback model. Defaults to `timeout_ms` when unset. Client disconnects during a stream are logged at `Debug` level — this is normal during Claude Code tool execution and is not a failure signal. + ## Key Files - `cmd/oc-go-cc/main.go` — CLI entry point (cobra). Default config template is generated here. diff --git a/cmd/oc-go-cc/main.go b/cmd/oc-go-cc/main.go index 059287d..bde8070 100644 --- a/cmd/oc-go-cc/main.go +++ b/cmd/oc-go-cc/main.go @@ -538,7 +538,7 @@ func getDefaultConfig() string { "port": 3456, "hot_reload": false, "enable_streaming_scenario_routing": false, - "respect_requested_model": false, + "respect_requested_model": true, "models": { "background": { "provider": "opencode-go", diff --git a/internal/client/opencode.go b/internal/client/opencode.go index 756f1d0..45d3571 100644 --- a/internal/client/opencode.go +++ b/internal/client/opencode.go @@ -42,12 +42,6 @@ func (c *OpenCodeClient) nextAPIKey(keys []string) string { // NewOpenCodeClient creates a new OpenCode client. func NewOpenCodeClient(atomic *config.AtomicConfig) *OpenCodeClient { - cfg := atomic.Get() - timeout := time.Duration(cfg.OpenCodeGo.TimeoutMs) * time.Millisecond - if timeout == 0 { - timeout = 5 * time.Minute - } - transport := &http.Transport{ MaxIdleConns: 100, MaxIdleConnsPerHost: 20, @@ -60,21 +54,51 @@ func NewOpenCodeClient(atomic *config.AtomicConfig) *OpenCodeClient { return &OpenCodeClient{ atomic: atomic, httpClient: &http.Client{ - Timeout: timeout, Transport: transport, }, } } +// StreamIdleTimeout returns the maximum gap between bytes on an active stream +// for a model. The stream lives as long as data keeps flowing; only an idle +// period longer than this value is treated as a stuck connection and aborted. +// Go provider models use OpenCodeGo.StreamTimeoutMs; Zen models use +// OpenCodeZen.StreamTimeoutMs. Falls back to 5 minutes if the config is +// unavailable or the value is zero. +func (c *OpenCodeClient) StreamIdleTimeout(modelConfig config.ModelConfig) time.Duration { + const fallback = 5 * time.Minute + if c == nil || c.atomic == nil { + return fallback + } + cfg := c.atomic.Get() + var ms int + if IsZen(modelConfig) { + ms = cfg.OpenCodeZen.StreamTimeoutMs + } else { + ms = cfg.OpenCodeGo.StreamTimeoutMs + } + if ms <= 0 { + ms = cfg.OpenCodeGo.TimeoutMs + } + if ms <= 0 { + return fallback + } + return time.Duration(ms) * time.Millisecond +} + // IsAnthropicModel returns true if the model requires the Anthropic endpoint. -// This includes both Go models (minimax, all qwen) and Zen models (claude, qwen3.7-max). +// Most Go provider models use the Chat Completions transform path for broader +// compatibility (tool format, message roles, etc.). Exceptions are models whose +// upstream backends don't support the OpenAI Chat Completions format and only +// accept Anthropic Messages format. +// +// Only Zen models use the raw Anthropic endpoint via ClassifyEndpoint. func IsAnthropicModel(modelID string) bool { switch modelID { - case "minimax-m2.5", "minimax-m2.7", "minimax-m3", - "qwen3.5-plus", "qwen3.6-plus", "qwen3.7-plus", "qwen3.7-max": + case "qwen3.7-max": // OpenCode Go backend doesn't support oa-compat for this model return true default: - return isZenAnthropicModel(modelID) + return false } } diff --git a/internal/client/opencode_test.go b/internal/client/opencode_test.go index 003072f..32d9b1a 100644 --- a/internal/client/opencode_test.go +++ b/internal/client/opencode_test.go @@ -2,6 +2,7 @@ package client import ( "testing" + "time" "oc-go-cc/internal/config" ) @@ -13,19 +14,19 @@ func TestIsAnthropicModelOnlyRoutesNativeAnthropicModels(t *testing.T) { want bool }{ { - name: "minimax m2.5 uses anthropic endpoint", + name: "minimax m2.5 uses openai endpoint on Go provider", modelID: "minimax-m2.5", - want: true, + want: false, }, { - name: "minimax m2.7 uses anthropic endpoint", + name: "minimax m2.7 uses openai endpoint on Go provider", modelID: "minimax-m2.7", - want: true, + want: false, }, { - name: "minimax m3 uses anthropic endpoint", + name: "minimax m3 uses openai endpoint on Go provider", modelID: "minimax-m3", - want: true, + want: false, }, { name: "deepseek pro uses openai endpoint", @@ -63,44 +64,44 @@ func TestIsAnthropicModelOnlyRoutesNativeAnthropicModels(t *testing.T) { want: false, }, { - name: "qwen3.5-plus uses anthropic endpoint", + name: "qwen3.5-plus uses openai endpoint on Go provider", modelID: "qwen3.5-plus", - want: true, + want: false, }, { - name: "qwen3.6-plus uses anthropic endpoint", + name: "qwen3.6-plus uses openai endpoint on Go provider", modelID: "qwen3.6-plus", - want: true, + want: false, }, { - name: "qwen3.7-plus uses anthropic endpoint", + name: "qwen3.7-plus uses openai endpoint on Go provider", modelID: "qwen3.7-plus", - want: true, + want: false, }, { - name: "qwen3.7-max uses anthropic endpoint", + name: "qwen3.7-max uses anthropic endpoint (no oa-compat support)", modelID: "qwen3.7-max", want: true, }, { - name: "claude-sonnet-4-5 uses anthropic endpoint", + name: "claude models use openai endpoint on Go provider", modelID: "claude-sonnet-4-5", - want: true, + want: false, }, { - name: "claude-opus-4-7 uses anthropic endpoint", + name: "claude-opus-4-7 uses openai endpoint on Go provider", modelID: "claude-opus-4-7", - want: true, + want: false, }, { - name: "claude-haiku-4-5 uses anthropic endpoint", + name: "claude-haiku-4-5 uses openai endpoint on Go provider", modelID: "claude-haiku-4-5", - want: true, + want: false, }, { - name: "claude-3-5-haiku uses anthropic endpoint", + name: "claude-3-5-haiku uses openai endpoint on Go provider", modelID: "claude-3-5-haiku", - want: true, + want: false, }, } @@ -461,3 +462,52 @@ func TestNextAPIKey_ConcurrentSafety(t *testing.T) { } } } + +func TestStreamIdleTimeout(t *testing.T) { + tests := []struct { + name string + goMs int + zenMs int + provider string + wantDur time.Duration + }{ + { + name: "Go provider uses OpenCodeGo.StreamTimeoutMs", + goMs: 120000, // 2 min + provider: "opencode-go", + wantDur: 120 * time.Second, + }, + { + name: "Zen provider uses OpenCodeZen.StreamTimeoutMs", + goMs: 100000, + zenMs: 600000, // 10 min + provider: "opencode-zen", + wantDur: 10 * time.Minute, + }, + { + name: "falls back to OpenCodeGo.TimeoutMs when StreamTimeoutMs is zero", + goMs: 300000, // 5 min + provider: "opencode-go", + wantDur: 5 * time.Minute, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{TimeoutMs: tt.goMs, StreamTimeoutMs: tt.goMs}, + OpenCodeZen: config.OpenCodeZenConfig{TimeoutMs: tt.zenMs, StreamTimeoutMs: tt.zenMs}, + } + // Fallback test: zero out StreamTimeoutMs for that provider. + if tt.name == "falls back to OpenCodeGo.TimeoutMs when StreamTimeoutMs is zero" { + cfg.OpenCodeGo.StreamTimeoutMs = 0 + } + atomic := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + c := &OpenCodeClient{atomic: atomic} + mc := config.ModelConfig{Provider: tt.provider, ModelID: "test-model"} + got := c.StreamIdleTimeout(mc) + if got != tt.wantDur { + t.Errorf("StreamIdleTimeout() = %v, want %v", got, tt.wantDur) + } + }) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index c5ae557..655aa0d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,7 +11,7 @@ type Config struct { Port int `json:"port"` HotReload bool `json:"hot_reload"` EnableStreamingScenarioRouting bool `json:"enable_streaming_scenario_routing"` - RespectRequestedModel bool `json:"respect_requested_model"` + RespectRequestedModel *bool `json:"respect_requested_model,omitempty"` Models map[string]ModelConfig `json:"models"` Fallbacks map[string][]ModelConfig `json:"fallbacks"` ModelOverrides map[string]ModelConfig `json:"model_overrides"` @@ -22,14 +22,15 @@ type Config struct { // ModelConfig defines routing rules for a specific model. type ModelConfig struct { - Provider string `json:"provider"` - ModelID string `json:"model_id"` - Temperature float64 `json:"temperature"` - MaxTokens int `json:"max_tokens"` - ContextThreshold int `json:"context_threshold"` - ReasoningEffort string `json:"reasoning_effort"` - Thinking json.RawMessage `json:"thinking,omitempty"` - Vision bool `json:"vision"` + Provider string `json:"provider"` + ModelID string `json:"model_id"` + Temperature float64 `json:"temperature"` + MaxTokens int `json:"max_tokens"` + ContextThreshold int `json:"context_threshold"` + ReasoningEffort string `json:"reasoning_effort"` + Thinking json.RawMessage `json:"thinking,omitempty"` + Vision bool `json:"vision"` + AnthropicToolsDisabled bool `json:"anthropic_tools_disabled"` } // OpenCodeGoConfig holds the upstream OpenCode Go API settings. @@ -37,6 +38,7 @@ type OpenCodeGoConfig struct { BaseURL string `json:"base_url"` AnthropicBaseURL string `json:"anthropic_base_url"` TimeoutMs int `json:"timeout_ms"` + StreamTimeoutMs int `json:"stream_timeout_ms"` } // OpenCodeZenConfig holds the upstream OpenCode Zen API settings. @@ -46,6 +48,7 @@ type OpenCodeZenConfig struct { ResponsesBaseURL string `json:"responses_base_url"` GeminiBaseURL string `json:"gemini_base_url"` TimeoutMs int `json:"timeout_ms"` + StreamTimeoutMs int `json:"stream_timeout_ms"` } // LoggingConfig controls application logging behavior. diff --git a/internal/config/loader.go b/internal/config/loader.go index 594b393..15eb144 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -146,6 +146,9 @@ func applyDefaults(cfg *Config) { if cfg.OpenCodeGo.TimeoutMs == 0 { cfg.OpenCodeGo.TimeoutMs = defaultTimeoutMs } + if cfg.OpenCodeGo.StreamTimeoutMs == 0 { + cfg.OpenCodeGo.StreamTimeoutMs = cfg.OpenCodeGo.TimeoutMs + } if cfg.OpenCodeZen.BaseURL == "" { cfg.OpenCodeZen.BaseURL = defaultZenBaseURL } @@ -161,6 +164,9 @@ func applyDefaults(cfg *Config) { if cfg.OpenCodeZen.TimeoutMs == 0 { cfg.OpenCodeZen.TimeoutMs = defaultTimeoutMs } + if cfg.OpenCodeZen.StreamTimeoutMs == 0 { + cfg.OpenCodeZen.StreamTimeoutMs = cfg.OpenCodeZen.TimeoutMs + } if cfg.Logging.Level == "" { cfg.Logging.Level = defaultLogLevel } @@ -185,6 +191,33 @@ func validate(cfg *Config) error { if err := validateModelOverrides(cfg.ModelOverrides); err != nil { return err } + + if err := validateAnthropicToolsDisabled(cfg); err != nil { + return err + } + + return nil +} + +// validateAnthropicToolsDisabled checks that models with anthropic_tools_disabled +// set are configured correctly. This field only applies to models that route to +// the Anthropic endpoint; enabling it on an OpenAI Chat Completions model has no +// effect and likely indicates a misconfiguration. +func validateAnthropicToolsDisabled(cfg *Config) error { + for key, mc := range cfg.Models { + if mc.AnthropicToolsDisabled { + // Models in cfg.Models are selectable by scenario routing. The flag + // is only meaningful on models that go through the Anthropic endpoint. + // Log a warning since the config system can't resolve the endpoint + // without the client package. + fmt.Fprintf(os.Stderr, "WARNING: config: models[%q] has anthropic_tools_disabled=true — this is only effective on models routing to the Anthropic endpoint\n", key) + } + } + for key, mc := range cfg.ModelOverrides { + if mc.AnthropicToolsDisabled { + fmt.Fprintf(os.Stderr, "WARNING: config: model_overrides[%q] has anthropic_tools_disabled=true — this is only effective on models routing to the Anthropic endpoint\n", key) + } + } return nil } diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 25a4b5d..c75eb2d 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -276,6 +276,10 @@ func TestDefaults(t *testing.T) { if cfg.OpenCodeGo.TimeoutMs != defaultTimeoutMs { t.Errorf("OpenCodeGo.TimeoutMs = %d, want %d", cfg.OpenCodeGo.TimeoutMs, defaultTimeoutMs) } + if cfg.OpenCodeGo.StreamTimeoutMs != defaultTimeoutMs { + t.Errorf("OpenCodeGo.StreamTimeoutMs = %d, want %d (should default to TimeoutMs when unset)", + cfg.OpenCodeGo.StreamTimeoutMs, defaultTimeoutMs) + } if cfg.OpenCodeZen.BaseURL != defaultZenBaseURL { t.Errorf("OpenCodeZen.BaseURL = %q, want %q", cfg.OpenCodeZen.BaseURL, defaultZenBaseURL) } @@ -291,6 +295,10 @@ func TestDefaults(t *testing.T) { if cfg.OpenCodeZen.TimeoutMs != defaultTimeoutMs { t.Errorf("OpenCodeZen.TimeoutMs = %d, want %d", cfg.OpenCodeZen.TimeoutMs, defaultTimeoutMs) } + if cfg.OpenCodeZen.StreamTimeoutMs != defaultTimeoutMs { + t.Errorf("OpenCodeZen.StreamTimeoutMs = %d, want %d (should default to TimeoutMs when unset)", + cfg.OpenCodeZen.StreamTimeoutMs, defaultTimeoutMs) + } if cfg.Logging.Level != defaultLogLevel { t.Errorf("LogLevel = %q, want %q", cfg.Logging.Level, defaultLogLevel) } diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index 207227a..127f5c6 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -4,11 +4,12 @@ package handlers import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" "net/http" - "strings" + "sync" "sync/atomic" "time" @@ -39,12 +40,19 @@ type MessagesHandler struct { } // responseWriter wraps http.ResponseWriter to track if headers were written. +// It is safe for concurrent use: Write, WriteHeader, and Flush are serialized +// via an internal mutex so that concurrent goroutines (e.g. heartbeat and +// stream proxy) don't interleave SSE frames. type responseWriter struct { http.ResponseWriter - wroteHeader bool + mu sync.Mutex + wroteHeader bool + ssePayloadWritten bool } func (w *responseWriter) WriteHeader(code int) { + w.mu.Lock() + defer w.mu.Unlock() if !w.wroteHeader { w.wroteHeader = true w.ResponseWriter.WriteHeader(code) @@ -52,14 +60,45 @@ func (w *responseWriter) WriteHeader(code int) { } func (w *responseWriter) Write(b []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() if !w.wroteHeader { - w.WriteHeader(http.StatusOK) + w.wroteHeader = true + w.ResponseWriter.WriteHeader(http.StatusOK) + } + if len(b) > 0 { + w.ssePayloadWritten = true } return w.ResponseWriter.Write(b) } +// headerWritten returns true if headers have been written to the response. +// Safe for concurrent use. +func (w *responseWriter) headerWritten() bool { + w.mu.Lock() + defer w.mu.Unlock() + return w.wroteHeader +} + // Flush implements http.Flusher for SSE streaming support. +// The mutex is held across the flush call to ensure Write, WriteHeader, and +// Flush remain serialized. Without this, a concurrent Flush and Write on the +// underlying http.ResponseWriter's *bufio.Writer would be a data race. func (w *responseWriter) Flush() { + w.mu.Lock() + defer w.mu.Unlock() + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// WriteKeepalive writes a keepalive comment frame (":keepalive\n\n") to the +// response. Unlike Write, it does NOT set ssePayloadWritten — keepalives are +// not real SSE events and should not block fallback logic on idle timeout. +func (w *responseWriter) WriteKeepalive() { + w.mu.Lock() + defer w.mu.Unlock() + _, _ = fmt.Fprintf(w.ResponseWriter, ":keepalive\n\n") if f, ok := w.ResponseWriter.(http.Flusher); ok { f.Flush() } @@ -83,7 +122,7 @@ func NewMessagesHandler( tokenCounter: tokenCounter, logger: slog.Default(), rateLimiter: middleware.NewRateLimiter(100, time.Minute), - requestDedup: middleware.NewRequestDeduplicator(500 * time.Millisecond), + requestDedup: nil, requestIDGen: middleware.NewRequestIDGenerator(), metrics: metrics, } @@ -112,18 +151,28 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) return } - // Read the raw request body for debug logging + // Read the raw request body with a size limit to prevent memory exhaustion. + const maxBodySize = 104857600 // 100 MB + r.Body = http.MaxBytesReader(w, r.Body, maxBodySize) var rawBody json.RawMessage if err := json.NewDecoder(r.Body).Decode(&rawBody); err != nil { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + h.sendError(w, http.StatusRequestEntityTooLarge, "request body too large", err) + return + } h.sendError(w, http.StatusBadRequest, "invalid request body", err) return } - // Deduplicate - skip duplicate requests - if _, ok := h.requestDedup.TryAcquire(rawBody); !ok { - h.metrics.RecordDeduplicated() - h.logger.Info("duplicate request skipped", "request_id", requestID) - return + // Deduplicate - skip duplicate requests. Skip when the deduplicator is + // not configured (nil requestDedup) — it is an optional component. + if h.requestDedup != nil { + if _, ok := h.requestDedup.TryAcquire(rawBody); !ok { + h.metrics.RecordDeduplicated() + h.logger.Info("duplicate request skipped", "request_id", requestID) + return + } } // Parse into Anthropic request @@ -246,7 +295,7 @@ func (h *MessagesHandler) routeOnce( ) (router.RouteResult, error) { if isStreaming && !h.modelRouter.IsStreamingScenarioRoutingEnabled() { // Streaming: use faster models to minimize TTFT (time-to-first-token) - return h.modelRouter.RouteForStreaming(routerMessages, tokenCount, requestedModel), nil + return h.modelRouter.RouteForStreaming(routerMessages, tokenCount, requestedModel) } return h.modelRouter.Route(routerMessages, tokenCount, requestedModel) } @@ -290,9 +339,7 @@ func (h *MessagesHandler) handleStreaming( w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") rw.WriteHeader(http.StatusOK) - if f, ok := w.(http.Flusher); ok { - f.Flush() - } + rw.Flush() // Start heartbeat var finished int32 @@ -307,10 +354,7 @@ func (h *MessagesHandler) handleStreaming( if atomic.LoadInt32(&finished) == 1 { return } - _, _ = fmt.Fprintf(rw, ":keepalive\n\n") - if f, ok := w.(http.Flusher); ok { - f.Flush() - } + rw.WriteKeepalive() case <-heartbeatDone: return case <-clientCtx.Done(): @@ -328,66 +372,98 @@ func (h *MessagesHandler) handleStreaming( for _, model := range modelChain { select { case <-clientCtx.Done(): - h.logger.Info("client disconnected, stopping streaming fallbacks") + h.logger.Debug("client disconnected, stopping streaming fallbacks") return default: } h.logger.Info("attempting streaming model", "model", model.ModelID, "provider", model.Provider) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + // Upstream context: no total deadline. The stream lives as long as + // data keeps flowing. Per-Read idle deadline is enforced in stream.go + // via http.ResponseController, so a stuck stream is still caught. + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-clientCtx.Done() + cancel() + }() + idleTimeout := h.client.StreamIdleTimeout(model) + + // recordStreamSuccess records a successful stream completion and + // marks the model attempt as done. + recordStreamSuccess := func(model config.ModelConfig) { + cancel() + latency := time.Since(streamStart) + h.metrics.RecordSuccess(model.ModelID, latency) + h.logger.Info("streaming completed", "model", model.ModelID, "latency", latency) + } + + // handleStreamError checks the error from a streaming attempt and + // decides whether to retry the next model or abort. It returns true + // if the caller should continue (fallback to next model), or false + // if it should return. + handleStreamError := func(err error, model config.ModelConfig, action string) bool { + cancel() + if clientCtx.Err() == context.Canceled { + h.logger.Debug("client disconnected during " + action + " stream") + return false // abort + } + if err == transformer.ErrStreamIdle { + h.logger.Warn("upstream "+action+" stream idle, trying next model", + "model", model.ModelID, "idle_timeout", idleTimeout) + if rw.ssePayloadWritten { + h.sendStreamError(rw, "stream idle after SSE payload started") + h.metrics.RecordFailure() + return false // abort + } + return true // continue to next model + } + h.logger.Warn(action+" streaming failed", "model", model.ModelID, "error", err) + if rw.ssePayloadWritten { + h.sendStreamError(rw, fmt.Sprintf("all upstream models failed after SSE payload started: %v", err)) + h.metrics.RecordFailure() + return false // abort — cannot fallback after SSE payload started + } + return true // continue to next model + } // Zen models use their own endpoint classification if client.IsZen(model) { endpointType := client.ClassifyEndpoint(model.ModelID) switch endpointType { case client.EndpointAnthropic: - modelBody := replaceModelInRawBody(rawBody, model.ModelID) - if err := h.handleAnthropicStreaming(ctx, rw, modelBody, model.ModelID, model); err != nil { - cancel() - if clientCtx.Err() == context.Canceled { - h.logger.Info("client disconnected during anthropic stream") - return + if model.AnthropicToolsDisabled { + // Fall through to OpenAI-compatible transform path below. + } else { + modelBody := replaceModelInRawBody(rawBody, model.ModelID) + if err := h.handleAnthropicStreaming(ctx, rw, modelBody, model.ModelID, model, idleTimeout, cancel, clientCtx); err != nil { + if !handleStreamError(err, model, "anthropic") { + return + } + continue } - h.logger.Warn("anthropic streaming failed", "model", model.ModelID, "error", err) - continue + recordStreamSuccess(model) + return } - cancel() - latency := time.Since(streamStart) - h.metrics.RecordSuccess(model.ModelID, latency) - h.logger.Info("streaming completed", "model", model.ModelID, "latency", latency) - return case client.EndpointResponses: - if err := h.handleResponsesStreaming(ctx, rw, anthropicReq, model, clientCtx); err != nil { - cancel() - if clientCtx.Err() == context.Canceled { - h.logger.Info("client disconnected during responses stream") + if err := h.handleResponsesStreaming(ctx, rw, anthropicReq, model, clientCtx, idleTimeout, cancel); err != nil { + if !handleStreamError(err, model, "responses") { return } - h.logger.Warn("responses streaming failed", "model", model.ModelID, "error", err) continue } - cancel() - latency := time.Since(streamStart) - h.metrics.RecordSuccess(model.ModelID, latency) - h.logger.Info("streaming completed", "model", model.ModelID, "latency", latency) + recordStreamSuccess(model) return case client.EndpointGemini: - if err := h.handleGeminiStreaming(ctx, rw, anthropicReq, model, clientCtx); err != nil { - cancel() - if clientCtx.Err() == context.Canceled { - h.logger.Info("client disconnected during gemini stream") + if err := h.handleGeminiStreaming(ctx, rw, anthropicReq, model, clientCtx, idleTimeout, cancel); err != nil { + if !handleStreamError(err, model, "gemini") { return } - h.logger.Warn("gemini streaming failed", "model", model.ModelID, "error", err) continue } - cancel() - latency := time.Since(streamStart) - h.metrics.RecordSuccess(model.ModelID, latency) - h.logger.Info("streaming completed", "model", model.ModelID, "latency", latency) + recordStreamSuccess(model) return default: @@ -395,6 +471,27 @@ func (h *MessagesHandler) handleStreaming( } } + // Go provider Anthropic-native models (qwen3.7-max) that require raw + // Anthropic format rather than the OpenAI Chat Completions transform. + if !client.IsZen(model) && client.IsAnthropicModel(model.ModelID) { + modelBody := replaceModelInRawBody(rawBody, model.ModelID) + if err := h.handleAnthropicStreaming(ctx, rw, modelBody, model.ModelID, model, idleTimeout, cancel, clientCtx); err != nil { + if !handleStreamError(err, model, "anthropic") { + return + } + // For non-idle errors after SSE payload started, send a more + // specific error message since this is the last attempt. + if err != transformer.ErrStreamIdle && rw.ssePayloadWritten { + h.sendStreamError(rw, fmt.Sprintf("all upstream models failed after SSE payload started: %v", err)) + h.metrics.RecordFailure() + return + } + continue + } + recordStreamSuccess(model) + return + } + // OpenAI-compatible models (both Go and Zen) openaiReq, err := h.requestTransformer.TransformRequest(anthropicReq, model) if err != nil { @@ -407,37 +504,36 @@ func (h *MessagesHandler) handleStreaming( if err != nil { cancel() if clientCtx.Err() == context.Canceled { - h.logger.Info("client disconnected during upstream request") + h.logger.Debug("client disconnected during upstream request") return } h.logger.Warn("streaming request failed", "model", model.ModelID, "error", err) continue } - if err := h.streamHandler.ProxyStream(rw, streamBody, model.ModelID, clientCtx); err != nil { + if err := h.streamHandler.ProxyStream(rw, streamBody, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { _ = streamBody.Close() - cancel() if err == transformer.ErrClientDisconnected { - h.logger.Info("client disconnected during stream") + h.logger.Debug("client disconnected during stream") return } - if clientCtx.Err() == context.Canceled { - h.logger.Info("client disconnected during stream (context canceled)") + if !handleStreamError(err, model, "openai") { return } - h.logger.Warn("stream proxy failed", "model", model.ModelID, "error", err) continue } _ = streamBody.Close() - cancel() - latency := time.Since(streamStart) - h.metrics.RecordSuccess(model.ModelID, latency) - h.logger.Info("streaming completed", "model", model.ModelID, "latency", latency) + recordStreamSuccess(model) return } h.metrics.RecordFailure() + if rw.ssePayloadWritten { + // SSE payload was already sent — do not attempt further writes + // beyond the error event. The client has a partial stream. + return + } if !rw.wroteHeader { h.sendError(w, http.StatusBadGateway, "all streaming models failed", nil) } else { @@ -452,6 +548,8 @@ func (h *MessagesHandler) handleResponsesStreaming( anthropicReq *types.MessageRequest, model config.ModelConfig, clientCtx context.Context, + idleTimeout time.Duration, + cancel context.CancelFunc, ) error { req, err := h.requestTransformer.TransformToResponses(anthropicReq, model) if err != nil { @@ -463,7 +561,7 @@ func (h *MessagesHandler) handleResponsesStreaming( return err } - if err := h.streamHandler.ProxyResponsesStream(w, streamBody, model.ModelID, clientCtx); err != nil { + if err := h.streamHandler.ProxyResponsesStream(w, streamBody, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { _ = streamBody.Close() return err } @@ -479,6 +577,8 @@ func (h *MessagesHandler) handleGeminiStreaming( anthropicReq *types.MessageRequest, model config.ModelConfig, clientCtx context.Context, + idleTimeout time.Duration, + cancel context.CancelFunc, ) error { req, err := h.requestTransformer.TransformToGemini(anthropicReq, model) if err != nil { @@ -490,7 +590,7 @@ func (h *MessagesHandler) handleGeminiStreaming( return err } - if err := h.streamHandler.ProxyGeminiStream(w, streamBody, model.ModelID, clientCtx); err != nil { + if err := h.streamHandler.ProxyGeminiStream(w, streamBody, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { _ = streamBody.Close() return err } @@ -499,26 +599,75 @@ func (h *MessagesHandler) handleGeminiStreaming( return nil } -// replaceModelInRawBody replaces the model field in raw JSON body with the actual model ID. -func replaceModelInRawBody(rawBody json.RawMessage, modelID string) json.RawMessage { - bodyStr := string(rawBody) - - if idx := strings.Index(bodyStr, `"model":"`); idx != -1 { - start := idx + len(`"model":"`) - if end := strings.Index(bodyStr[start:], `"`); end != -1 { - oldModel := bodyStr[start : start+end] - newBody := bodyStr[:start] + modelID + bodyStr[start+end:] - slog.Debug("replaced model in request body", - "old_model", oldModel, - "new_model", modelID, - "success", true) - return json.RawMessage(newBody) +// sanitizeAnthropicBody removes the "type" field from tools whose value is +// "custom" (server-tool shorthands used by Claude Code for MCP tools that some +// upstream models don't understand). The upstream treats the tool as absent +// when type is missing rather than rejecting type:"custom". +// Returns the original body unchanged if no tools array is present or if no +// tool has type:"custom". +func sanitizeAnthropicBody(rawBody json.RawMessage) json.RawMessage { + var body map[string]any + if err := json.Unmarshal(rawBody, &body); err != nil { + return rawBody + } + + tools, ok := body["tools"].([]any) + if !ok || len(tools) == 0 { + return rawBody + } + + modified := false + for _, tool := range tools { + toolMap, ok := tool.(map[string]any) + if !ok { + continue + } + if toolType, ok := toolMap["type"].(string); ok && toolType == "custom" { + delete(toolMap, "type") + modified = true } } - slog.Warn("could not find model field in request body, using original", - "body_preview", bodyStr[:min(len(bodyStr), 200)]) - return rawBody + if !modified { + return rawBody + } + + result, err := json.Marshal(body) + if err != nil { + return rawBody + } + return json.RawMessage(result) +} + +// replaceModelInRawBody replaces the top-level "model" field in raw JSON body +// with the actual model ID. Uses JSON unmarshal/marshal rather than string +// search so that nested occurrences of "model" in user content, tool schemas, +// or escaped strings are never touched. +func replaceModelInRawBody(rawBody json.RawMessage, modelID string) json.RawMessage { + var obj map[string]json.RawMessage + if err := json.Unmarshal(rawBody, &obj); err != nil { + slog.Error("could not parse request body for model replacement, using original", + "error", err) + return rawBody + } + encoded, err := json.Marshal(modelID) + if err != nil { + // json.Marshal on a string should never fail, but guard anyway. + slog.Error("failed to marshal model ID for body replacement", + "error", err, "model_id", modelID) + return rawBody + } + obj["model"] = encoded + result, err := json.Marshal(obj) + if err != nil { + slog.Error("could not marshal request body after model replacement, using original", + "error", err) + return rawBody + } + slog.Debug("replaced model in request body", + "new_model", modelID, + "success", true) + return json.RawMessage(result) } // handleAnthropicStreaming sends a raw Anthropic request to the Anthropic endpoint. @@ -528,7 +677,14 @@ func (h *MessagesHandler) handleAnthropicStreaming( rawBody json.RawMessage, modelID string, model config.ModelConfig, + idleTimeout time.Duration, + cancel context.CancelFunc, + clientCtx context.Context, ) error { + // Sanitize Anthropic-specific fields (e.g., tool type shorthands) that + // upstream models may not understand. + rawBody = sanitizeAnthropicBody(rawBody) + h.logger.Debug("sending anthropic streaming request", "model_id", modelID, "body_preview", string(rawBody)[:min(len(rawBody), 200)]) @@ -538,16 +694,50 @@ func (h *MessagesHandler) handleAnthropicStreaming( return err } defer func() { _ = resp.Body.Close() }() - - _, err = io.Copy(w, resp.Body) - if err != nil { - if ctx.Err() == context.Canceled { + defer cancel() + + // Stream the body chunk-by-chunk with an idle watchdog. The stream lives + // as long as data keeps flowing and is aborted when no byte arrives + // within idleTimeout. + buf := make([]byte, 4096) + ping := transformer.StartIdleWatchdog(ctx, cancel, idleTimeout) + for { + select { + case <-ctx.Done(): + // ctx is canceled by either the idle watchdog or client disconnect. + // Distinguish: watchdog fires while client is still connected. + if clientCtx.Err() == nil { + return transformer.ErrStreamIdle + } return transformer.ErrClientDisconnected + default: + } + n, rerr := resp.Body.Read(buf) + if n > 0 { + ping() + if _, werr := w.Write(buf[:n]); werr != nil { + return transformer.ErrClientDisconnected + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + if rerr == io.EOF { + return nil + } + if rerr != nil { + if transformer.IsIdleTimeout(rerr) { + return transformer.ErrStreamIdle + } + if errors.Is(rerr, context.Canceled) || ctx.Err() == context.Canceled { + if clientCtx.Err() == nil { + return transformer.ErrStreamIdle + } + return transformer.ErrClientDisconnected + } + return fmt.Errorf("failed to copy response: %w", rerr) } - return fmt.Errorf("failed to copy response: %w", err) } - - return nil } // sendStreamError sends an error event in the SSE stream. @@ -563,7 +753,9 @@ func (h *MessagesHandler) sendStreamError(w http.ResponseWriter, message string) } data, _ := json.Marshal(errorEvent) - _, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", string(data)) + if _, err := fmt.Fprintf(w, "event: error\ndata: %s\n\n", string(data)); err != nil { + h.logger.Debug("failed to write stream error event", "error", err, "message", message) + } if f, ok := w.(http.Flusher); ok { f.Flush() @@ -590,7 +782,11 @@ func (h *MessagesHandler) handleNonStreaming( endpointType := client.ClassifyEndpoint(model.ModelID) switch endpointType { case client.EndpointAnthropic: - return h.executeAnthropicRequest(ctx, rawBody, model) + if model.AnthropicToolsDisabled { + // Fall through to OpenAI-compatible handling below. + } else { + return h.executeAnthropicRequest(ctx, replaceModelInRawBody(rawBody, model.ModelID), model) + } case client.EndpointResponses: return h.executeResponsesRequest(ctx, anthropicReq, model) case client.EndpointGemini: @@ -600,7 +796,7 @@ func (h *MessagesHandler) handleNonStreaming( } } else if client.IsAnthropicModel(model.ModelID) { // Go provider Anthropic-native models (MiniMax, Qwen) - return h.executeAnthropicRequest(ctx, rawBody, model) + return h.executeAnthropicRequest(ctx, replaceModelInRawBody(rawBody, model.ModelID), model) } // OpenAI-compatible models (both Go and Zen) @@ -634,6 +830,10 @@ func (h *MessagesHandler) executeAnthropicRequest( rawBody json.RawMessage, model config.ModelConfig, ) ([]byte, error) { + // Sanitize Anthropic-specific fields (e.g., tool type shorthands) that + // upstream models may not understand. + rawBody = sanitizeAnthropicBody(rawBody) + resp, err := h.client.SendAnthropicRequest(ctx, rawBody, false, model) if err != nil { return nil, fmt.Errorf("anthropic request failed: %w", err) @@ -750,7 +950,7 @@ func (h *MessagesHandler) sendError(w http.ResponseWriter, statusCode int, messa "error", err, ) - if rw, ok := w.(*responseWriter); ok && rw.wroteHeader { + if rw, ok := w.(*responseWriter); ok && rw.headerWritten() { return } diff --git a/internal/handlers/messages_test.go b/internal/handlers/messages_test.go index f3dcac1..aa53c62 100644 --- a/internal/handlers/messages_test.go +++ b/internal/handlers/messages_test.go @@ -1,6 +1,7 @@ package handlers import ( + "encoding/json" "log/slog" "testing" @@ -8,6 +9,8 @@ import ( "oc-go-cc/internal/router" ) +func boolPtr(b bool) *bool { return &b } + func TestAppendUniqueModels_DedupsByModelID(t *testing.T) { base := []config.ModelConfig{ {Provider: "opencode-go", ModelID: "kimi-k2.6"}, @@ -300,7 +303,7 @@ func TestBuildModelChain_UnknownModel_FallsThroughToScenarioRoute(t *testing.T) // Requested model has no entry in model_overrides and not in models map, // and respect_requested_model is false → scenario routing. cfg := &config.Config{ - RespectRequestedModel: false, + RespectRequestedModel: boolPtr(false), Models: map[string]config.ModelConfig{ "default": {Provider: "opencode-go", ModelID: "kimi-k2.6"}, }, @@ -337,3 +340,129 @@ func equalStrings(a, b []string) bool { } return true } + +func TestSanitizeAnthropicBody_RemovesToolTypeField(t *testing.T) { + rawBody := json.RawMessage(`{ + "model": "minimax-m3", + "tools": [ + { + "type": "custom", + "name": "my_tool", + "description": "A test tool", + "input_schema": {"type": "object"} + }, + { + "type": "custom", + "name": "other_tool", + "description": "Another tool", + "input_schema": {"type": "object"} + } + ] + }`) + + result := sanitizeAnthropicBody(rawBody) + + var body map[string]any + if err := json.Unmarshal(result, &body); err != nil { + t.Fatalf("failed to unmarshal result: %v", err) + } + + tools, ok := body["tools"].([]any) + if !ok { + t.Fatal("expected tools array in result") + } + + for i, tool := range tools { + toolMap, ok := tool.(map[string]any) + if !ok { + t.Fatalf("tool %d is not a map", i) + } + if _, hasType := toolMap["type"]; hasType { + t.Errorf("tool %d still has type field after sanitization", i) + } + if name, ok := toolMap["name"]; !ok || name != ([]string{"my_tool", "other_tool"})[i] { + t.Errorf("tool %d name field was corrupted", i) + } + } +} + +func TestSanitizeAnthropicBody_NoTools(t *testing.T) { + rawBody := json.RawMessage(`{"model": "minimax-m3", "messages": []}`) + result := sanitizeAnthropicBody(rawBody) + + // Should return the original body unchanged + if string(result) != string(rawBody) { + t.Error("body without tools should be returned unchanged") + } +} + +func TestSanitizeAnthropicBody_ToolsWithoutType(t *testing.T) { + rawBody := json.RawMessage(`{ + "tools": [ + { + "name": "my_tool", + "description": "No type field", + "input_schema": {"type": "object"} + } + ] + }`) + result := sanitizeAnthropicBody(rawBody) + + // Should return the original body unchanged (no type field to remove) + if string(result) != string(rawBody) { + t.Error("body with tools without type should be returned unchanged") + } +} + +func TestSanitizeAnthropicBody_InvalidJSON(t *testing.T) { + rawBody := json.RawMessage(`{invalid json}`) + result := sanitizeAnthropicBody(rawBody) + + // Should return original body unchanged on invalid JSON + if string(result) != string(rawBody) { + t.Error("invalid JSON should be returned unchanged") + } +} + +func TestSanitizeAnthropicBody_EmptyBody(t *testing.T) { + rawBody := json.RawMessage(`{}`) + result := sanitizeAnthropicBody(rawBody) + + if string(result) != string(rawBody) { + t.Error("empty body should be returned unchanged") + } +} + +func TestSanitizeAnthropicBody_KeepsOtherFields(t *testing.T) { + rawBody := json.RawMessage(`{ + "model": "minimax-m3", + "system": "You are a helpful assistant", + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4096, + "tools": [ + { + "type": "custom", + "name": "test_tool", + "description": "desc", + "input_schema": {"type": "object", "properties": {}} + } + ] + }`) + result := sanitizeAnthropicBody(rawBody) + + var body map[string]any + if err := json.Unmarshal(result, &body); err != nil { + t.Fatalf("failed to unmarshal result: %v", err) + } + + // Check that non-tool fields are preserved + if body["model"] != "minimax-m3" { + t.Error("model field was corrupted") + } + if body["system"] != "You are a helpful assistant" { + t.Error("system field was corrupted") + } + if body["max_tokens"] != float64(4096) { + t.Error("max_tokens field was corrupted") + } +} diff --git a/internal/router/fallback.go b/internal/router/fallback.go index e741e41..862d25a 100644 --- a/internal/router/fallback.go +++ b/internal/router/fallback.go @@ -208,13 +208,21 @@ func (h *FallbackHandler) ExecuteWithFallback( }, body, nil } - cb.RecordFailure() - h.logger.Warn("model failed, trying fallback", - "model", model.ModelID, - "error", err, - "remaining", totalModels-i-1, - "circuit_state", cb.State(), - ) + if IsRetryableError(err) { + cb.RecordFailure() + h.logger.Warn("model failed, trying fallback", + "model", model.ModelID, + "error", err, + "remaining", totalModels-i-1, + "circuit_state", cb.State(), + ) + } else { + h.logger.Warn("non-retryable error (skipping circuit breaker), trying fallback", + "model", model.ModelID, + "error", err, + "remaining", totalModels-i-1, + ) + } } return &FallbackResult{ @@ -243,13 +251,22 @@ func IsRetryableError(err error) bool { } errStr := err.Error() - // Retry on network errors, timeouts, rate limits, server errors + + // 4xx client errors are not retryable — the request format itself is invalid + // for that model, and retrying won't fix it. + if strings.Contains(errStr, "API error 4") { + return false + } + + // Retry on network errors, timeouts, rate limits (from non-4xx paths), + // and server errors (5xx). 4xx client errors are already excluded by + // the "API error 4" check above — 429 is correctly non-retryable, so + // the circuit breaker doesn't open for rate limits. retryable := []string{ "timeout", "connection refused", "connection reset", "rate limit", - "429", "503", "502", "500", diff --git a/internal/router/fallback_test.go b/internal/router/fallback_test.go new file mode 100644 index 0000000..b7a79d9 --- /dev/null +++ b/internal/router/fallback_test.go @@ -0,0 +1,155 @@ +package router + +import ( + "context" + "errors" + "fmt" + "testing" + + "oc-go-cc/internal/config" +) + +func TestIsRetryableError_ClientsErrorsNotRetryable(t *testing.T) { + tests := []struct { + err string + want bool + }{ + // 4xx errors should NOT be retryable + {err: "API error 400: bad request", want: false}, + {err: "API error 401: unauthorized", want: false}, + {err: "API error 403: forbidden", want: false}, + {err: "API error 404: not found", want: false}, + {err: "API error 422: unprocessable", want: false}, + {err: "API error 429: rate limit", want: false}, + + // 5xx and network errors should be retryable (existing behavior) + {err: "API error 500: internal error", want: true}, + {err: "API error 502: bad gateway", want: true}, + {err: "API error 503: service unavailable", want: true}, + {err: "request timeout", want: true}, + {err: "connection refused", want: true}, + {err: "connection reset by peer", want: true}, + {err: "rate limit exceeded", want: true}, + + // Edge cases + {err: "", want: false}, + {err: "random error", want: false}, + {err: "API error 400", want: false}, + {err: "API error 500", want: true}, + } + + for _, tt := range tests { + t.Run(tt.err, func(t *testing.T) { + var err error + if tt.err != "" { + err = errors.New(tt.err) + } + if got := IsRetryableError(err); got != tt.want { + t.Errorf("IsRetryableError(%q) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +func TestExecuteWithFallback_NonRetryableDoesNotOpenCircuit(t *testing.T) { + h := NewFallbackHandler(nil, 1, 0) // 1 failure = open circuit + + models := []config.ModelConfig{ + {ModelID: "model-a"}, + {ModelID: "model-b"}, + } + + attempts := 0 + _, _, err := h.ExecuteWithFallback( + context.Background(), + models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + attempts++ + // Non-retryable 400 error — should NOT open circuit breaker + return nil, fmt.Errorf("API error 400: bad request") + }, + ) + + if err == nil { + t.Fatal("expected all models to fail") + } + + // Circuit breaker should still be closed since errors were non-retryable + cb := h.getCircuitBreaker("model-a") + if cb.State() != CircuitClosed { + t.Errorf("model-a circuit should be closed after non-retryable errors, got %v", cb.State()) + } + + // All models were tried + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestExecuteWithFallback_RetryableOpensCircuit(t *testing.T) { + h := NewFallbackHandler(nil, 1, 0) + + models := []config.ModelConfig{ + {ModelID: "model-a"}, + {ModelID: "model-b"}, + } + + _, _, err := h.ExecuteWithFallback( + context.Background(), + models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + // Retryable 500 error — should open circuit breaker + return nil, fmt.Errorf("API error 500: internal error") + }, + ) + + if err == nil { + t.Fatal("expected all models to fail") + } + + // Circuit breaker should be OPEN after retryable failure + cb := h.getCircuitBreaker("model-a") + if cb.State() != CircuitOpen { + t.Errorf("model-a circuit should be open after retryable error, got %v", cb.State()) + } +} + +func TestExecuteWithFallback_NonRetryableThenRetryable(t *testing.T) { + h := NewFallbackHandler(nil, 1, 0) + callCount := 0 + + models := []config.ModelConfig{ + {ModelID: "model-a"}, + {ModelID: "model-b"}, + } + + _, _, err := h.ExecuteWithFallback( + context.Background(), + models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + callCount++ + if callCount == 1 { + // Non-retryable: model-a should NOT get circuit opened + return nil, fmt.Errorf("API error 400: bad request") + } + // Retryable: model-b should get circuit opened + return nil, fmt.Errorf("API error 500: internal error") + }, + ) + + if err == nil { + t.Fatal("expected all models to fail") + } + + // model-a circuit should be closed (non-retryable) + cbA := h.getCircuitBreaker("model-a") + if cbA.State() != CircuitClosed { + t.Errorf("model-a circuit should be closed after non-retryable error, got %v", cbA.State()) + } + + // model-b circuit should be open (retryable) + cbB := h.getCircuitBreaker("model-b") + if cbB.State() != CircuitOpen { + t.Errorf("model-b circuit should be open after retryable error, got %v", cbB.State()) + } +} diff --git a/internal/router/model_router.go b/internal/router/model_router.go index b16cb4f..9e16008 100644 --- a/internal/router/model_router.go +++ b/internal/router/model_router.go @@ -18,6 +18,16 @@ func NewModelRouter(atomic *config.AtomicConfig) *ModelRouter { return &ModelRouter{atomic: atomic} } +// isRespectRequestedModel returns true when the client-specified model should be +// used as the primary routing target. nil (unset in config) defaults to true; +// an explicit *false from the user config is honoured. +func isRespectRequestedModel(cfg *config.Config) bool { + if cfg.RespectRequestedModel == nil { + return true // default when not explicitly set + } + return *cfg.RespectRequestedModel +} + // RouteResult contains the selected model and fallback chain. type RouteResult struct { Primary config.ModelConfig @@ -29,7 +39,7 @@ type RouteResult struct { // scenario-based routing. Returns the route result and true if it matched, // or zero value and false if scenario routing should proceed normally. func (r *ModelRouter) resolveRequestedModel(cfg *config.Config, requestedModel string) (RouteResult, bool) { - if !cfg.RespectRequestedModel || requestedModel == "" { + if !isRespectRequestedModel(cfg) || requestedModel == "" { return RouteResult{}, false } @@ -138,11 +148,11 @@ func (rr *RouteResult) GetModelChain() []config.ModelConfig { // RouteForStreaming determines which model to use for streaming requests. // Prioritizes fast TTFT (time-to-first-token) over capability. // If respect_requested_model is enabled and requestedModel is provided, it overrides scenario-based routing. -func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount int, requestedModel string) RouteResult { +func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount int, requestedModel string) (RouteResult, error) { cfg := r.atomic.Get() if result, ok := r.resolveRequestedModel(cfg, requestedModel); ok { - return result + return result, nil } // Otherwise, use scenario-based routing for streaming @@ -158,6 +168,9 @@ func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount in primary = cfg.Models["default"] } } + if primary.ModelID == "" { + return RouteResult{}, fmt.Errorf("no model configured for streaming; neither scenario %q, \"fast\", nor \"default\" exist in models map", result.Scenario) + } // Get fallbacks for scenario fallbacks := cfg.Fallbacks[string(result.Scenario)] @@ -170,5 +183,5 @@ func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount in Primary: primary, Fallbacks: fallbacks, Scenario: result.Scenario, - } + }, nil } diff --git a/internal/router/model_router_test.go b/internal/router/model_router_test.go index 608f593..729d4e8 100644 --- a/internal/router/model_router_test.go +++ b/internal/router/model_router_test.go @@ -6,13 +6,15 @@ import ( "oc-go-cc/internal/config" ) +func boolPtr(b bool) *bool { return &b } + func newTestAtomicConfig(cfg *config.Config) *config.AtomicConfig { return config.NewAtomicConfig(cfg, "/tmp/test-config.json") } func TestRoute_RespectRequestedModel_BypassesScenarioRouting(t *testing.T) { cfg := &config.Config{ - RespectRequestedModel: true, + RespectRequestedModel: boolPtr(true), Models: map[string]config.ModelConfig{ "default": { Provider: "opencode-go", @@ -60,7 +62,7 @@ func TestRoute_RespectRequestedModel_BypassesScenarioRouting(t *testing.T) { func TestRoute_RespectRequestedModel_False_UsesScenarioRouting(t *testing.T) { cfg := &config.Config{ - RespectRequestedModel: false, + RespectRequestedModel: boolPtr(false), Models: map[string]config.ModelConfig{ "default": {ModelID: "kimi-k2.6"}, "complex": {ModelID: "glm-5.1"}, @@ -88,7 +90,7 @@ func TestRoute_RespectRequestedModel_False_UsesScenarioRouting(t *testing.T) { func TestRoute_RespectRequestedModel_EmptyModel_FallsThrough(t *testing.T) { cfg := &config.Config{ - RespectRequestedModel: true, + RespectRequestedModel: boolPtr(true), Models: map[string]config.ModelConfig{ "default": {ModelID: "kimi-k2.6"}, }, @@ -115,7 +117,7 @@ func TestRoute_RespectRequestedModel_EmptyModel_FallsThrough(t *testing.T) { func TestRoute_RespectRequestedModel_UnknownModel_UsesDefaults(t *testing.T) { cfg := &config.Config{ - RespectRequestedModel: true, + RespectRequestedModel: boolPtr(true), Models: map[string]config.ModelConfig{ "default": { Provider: "opencode-go", @@ -153,7 +155,7 @@ func TestRoute_RespectRequestedModel_UnknownModel_UsesDefaults(t *testing.T) { func TestRouteForStreaming_RespectRequestedModel_BypassesScenarioRouting(t *testing.T) { cfg := &config.Config{ - RespectRequestedModel: true, + RespectRequestedModel: boolPtr(true), Models: map[string]config.ModelConfig{ "default": {ModelID: "qwen3.6-plus"}, "kimi-k2.6": { @@ -172,7 +174,10 @@ func TestRouteForStreaming_RespectRequestedModel_BypassesScenarioRouting(t *test {Role: "user", Content: "Hello"}, } - result := router.RouteForStreaming(messages, 100, "kimi-k2.6") + result, err := router.RouteForStreaming(messages, 100, "kimi-k2.6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if result.Primary.ModelID != "kimi-k2.6" { t.Errorf("expected model kimi-k2.6, got %s", result.Primary.ModelID) } @@ -183,7 +188,7 @@ func TestRouteForStreaming_RespectRequestedModel_BypassesScenarioRouting(t *test func TestRouteForStreaming_RespectRequestedModel_False_UsesScenarioRouting(t *testing.T) { cfg := &config.Config{ - RespectRequestedModel: false, + RespectRequestedModel: boolPtr(false), Models: map[string]config.ModelConfig{ "default": {ModelID: "qwen3.6-plus"}, }, @@ -198,7 +203,10 @@ func TestRouteForStreaming_RespectRequestedModel_False_UsesScenarioRouting(t *te {Role: "user", Content: "Hello"}, } - result := router.RouteForStreaming(messages, 100, "kimi-k2.6") + result, err := router.RouteForStreaming(messages, 100, "kimi-k2.6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Should use streaming scenario routing, not the requested model if result.Primary.ModelID != "qwen3.6-plus" { t.Errorf("expected streaming model qwen3.6-plus, got %s", result.Primary.ModelID) @@ -207,7 +215,7 @@ func TestRouteForStreaming_RespectRequestedModel_False_UsesScenarioRouting(t *te func TestResolveRequestedModel_UsesFallbacks(t *testing.T) { cfg := &config.Config{ - RespectRequestedModel: true, + RespectRequestedModel: boolPtr(true), Models: map[string]config.ModelConfig{ "kimi-k2.6": {ModelID: "kimi-k2.6"}, }, diff --git a/internal/server/server.go b/internal/server/server.go index 0c96909..af7e87d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -72,11 +72,17 @@ func NewServer(atomic *config.AtomicConfig) (*Server, error) { // Create HTTP server. addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) httpSrv := &http.Server{ - Addr: addr, - Handler: mux, - ReadTimeout: 30 * time.Second, - WriteTimeout: 5 * time.Minute, - IdleTimeout: 120 * time.Second, + Addr: addr, + Handler: mux, + ReadTimeout: 120 * time.Second, + // WriteTimeout is disabled (zero). Long-running SSE streams must not be + // killed mid-flight. Stuck upstream connections are handled by the + // per-stream idle watchdog (transformer/idle.go) which cancels the + // upstream context when no bytes arrive within the model's idle timeout. + // IdleTimeout here governs keep-alive between separate HTTP requests on + // the same TCP connection; it does NOT affect in-stream byte gaps. + WriteTimeout: 0, + IdleTimeout: 300 * time.Second, } srv := &Server{ diff --git a/internal/transformer/idle.go b/internal/transformer/idle.go new file mode 100644 index 0000000..1fa5603 --- /dev/null +++ b/internal/transformer/idle.go @@ -0,0 +1,66 @@ +// Package transformer handles request/response transformation and token counting. +package transformer + +import ( + "context" + "sync" + "time" +) + +// StartIdleWatchdog launches a goroutine that calls cancel() if no call to the +// returned ping function occurs within idleTimeout. The caller must invoke ping() +// after every successful byte read from the upstream stream. +// +// The watchdog goroutine exits when ctx is done (e.g., the stream completed or +// the caller cancelled the context). The caller MUST cancel ctx when the stream +// is finished to avoid leaking the goroutine. +// +// Pass idleTimeout <= 0 to disable the watchdog (the returned ping is a no-op). +// +// Typical usage: +// +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() +// ping := StartIdleWatchdog(ctx, cancel, idleTimeout) +// // In the read loop: +// n, err := body.Read(buf) +// if n > 0 { +// ping() +// // process bytes +// } +func StartIdleWatchdog(ctx context.Context, cancel context.CancelFunc, idleTimeout time.Duration) func() { + if idleTimeout <= 0 { + return func() {} + } + + var mu sync.Mutex + timer := time.NewTimer(idleTimeout) + + go func() { + defer timer.Stop() + for { + select { + case <-ctx.Done(): + return + case <-timer.C: + cancel() + return + } + } + }() + + return func() { + mu.Lock() + // Reset the timer on every ping so the deadline is always idleTimeout + // from the most recent byte, not from the last timer fire. + if !timer.Stop() { + // Timer already fired; drain the channel to avoid a spurious wake. + select { + case <-timer.C: + default: + } + } + timer.Reset(idleTimeout) + mu.Unlock() + } +} diff --git a/internal/transformer/request.go b/internal/transformer/request.go index 404c552..11a89f0 100644 --- a/internal/transformer/request.go +++ b/internal/transformer/request.go @@ -52,6 +52,17 @@ func needsPlaceholderReasoning(modelID string) bool { return strings.HasPrefix(modelID, "kimi-") } +// constrainTemperature overrides model-specific temperature constraints. +// Some models require specific temperature values — return the constrained +// value or the original if no constraint applies. +func constrainTemperature(modelID string, temp float64) float64 { + // Moonshot AI (kimi-k2.7-code) only allows temperature=1. + if modelID == "kimi-k2.7-code" { + return 1.0 + } + return temp +} + // stripCacheControl removes cache_control from all messages in the list. // The caller must not hold references to the slice elements. func stripCacheControl(messages []types.ChatMessage) { @@ -100,10 +111,14 @@ func (t *RequestTransformer) TransformRequest( openaiReq.MaxTokens = &maxTokens } - // Apply model-specific overrides + // Apply model-specific overrides and temperature constraints if model.Temperature > 0 { openaiReq.Temperature = &model.Temperature } + if openaiReq.Temperature != nil { + temp := constrainTemperature(model.ModelID, *openaiReq.Temperature) + openaiReq.Temperature = &temp + } if model.MaxTokens > 0 { maxTokens := model.MaxTokens openaiReq.MaxTokens = &maxTokens diff --git a/internal/transformer/request_test.go b/internal/transformer/request_test.go index 7b68818..f9de7c9 100644 --- a/internal/transformer/request_test.go +++ b/internal/transformer/request_test.go @@ -3,6 +3,7 @@ package transformer import ( "bytes" "encoding/json" + "fmt" "strings" "testing" @@ -1523,3 +1524,34 @@ func TestTransformRequestStandardModelIgnoresThinkingAndEffort(t *testing.T) { t.Fatalf("expected Thinking to be nil for standard model, got %s", string(openaiReq.Thinking)) } } + +func TestConstrainTemperature(t *testing.T) { + tests := []struct { + modelID string + input float64 + want float64 + }{ + // kimi-k2.7-code forces temperature to 1.0 + {modelID: "kimi-k2.7-code", input: 0.7, want: 1.0}, + {modelID: "kimi-k2.7-code", input: 0.0, want: 1.0}, + {modelID: "kimi-k2.7-code", input: 1.5, want: 1.0}, + + // Other kimi models are not constrained + {modelID: "kimi-k2.6", input: 0.7, want: 0.7}, + {modelID: "kimi-k2.5", input: 0.5, want: 0.5}, + + // Other models are not constrained + {modelID: "minimax-m3", input: 0.7, want: 0.7}, + {modelID: "deepseek-v4-pro", input: 0.5, want: 0.5}, + {modelID: "glm-5.1", input: 0.3, want: 0.3}, + {modelID: "qwen3.7-plus", input: 0.9, want: 0.9}, + } + + for _, tt := range tests { + t.Run(tt.modelID+"/"+fmt.Sprint(tt.input), func(t *testing.T) { + if got := constrainTemperature(tt.modelID, tt.input); got != tt.want { + t.Errorf("constrainTemperature(%q, %f) = %f, want %f", tt.modelID, tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/transformer/stream.go b/internal/transformer/stream.go index 646ec44..ced8c6e 100644 --- a/internal/transformer/stream.go +++ b/internal/transformer/stream.go @@ -5,8 +5,10 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" + "net" "net/http" "sort" "strings" @@ -18,6 +20,27 @@ import ( // ErrClientDisconnected is returned when the client disconnects during streaming. var ErrClientDisconnected = fmt.Errorf("client disconnected") +// ErrStreamIdle is returned when no bytes arrive within idleTimeout on the +// upstream stream. The connection is stale (e.g. backend hang or network +// partition). The handler decides whether to fall back to another model. +var ErrStreamIdle = fmt.Errorf("upstream stream idle") + +// IsIdleTimeout reports whether err is a read-timeout (network deadline +// exceeded on an otherwise live stream). +func IsIdleTimeout(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.DeadlineExceeded) { + return true + } + var ne net.Error + if errors.As(err, &ne) && ne.Timeout() { + return true + } + return false +} + // StreamHandler handles streaming SSE transformation from OpenAI to Anthropic format. type StreamHandler struct { responseTransformer *ResponseTransformer @@ -36,11 +59,18 @@ func NewStreamHandler() *StreamHandler { // // CRITICAL: This function reads directly from resp.Body without buffering to minimize latency. // Per deep research: "Don't use bufio.Scanner or bufio.Reader on the response body - it adds buffering" +// +// idleTimeout is the maximum gap between bytes on the upstream stream. The +// stream lives as long as data keeps flowing; only an idle period longer than +// idleTimeout is treated as a stuck connection and surfaces as ErrStreamIdle. +// Pass 0 to disable (stream lives until EOF or error). func (h *StreamHandler) ProxyStream( w http.ResponseWriter, openaiResp io.ReadCloser, originalModel string, clientCtx context.Context, + idleTimeout time.Duration, + cancel context.CancelFunc, ) error { flusher, ok := w.(http.Flusher) if !ok { @@ -75,10 +105,17 @@ func (h *StreamHandler) ProxyStream( stopSent := false toolUseCount := 0 startedToolCalls := make(map[int]int) // maps OpenAI tool call index → Anthropic content block index + decodeErrors := 0 // consecutive SSE decode failures // Read in larger chunks for efficiency, then parse lines readBuf := make([]byte, 4096) + // Start the idle watchdog. Each successful read pings the watchdog so + // the stream lives as long as data keeps flowing. If no bytes arrive + // within idleTimeout, cancel() is called, which aborts the upstream + // HTTP request and causes the next Read to return a context error. + ping := StartIdleWatchdog(clientCtx, cancel, idleTimeout) + for { // Check if client disconnected select { @@ -90,6 +127,9 @@ func (h *StreamHandler) ProxyStream( // Read chunk from upstream n, err := openaiResp.Read(readBuf) if n > 0 { + // Data is flowing — reset the idle watchdog so the stream + // lives as long as data keeps arriving. + ping() // Process bytes immediately for i := 0; i < n; i++ { b := readBuf[i] @@ -98,7 +138,7 @@ func (h *StreamHandler) ProxyStream( lineBuf.Reset() // Process complete line - if err := h.processSSELine(w, flusher, line, &contentIndex, &contentStarted, &reasoningStarted, &stopSent, &toolUseCount, startedToolCalls, originalModel); err != nil { + if err := h.processSSELine(w, flusher, line, &contentIndex, &contentStarted, &reasoningStarted, &stopSent, &toolUseCount, startedToolCalls, originalModel, &decodeErrors); err != nil { return err } } else { @@ -111,13 +151,22 @@ func (h *StreamHandler) ProxyStream( // Process any remaining data in buffer if lineBuf.Len() > 0 { line := lineBuf.String() - if err := h.processSSELine(w, flusher, line, &contentIndex, &contentStarted, &reasoningStarted, &stopSent, &toolUseCount, startedToolCalls, originalModel); err != nil { + if err := h.processSSELine(w, flusher, line, &contentIndex, &contentStarted, &reasoningStarted, &stopSent, &toolUseCount, startedToolCalls, originalModel, &decodeErrors); err != nil { return err } } break } if err != nil { + if IsIdleTimeout(err) { + return ErrStreamIdle + } + // When the idle watchdog fires, it cancels the upstream context + // which produces context.Canceled on Read. Distinguish that + // from a client disconnect by checking clientCtx. + if errors.Is(err, context.Canceled) && clientCtx.Err() == nil { + return ErrStreamIdle + } return fmt.Errorf("failed to read stream: %w", err) } } @@ -150,12 +199,7 @@ func (h *StreamHandler) ProxyStream( return entries[i].blockIdx < entries[j].blockIdx }) for _, e := range entries { - idx := e.blockIdx - stopEvent := types.MessageEvent{ - Type: "content_block_stop", - Index: &idx, - } - if err := writeSSEEvent(w, stopEvent); err != nil { + if err := writeContentBlockStop(w, e.blockIdx); err != nil { return ErrClientDisconnected } } @@ -207,6 +251,7 @@ func (h *StreamHandler) processSSELine( toolUseCount *int, startedToolCalls map[int]int, originalModel string, + decodeErrors *int, ) error { line = strings.TrimSpace(line) @@ -241,20 +286,29 @@ func (h *StreamHandler) processSSELine( !strings.Contains(data, `"tool_calls"`) && !strings.Contains(data, `"usage"`) { if idx := strings.Index(data, `"delta":{"content":"`); idx != -1 { - // Extract content directly + // Walk past JSON escape sequences to find the real closing + // quote. A naive strings.Index would stop at an escaped + // \" inside the content. start := idx + len(`"delta":{"content":"`) - end := strings.Index(data[start:], `"`) + suffix := data[start:] + end := -1 + for i := 0; i < len(suffix); i++ { + if suffix[i] == '\\' { + i++ // skip the escaped character + continue + } + if suffix[i] == '"' { + end = i + break + } + } if end != -1 { content := data[start : start+end] if content != "" { if !*contentStarted { // If reasoning was already started, close it first if *reasoningStarted { - stopEvent := types.MessageEvent{ - Type: "content_block_stop", - Index: contentIndex, - } - if err := writeSSEEvent(w, stopEvent); err != nil { + if err := writeContentBlockStop(w, *contentIndex); err != nil { return ErrClientDisconnected } *contentIndex++ @@ -287,6 +341,10 @@ func (h *StreamHandler) processSSELine( } flusher.Flush() } + // Valid SSE line accepted via fast path — reset the + // consecutive decode failure counter so interleaved valid + // chunks don't accumulate spurious "too many failures". + *decodeErrors = 0 return nil } } @@ -295,9 +353,16 @@ func (h *StreamHandler) processSSELine( // For tool calls and other complex cases, fall back to full JSON parsing var chunk types.ChatCompletionChunk if err := json.Unmarshal([]byte(data), &chunk); err != nil { - // Skip malformed chunks - don't fail the whole stream + // Track consecutive decode failures. A transient glitch is tolerated, + // but persistent corruption terminates the stream rather than silently + // dropping content. + *decodeErrors++ + if *decodeErrors > 3 { + return fmt.Errorf("too many consecutive SSE decode failures (%d)", *decodeErrors) + } return nil } + *decodeErrors = 0 if len(chunk.Choices) == 0 { if chunk.Usage != nil { @@ -423,7 +488,11 @@ func (h *StreamHandler) processSSELine( // already fully processed. continue } - if *contentStarted || *reasoningStarted { + // Close any existing content/reasoning block before opening the + // tool block. Capture the state first so we know whether to + // advance contentIndex — the close itself clears the flags. + hadStartedBlock := *contentStarted || *reasoningStarted + if hadStartedBlock { stopEvent := types.MessageEvent{ Type: "content_block_stop", Index: contentIndex, @@ -435,7 +504,15 @@ func (h *StreamHandler) processSSELine( *reasoningStarted = false } // First time seeing this logical tool call — start a new block. - *contentIndex++ + // Only increment contentIndex when a previous text or reasoning + // block was already started, OR when a prior tool call has already + // claimed index 0 (parallel or sequential tool calls). If nothing + // was started yet (single-tool response), the first tool block + // keeps contentIndex at 0 so the Anthropic SSE content block + // indices are contiguous. + if hadStartedBlock || len(startedToolCalls) > 0 { + *contentIndex++ + } *toolUseCount++ blockIdx = *contentIndex startedToolCalls[oi] = blockIdx @@ -576,6 +653,14 @@ func usageInfoToAnthropic(usage *types.UsageInfo) *types.Usage { } } +// writeContentBlockStop writes a content_block_stop SSE event at the given index. +func writeContentBlockStop(w http.ResponseWriter, index int) error { + return writeSSEEvent(w, types.MessageEvent{ + Type: "content_block_stop", + Index: &index, + }) +} + // writeSSEEvent writes a single SSE event to the HTTP response writer. // Format: "event: \ndata: \n\n" func writeSSEEvent(w http.ResponseWriter, event types.MessageEvent) error { @@ -599,6 +684,8 @@ func (h *StreamHandler) ProxyResponsesStream( responsesResp io.ReadCloser, originalModel string, clientCtx context.Context, + idleTimeout time.Duration, + cancel context.CancelFunc, ) error { flusher, ok := w.(http.Flusher) if !ok { @@ -627,6 +714,8 @@ func (h *StreamHandler) ProxyResponsesStream( stopSent := false readBuf := make([]byte, 4096) + ping := StartIdleWatchdog(clientCtx, cancel, idleTimeout) + for { select { case <-clientCtx.Done(): @@ -636,6 +725,7 @@ func (h *StreamHandler) ProxyResponsesStream( n, err := responsesResp.Read(readBuf) if n > 0 { + ping() for i := 0; i < n; i++ { b := readBuf[i] if b == '\n' { @@ -660,6 +750,12 @@ func (h *StreamHandler) ProxyResponsesStream( break } if err != nil { + if IsIdleTimeout(err) { + return ErrStreamIdle + } + if errors.Is(err, context.Canceled) && clientCtx.Err() == nil { + return ErrStreamIdle + } return fmt.Errorf("failed to read stream: %w", err) } } @@ -777,6 +873,8 @@ func (h *StreamHandler) ProxyGeminiStream( geminiResp io.ReadCloser, originalModel string, clientCtx context.Context, + idleTimeout time.Duration, + cancel context.CancelFunc, ) error { flusher, ok := w.(http.Flusher) if !ok { @@ -805,6 +903,8 @@ func (h *StreamHandler) ProxyGeminiStream( stopSent := false readBuf := make([]byte, 4096) + ping := StartIdleWatchdog(clientCtx, cancel, idleTimeout) + for { select { case <-clientCtx.Done(): @@ -814,6 +914,7 @@ func (h *StreamHandler) ProxyGeminiStream( n, err := geminiResp.Read(readBuf) if n > 0 { + ping() for i := 0; i < n; i++ { b := readBuf[i] if b == '\n' { @@ -838,6 +939,12 @@ func (h *StreamHandler) ProxyGeminiStream( break } if err != nil { + if IsIdleTimeout(err) { + return ErrStreamIdle + } + if errors.Is(err, context.Canceled) && clientCtx.Err() == nil { + return ErrStreamIdle + } return fmt.Errorf("failed to read stream: %w", err) } } diff --git a/internal/transformer/stream_test.go b/internal/transformer/stream_test.go index fc4458d..f592b7c 100644 --- a/internal/transformer/stream_test.go +++ b/internal/transformer/stream_test.go @@ -75,7 +75,7 @@ func TestProxyStream_ReasoningContentFastPath(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -134,7 +134,7 @@ func TestProxyStream_ReasoningThenText(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -187,7 +187,7 @@ func TestProxyStream_TextOnlyStillWorks(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -221,7 +221,7 @@ func TestProxyStream_UsageOnlyChunk(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "deepseek-v4-pro", ctx); err != nil { + if err := handler.ProxyStream(w, body, "deepseek-v4-pro", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -267,7 +267,7 @@ func TestProxyStream_PartialCacheTokensStreaming(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "deepseek-v4-pro", ctx); err != nil { + if err := handler.ProxyStream(w, body, "deepseek-v4-pro", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -309,7 +309,7 @@ func TestProxyStream_NoDuplicateMessageDelta(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "deepseek-v4-pro", ctx); err != nil { + if err := handler.ProxyStream(w, body, "deepseek-v4-pro", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -355,7 +355,7 @@ func TestProxyStream_ReasoningJSONFallback(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -389,7 +389,7 @@ func TestProxyStream_EmptyReasoningContentSkipped(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -423,7 +423,7 @@ func TestProxyStream_ReasoningAndContentInSameChunk(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -491,7 +491,7 @@ func TestProxyStream_ReasoningBeforeContentFastPathRegression(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "deepseek-v4-flash", ctx); err != nil { + if err := handler.ProxyStream(w, body, "deepseek-v4-flash", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -542,7 +542,7 @@ func TestProxyStream_ToolCallFinishReasonWithUsage(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -587,15 +587,14 @@ func TestProxyStream_SingleToolCall(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } events := parseSSEEvents(t, w.buf.String()) - // Expected: message_start, tool_start(idx=1), 2x input_json_delta (3rd arg arrives - // with finish_reason in same chunk, fast path returns before processing delta), - // tool_stop(idx=1), message_delta, message_stop = 7 + // Expected: message_start, tool_start(idx=0), 2x input_json_delta, + // tool_stop(idx=0), message_delta, message_stop = 7 if len(events) != 7 { t.Fatalf("expected 7 events, got %d: %+v", len(events), events) } @@ -663,7 +662,7 @@ func TestProxyStream_MultipleParallelToolCalls(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -725,7 +724,7 @@ func TestProxyStream_ToolCallGhostChunk(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -758,7 +757,7 @@ func TestProxyStream_MixedTextAndToolCall(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -812,7 +811,7 @@ func TestProxyStream_MixedReasoningAndToolCall(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -856,7 +855,7 @@ func TestProxyStream_ToolCallFinishReasonFastPath(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -889,7 +888,7 @@ func TestProxyStream_ContentAndFinishReasonInSameChunk(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -933,7 +932,7 @@ func TestProxyStream_ToolCallAndFinishReasonInSameChunk(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -942,8 +941,8 @@ func TestProxyStream_ToolCallAndFinishReasonInSameChunk(t *testing.T) { // Expected events: // 0: message_start // 1: content_block_start (index 1, type tool_use) - // 2: content_block_delta (index 1, partial_json "{\"loc\":\"Beijing\"}") - // 3: content_block_stop (index 1) + // 2: content_block_delta (index 0, partial_json "{\"loc\":\"Beijing\"}") + // 3: content_block_stop (index 0) // 4: message_delta (stop_reason: tool_use) // 5: message_stop if len(events) != 6 { @@ -956,8 +955,8 @@ func TestProxyStream_ToolCallAndFinishReasonInSameChunk(t *testing.T) { if events[2].Type != "content_block_delta" || events[2].Delta == nil || events[2].Delta.PartialJSON != `{"loc":"Beijing"}` { t.Errorf("event[2] = %+v, want content_block_delta", events[2]) } - if events[3].Type != "content_block_stop" || events[3].Index == nil || *events[3].Index != 1 { - t.Errorf("event[3] = %+v, want content_block_stop(1)", events[3]) + if events[3].Type != "content_block_stop" || events[3].Index == nil || *events[3].Index != 0 { + t.Errorf("event[3] = %+v, want content_block_stop(0)", events[3]) } if events[4].Type != "message_delta" || events[4].Delta == nil || events[4].Delta.StopReason != "tool_use" { t.Errorf("event[4] = %+v, want message_delta(tool_use)", events[4]) @@ -975,7 +974,7 @@ func TestProxyStream_NoUsageFallback(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "qwen3.6-plus", ctx); err != nil { + if err := handler.ProxyStream(w, body, "qwen3.6-plus", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -1011,7 +1010,7 @@ func TestProxyStream_NoFinishReasonFallback(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "qwen3.6-plus", ctx); err != nil { + if err := handler.ProxyStream(w, body, "qwen3.6-plus", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } @@ -1046,7 +1045,7 @@ func TestProxyStream_EOFFallbackStopReasonToolUse(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx); err != nil { + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { t.Fatalf("ProxyStream error: %v", err) } diff --git a/scripts/e2e-test.sh b/scripts/e2e-test.sh new file mode 100755 index 0000000..7058f4a --- /dev/null +++ b/scripts/e2e-test.sh @@ -0,0 +1,301 @@ +#!/usr/bin/env bash +set -euo pipefail + +# End-to-end test for model tool format and fallback fixes. +# Tests each model with a tool-containing request to verify: +# - No "Unknown server-tool shorthand" 400 errors +# - No temperature constraint violations +# - All Go provider models work through the transform path +# +# Usage: +# source .env && ./scripts/e2e-test.sh [--build] + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$REPO_ROOT" + +# --- Config --- +PORT="${OC_GO_CC_PORT:-3457}" +HOST="${OC_GO_CC_HOST:-127.0.0.1}" +BASE_URL="http://${HOST}:${PORT}" +TIMEOUT_SEC=60 +pass=0 +fail=0 + +# Colors +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +cleanup() { + echo "=== Cleaning up ===" + if [ -n "${PROXY_PID:-}" ]; then + kill "${PROXY_PID}" 2>/dev/null || true + wait "${PROXY_PID}" 2>/dev/null || true + fi + ./bin/oc-go-cc stop 2>/dev/null || true + sleep 1 + rm -f ~/.config/oc-go-cc/oc-go-cc.pid +} +trap cleanup EXIT + +# --- Build --- +if [ "${1:-}" = "--skip-build" ]; then + echo -e "${YELLOW}Skipping build...${NC}" +else + echo "=== Building oc-go-cc ===" + make build + echo "" +fi + +# --- Source .env (must be done BEFORE start) --- +if [ -f .env ]; then + set -a + # shellcheck disable=SC1091 + source .env + set +a +fi + +if [ -z "${OC_GO_CC_API_KEY:-}" ]; then + echo -e "${RED}Error: OC_GO_CC_API_KEY not set. Create a .env file or export it.${NC}" + exit 1 +fi + +# --- Start proxy --- +echo "=== Starting proxy on ${HOST}:${PORT} ===" +cleanup +# Run server in foreground but background it with & so we can capture the PID. +# Do NOT use -b (daemonize) because that forks and exits the parent, making $! +# capture the wrong PID. +./bin/oc-go-cc serve --port "$PORT" > /tmp/oc-go-cc-e2e.log 2>&1 & +PROXY_PID=$! +echo "Server PID: ${PROXY_PID}" + +# Wait for health check with timeout +HEALTH_OK=false +for i in $(seq 1 10); do + if curl -sf "${BASE_URL}/health" > /dev/null 2>&1; then + HEALTH_OK=true + break + fi + sleep 1 +done +if [ "${HEALTH_OK}" != "true" ]; then + echo -e "${RED}Proxy failed to start${NC}" + cat /tmp/oc-go-cc-e2e.log 2>/dev/null || true + exit 1 +fi +echo -e "${GREEN}Proxy is running${NC}" +echo "" + +# --- Test helper --- +test_model() { + local model=$1 + local label=$2 + + echo -n " [$label] ${model} ... " + + REQUEST_BODY=$(cat <<'JSON' +{ + "model": "MODEL_PLACEHOLDER", + "tools": [ + { + "type": "custom", + "name": "read_file", + "description": "Read a file from the filesystem", + "input_schema": { + "type": "object", + "properties": { + "path": {"type": "string"} + } + } + } + ], + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Say hello and nothing else"} + ] + } + ], + "max_tokens": 100, + "stream": false +} +JSON +) + REQUEST_BODY="${REQUEST_BODY//MODEL_PLACEHOLDER/$model}" + + HTTP_CODE=$(curl -s -o /tmp/oc-go-cc-e2e-response.json -w '%{http_code}' \ + -X POST "${BASE_URL}/v1/messages" \ + -H "Content-Type: application/json" \ + -H "x-api-key: ${OC_GO_CC_API_KEY}" \ + -d "$REQUEST_BODY" \ + --max-time "$TIMEOUT_SEC") + + if [ "$HTTP_CODE" = 200 ]; then + # Extract the text response for verification + TEXT=$(python3 -c " +import json +with open('/tmp/oc-go-cc-e2e-response.json') as f: + d = json.load(f) +blocks = d.get('content', []) +for b in blocks: + if b.get('type') == 'text': + print(b.get('text', '')) +" 2>/dev/null) + echo -e "${GREEN}PASS${NC} (200, response: \"${TEXT}\")" + pass=$((pass + 1)) + else + ERROR_MSG=$(head -c 300 /tmp/oc-go-cc-e2e-response.json 2>/dev/null || echo "") + echo -e "${RED}FAIL${NC} (HTTP ${HTTP_CODE})" + echo " Response: ${ERROR_MSG}" + fail=$((fail + 1)) + fi +} + +# --- Streaming test helper --- +test_streaming_model() { + local model=$1 + local label=$2 + + echo -n " [${label}] ${model} (streaming) ... " + + REQUEST_BODY=$(cat <<'JSON' +{ + "model": "MODEL_PLACEHOLDER", + "tools": [ + { + "type": "custom", + "name": "read_file", + "description": "Read a file from the filesystem", + "input_schema": { + "type": "object", + "properties": { + "path": {"type": "string"} + } + } + } + ], + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Say hello and nothing else"} + ] + } + ], + "max_tokens": 100, + "stream": true +} +JSON +) + REQUEST_BODY="${REQUEST_BODY//MODEL_PLACEHOLDER/$model}" + + HTTP_CODE=$(curl -s -o /tmp/oc-go-cc-e2e-stream-response.txt -w '%{http_code}' \ + -X POST "${BASE_URL}/v1/messages" \ + -H "Content-Type: application/json" \ + -H "x-api-key: ${OC_GO_CC_API_KEY}" \ + -d "$REQUEST_BODY" \ + --max-time "$TIMEOUT_SEC") + + if [ "$HTTP_CODE" = 200 ]; then + # Verify it's a valid SSE stream: must have message_start and message_stop + if grep -q "event: message_start" /tmp/oc-go-cc-e2e-stream-response.txt && \ + grep -q "event: message_stop" /tmp/oc-go-cc-e2e-stream-response.txt; then + echo -e "${GREEN}PASS${NC} (200, valid SSE stream)" + pass=$((pass + 1)) + else + echo -e "${RED}FAIL${NC} (200 but missing message_start/message_stop — corrupted SSE)" + head -c 400 /tmp/oc-go-cc-e2e-stream-response.txt + fail=$((fail + 1)) + fi + else + ERROR_MSG=$(head -c 300 /tmp/oc-go-cc-e2e-stream-response.txt 2>/dev/null || echo "") + echo -e "${RED}FAIL${NC} (HTTP ${HTTP_CODE})" + echo " Response: ${ERROR_MSG}" + fail=$((fail + 1)) + fi +} + +# --- Long streaming test helper (exercises heartbeat path) --- +test_streaming_long() { + local model=$1 + local label=$2 + + echo -n " [${label}] ${model} (streaming long) ... " + + REQUEST_BODY=$(cat <<'JSON' +{ + "model": "MODEL_PLACEHOLDER", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Write a paragraph about the importance of testing in software engineering. Aim for 200 words."} + ] + } + ], + "max_tokens": 500, + "stream": true +} +JSON +) + REQUEST_BODY="${REQUEST_BODY//MODEL_PLACEHOLDER/$model}" + + HTTP_CODE=$(curl -s -o /tmp/oc-go-cc-e2e-stream-long.txt -w '%{http_code}' \ + -X POST "${BASE_URL}/v1/messages" \ + -H "Content-Type: application/json" \ + -H "x-api-key: ${OC_GO_CC_API_KEY}" \ + -d "$REQUEST_BODY" \ + --max-time 120) + + if [ "$HTTP_CODE" = 200 ]; then + if grep -q "event: message_start" /tmp/oc-go-cc-e2e-stream-long.txt && \ + grep -q "event: message_stop" /tmp/oc-go-cc-e2e-stream-long.txt; then + DELTA_COUNT=$(grep -c "event: content_block_delta" /tmp/oc-go-cc-e2e-stream-long.txt 2>/dev/null || echo 0) + echo -e "${GREEN}PASS${NC} (200, ${DELTA_COUNT} content deltas, valid SSE)" + pass=$((pass + 1)) + else + echo -e "${RED}FAIL${NC} (200 but invalid SSE — missing start/stop)" + head -c 400 /tmp/oc-go-cc-e2e-stream-long.txt + fail=$((fail + 1)) + fi + else + ERROR_MSG=$(head -c 300 /tmp/oc-go-cc-e2e-stream-long.txt 2>/dev/null || echo "") + echo -e "${RED}FAIL${NC} (HTTP ${HTTP_CODE})" + echo " Response: ${ERROR_MSG}" + fail=$((fail + 1)) + fi +} + +# --- Test cases --- +echo "=== E2E Model Tests (with tools and custom type) ===" +echo "" + +test_model "minimax-m3" "Tools format fix (was 400)" +test_model "deepseek-v4-flash" "Baseline" +test_model "kimi-k2.7-code" "Temperature fix (was 400)" +test_model "deepseek-v4-pro" "Thinking model" +test_model "qwen3.7-plus" "Go provider qwen (transform path)" +test_model "qwen3.7-max" "Anthropic endpoint + sanitization" + +echo "" +echo "=== E2E Streaming Tests (SSE proxying, heartbeat safety) ===" +echo "" + +test_streaming_model "deepseek-v4-flash" "Streaming + tools" +test_streaming_model "deepseek-v4-pro" "Streaming + thinking" +test_streaming_model "kimi-k2.7-code" "Streaming Go provider" +test_streaming_model "minimax-m3" "Streaming Anthropic endpoint" +test_streaming_long "deepseek-v4-flash" "Long stream (heartbeat)" + +echo "" +echo "=== Results ===" +echo -e "${GREEN}Passed: ${pass}${NC}" +echo -e "${RED}Failed: ${fail}${NC}" +echo "" + +if [ "$fail" -gt 0 ]; then + exit 1 +fi