From 8dbebbe0e3d59062176196adde4f3c8eb3845f2e Mon Sep 17 00:00:00 2001 From: samuel tuyizere Date: Fri, 19 Jun 2026 12:28:14 +0200 Subject: [PATCH 1/7] feat: Implement OpenCode providers for Go and Zen with streaming support - Added OpenCodeGoProvider and OpenCodeZenProvider to handle requests for the respective backends. - Implemented streaming capabilities for both providers, supporting various wire formats including Anthropic, OpenAI, and Gemini. - Introduced StreamProxy to manage SSE stream forwarding and transformation of events. - Enhanced transformer functions to convert between normalized requests and wire formats. - Updated server initialization to register new providers and handle requests accordingly. --- internal/config/config.go | 1 + internal/core/errors.go | 33 ++ internal/core/normalize.go | 146 +++++++ internal/core/normalized.go | 57 +++ internal/core/provider.go | 92 +++++ internal/core/registry.go | 58 +++ internal/core/validate.go | 47 +++ internal/handlers/messages.go | 97 ++++- internal/handlers/streaming.go | 114 ++++++ internal/provider/opencode_go.go | 279 +++++++++++++ internal/provider/opencode_zen.go | 470 ++++++++++++++++++++++ internal/provider/provider.go | 48 +++ internal/server/server.go | 8 + internal/transformer/normalized_bridge.go | 379 +++++++++++++++++ 14 files changed, 1813 insertions(+), 16 deletions(-) create mode 100644 internal/core/errors.go create mode 100644 internal/core/normalize.go create mode 100644 internal/core/normalized.go create mode 100644 internal/core/provider.go create mode 100644 internal/core/registry.go create mode 100644 internal/core/validate.go create mode 100644 internal/handlers/streaming.go create mode 100644 internal/provider/opencode_go.go create mode 100644 internal/provider/opencode_zen.go create mode 100644 internal/provider/provider.go create mode 100644 internal/transformer/normalized_bridge.go diff --git a/internal/config/config.go b/internal/config/config.go index 655aa0d..9df4a39 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -24,6 +24,7 @@ type Config struct { type ModelConfig struct { Provider string `json:"provider"` ModelID string `json:"model_id"` + WireFormat string `json:"wire_format,omitempty"` // "auto" (default), "openai", "anthropic", "responses", "gemini" Temperature float64 `json:"temperature"` MaxTokens int `json:"max_tokens"` ContextThreshold int `json:"context_threshold"` diff --git a/internal/core/errors.go b/internal/core/errors.go new file mode 100644 index 0000000..ae4e32e --- /dev/null +++ b/internal/core/errors.go @@ -0,0 +1,33 @@ +package core + +import "errors" + +// Sentinel errors for common provider and routing failures. +var ( + ErrModelNotFound = errors.New("model not found") + ErrProviderNotFound = errors.New("provider not found") + ErrUnsupportedCapability = errors.New("capability not supported by model") + ErrRateLimited = errors.New("rate limited by provider") + ErrStreamIdle = errors.New("upstream stream idle") + ErrClientDisconnected = errors.New("client disconnected") +) + +// NormalizedError wraps a provider error with structured context. +type NormalizedError struct { + Kind string // "api_error", "rate_limit", "invalid_request", etc. + Message string + Retryable bool + StatusCode int + Provider string + ModelID string +} + +// Error implements the error interface. +func (e *NormalizedError) Error() string { + return e.Message +} + +// IsRetryable returns true if the error is safe to retry with a fallback model. +func (e *NormalizedError) IsRetryable() bool { + return e.Retryable +} diff --git a/internal/core/normalize.go b/internal/core/normalize.go new file mode 100644 index 0000000..fb08789 --- /dev/null +++ b/internal/core/normalize.go @@ -0,0 +1,146 @@ +package core + +import ( + "encoding/json" + + "github.com/routatic/proxy/pkg/types" +) + +// thinkingConfig mirrors the Anthropic thinking field structure so we can +// decode it without coupling to a specific json.RawMessage layout. +type thinkingConfig struct { + Type string `json:"type"` + BudgetTokens int `json:"budget_tokens,omitempty"` +} + +// NormalizeRequest converts an Anthropic MessageRequest to a NormalizedRequest. +// This is a lossless extraction: all data from the Anthropic format survives. +func NormalizeRequest(anthropicReq *types.MessageRequest) *NormalizedRequest { + nr := &NormalizedRequest{ + Model: anthropicReq.Model, + MaxTokens: anthropicReq.MaxTokens, + Stream: anthropicReq.Stream != nil && *anthropicReq.Stream, + } + + // Extract system prompt (string or array of content blocks). + nr.SystemPrompt = anthropicReq.SystemText() + + // Set temperature if provided. + if anthropicReq.Temperature != nil { + nr.Temperature = anthropicReq.Temperature + } + + // Extract reasoning effort and thinking budget. + if len(anthropicReq.Thinking) > 0 { + var tc thinkingConfig + if err := json.Unmarshal(anthropicReq.Thinking, &tc); err == nil { + nr.ReasoningEffort = tc.Type + nr.ThinkingBudget = tc.BudgetTokens + } + } + + // Convert messages. + for _, msg := range anthropicReq.Messages { + nm := NormalizedMessage{ + Role: msg.Role, + } + + blocks := msg.ContentBlocks() + for _, block := range blocks { + switch block.Type { + case "text": + nm.Content += block.Text + case "tool_use": + nm.ToolCalls = append(nm.ToolCalls, NormalizedToolCall{ + ID: block.ID, + Name: block.Name, + Arguments: string(block.Input), + }) + case "tool_result": + nm.ToolCallID = block.ToolUseID + nm.Content += block.TextContent() + case "thinking": + nm.Thinking += block.Thinking + case "image": + nm.Content += "[Image]" + } + } + + nr.Messages = append(nr.Messages, nm) + } + + // Convert tools. + for _, tool := range anthropicReq.Tools { + nt := NormalizedToolDef{ + Name: tool.Name, + Description: tool.Description, + InputSchema: tool.InputSchema, + } + nr.Tools = append(nr.Tools, nt) + } + + return nr +} + +// DenormalizeResponse converts a NormalizedResponse to an Anthropic MessageResponse. +func DenormalizeResponse(nr *NormalizedResponse) *types.MessageResponse { + resp := &types.MessageResponse{ + ID: nr.ID, + Type: "message", + Model: nr.Model, + Usage: types.Usage{ + InputTokens: nr.Usage.InputTokens, + OutputTokens: nr.Usage.OutputTokens, + CacheCreationInputTokens: nr.Usage.CacheCreationTokens, + CacheReadInputTokens: nr.Usage.CacheReadTokens, + }, + } + + // Build content blocks from messages. + for _, msg := range nr.Messages { + switch msg.Role { + case "assistant": + resp.Role = "assistant" + + // Add thinking block if present. + if msg.Thinking != "" { + resp.Content = append(resp.Content, types.ContentBlock{ + Type: "thinking", + Thinking: msg.Thinking, + }) + } + + // Add text block if present. + if msg.Content != "" { + resp.Content = append(resp.Content, types.ContentBlock{ + Type: "text", + Text: msg.Content, + }) + } + + // Add tool_use blocks. + for _, tc := range msg.ToolCalls { + resp.Content = append(resp.Content, types.ContentBlock{ + Type: "tool_use", + ID: tc.ID, + Name: tc.Name, + Input: []byte(tc.Arguments), + }) + } + } + + // Determine stop reason. + switch nr.StopReason { + case "end_turn": + resp.StopReason = "end_turn" + case "max_tokens": + resp.StopReason = "max_tokens" + case "tool_use": + resp.StopReason = "tool_use" + default: + resp.StopReason = "end_turn" + } + } + + return resp +} diff --git a/internal/core/normalized.go b/internal/core/normalized.go new file mode 100644 index 0000000..780eb6a --- /dev/null +++ b/internal/core/normalized.go @@ -0,0 +1,57 @@ +package core + +// NormalizedMessage is a single message in the internal canonical format. +// All wire formats (Anthropic, OpenAI, Responses, Gemini) map to and from +// this representation. +type NormalizedMessage struct { + Role string // "user", "assistant", "system", "tool" + Content string // Concatenated text content + ToolCalls []NormalizedToolCall // Present on assistant messages + ToolCallID string // Present on tool-result messages + Thinking string // Reasoning/thinking content (assistant only) +} + +// NormalizedToolCall represents a tool invocation in the internal format. +type NormalizedToolCall struct { + ID string + Name string + Arguments string // JSON string +} + +// NormalizedRequest is the canonical internal request format. +type NormalizedRequest struct { + Model string + SystemPrompt string + Messages []NormalizedMessage + MaxTokens int + Temperature *float64 + TopP *float64 + Stream bool + Tools []NormalizedToolDef + ReasoningEffort string // "low", "medium", "high" + ThinkingBudget int // budget_tokens for thinking mode +} + +// NormalizedToolDef is a tool definition in the internal format. +type NormalizedToolDef struct { + Name string + Description string + InputSchema []byte // JSON bytes of the schema +} + +// NormalizedResponse is the canonical internal response format. +type NormalizedResponse struct { + ID string + Model string + Messages []NormalizedMessage + StopReason string // "end_turn", "max_tokens", "tool_use" + Usage NormalizedUsage +} + +// NormalizedUsage holds token counts in the internal format. +type NormalizedUsage struct { + InputTokens int + OutputTokens int + CacheReadTokens int + CacheCreationTokens int +} diff --git a/internal/core/provider.go b/internal/core/provider.go new file mode 100644 index 0000000..03e141e --- /dev/null +++ b/internal/core/provider.go @@ -0,0 +1,92 @@ +// Package core defines the provider abstraction, wire format types, and +// capability metadata that form the foundation of the routing engine. +package core + +import ( + "context" + "io" + "time" + + "github.com/routatic/proxy/internal/config" +) + +// WireFormat describes the upstream API format a provider uses for a given model. +type WireFormat int + +const ( + // WireFormatOpenAIChat is the OpenAI Chat Completions format (/v1/chat/completions). + WireFormatOpenAIChat WireFormat = iota + // WireFormatAnthropic is the Anthropic Messages format (/v1/messages). + WireFormatAnthropic + // WireFormatOpenAIResponses is the OpenAI Responses format (/v1/responses). + WireFormatOpenAIResponses + // WireFormatGemini is the Google Gemini format (/v1/models/{id}). + WireFormatGemini +) + +// String returns a human-readable name for the wire format. +func (w WireFormat) String() string { + switch w { + case WireFormatOpenAIChat: + return "openai" + case WireFormatAnthropic: + return "anthropic" + case WireFormatOpenAIResponses: + return "responses" + case WireFormatGemini: + return "gemini" + default: + return "unknown" + } +} + +// ProviderCapabilities describes what a provider can do at the provider level. +// Per-model refinements are returned by ModelCapabilities. +type ProviderCapabilities struct { + SupportsStreaming bool + SupportsTools bool + SupportsThinking bool + SupportsImageInput bool + MaxContextLength int // in tokens + DefaultMaxTokens int + KnownModels []string +} + +// ExecuteResult holds the result of a non-streaming provider call. +type ExecuteResult struct { + Body []byte + ModelID string + Latency time.Duration +} + +// Provider is the abstraction for an upstream LLM provider. +type Provider interface { + // Name returns the provider identifier (e.g. "opencode-go", "opencode-zen"). + Name() string + + // Capabilities returns provider-level capabilities. + Capabilities() ProviderCapabilities + + // ModelCapabilities returns per-model capabilities. Returns false if the + // model is unknown to this provider. + ModelCapabilities(modelID string) (ProviderCapabilities, bool) + + // WireFormat returns the wire format for the given model on this provider. + WireFormat(modelID string) WireFormat + + // Execute sends a non-streaming request and returns the response. + Execute(ctx context.Context, req *NormalizedRequest, model config.ModelConfig) (*ExecuteResult, error) + + // Stream sends a streaming request and returns an io.ReadCloser for SSE + // events. The stream emits raw SSE bytes; the handler is responsible for + // forwarding them. + Stream(ctx context.Context, req *NormalizedRequest, model config.ModelConfig) (io.ReadCloser, error) + + // RoundTripName returns the model ID to use in the upstream request. This + // may differ from the config's ModelID (e.g. for model overrides). + RoundTripName(model config.ModelConfig) string + + // StreamIdleTimeout returns the maximum gap between bytes on an active + // stream before it is treated as stuck and aborted. + StreamIdleTimeout(model config.ModelConfig) time.Duration +} diff --git a/internal/core/registry.go b/internal/core/registry.go new file mode 100644 index 0000000..df28324 --- /dev/null +++ b/internal/core/registry.go @@ -0,0 +1,58 @@ +package core + +import ( + "fmt" + "sync" +) + +// ProviderRegistry provides thread-safe access to registered providers. +type ProviderRegistry struct { + mu sync.RWMutex + providers map[string]Provider +} + +// NewProviderRegistry creates a new provider registry. +func NewProviderRegistry() *ProviderRegistry { + return &ProviderRegistry{ + providers: make(map[string]Provider), + } +} + +// Register adds a provider. Returns an error if the name is already registered. +func (r *ProviderRegistry) Register(p Provider) error { + r.mu.Lock() + defer r.mu.Unlock() + name := p.Name() + if _, ok := r.providers[name]; ok { + return fmt.Errorf("provider %q already registered", name) + } + r.providers[name] = p + return nil +} + +// Get retrieves a provider by name. Returns false if not found. +func (r *ProviderRegistry) Get(name string) (Provider, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + p, ok := r.providers[name] + return p, ok +} + +// MustGet retrieves a provider by name, panicking if missing. +func (r *ProviderRegistry) MustGet(name string) Provider { + if p, ok := r.Get(name); ok { + return p + } + panic(fmt.Sprintf("provider %q not registered", name)) +} + +// List returns all registered provider names. +func (r *ProviderRegistry) List() []string { + r.mu.RLock() + defer r.mu.RUnlock() + names := make([]string, 0, len(r.providers)) + for n := range r.providers { + names = append(names, n) + } + return names +} diff --git a/internal/core/validate.go b/internal/core/validate.go new file mode 100644 index 0000000..41764a9 --- /dev/null +++ b/internal/core/validate.go @@ -0,0 +1,47 @@ +package core + +import "fmt" + +// ValidateRequest checks a NormalizedRequest for structural validity. +func ValidateRequest(req *NormalizedRequest) error { + if req.Model == "" { + return fmt.Errorf("model is required") + } + if len(req.Messages) == 0 { + return fmt.Errorf("messages is required") + } + + // Validate message ordering: user, assistant, tool-result alternation. + for i, msg := range req.Messages { + switch msg.Role { + case "user", "assistant", "system", "tool": + // Valid roles + default: + return fmt.Errorf("messages[%d]: invalid role %q", i, msg.Role) + } + + // Tool-result messages must have a ToolCallID. + if msg.Role == "tool" && msg.ToolCallID == "" { + return fmt.Errorf("messages[%d]: tool-result message missing tool_call_id", i) + } + + // Assistant messages with tool calls must have non-empty tool calls. + if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { + for j, tc := range msg.ToolCalls { + if tc.ID == "" { + return fmt.Errorf("messages[%d].tool_calls[%d]: missing id", i, j) + } + if tc.Name == "" { + return fmt.Errorf("messages[%d].tool_calls[%d]: missing name", i, j) + } + } + } + } + + // Validate max_tokens bounds. + if req.MaxTokens < 0 { + return fmt.Errorf("max_tokens must be non-negative") + } + + return nil +} diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index 92848cb..06d4aa6 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -15,6 +15,7 @@ import ( "github.com/routatic/proxy/internal/client" "github.com/routatic/proxy/internal/config" + "github.com/routatic/proxy/internal/core" "github.com/routatic/proxy/internal/metrics" "github.com/routatic/proxy/internal/middleware" "github.com/routatic/proxy/internal/router" @@ -25,18 +26,20 @@ import ( // MessagesHandler handles /v1/messages requests. type MessagesHandler struct { - client *client.OpenCodeClient - modelRouter *router.ModelRouter - fallbackHandler *router.FallbackHandler + client *client.OpenCodeClient // kept for backward compat during migration + providerRegistry *core.ProviderRegistry // new: provider dispatch + modelRouter *router.ModelRouter + fallbackHandler *router.FallbackHandler + streamProxy *StreamProxy // new: SSE proxy by wire format requestTransformer *transformer.RequestTransformer responseTransformer *transformer.ResponseTransformer streamHandler *transformer.StreamHandler - tokenCounter *token.Counter - logger *slog.Logger - rateLimiter *middleware.RateLimiter - requestDedup *middleware.RequestDeduplicator - requestIDGen *middleware.RequestIDGenerator - metrics *metrics.Metrics + tokenCounter *token.Counter + logger *slog.Logger + rateLimiter *middleware.RateLimiter + requestDedup *middleware.RequestDeduplicator + requestIDGen *middleware.RequestIDGenerator + metrics *metrics.Metrics } // responseWriter wraps http.ResponseWriter to track if headers were written. @@ -107,6 +110,7 @@ func (w *responseWriter) WriteKeepalive() { // NewMessagesHandler creates a new messages handler. func NewMessagesHandler( openCodeClient *client.OpenCodeClient, + providerRegistry *core.ProviderRegistry, modelRouter *router.ModelRouter, fallbackHandler *router.FallbackHandler, tokenCounter *token.Counter, @@ -114,8 +118,10 @@ func NewMessagesHandler( ) *MessagesHandler { return &MessagesHandler{ client: openCodeClient, + providerRegistry: providerRegistry, modelRouter: modelRouter, fallbackHandler: fallbackHandler, + streamProxy: NewStreamProxy(), requestTransformer: transformer.NewRequestTransformer(), responseTransformer: transformer.NewResponseTransformer(), streamHandler: transformer.NewStreamHandler(), @@ -240,12 +246,13 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) "tokens", tokenCount, ) + normalizedReq := core.NormalizeRequest(&anthropicReq) + normalizedReq.Stream = isStreaming + if isStreaming { - // Streaming: use ProxyStream for real-time SSE transformation - h.handleStreaming(w, r, &anthropicReq, modelChain, rawBody) + h.handleStreaming(w, r, &anthropicReq, normalizedReq, modelChain, rawBody) } else { - // Non-streaming: execute with fallback and return full response - h.handleNonStreaming(w, r, &anthropicReq, modelChain, rawBody) + h.handleNonStreaming(w, r, &anthropicReq, normalizedReq, modelChain, rawBody) } } @@ -427,6 +434,52 @@ func (h *MessagesHandler) handleStreaming( return true // continue to next model } + // Try new provider-based dispatch first. + if prov, ok := h.providerRegistry.Get(model.Provider); ok { + normalizedReq := core.NormalizeRequest(anthropicReq) + normalizedReq.Stream = true + + caps, ok := prov.ModelCapabilities(model.ModelID) + if !ok || !caps.SupportsStreaming { + h.logger.Warn("model does not support streaming", "model", model.ModelID, "provider", model.Provider) + cancel() + continue + } + + streamBody, err := prov.Stream(ctx, normalizedReq, model) + if err != nil { + cancel() + if clientCtx.Err() == context.Canceled { + h.logger.Debug("client disconnected during upstream request") + return + } + h.logger.Warn("streaming request failed via provider", "model", model.ModelID, "provider", model.Provider, "error", err) + continue + } + + wireFormat := prov.WireFormat(model.ModelID) + if err := h.streamProxy.ProxyStream(rw, streamBody, wireFormat, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { + _ = streamBody.Close() + if err == transformer.ErrClientDisconnected { + h.logger.Debug("client disconnected during stream") + return + } + if !handleStreamError(err, model, wireFormat.String()) { + return + } + continue + } + + _ = streamBody.Close() + recordStreamSuccess(model) + return + } + + // Legacy path for backward compatibility while old client is still in + // use. Falls through to the old endpoint-classification logic. + h.logger.Warn("provider not found in registry, falling back to old client", + "provider", model.Provider, "model", model.ModelID) + // Zen models use their own endpoint classification if client.IsZen(model) { endpointType := client.ClassifyEndpoint(model.ModelID) @@ -479,8 +532,6 @@ func (h *MessagesHandler) handleStreaming( if !handleStreamError(err, model, "anthropic") { return } - // For non-idle errors after SSE payload started, send a more - // specific error message since this is the last attempt. if err != transformer.ErrStreamIdle && rw.ssePayloadWritten { h.sendStreamError(rw, fmt.Sprintf("all upstream models failed after SSE payload started: %v", err)) h.metrics.RecordFailure() @@ -777,7 +828,21 @@ func (h *MessagesHandler) handleNonStreaming( ctx, modelChain, func(ctx context.Context, model config.ModelConfig) ([]byte, error) { - // Zen models use their own endpoint classification + // Try new provider-based dispatch first. + if prov, ok := h.providerRegistry.Get(model.Provider); ok { + normalizedReq := core.NormalizeRequest(anthropicReq) + normalizedReq.Stream = false + execResult, execErr := prov.Execute(ctx, normalizedReq, model) + if execErr != nil { + return nil, execErr + } + return execResult.Body, nil + } + + h.logger.Warn("provider not found in registry, falling back to old client", + "provider", model.Provider, "model", model.ModelID) + + // Legacy path: Zen models use their own endpoint classification if client.IsZen(model) { endpointType := client.ClassifyEndpoint(model.ModelID) switch endpointType { diff --git a/internal/handlers/streaming.go b/internal/handlers/streaming.go new file mode 100644 index 0000000..8e00867 --- /dev/null +++ b/internal/handlers/streaming.go @@ -0,0 +1,114 @@ +package handlers + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "time" + + "github.com/routatic/proxy/internal/core" + "github.com/routatic/proxy/internal/transformer" +) + +// StreamProxy handles SSE stream forwarding from various upstream wire formats +// to Anthropic-format SSE events. It wraps transformer.StreamHandler and +// dispatches by WireFormat. +type StreamProxy struct { + handler *transformer.StreamHandler +} + +// NewStreamProxy creates a new StreamProxy. +func NewStreamProxy() *StreamProxy { + return &StreamProxy{ + handler: transformer.NewStreamHandler(), + } +} + +// ProxyStream proxies an upstream SSE stream to the response writer, transforming +// events from the wire format to Anthropic SSE events. +func (sp *StreamProxy) ProxyStream( + w http.ResponseWriter, + body io.ReadCloser, + wireFormat core.WireFormat, + modelID string, + clientCtx context.Context, + idleTimeout time.Duration, + cancel context.CancelFunc, +) error { + switch wireFormat { + case core.WireFormatAnthropic: + return sp.proxyAnthropicPassthroughStream(w, body, idleTimeout, clientCtx, cancel) + default: + return sp.proxyOpenAIStream(w, body, modelID, clientCtx, idleTimeout, cancel) + case core.WireFormatOpenAIResponses: + return sp.handler.ProxyResponsesStream(w, body, modelID, clientCtx, idleTimeout, cancel) + case core.WireFormatGemini: + return sp.handler.ProxyGeminiStream(w, body, modelID, clientCtx, idleTimeout, cancel) + } +} + +// proxyOpenAIStream delegates to the transformer's ProxyStream. +func (sp *StreamProxy) proxyOpenAIStream( + w http.ResponseWriter, + body io.ReadCloser, + modelID string, + clientCtx context.Context, + idleTimeout time.Duration, + cancel context.CancelFunc, +) error { + return sp.handler.ProxyStream(w, body, modelID, clientCtx, idleTimeout, cancel) +} + +// proxyAnthropicPassthroughStream forwards raw Anthropic SSE bytes directly to +// the client, with an idle watchdog. No transformation is needed since the +// upstream already speaks Anthropic format. +func (sp *StreamProxy) proxyAnthropicPassthroughStream( + w http.ResponseWriter, + body io.ReadCloser, + idleTimeout time.Duration, + clientCtx context.Context, + cancel context.CancelFunc, +) error { + defer body.Close() + defer cancel() + + buf := make([]byte, 4096) + ping := transformer.StartIdleWatchdog(clientCtx, cancel, idleTimeout) + for { + select { + case <-clientCtx.Done(): + if clientCtx.Err() == nil { + return transformer.ErrStreamIdle + } + return transformer.ErrClientDisconnected + default: + } + n, rerr := body.Read(buf) + if n > 0 { + ping() + if _, werr := w.Write(buf[:n]); werr != nil { + return transformer.ErrClientDisconnected + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + if rerr == io.EOF { + return nil + } + if rerr != nil { + if transformer.IsIdleTimeout(rerr) { + return transformer.ErrStreamIdle + } + if errors.Is(rerr, context.Canceled) || clientCtx.Err() == context.Canceled { + if clientCtx.Err() == nil { + return transformer.ErrStreamIdle + } + return transformer.ErrClientDisconnected + } + return fmt.Errorf("failed to copy response: %w", rerr) + } + } +} diff --git a/internal/provider/opencode_go.go b/internal/provider/opencode_go.go new file mode 100644 index 0000000..4161733 --- /dev/null +++ b/internal/provider/opencode_go.go @@ -0,0 +1,279 @@ +package provider + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/routatic/proxy/internal/config" + "github.com/routatic/proxy/internal/core" + "github.com/routatic/proxy/internal/transformer" + "github.com/routatic/proxy/pkg/types" +) + +// OpenCodeGoProvider implements core.Provider for the OpenCode Go backend. +type OpenCodeGoProvider struct { + baseProvider +} + +// NewOpenCodeGoProvider creates a new OpenCodeGoProvider. +func NewOpenCodeGoProvider(atomic *config.AtomicConfig) *OpenCodeGoProvider { + return &OpenCodeGoProvider{baseProvider: newBaseProvider(atomic)} +} + +// Name returns the provider identifier. +func (p *OpenCodeGoProvider) Name() string { return "opencode-go" } + +// Capabilities returns provider-level capabilities. +func (p *OpenCodeGoProvider) Capabilities() core.ProviderCapabilities { + return core.ProviderCapabilities{ + SupportsStreaming: true, + SupportsTools: true, + SupportsThinking: true, + SupportsImageInput: true, + MaxContextLength: 128_000, + DefaultMaxTokens: 4096, + } +} + +// ModelCapabilities returns per-model capabilities. Returns false if unknown. +func (p *OpenCodeGoProvider) ModelCapabilities(modelID string) (core.ProviderCapabilities, bool) { + caps := p.Capabilities() + // qwen3.7-max has a larger context window on the Go provider. + if modelID == "qwen3.7-max" { + caps.MaxContextLength = 1_000_000 + } + // MiniMax models support 1M context. + switch modelID { + case "minimax-m2.5", "minimax-m2.7", "minimax-m3": + caps.MaxContextLength = 1_000_000 + } + return caps, true +} + +// WireFormat returns the wire format for the given model on the Go provider. +func (p *OpenCodeGoProvider) WireFormat(modelID string) core.WireFormat { + if isAnthropicNativeGo(modelID) { + return core.WireFormatAnthropic + } + return core.WireFormatOpenAIChat +} + +// isAnthropicNativeGo returns true for Go provider models that require the +// Anthropic Messages endpoint rather than the OpenAI Chat Completions endpoint. +func isAnthropicNativeGo(modelID string) bool { + return modelID == "qwen3.7-max" +} + +// RoundTripName returns the model ID to use in the upstream request. +func (p *OpenCodeGoProvider) RoundTripName(model config.ModelConfig) string { + return model.ModelID +} + +// StreamIdleTimeout returns the maximum gap between bytes on an active stream. +func (p *OpenCodeGoProvider) StreamIdleTimeout(model config.ModelConfig) time.Duration { + const fallback = 5 * time.Minute + cfg := p.atomic.Get() + ms := cfg.OpenCodeGo.StreamTimeoutMs + if ms <= 0 { + ms = cfg.OpenCodeGo.TimeoutMs + } + if ms <= 0 { + return fallback + } + return time.Duration(ms) * time.Millisecond +} + +// Execute sends a non-streaming request and returns the response. +func (p *OpenCodeGoProvider) Execute(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (*core.ExecuteResult, error) { + switch p.WireFormat(model.ModelID) { + case core.WireFormatAnthropic: + return p.executeAnthropic(ctx, req, model) + default: + return p.executeOpenAI(ctx, req, model) + } +} + +// Stream sends a streaming request and returns an io.ReadCloser for SSE events. +func (p *OpenCodeGoProvider) Stream(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (io.ReadCloser, error) { + switch p.WireFormat(model.ModelID) { + case core.WireFormatAnthropic: + return p.streamAnthropic(ctx, req, model) + default: + return p.streamOpenAI(ctx, req, model) + } +} + +// ── OpenAI Chat Completions ──────────────────────────────────────────── + +func (p *OpenCodeGoProvider) executeOpenAI(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (*core.ExecuteResult, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeGo.BaseURL + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + openaiReq := transformer.TransformRequestFromNormalized(req, model) + streamFalse := false + openaiReq.Stream = &streamFalse + + start := time.Now() + resp, err := p.doRequest(ctx, endpoint, apiKey, openaiReq, false) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + var chatResp types.ChatCompletionResponse + if err := json.Unmarshal(body, &chatResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + normResp := transformer.OpenAIResponseToNormalized(&chatResp, model.ModelID) + anthropicResp := core.DenormalizeResponse(normResp) + resultBody, err := json.Marshal(anthropicResp) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return &core.ExecuteResult{ + Body: resultBody, + ModelID: model.ModelID, + Latency: time.Since(start), + }, nil +} + +func (p *OpenCodeGoProvider) streamOpenAI(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (io.ReadCloser, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeGo.BaseURL + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + openaiReq := transformer.TransformRequestFromNormalized(req, model) + streamTrue := true + openaiReq.Stream = &streamTrue + + resp, err := p.doRequest(ctx, endpoint, apiKey, openaiReq, true) + if err != nil { + return nil, err + } + + return resp.Body, nil +} + +// ── Anthropic Messages ──────────────────────────────────────────────── + +func (p *OpenCodeGoProvider) executeAnthropic(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (*core.ExecuteResult, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeGo.AnthropicBaseURL + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + anthropicReq := transformer.NormalizedToAnthropic(req, model) + rawBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal anthropic request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(rawBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+apiKey) + httpReq.Header.Set("x-api-key", apiKey) + + start := time.Now() + resp, err := p.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= http.StatusBadRequest { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + return &core.ExecuteResult{ + Body: body, + ModelID: model.ModelID, + Latency: time.Since(start), + }, nil +} + +func (p *OpenCodeGoProvider) streamAnthropic(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (io.ReadCloser, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeGo.AnthropicBaseURL + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + anthropicReq := transformer.NormalizedToAnthropic(req, model) + rawBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal anthropic request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(rawBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+apiKey) + httpReq.Header.Set("x-api-key", apiKey) + httpReq.Header.Set("Accept", "text/event-stream") + + resp, err := p.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if resp.StatusCode >= http.StatusBadRequest { + bodyBytes, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + } + + return resp.Body, nil +} + +// ── HTTP helpers ────────────────────────────────────────────────────── + +func (p *OpenCodeGoProvider) doRequest(ctx context.Context, endpoint, apiKey string, req any, stream bool) (*http.Response, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+apiKey) + if stream { + httpReq.Header.Set("Accept", "text/event-stream") + } + + resp, err := p.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if resp.StatusCode >= http.StatusBadRequest { + bodyBytes, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + } + + return resp, nil +} diff --git a/internal/provider/opencode_zen.go b/internal/provider/opencode_zen.go new file mode 100644 index 0000000..0664514 --- /dev/null +++ b/internal/provider/opencode_zen.go @@ -0,0 +1,470 @@ +package provider + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/routatic/proxy/internal/config" + "github.com/routatic/proxy/internal/core" + "github.com/routatic/proxy/internal/transformer" + "github.com/routatic/proxy/pkg/types" +) + +// OpenCodeZenProvider implements core.Provider for the OpenCode Zen backend. +// Zen supports four wire formats determined by model ID: Anthropic (Claude, +// Qwen), Responses (GPT), Gemini, and Chat Completions (everything else). +type OpenCodeZenProvider struct { + baseProvider +} + +// NewOpenCodeZenProvider creates a new OpenCodeZenProvider. +func NewOpenCodeZenProvider(atomic *config.AtomicConfig) *OpenCodeZenProvider { + return &OpenCodeZenProvider{baseProvider: newBaseProvider(atomic)} +} + +// Name returns the provider identifier. +func (p *OpenCodeZenProvider) Name() string { return "opencode-zen" } + +// Capabilities returns provider-level capabilities. +func (p *OpenCodeZenProvider) Capabilities() core.ProviderCapabilities { + return core.ProviderCapabilities{ + SupportsStreaming: true, + SupportsTools: true, + SupportsThinking: true, + SupportsImageInput: true, + MaxContextLength: 200_000, + DefaultMaxTokens: 4096, + } +} + +// ModelCapabilities returns per-model capabilities. +func (p *OpenCodeZenProvider) ModelCapabilities(modelID string) (core.ProviderCapabilities, bool) { + caps := p.Capabilities() + switch { + case strings.HasPrefix(modelID, "claude-"): + caps.MaxContextLength = 200_000 + case strings.HasPrefix(modelID, "gemini-"): + caps.MaxContextLength = 1_000_000 + case strings.HasPrefix(modelID, "gpt-"): + caps.MaxContextLength = 128_000 + caps.SupportsThinking = false + case strings.HasPrefix(modelID, "minimax-"): + caps.MaxContextLength = 1_000_000 + case strings.HasPrefix(modelID, "deepseek-"): + caps.MaxContextLength = 1_000_000 + } + return caps, true +} + +// WireFormat returns the wire format for the given model on Zen. +// This replaces the old client.ClassifyEndpoint function. +func (p *OpenCodeZenProvider) WireFormat(modelID string) core.WireFormat { + switch { + case isZenAnthropicModel(modelID): + return core.WireFormatAnthropic + case isGeminiModel(modelID): + return core.WireFormatGemini + case isResponsesModel(modelID): + return core.WireFormatOpenAIResponses + default: + return core.WireFormatOpenAIChat + } +} + +// isZenAnthropicModel returns true for Zen models that use the Anthropic endpoint. +func isZenAnthropicModel(modelID string) bool { + if strings.HasPrefix(modelID, "claude-") { + return true + } + if strings.HasPrefix(modelID, "qwen") { + return true + } + return false +} + +// isGeminiModel returns true for models using the Gemini endpoint. +func isGeminiModel(modelID string) bool { + switch modelID { + case "gemini-3.5-flash", "gemini-3.1-pro", "gemini-3-flash": + return true + default: + return false + } +} + +// isResponsesModel returns true for models using the OpenAI Responses endpoint. +func isResponsesModel(modelID string) bool { + switch modelID { + case "gpt-5.5", "gpt-5.5-pro", "gpt-5.4", "gpt-5.4-pro", "gpt-5.4-mini", "gpt-5.4-nano", + "gpt-5.3-codex", "gpt-5.3-codex-spark", "gpt-5.2", "gpt-5.2-codex", + "gpt-5.1", "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", + "gpt-5", "gpt-5-codex", "gpt-5-nano": + return true + default: + return false + } +} + +// RoundTripName returns the model ID to use in the upstream request. +func (p *OpenCodeZenProvider) RoundTripName(model config.ModelConfig) string { + return model.ModelID +} + +// StreamIdleTimeout returns the maximum gap between bytes on an active stream. +func (p *OpenCodeZenProvider) StreamIdleTimeout(model config.ModelConfig) time.Duration { + const fallback = 5 * time.Minute + cfg := p.atomic.Get() + ms := cfg.OpenCodeZen.StreamTimeoutMs + if ms <= 0 { + ms = cfg.OpenCodeZen.TimeoutMs + } + if ms <= 0 { + return fallback + } + return time.Duration(ms) * time.Millisecond +} + +// Execute sends a non-streaming request and returns the response. +func (p *OpenCodeZenProvider) Execute(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (*core.ExecuteResult, error) { + switch p.WireFormat(model.ModelID) { + case core.WireFormatAnthropic: + return p.executeAnthropic(ctx, req, model) + case core.WireFormatOpenAIResponses: + return p.executeResponses(ctx, req, model) + case core.WireFormatGemini: + return p.executeGemini(ctx, req, model) + default: + return p.executeOpenAI(ctx, req, model) + } +} + +// Stream sends a streaming request and returns an io.ReadCloser for SSE events. +func (p *OpenCodeZenProvider) Stream(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (io.ReadCloser, error) { + switch p.WireFormat(model.ModelID) { + case core.WireFormatAnthropic: + return p.streamAnthropic(ctx, req, model) + case core.WireFormatOpenAIResponses: + return p.streamResponses(ctx, req, model) + case core.WireFormatGemini: + return p.streamGemini(ctx, req, model) + default: + return p.streamOpenAI(ctx, req, model) + } +} + +// ── OpenAI Chat Completions ──────────────────────────────────────────── + +func (p *OpenCodeZenProvider) executeOpenAI(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (*core.ExecuteResult, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeZen.BaseURL + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + openaiReq := transformer.TransformRequestFromNormalized(req, model) + streamFalse := false + openaiReq.Stream = &streamFalse + + start := time.Now() + resp, err := p.doRequest(ctx, endpoint, apiKey, openaiReq, false) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + var chatResp types.ChatCompletionResponse + if err := json.Unmarshal(body, &chatResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + normResp := transformer.OpenAIResponseToNormalized(&chatResp, model.ModelID) + anthropicResp := core.DenormalizeResponse(normResp) + resultBody, err := json.Marshal(anthropicResp) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return &core.ExecuteResult{ + Body: resultBody, + ModelID: model.ModelID, + Latency: time.Since(start), + }, nil +} + +func (p *OpenCodeZenProvider) streamOpenAI(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (io.ReadCloser, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeZen.BaseURL + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + openaiReq := transformer.TransformRequestFromNormalized(req, model) + streamTrue := true + openaiReq.Stream = &streamTrue + + resp, err := p.doRequest(ctx, endpoint, apiKey, openaiReq, true) + if err != nil { + return nil, err + } + + return resp.Body, nil +} + +// ── Anthropic Messages ──────────────────────────────────────────────── + +func (p *OpenCodeZenProvider) executeAnthropic(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (*core.ExecuteResult, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeZen.AnthropicBaseURL + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + anthropicReq := transformer.NormalizedToAnthropic(req, model) + rawBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(rawBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+apiKey) + httpReq.Header.Set("x-api-key", apiKey) + + start := time.Now() + resp, err := p.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= http.StatusBadRequest { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + return &core.ExecuteResult{ + Body: body, + ModelID: model.ModelID, + Latency: time.Since(start), + }, nil +} + +func (p *OpenCodeZenProvider) streamAnthropic(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (io.ReadCloser, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeZen.AnthropicBaseURL + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + anthropicReq := transformer.NormalizedToAnthropic(req, model) + rawBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(rawBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+apiKey) + httpReq.Header.Set("x-api-key", apiKey) + httpReq.Header.Set("Accept", "text/event-stream") + + resp, err := p.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if resp.StatusCode >= http.StatusBadRequest { + bodyBytes, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + } + + return resp.Body, nil +} + +// ── OpenAI Responses ────────────────────────────────────────────────── + +func (p *OpenCodeZenProvider) executeResponses(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (*core.ExecuteResult, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeZen.ResponsesBaseURL + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + responsesReq := transformer.NormalizedToResponses(req, model) + responsesReq.Stream = false + + start := time.Now() + resp, err := p.doJSONRequest(ctx, endpoint, apiKey, responsesReq) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + var responsesResp types.ResponsesResponse + if err := json.Unmarshal(body, &responsesResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + normResp := transformer.ResponsesToNormalized(&responsesResp, model.ModelID) + anthropicResp := core.DenormalizeResponse(normResp) + resultBody, err := json.Marshal(anthropicResp) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return &core.ExecuteResult{ + Body: resultBody, + ModelID: model.ModelID, + Latency: time.Since(start), + }, nil +} + +func (p *OpenCodeZenProvider) streamResponses(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (io.ReadCloser, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeZen.ResponsesBaseURL + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + responsesReq := transformer.NormalizedToResponses(req, model) + responsesReq.Stream = true + + resp, err := p.doJSONRequest(ctx, endpoint, apiKey, responsesReq) + if err != nil { + return nil, err + } + + return resp.Body, nil +} + +// ── Gemini ──────────────────────────────────────────────────────────── + +func (p *OpenCodeZenProvider) executeGemini(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (*core.ExecuteResult, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeZen.GeminiBaseURL + "/" + model.ModelID + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + geminiReq := transformer.NormalizedToGemini(req, model) + geminiReq.Stream = false + + start := time.Now() + resp, err := p.doJSONRequest(ctx, endpoint, apiKey, geminiReq) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + var geminiResp types.GeminiResponse + if err := json.Unmarshal(body, &geminiResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + normResp := transformer.GeminiToNormalized(&geminiResp, model.ModelID) + anthropicResp := core.DenormalizeResponse(normResp) + resultBody, err := json.Marshal(anthropicResp) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return &core.ExecuteResult{ + Body: resultBody, + ModelID: model.ModelID, + Latency: time.Since(start), + }, nil +} + +func (p *OpenCodeZenProvider) streamGemini(ctx context.Context, req *core.NormalizedRequest, model config.ModelConfig) (io.ReadCloser, error) { + cfg := p.atomic.Get() + endpoint := cfg.OpenCodeZen.GeminiBaseURL + "/" + model.ModelID + apiKey := p.nextAPIKey(cfg.EffectiveAPIKeys()) + + geminiReq := transformer.NormalizedToGemini(req, model) + geminiReq.Stream = true + + resp, err := p.doJSONRequest(ctx, endpoint, apiKey, geminiReq) + if err != nil { + return nil, err + } + + return resp.Body, nil +} + +// ── HTTP helpers ────────────────────────────────────────────────────── + +func (p *OpenCodeZenProvider) doRequest(ctx context.Context, endpoint, apiKey string, req any, stream bool) (*http.Response, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+apiKey) + if stream { + httpReq.Header.Set("Accept", "text/event-stream") + } + + resp, err := p.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if resp.StatusCode >= http.StatusBadRequest { + bodyBytes, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + } + + return resp, nil +} + +func (p *OpenCodeZenProvider) doJSONRequest(ctx context.Context, endpoint, apiKey string, req any) (*http.Response, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := p.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if resp.StatusCode >= http.StatusBadRequest { + bodyBytes, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + } + + return resp, nil +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go new file mode 100644 index 0000000..683fc50 --- /dev/null +++ b/internal/provider/provider.go @@ -0,0 +1,48 @@ +// Package provider implements the core.Provider interface for all supported +// upstream LLM providers. +package provider + +import ( + "net/http" + "sync/atomic" + "time" + + "github.com/routatic/proxy/internal/config" +) + +// baseProvider holds shared HTTP transport and key rotation used by all +// provider implementations in this package. +type baseProvider struct { + atomic *config.AtomicConfig + httpClient *http.Client + keyCounter atomic.Uint64 +} + +// newBaseProvider creates a baseProvider with a shared HTTP transport tuned +// for high-concurrency upstream calls. +func newBaseProvider(atomic *config.AtomicConfig) baseProvider { + transport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, + IdleConnTimeout: 90 * time.Second, + MaxConnsPerHost: 50, + DisableKeepAlives: false, + Proxy: http.ProxyFromEnvironment, + } + return baseProvider{ + atomic: atomic, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +// nextAPIKey returns the next API key in round-robin order from the given pool. +func (b *baseProvider) nextAPIKey(keys []string) string { + if len(keys) == 0 { + return "" + } + n := uint64(len(keys)) + old := b.keyCounter.Add(1) + return keys[(old-1)%n] +} diff --git a/internal/server/server.go b/internal/server/server.go index 2aba183..2b96b53 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -13,8 +13,10 @@ import ( "github.com/routatic/proxy/internal/client" "github.com/routatic/proxy/internal/config" + "github.com/routatic/proxy/internal/core" "github.com/routatic/proxy/internal/handlers" "github.com/routatic/proxy/internal/metrics" + "github.com/routatic/proxy/internal/provider" "github.com/routatic/proxy/internal/router" "github.com/routatic/proxy/internal/token" ) @@ -51,9 +53,15 @@ func NewServer(atomic *config.AtomicConfig) (*Server, error) { modelRouter := router.NewModelRouter(atomic) fallbackHandler := router.NewFallbackHandler(logger, 3, 30*time.Second) + // Register providers. + providerRegistry := core.NewProviderRegistry() + _ = providerRegistry.Register(provider.NewOpenCodeGoProvider(atomic)) + _ = providerRegistry.Register(provider.NewOpenCodeZenProvider(atomic)) + // Create handlers. messagesHandler := handlers.NewMessagesHandler( openCodeClient, + providerRegistry, modelRouter, fallbackHandler, tokenCounter, diff --git a/internal/transformer/normalized_bridge.go b/internal/transformer/normalized_bridge.go new file mode 100644 index 0000000..a78afe2 --- /dev/null +++ b/internal/transformer/normalized_bridge.go @@ -0,0 +1,379 @@ +package transformer + +import ( + "encoding/json" + + "github.com/routatic/proxy/internal/config" + "github.com/routatic/proxy/internal/core" + "github.com/routatic/proxy/pkg/types" +) + +// ── Request-side: NormalizedRequest → wire format ───────────────────── + +// TransformRequestFromNormalized converts a NormalizedRequest to OpenAI +// ChatCompletionRequest by first reconstructing the Anthropic format and +// running it through the existing TransformRequest pipeline. +func TransformRequestFromNormalized(req *core.NormalizedRequest, model config.ModelConfig) *types.ChatCompletionRequest { + anthropicReq := normalizedToMessageRequest(req) + t := NewRequestTransformer() + openaiReq, err := t.TransformRequest(anthropicReq, model) + if err != nil { + // The Anthropic reconstruction should never fail for valid normalized + // requests, but if it does, return a minimal valid request so the + // upstream gets a usable payload rather than a nil pointer. + stream := req.Stream + maxTokens := req.MaxTokens + return &types.ChatCompletionRequest{ + Model: model.ModelID, + Messages: []types.ChatMessage{{Role: "user", Content: types.TextContent(req.SystemPrompt + "\n" + joinMessageText(req.Messages))}}, + Stream: &stream, + MaxTokens: &maxTokens, + } + } + return openaiReq +} + +// NormalizedToAnthropic converts a NormalizedRequest to an Anthropic MessageRequest. +func NormalizedToAnthropic(req *core.NormalizedRequest, model config.ModelConfig) *types.MessageRequest { + anthropicReq := normalizedToMessageRequest(req) + // Override model ID with the config's model ID. + anthropicReq.Model = model.ModelID + return anthropicReq +} + +// NormalizedToResponses converts a NormalizedRequest to a ResponsesRequest. +func NormalizedToResponses(req *core.NormalizedRequest, model config.ModelConfig) *types.ResponsesRequest { + responsesReq := &types.ResponsesRequest{ + Model: model.ModelID, + } + + // System prompt becomes a "developer" role input. + if req.SystemPrompt != "" { + responsesReq.Input = append(responsesReq.Input, types.ResponsesInput{ + Role: "developer", + Content: json.RawMessage(`"` + req.SystemPrompt + `"`), + }) + } + + // Convert messages. + for _, msg := range req.Messages { + input := types.ResponsesInput{Role: msg.Role} + content := msg.Content + + // For assistant messages with tool calls, serialize as text. + if len(msg.ToolCalls) > 0 { + for _, tc := range msg.ToolCalls { + content += "[Tool: " + tc.Name + "(" + tc.Arguments + ")]" + } + } + + if content != "" { + input.Content = json.RawMessage(`"` + content + `"`) + } + responsesReq.Input = append(responsesReq.Input, input) + } + + // Convert tools. + for _, tool := range req.Tools { + responsesReq.Tools = append(responsesReq.Tools, types.ResponsesTool{ + Type: "function", + Name: tool.Name, + Description: tool.Description, + Parameters: tool.InputSchema, + }) + } + + return responsesReq +} + +// NormalizedToGemini converts a NormalizedRequest to a GeminiRequest. +func NormalizedToGemini(req *core.NormalizedRequest, model config.ModelConfig) *types.GeminiRequest { + geminiReq := &types.GeminiRequest{ + GenerationConfig: &types.GeminiGenerationConfig{ + MaxOutputTokens: req.MaxTokens, + }, + } + + if req.Temperature != nil { + geminiReq.GenerationConfig.Temperature = *req.Temperature + } + + // System prompt is prepended as a user message (Gemini has no system role). + var contents []types.GeminiContent + if req.SystemPrompt != "" { + contents = append(contents, types.GeminiContent{ + Role: "user", + Parts: []types.GeminiPart{{Text: req.SystemPrompt}}, + }) + } + + // Convert messages. + for _, msg := range req.Messages { + gc := types.GeminiContent{Role: msg.Role} + gc.Parts = append(gc.Parts, types.GeminiPart{Text: msg.Content}) + contents = append(contents, gc) + } + + geminiReq.Contents = contents + + // Convert tools. + if len(req.Tools) > 0 { + var functions []types.GeminiFunctionDeclaration + for _, tool := range req.Tools { + functions = append(functions, types.GeminiFunctionDeclaration{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.InputSchema, + }) + } + geminiReq.Tools = []types.GeminiTool{ + {FunctionDeclarations: functions}, + } + } + + return geminiReq +} + +// ── Response-side: wire format → NormalizedResponse ─────────────────── + +// OpenAIResponseToNormalized converts an OpenAI ChatCompletionResponse to NormalizedResponse. +func OpenAIResponseToNormalized(openaiResp *types.ChatCompletionResponse, modelID string) *core.NormalizedResponse { + nr := &core.NormalizedResponse{ + ID: openaiResp.ID, + Model: modelID, + } + + for _, choice := range openaiResp.Choices { + msg := choice.Message + + nm := core.NormalizedMessage{Role: msg.Role} + + // Extract text content. + if msg.Content != nil { + nm.Content = msg.ContentText() + } + + // Extract reasoning content (pointer field). + if msg.ReasoningContent != nil { + nm.Thinking = *msg.ReasoningContent + } + + // Extract tool calls. + for _, tc := range msg.ToolCalls { + nm.ToolCalls = append(nm.ToolCalls, core.NormalizedToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }) + } + + nr.Messages = append(nr.Messages, nm) + + // Map finish reason. + switch choice.FinishReason { + case "stop": + nr.StopReason = "end_turn" + case "length": + nr.StopReason = "max_tokens" + case "tool_calls": + nr.StopReason = "tool_use" + default: + nr.StopReason = "end_turn" + } + } + + // Map usage. UsageInfo is a value type; check if it was populated. + if openaiResp.Usage.PromptTokens > 0 || openaiResp.Usage.CompletionTokens > 0 { + nr.Usage = core.NormalizedUsage{ + InputTokens: openaiResp.Usage.PromptTokens, + OutputTokens: openaiResp.Usage.CompletionTokens, + CacheReadTokens: openaiResp.Usage.PromptCacheHitTokens, + CacheCreationTokens: openaiResp.Usage.PromptCacheMissTokens, + } + } + + return nr +} + +// ResponsesToNormalized converts an OpenAI ResponsesResponse to NormalizedResponse. +func ResponsesToNormalized(responsesResp *types.ResponsesResponse, modelID string) *core.NormalizedResponse { + nr := &core.NormalizedResponse{ + ID: responsesResp.ID, + Model: modelID, + } + + for _, output := range responsesResp.Output { + switch output.Type { + case "message": + nm := core.NormalizedMessage{Role: output.Role} + for _, c := range output.Content { + if c.Type == "output_text" { + nm.Content += c.Text + } + } + nr.Messages = append(nr.Messages, nm) + case "function_call": + nm := core.NormalizedMessage{ + Role: "assistant", + ToolCalls: []core.NormalizedToolCall{ + { + ID: output.CallID, + Name: output.Name, + Arguments: output.Arguments, + }, + }, + } + nr.Messages = append(nr.Messages, nm) + } + } + + nr.StopReason = "end_turn" + + nr.Usage = core.NormalizedUsage{ + InputTokens: responsesResp.Usage.InputTokens, + OutputTokens: responsesResp.Usage.OutputTokens, + } + + return nr +} + +// GeminiToNormalized converts a GeminiResponse to NormalizedResponse. +func GeminiToNormalized(geminiResp *types.GeminiResponse, modelID string) *core.NormalizedResponse { + nr := &core.NormalizedResponse{ + Model: modelID, + } + + if len(geminiResp.Candidates) > 0 { + candidate := geminiResp.Candidates[0] + nm := core.NormalizedMessage{Role: candidate.Content.Role} + + for _, part := range candidate.Content.Parts { + if part.Text != "" { + nm.Content += part.Text + } + } + + nr.Messages = append(nr.Messages, nm) + + switch candidate.FinishReason { + case "STOP": + nr.StopReason = "end_turn" + case "MAX_TOKENS": + nr.StopReason = "max_tokens" + default: + nr.StopReason = "end_turn" + } + } + + if geminiResp.UsageMetadata != nil { + nr.Usage = core.NormalizedUsage{ + InputTokens: geminiResp.UsageMetadata.PromptTokenCount, + OutputTokens: geminiResp.UsageMetadata.CandidatesTokenCount, + } + } + + return nr +} + +// ── Helpers ─────────────────────────────────────────────────────────── + +// normalizedToMessageRequest reconstructs an Anthropic MessageRequest from a +// NormalizedRequest. This is used as input to the existing TransformRequest +// pipeline. +func normalizedToMessageRequest(req *core.NormalizedRequest) *types.MessageRequest { + anthropicReq := &types.MessageRequest{ + Model: req.Model, + MaxTokens: req.MaxTokens, + } + + // Set system prompt. + if req.SystemPrompt != "" { + anthropicReq.System = json.RawMessage(`"` + req.SystemPrompt + `"`) + } + + // Set stream. + if req.Stream { + t := true + anthropicReq.Stream = &t + } + + // Set temperature. + if req.Temperature != nil { + anthropicReq.Temperature = req.Temperature + } + + // Set thinking. + if req.ReasoningEffort != "" || req.ThinkingBudget > 0 { + tc := map[string]any{ + "type": req.ReasoningEffort, + "budget_tokens": req.ThinkingBudget, + } + if b, err := json.Marshal(tc); err == nil { + anthropicReq.Thinking = b + } + } + + // Convert messages. + for _, nm := range req.Messages { + msg := types.Message{Role: nm.Role} + + var blocks []types.ContentBlock + if nm.Content != "" { + blocks = append(blocks, types.ContentBlock{Type: "text", Text: nm.Content}) + } + if nm.Thinking != "" { + blocks = append(blocks, types.ContentBlock{Type: "thinking", Thinking: nm.Thinking}) + } + for _, tc := range nm.ToolCalls { + blocks = append(blocks, types.ContentBlock{ + Type: "tool_use", + ID: tc.ID, + Name: tc.Name, + Input: []byte(tc.Arguments), + }) + } + if nm.ToolCallID != "" { + content, _ := json.Marshal(nm.Content) + blocks = append(blocks, types.ContentBlock{ + Type: "tool_result", + ToolUseID: nm.ToolCallID, + Content: content, + }) + } + + if len(blocks) > 0 { + b, _ := json.Marshal(blocks) + msg.Content = b + } else { + msg.Content = json.RawMessage(`""`) + } + + anthropicReq.Messages = append(anthropicReq.Messages, msg) + } + + // Convert tools. + for _, nt := range req.Tools { + anthropicReq.Tools = append(anthropicReq.Tools, types.Tool{ + Name: nt.Name, + Description: nt.Description, + InputSchema: nt.InputSchema, + }) + } + + return anthropicReq +} + +// joinMessageText concatenates the content of all messages for use as a +// fallback when the transform pipeline fails. +func joinMessageText(messages []core.NormalizedMessage) string { + var text string + for _, m := range messages { + if m.Content != "" { + if text != "" { + text += "\n" + } + text += m.Role + ": " + m.Content + } + } + return text +} From 0e1bc8ba52bdaf2dee48761e607f15a715c34bb6 Mon Sep 17 00:00:00 2001 From: samuel tuyizere Date: Fri, 19 Jun 2026 12:34:29 +0200 Subject: [PATCH 2/7] feat: Add routing policy engine and implement model override and scenario policies --- .github/workflows/release.yml | 8 +- internal/handlers/messages.go | 7 +- internal/router/policy.go | 174 ++++++++++++++++++++++++++++++++++ 3 files changed, 183 insertions(+), 6 deletions(-) create mode 100644 internal/router/policy.go diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 42a219a..25803fb 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -248,6 +248,12 @@ jobs: type=semver,pattern={{major}},value=${{ needs.release.outputs.tag }} type=raw,value=latest + - name: Strip v prefix from version + id: version + env: + TAG: ${{ needs.release.outputs.tag }} + run: echo "value=${TAG#v}" >> "$GITHUB_OUTPUT" + - uses: docker/build-push-action@v6 with: context: . @@ -255,7 +261,7 @@ jobs: push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} - build-args: VERSION=${{ needs.release.outputs.tag | trimPrefix 'v' }} + build-args: VERSION=${{ steps.version.outputs.value }} cache-from: type=gha cache-to: type=gha,mode=max diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index 06d4aa6..99b9b42 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -333,6 +333,7 @@ func (h *MessagesHandler) handleStreaming( w http.ResponseWriter, r *http.Request, anthropicReq *types.MessageRequest, + normalizedReq *core.NormalizedRequest, modelChain []config.ModelConfig, rawBody json.RawMessage, ) { @@ -436,9 +437,6 @@ func (h *MessagesHandler) handleStreaming( // Try new provider-based dispatch first. if prov, ok := h.providerRegistry.Get(model.Provider); ok { - normalizedReq := core.NormalizeRequest(anthropicReq) - normalizedReq.Stream = true - caps, ok := prov.ModelCapabilities(model.ModelID) if !ok || !caps.SupportsStreaming { h.logger.Warn("model does not support streaming", "model", model.ModelID, "provider", model.Provider) @@ -818,6 +816,7 @@ func (h *MessagesHandler) handleNonStreaming( w http.ResponseWriter, r *http.Request, anthropicReq *types.MessageRequest, + normalizedReq *core.NormalizedRequest, modelChain []config.ModelConfig, rawBody json.RawMessage, ) { @@ -830,8 +829,6 @@ func (h *MessagesHandler) handleNonStreaming( func(ctx context.Context, model config.ModelConfig) ([]byte, error) { // Try new provider-based dispatch first. if prov, ok := h.providerRegistry.Get(model.Provider); ok { - normalizedReq := core.NormalizeRequest(anthropicReq) - normalizedReq.Stream = false execResult, execErr := prov.Execute(ctx, normalizedReq, model) if execErr != nil { return nil, execErr diff --git a/internal/router/policy.go b/internal/router/policy.go new file mode 100644 index 0000000..5b078dd --- /dev/null +++ b/internal/router/policy.go @@ -0,0 +1,174 @@ +package router + +import ( + "fmt" + + "github.com/routatic/proxy/internal/config" + "github.com/routatic/proxy/internal/core" +) + +// EvaluationContext carries all information needed to evaluate routing policies. +type EvaluationContext struct { + Request *core.NormalizedRequest + TokenCount int + AvailableModels []config.ModelConfig + History []RouteDecision +} + +// RouteDecision records a routing decision for observability. +type RouteDecision struct { + PolicyName string + ModelID string + Provider string + Reason string + Weight int +} + +// Policy evaluates a routing strategy and selects a model chain. +type Policy interface { + // Name returns the policy identifier. + Name() string + + // Evaluate examines the context and returns the model chain to try, the + // decision explanation, or an error if no model matches this policy. + Evaluate(ctx *EvaluationContext) ([]config.ModelConfig, RouteDecision, error) +} + +// PolicyEngine composes multiple policies with ordered evaluation. Policies +// are evaluated in registration order; the first policy that returns a +// non-empty chain wins. +type PolicyEngine struct { + policies []Policy +} + +// NewPolicyEngine creates a policy engine with the default set of policies: +// 1. ModelOverridePolicy — check model_overrides config entries +// 2. RespectRequestModelPolicy — check respect_requested_model config +// 3. ScenarioPolicy — scenario-based routing (existing DetectScenario logic) +func NewPolicyEngine() *PolicyEngine { + return &PolicyEngine{} +} + +// AddPolicy appends a policy to the evaluation chain. +func (eng *PolicyEngine) AddPolicy(p Policy) { + eng.policies = append(eng.policies, p) +} + +// Evaluate runs each policy in order and returns the first successful result. +func (eng *PolicyEngine) Evaluate(ctx *EvaluationContext) ([]config.ModelConfig, RouteDecision, error) { + for _, p := range eng.policies { + chain, decision, err := p.Evaluate(ctx) + if err != nil { + continue + } + if len(chain) > 0 { + return chain, decision, nil + } + } + return nil, RouteDecision{}, fmt.Errorf("no policy could route the request") +} + +// EvaluateDryRun returns all policy decisions without executing. Useful for +// debugging and the dry-run endpoint. +func (eng *PolicyEngine) EvaluateDryRun(ctx *EvaluationContext) []RouteDecision { + var decisions []RouteDecision + for _, p := range eng.policies { + _, decision, err := p.Evaluate(ctx) + if err != nil { + decisions = append(decisions, RouteDecision{ + PolicyName: p.Name(), + Reason: err.Error(), + }) + continue + } + decisions = append(decisions, decision) + } + return decisions +} + +// ── ModelOverridePolicy ─────────────────────────────────────────────── + +// ModelOverridePolicy checks whether the requested model has an entry in +// model_overrides. If so, it uses that override as the primary and appends +// the default fallback chain. +type ModelOverridePolicy struct { + router *ModelRouter +} + +// NewModelOverridePolicy creates a model override policy. +func NewModelOverridePolicy(router *ModelRouter) *ModelOverridePolicy { + return &ModelOverridePolicy{router: router} +} + +// Name returns the policy identifier. +func (p *ModelOverridePolicy) Name() string { return "model_override" } + +// Evaluate checks model_overrides for the requested model. +func (p *ModelOverridePolicy) Evaluate(ctx *EvaluationContext) ([]config.ModelConfig, RouteDecision, error) { + requestedModel := ctx.Request.Model + if requestedModel == "" { + return nil, RouteDecision{}, fmt.Errorf("no model in request") + } + + result, ok := p.router.RouteWithOverride(requestedModel) + if !ok { + return nil, RouteDecision{}, fmt.Errorf("no override for %q", requestedModel) + } + + return result.GetModelChain(), RouteDecision{ + PolicyName: "model_override", + ModelID: result.Primary.ModelID, + Provider: result.Primary.Provider, + Reason: fmt.Sprintf("matched model_override for %q", requestedModel), + }, nil +} + +// ── ScenarioPolicy ──────────────────────────────────────────────────── + +// ScenarioPolicy runs scenario-based routing using the existing DetectScenario +// logic. It handles both streaming and non-streaming paths. +type ScenarioPolicy struct { + router *ModelRouter +} + +// NewScenarioPolicy creates a scenario policy. +func NewScenarioPolicy(router *ModelRouter) *ScenarioPolicy { + return &ScenarioPolicy{router: router} +} + +// Name returns the policy identifier. +func (p *ScenarioPolicy) Name() string { return "scenario" } + +// Evaluate runs scenario detection and returns the model chain. +func (p *ScenarioPolicy) Evaluate(ctx *EvaluationContext) ([]config.ModelConfig, RouteDecision, error) { + // Build router messages from the normalized request. + var messages []MessageContent + systemText := ctx.Request.SystemPrompt + if systemText != "" { + messages = append(messages, MessageContent{Role: "system", Content: systemText}) + } + for _, msg := range ctx.Request.Messages { + messages = append(messages, MessageContent{Role: msg.Role, Content: msg.Content}) + } + + isStreaming := ctx.Request.Stream + var result RouteResult + var err error + + if isStreaming && !p.router.IsStreamingScenarioRoutingEnabled() { + result, err = p.router.RouteForStreaming(messages, ctx.TokenCount, "") + } else { + result, err = p.router.Route(messages, ctx.TokenCount, "") + } + + if err != nil { + return nil, RouteDecision{}, fmt.Errorf("scenario routing failed: %w", err) + } + + return result.GetModelChain(), RouteDecision{ + PolicyName: "scenario", + ModelID: result.Primary.ModelID, + Provider: result.Primary.Provider, + Reason: fmt.Sprintf("scenario=%s: %s", result.Scenario, result.Scenario), + }, nil +} From 657fef09eae7ad0115650a568551448df50c8634 Mon Sep 17 00:00:00 2001 From: samuel tuyizere Date: Fri, 19 Jun 2026 12:41:07 +0200 Subject: [PATCH 3/7] style: Consistently format error messages and struct fields for improved readability --- internal/core/errors.go | 12 ++++++------ internal/core/normalize.go | 4 ++-- internal/core/normalized.go | 8 ++++---- internal/handlers/messages.go | 22 +++++++++++----------- internal/router/policy.go | 6 +++--- internal/transformer/normalized_bridge.go | 6 +++--- 6 files changed, 29 insertions(+), 29 deletions(-) diff --git a/internal/core/errors.go b/internal/core/errors.go index ae4e32e..70662a0 100644 --- a/internal/core/errors.go +++ b/internal/core/errors.go @@ -4,12 +4,12 @@ import "errors" // Sentinel errors for common provider and routing failures. var ( - ErrModelNotFound = errors.New("model not found") - ErrProviderNotFound = errors.New("provider not found") - ErrUnsupportedCapability = errors.New("capability not supported by model") - ErrRateLimited = errors.New("rate limited by provider") - ErrStreamIdle = errors.New("upstream stream idle") - ErrClientDisconnected = errors.New("client disconnected") + ErrModelNotFound = errors.New("model not found") + ErrProviderNotFound = errors.New("provider not found") + ErrUnsupportedCapability = errors.New("capability not supported by model") + ErrRateLimited = errors.New("rate limited by provider") + ErrStreamIdle = errors.New("upstream stream idle") + ErrClientDisconnected = errors.New("client disconnected") ) // NormalizedError wraps a provider error with structured context. diff --git a/internal/core/normalize.go b/internal/core/normalize.go index fb08789..6b8b3c9 100644 --- a/internal/core/normalize.go +++ b/internal/core/normalize.go @@ -89,8 +89,8 @@ func DenormalizeResponse(nr *NormalizedResponse) *types.MessageResponse { Type: "message", Model: nr.Model, Usage: types.Usage{ - InputTokens: nr.Usage.InputTokens, - OutputTokens: nr.Usage.OutputTokens, + InputTokens: nr.Usage.InputTokens, + OutputTokens: nr.Usage.OutputTokens, CacheCreationInputTokens: nr.Usage.CacheCreationTokens, CacheReadInputTokens: nr.Usage.CacheReadTokens, }, diff --git a/internal/core/normalized.go b/internal/core/normalized.go index 780eb6a..8f7e6a1 100644 --- a/internal/core/normalized.go +++ b/internal/core/normalized.go @@ -4,11 +4,11 @@ package core // All wire formats (Anthropic, OpenAI, Responses, Gemini) map to and from // this representation. type NormalizedMessage struct { - Role string // "user", "assistant", "system", "tool" - Content string // Concatenated text content + Role string // "user", "assistant", "system", "tool" + Content string // Concatenated text content ToolCalls []NormalizedToolCall // Present on assistant messages - ToolCallID string // Present on tool-result messages - Thinking string // Reasoning/thinking content (assistant only) + ToolCallID string // Present on tool-result messages + Thinking string // Reasoning/thinking content (assistant only) } // NormalizedToolCall represents a tool invocation in the internal format. diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index 99b9b42..aa41c6b 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -26,20 +26,20 @@ import ( // MessagesHandler handles /v1/messages requests. type MessagesHandler struct { - client *client.OpenCodeClient // kept for backward compat during migration - providerRegistry *core.ProviderRegistry // new: provider dispatch - modelRouter *router.ModelRouter - fallbackHandler *router.FallbackHandler - streamProxy *StreamProxy // new: SSE proxy by wire format + client *client.OpenCodeClient // kept for backward compat during migration + providerRegistry *core.ProviderRegistry // new: provider dispatch + modelRouter *router.ModelRouter + fallbackHandler *router.FallbackHandler + streamProxy *StreamProxy // new: SSE proxy by wire format requestTransformer *transformer.RequestTransformer responseTransformer *transformer.ResponseTransformer streamHandler *transformer.StreamHandler - tokenCounter *token.Counter - logger *slog.Logger - rateLimiter *middleware.RateLimiter - requestDedup *middleware.RequestDeduplicator - requestIDGen *middleware.RequestIDGenerator - metrics *metrics.Metrics + tokenCounter *token.Counter + logger *slog.Logger + rateLimiter *middleware.RateLimiter + requestDedup *middleware.RequestDeduplicator + requestIDGen *middleware.RequestIDGenerator + metrics *metrics.Metrics } // responseWriter wraps http.ResponseWriter to track if headers were written. diff --git a/internal/router/policy.go b/internal/router/policy.go index 5b078dd..185e5a6 100644 --- a/internal/router/policy.go +++ b/internal/router/policy.go @@ -9,10 +9,10 @@ import ( // EvaluationContext carries all information needed to evaluate routing policies. type EvaluationContext struct { - Request *core.NormalizedRequest - TokenCount int + Request *core.NormalizedRequest + TokenCount int AvailableModels []config.ModelConfig - History []RouteDecision + History []RouteDecision } // RouteDecision records a routing decision for observability. diff --git a/internal/transformer/normalized_bridge.go b/internal/transformer/normalized_bridge.go index a78afe2..045ec13 100644 --- a/internal/transformer/normalized_bridge.go +++ b/internal/transformer/normalized_bridge.go @@ -24,9 +24,9 @@ func TransformRequestFromNormalized(req *core.NormalizedRequest, model config.Mo stream := req.Stream maxTokens := req.MaxTokens return &types.ChatCompletionRequest{ - Model: model.ModelID, - Messages: []types.ChatMessage{{Role: "user", Content: types.TextContent(req.SystemPrompt + "\n" + joinMessageText(req.Messages))}}, - Stream: &stream, + Model: model.ModelID, + Messages: []types.ChatMessage{{Role: "user", Content: types.TextContent(req.SystemPrompt + "\n" + joinMessageText(req.Messages))}}, + Stream: &stream, MaxTokens: &maxTokens, } } From 2a2c4f4150599d254e9737fbb547a2e3c44e9fa5 Mon Sep 17 00:00:00 2001 From: samuel tuyizere Date: Fri, 19 Jun 2026 12:52:04 +0200 Subject: [PATCH 4/7] fix: Ensure early returns in tests for better error handling and clarity --- internal/handlers/streaming.go | 6 +++--- internal/transformer/request_test.go | 3 +++ internal/transformer/stream_test.go | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/internal/handlers/streaming.go b/internal/handlers/streaming.go index 8e00867..a4d2f72 100644 --- a/internal/handlers/streaming.go +++ b/internal/handlers/streaming.go @@ -40,12 +40,12 @@ func (sp *StreamProxy) ProxyStream( switch wireFormat { case core.WireFormatAnthropic: return sp.proxyAnthropicPassthroughStream(w, body, idleTimeout, clientCtx, cancel) - default: - return sp.proxyOpenAIStream(w, body, modelID, clientCtx, idleTimeout, cancel) case core.WireFormatOpenAIResponses: return sp.handler.ProxyResponsesStream(w, body, modelID, clientCtx, idleTimeout, cancel) case core.WireFormatGemini: return sp.handler.ProxyGeminiStream(w, body, modelID, clientCtx, idleTimeout, cancel) + default: + return sp.proxyOpenAIStream(w, body, modelID, clientCtx, idleTimeout, cancel) } } @@ -71,7 +71,7 @@ func (sp *StreamProxy) proxyAnthropicPassthroughStream( clientCtx context.Context, cancel context.CancelFunc, ) error { - defer body.Close() + defer func() { _ = body.Close() }() defer cancel() buf := make([]byte, 4096) diff --git a/internal/transformer/request_test.go b/internal/transformer/request_test.go index 580b32d..3edafc3 100644 --- a/internal/transformer/request_test.go +++ b/internal/transformer/request_test.go @@ -87,6 +87,7 @@ func TestTransformRequestRoundTripReasoning(t *testing.T) { } if assistantMsg == nil { t.Fatal("assistant message not found in transformed request") + return } // Step 5: Verify reasoning_content is preserved. @@ -968,6 +969,7 @@ func TestTransformRequestDeepSeekPlaceholderWithThinkingHistory(t *testing.T) { } if toolCallAssistant == nil { t.Fatal("no assistant message with tool_calls found") + return } if toolCallAssistant.ReasoningContent == nil { t.Fatal("ReasoningContent = nil, want non-nil placeholder for DeepSeek with thinking history") @@ -1037,6 +1039,7 @@ func TestTransformRequestDeepSeekPlaceholderForTextOnlyAssistant(t *testing.T) { } if textOnlyAssistant == nil { t.Fatal("expected two assistant messages in transformed request, found fewer") + return } if len(textOnlyAssistant.ToolCalls) != 0 { t.Fatalf("text-only assistant message unexpectedly had tool_calls: %+v", textOnlyAssistant.ToolCalls) diff --git a/internal/transformer/stream_test.go b/internal/transformer/stream_test.go index 3b33462..07a039b 100644 --- a/internal/transformer/stream_test.go +++ b/internal/transformer/stream_test.go @@ -234,6 +234,7 @@ func TestProxyStream_UsageOnlyChunk(t *testing.T) { } if usage == nil { t.Fatalf("no usage event found in stream: %+v", events) + return } // Per Anthropic spec, input_tokens excludes cache reads AND cache // creations. Upstream prompt_tokens=123 split as 100 hit + 23 miss From 179e6e88cd5fe73872375ef5327bd4c5bc7b3b0e Mon Sep 17 00:00:00 2001 From: samuel tuyizere Date: Fri, 19 Jun 2026 12:54:46 +0200 Subject: [PATCH 5/7] fix: Add early returns in tests for better error handling and clarity --- internal/transformer/request_test.go | 1 + internal/transformer/stream_test.go | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/internal/transformer/request_test.go b/internal/transformer/request_test.go index 3edafc3..a8a1350 100644 --- a/internal/transformer/request_test.go +++ b/internal/transformer/request_test.go @@ -1171,6 +1171,7 @@ func TestTransformRequestExtractsThinkingFromToolUseBlock(t *testing.T) { } if assistantMsg == nil { t.Fatal("no assistant message in transformed request") + return } if assistantMsg.ReasoningContent == nil { t.Fatal("ReasoningContent = nil, want non-nil (thinking on tool_use must round-trip)") diff --git a/internal/transformer/stream_test.go b/internal/transformer/stream_test.go index 07a039b..60c496d 100644 --- a/internal/transformer/stream_test.go +++ b/internal/transformer/stream_test.go @@ -281,6 +281,7 @@ func TestProxyStream_PartialCacheTokensStreaming(t *testing.T) { } if usage == nil { t.Fatalf("no usage event found in stream: %+v", events) + return } // 100 - 60 - 30 = 10 tokens are neither cached nor newly cached. if got, want := usage.InputTokens, 10; got != want { @@ -337,6 +338,7 @@ func TestProxyStream_NoDuplicateMessageDelta(t *testing.T) { } if totalUsage == nil { t.Fatalf("no usage found in stream: %+v", events) + return } if got, want := totalUsage.InputTokens, 100; got != want { t.Errorf("InputTokens = %d, want %d", got, want) @@ -990,10 +992,12 @@ func TestProxyStream_NoUsageFallback(t *testing.T) { if messageDeltaEvent == nil { t.Fatalf("expected message_delta event, got none: %+v", events) + return } if messageDeltaEvent.Usage == nil { t.Fatal("expected message_delta event to have non-nil Usage, but it was nil") + return } if messageDeltaEvent.Usage.InputTokens != 0 || messageDeltaEvent.Usage.OutputTokens != 0 { @@ -1061,6 +1065,7 @@ func TestProxyStream_EOFFallbackStopReasonToolUse(t *testing.T) { } if msgDelta == nil { t.Fatalf("expected message_delta event, got none: %+v", events) + return } if msgDelta.Delta == nil || msgDelta.Delta.StopReason != "tool_use" { t.Errorf("stop_reason = %q, want tool_use (stream ended mid-tool-call)", msgDelta.Delta.StopReason) From 011cca9ae151500ecfb2f9f57e9cf5e6649832c5 Mon Sep 17 00:00:00 2001 From: samuel tuyizere Date: Fri, 19 Jun 2026 13:06:51 +0200 Subject: [PATCH 6/7] fix: Ensure config directory exists before writing PID file in routatic-proxy --- .gitignore | 1 + cmd/routatic-proxy/main.go | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/.gitignore b/.gitignore index 422b76b..b401ea4 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ bin/ configs/config.json .tmp/ oc-go-cc +routatic-proxy diff --git a/cmd/routatic-proxy/main.go b/cmd/routatic-proxy/main.go index e40dc20..b2fc137 100644 --- a/cmd/routatic-proxy/main.go +++ b/cmd/routatic-proxy/main.go @@ -115,6 +115,14 @@ func serveCmd() *cobra.Command { return err } } else { + // Ensure config directory exists before writing PID file. + paths, err := daemon.DefaultPaths() + if err != nil { + return err + } + if err := paths.EnsureConfigDir(); err != nil { + return err + } // Write PID file for foreground mode. if err := daemon.WritePID(pidPath, os.Getpid()); err != nil { return fmt.Errorf("failed to write PID file: %w", err) From 763f0adb142b69ea901bc8172d9a695e92bbd944 Mon Sep 17 00:00:00 2001 From: samuel tuyizere Date: Fri, 19 Jun 2026 13:27:48 +0200 Subject: [PATCH 7/7] feat: Enhance tool result handling in normalization process --- internal/core/normalize.go | 6 ++++-- internal/core/normalized.go | 17 ++++++++++++----- internal/transformer/normalized_bridge.go | 11 ++++++++++- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/internal/core/normalize.go b/internal/core/normalize.go index 6b8b3c9..a280428 100644 --- a/internal/core/normalize.go +++ b/internal/core/normalize.go @@ -57,8 +57,10 @@ func NormalizeRequest(anthropicReq *types.MessageRequest) *NormalizedRequest { Arguments: string(block.Input), }) case "tool_result": - nm.ToolCallID = block.ToolUseID - nm.Content += block.TextContent() + nm.ToolResults = append(nm.ToolResults, NormalizedToolResult{ + ToolCallID: block.ToolUseID, + Content: block.TextContent(), + }) case "thinking": nm.Thinking += block.Thinking case "image": diff --git a/internal/core/normalized.go b/internal/core/normalized.go index 8f7e6a1..a1f15f5 100644 --- a/internal/core/normalized.go +++ b/internal/core/normalized.go @@ -1,14 +1,21 @@ package core +// NormalizedToolResult represents a single tool result in the normalized format. +type NormalizedToolResult struct { + ToolCallID string + Content string +} + // NormalizedMessage is a single message in the internal canonical format. // All wire formats (Anthropic, OpenAI, Responses, Gemini) map to and from // this representation. type NormalizedMessage struct { - Role string // "user", "assistant", "system", "tool" - Content string // Concatenated text content - ToolCalls []NormalizedToolCall // Present on assistant messages - ToolCallID string // Present on tool-result messages - Thinking string // Reasoning/thinking content (assistant only) + Role string // "user", "assistant", "system", "tool" + Content string // Concatenated text content + ToolCalls []NormalizedToolCall // Present on assistant messages + ToolResults []NormalizedToolResult // Present on user messages with tool results + ToolCallID string // Deprecated: use ToolResults instead. Kept for backward compat. + Thinking string // Reasoning/thinking content (assistant only) } // NormalizedToolCall represents a tool invocation in the internal format. diff --git a/internal/transformer/normalized_bridge.go b/internal/transformer/normalized_bridge.go index 045ec13..39ebce9 100644 --- a/internal/transformer/normalized_bridge.go +++ b/internal/transformer/normalized_bridge.go @@ -332,7 +332,16 @@ func normalizedToMessageRequest(req *core.NormalizedRequest) *types.MessageReque Input: []byte(tc.Arguments), }) } - if nm.ToolCallID != "" { + if len(nm.ToolResults) > 0 { + for _, tr := range nm.ToolResults { + content, _ := json.Marshal(tr.Content) + blocks = append(blocks, types.ContentBlock{ + Type: "tool_result", + ToolUseID: tr.ToolCallID, + Content: content, + }) + } + } else if nm.ToolCallID != "" { content, _ := json.Marshal(nm.Content) blocks = append(blocks, types.ContentBlock{ Type: "tool_result",