diff --git a/.gitignore b/.gitignore index 6d33b11..0f94e1f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ bin/ .env .tmp/ oc-go-cc +tmp/ diff --git a/cmd/oc-go-cc/main.go b/cmd/oc-go-cc/main.go index 6298866..332cd4e 100644 --- a/cmd/oc-go-cc/main.go +++ b/cmd/oc-go-cc/main.go @@ -9,6 +9,7 @@ import ( "path/filepath" "github.com/spf13/cobra" + "oc-go-cc/internal/buildinfo" "oc-go-cc/internal/config" "oc-go-cc/internal/daemon" "oc-go-cc/internal/server" @@ -19,9 +20,6 @@ const ( pidFileName = "oc-go-cc.pid" ) -// Version is set at build time via -ldflags "-X main.version=...". -var version = "dev" - func main() { rootCmd := &cobra.Command{ Use: appName, @@ -31,7 +29,7 @@ subscription with Claude Code. It intercepts Claude Code's Anthropic API request transforms them to OpenAI format, and forwards them to OpenCode Go. Configuration is stored at ~/.config/oc-go-cc/config.json`, - Version: version, + Version: buildinfo.Version, } // Add subcommands. @@ -89,7 +87,7 @@ func serveCmd() *cobra.Command { if !daemonize { if pid, err := daemon.GetPID(pidPath); err == nil { // Check if process is still running. - if daemon.IsProcessRunning(pid) { + if daemon.IsProcessRunning(pid) && daemon.IsAppProcess(pid, appName) { return fmt.Errorf("server is already running (PID %d)", pid) } // Stale PID file, clean up. @@ -144,7 +142,18 @@ func serveCmd() *cobra.Command { }() } - fmt.Printf("Starting %s v%s\n", appName, version) + slog.Info("starting proxy", + "binary", buildinfo.BinaryPath(), + "version", buildinfo.Version, + "build_time", buildinfo.BuildTime, + "pid", buildinfo.PID(), + "listen", fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + ) + + fmt.Printf("Starting %s v%s\n", appName, buildinfo.Version) + fmt.Printf("Binary: %s\n", buildinfo.BinaryPath()) + fmt.Printf("Build time: %s\n", buildinfo.BuildTime) + fmt.Printf("PID: %d\n", buildinfo.PID()) fmt.Printf("Listening on %s:%d\n", cfg.Host, cfg.Port) fmt.Printf("Forwarding to: %s\n", cfg.OpenCodeGo.BaseURL) fmt.Println() @@ -178,6 +187,15 @@ func stopCmd() *cobra.Command { return fmt.Errorf("server is not running (no PID file)") } + if !daemon.IsProcessRunning(pid) { + _ = os.Remove(pidPath) + return fmt.Errorf("server is not running (stale PID file)") + } + if !daemon.IsAppProcess(pid, appName) { + _ = os.Remove(pidPath) + return fmt.Errorf("server is not running (PID %d belongs to another process)", pid) + } + if err := daemon.StopProcess(pid); err != nil { return fmt.Errorf("failed to stop server: %w", err) } @@ -207,6 +225,11 @@ func statusCmd() *cobra.Command { _ = os.Remove(pidPath) return nil } + if !daemon.IsAppProcess(pid, appName) { + fmt.Printf("Server is not running (PID %d belongs to another process)\n", pid) + _ = os.Remove(pidPath) + return nil + } fmt.Printf("Server is running (PID %d)\n", pid) return nil diff --git a/configs/config.example.json b/configs/config.example.json index c3ded67..535e296 100644 --- a/configs/config.example.json +++ b/configs/config.example.json @@ -3,102 +3,178 @@ "host": "127.0.0.1", "port": 3456, "hot_reload": false, - "enable_streaming_scenario_routing": false, + "enable_streaming_scenario_routing": true, "respect_requested_model": false, - "models": { - "background": { - "provider": "opencode-go", - "model_id": "qwen3.5-plus", - "temperature": 0.5, - "max_tokens": 2048 - }, "default": { - "provider": "opencode-go", - "model_id": "kimi-k2.6", - "temperature": 0.7, - "max_tokens": 4096 - }, - "long_context": { - "provider": "opencode-go", - "model_id": "minimax-m2.5", - "temperature": 0.7, - "max_tokens": 16384, - "context_threshold": 80000 - }, - "deepseek-v4-pro": { "provider": "opencode-go", "model_id": "deepseek-v4-pro", - "temperature": 0.7, + "supports_vision": false, + "temperature": 0.1, "max_tokens": 8192, + "max_output_tokens": 8192, + "context_window": 1000000, + "context_margin": 8192, + "supports_tools": true, "reasoning_effort": "max", - "thinking": { - "type": "enabled" - } + "thinking": { "type": "enabled" } }, - "deepseek-v4-flash": { + "fast": { "provider": "opencode-go", "model_id": "deepseek-v4-flash", - "temperature": 0.7, + "supports_vision": false, + "temperature": 0.1, "max_tokens": 4096, + "max_output_tokens": 4096, + "context_window": 1000000, + "context_margin": 8192, + "supports_tools": true, "reasoning_effort": "max", - "thinking": { - "type": "enabled" - } + "thinking": { "type": "enabled" } + }, + "background": { + "provider": "opencode-go", + "model_id": "deepseek-v4-flash", + "supports_vision": false, + "temperature": 0.1, + "max_tokens": 4096, + "max_output_tokens": 4096, + "context_window": 1000000, + "context_margin": 8192, + "supports_tools": true, + "reasoning_effort": "max", + "thinking": { "type": "enabled" } }, "think": { "provider": "opencode-go", - "model_id": "glm-5", - "temperature": 0.7, - "max_tokens": 8192 + "model_id": "deepseek-v4-pro", + "supports_vision": false, + "temperature": 0.1, + "max_tokens": 8192, + "max_output_tokens": 8192, + "context_window": 1000000, + "context_margin": 8192, + "supports_tools": true, + "reasoning_effort": "max", + "thinking": { "type": "enabled" } }, "complex": { "provider": "opencode-go", - "model_id": "glm-5.1", - "temperature": 0.7, - "max_tokens": 4096 + "model_id": "deepseek-v4-pro", + "supports_vision": false, + "temperature": 0.1, + "max_tokens": 8192, + "max_output_tokens": 8192, + "context_window": 1000000, + "context_margin": 8192, + "supports_tools": true, + "reasoning_effort": "max", + "thinking": { "type": "enabled" } }, - "fast": { + "long_context": { + "provider": "opencode-go", + "model_id": "deepseek-v4-pro", + "supports_vision": false, + "temperature": 0.1, + "max_tokens": 16384, + "max_output_tokens": 8192, + "context_window": 1000000, + "context_margin": 8192, + "supports_tools": true, + "context_threshold": 80000, + "reasoning_effort": "max", + "thinking": { "type": "enabled" } + }, + "vision": { "provider": "opencode-go", "model_id": "qwen3.6-plus", - "temperature": 0.7, - "max_tokens": 4096 + "supports_vision": true, + "temperature": 0.1, + "max_tokens": 8192, + "max_output_tokens": 8192, + "context_window": 1000000, + "context_margin": 8192, + "supports_tools": true, + "thinking": { "type": "enabled" } + }, + "vision_complex": { + "provider": "opencode-go", + "model_id": "kimi-k2.6", + "supports_vision": true, + "temperature": 0.1, + "max_tokens": 8192, + "max_output_tokens": 8192, + "context_window": 256000, + "context_margin": 8192, + "supports_tools": true, + "thinking": { "type": "enabled" } + }, + "vision_long_context": { + "provider": "opencode-go", + "model_id": "kimi-k2.6", + "supports_vision": true, + "temperature": 0.1, + "max_tokens": 16384, + "max_output_tokens": 8192, + "context_window": 256000, + "context_margin": 8192, + "supports_tools": true, + "context_threshold": 80000, + "thinking": { "type": "enabled" } } }, - "fallbacks": { - "background": [ + "default": [ { "provider": "opencode-go", "model_id": "qwen3.6-plus" }, - { "provider": "opencode-go", "model_id": "minimax-m2.5" } + { "provider": "opencode-go", "model_id": "kimi-k2.6" }, + { "provider": "opencode-go", "model_id": "mimo-v2.5-pro" } ], - "default": [ - { "provider": "opencode-go", "model_id": "mimo-v2-pro" }, + "fast": [ + { "provider": "opencode-go", "model_id": "qwen3.5-plus" }, + { "provider": "opencode-go", "model_id": "minimax-m2.5" }, { "provider": "opencode-go", "model_id": "qwen3.6-plus" } ], - "long_context": [ - { "provider": "opencode-go", "model_id": "minimax-m2.7" }, - { "provider": "opencode-go", "model_id": "kimi-k2.6" } + "background": [ + { "provider": "opencode-go", "model_id": "qwen3.5-plus" }, + { "provider": "opencode-go", "model_id": "minimax-m2.5" }, + { "provider": "opencode-go", "model_id": "qwen3.6-plus" } ], "think": [ { "provider": "opencode-go", "model_id": "kimi-k2.6" }, - { "provider": "opencode-go", "model_id": "mimo-v2-pro" } + { "provider": "opencode-go", "model_id": "mimo-v2.5-pro" }, + { "provider": "opencode-go", "model_id": "glm-5.1" } ], "complex": [ - { "provider": "opencode-go", "model_id": "glm-5" }, - { "provider": "opencode-go", "model_id": "kimi-k2.6" } + { "provider": "opencode-go", "model_id": "mimo-v2.5-pro" }, + { "provider": "opencode-go", "model_id": "kimi-k2.6" }, + { "provider": "opencode-go", "model_id": "qwen3.6-plus" } ], - "fast": [ - { "provider": "opencode-go", "model_id": "qwen3.5-plus" }, - { "provider": "opencode-go", "model_id": "minimax-m2.5" } + "long_context": [ + { "provider": "opencode-go", "model_id": "qwen3.6-plus" }, + { "provider": "opencode-go", "model_id": "mimo-v2.5-pro" }, + { "provider": "opencode-go", "model_id": "minimax-m2.7" } + ], + "vision": [ + { "provider": "opencode-go", "model_id": "kimi-k2.6", "supports_vision": true }, + { "provider": "opencode-go", "model_id": "qwen3.5-plus", "supports_vision": true }, + { "provider": "opencode-go", "model_id": "kimi-k2.5", "supports_vision": true } + ], + "vision_complex": [ + { "provider": "opencode-go", "model_id": "qwen3.6-plus", "supports_vision": true }, + { "provider": "opencode-go", "model_id": "qwen3.5-plus", "supports_vision": true }, + { "provider": "opencode-go", "model_id": "kimi-k2.5", "supports_vision": true } + ], + "vision_long_context": [ + { "provider": "opencode-go", "model_id": "kimi-k2.6", "supports_vision": true }, + { "provider": "opencode-go", "model_id": "qwen3.5-plus", "supports_vision": true }, + { "provider": "opencode-go", "model_id": "kimi-k2.5", "supports_vision": true } ] }, - "opencode_go": { "base_url": "https://opencode.ai/zen/go/v1/chat/completions", "anthropic_base_url": "https://opencode.ai/zen/go/v1/messages", "timeout_ms": 300000 }, - "logging": { "level": "info", "requests": true diff --git a/internal/buildinfo/buildinfo.go b/internal/buildinfo/buildinfo.go new file mode 100644 index 0000000..f83074a --- /dev/null +++ b/internal/buildinfo/buildinfo.go @@ -0,0 +1,18 @@ +package buildinfo + +import "os" + +var Version = "dev" +var BuildTime = "unknown" + +func BinaryPath() string { + path, err := os.Executable() + if err != nil { + return "unknown" + } + return path +} + +func PID() int { + return os.Getpid() +} diff --git a/internal/config/config.go b/internal/config/config.go index 285a960..7fa8969 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -23,8 +23,13 @@ type ModelConfig struct { ModelID string `json:"model_id"` Temperature float64 `json:"temperature"` MaxTokens int `json:"max_tokens"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + ContextWindow int `json:"context_window,omitempty"` + ContextMargin int `json:"context_margin,omitempty"` ContextThreshold int `json:"context_threshold"` ReasoningEffort string `json:"reasoning_effort"` + SupportsVision bool `json:"supports_vision,omitempty"` + SupportsTools *bool `json:"supports_tools,omitempty"` Thinking json.RawMessage `json:"thinking,omitempty"` } diff --git a/internal/config/loader.go b/internal/config/loader.go index bcd4650..8c75f69 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -147,5 +147,22 @@ func validate(cfg *Config) error { if cfg.APIKey == "" { return fmt.Errorf("api_key is required (set via config file or OC_GO_CC_API_KEY env var)") } + if len(cfg.Models) > 0 { + for _, scenario := range []string{"vision", "vision_complex", "vision_long_context"} { + model, ok := cfg.Models[scenario] + if !ok { + return fmt.Errorf("%s model is required when models are configured", scenario) + } + if !model.SupportsVision { + return fmt.Errorf("%s model %s must support vision", scenario, model.ModelID) + } + for _, fallback := range cfg.Fallbacks[scenario] { + fallback = ResolveModelConfig(fallback) + if !fallback.SupportsVision { + return fmt.Errorf("%s fallback model %s must support vision", scenario, fallback.ModelID) + } + } + } + } return nil } diff --git a/internal/config/model_registry.go b/internal/config/model_registry.go new file mode 100644 index 0000000..a2233a1 --- /dev/null +++ b/internal/config/model_registry.go @@ -0,0 +1,56 @@ +package config + +const DefaultContextMargin = 8192 + +type ModelMetadata struct { + ContextWindow int + MaxOutputTokens int + SupportsVision bool + SupportsTools bool +} + +var modelMetadata = map[string]ModelMetadata{ + "deepseek-v4-pro": {ContextWindow: 1000000, MaxOutputTokens: 8192, SupportsVision: false, SupportsTools: true}, + "deepseek-v4-flash": {ContextWindow: 1000000, MaxOutputTokens: 4096, SupportsVision: false, SupportsTools: true}, + "kimi-k2.6": {ContextWindow: 256000, MaxOutputTokens: 8192, SupportsVision: true, SupportsTools: true}, + "kimi-k2.5": {ContextWindow: 256000, MaxOutputTokens: 8192, SupportsVision: true, SupportsTools: true}, + "mimo-v2.5-pro": {ContextWindow: 1000000, MaxOutputTokens: 16384, SupportsVision: false, SupportsTools: true}, + "mimo-v2.5": {ContextWindow: 1000000, MaxOutputTokens: 8192, SupportsVision: false, SupportsTools: true}, + "minimax-m2.7": {ContextWindow: 200000, MaxOutputTokens: 8192, SupportsVision: false, SupportsTools: true}, + "minimax-m2.5": {ContextWindow: 200000, MaxOutputTokens: 4096, SupportsVision: false, SupportsTools: true}, + "qwen3.6-plus": {ContextWindow: 1000000, MaxOutputTokens: 8192, SupportsVision: true, SupportsTools: true}, + "qwen3.5-plus": {ContextWindow: 1000000, MaxOutputTokens: 8192, SupportsVision: true, SupportsTools: true}, + "glm-5.1": {ContextWindow: 200000, MaxOutputTokens: 8192, SupportsVision: false, SupportsTools: true}, + "glm-5": {ContextWindow: 200000, MaxOutputTokens: 8192, SupportsVision: false, SupportsTools: true}, +} + +func ResolveModelConfig(model ModelConfig) ModelConfig { + if meta, ok := modelMetadata[model.ModelID]; ok { + if model.ContextWindow == 0 { + model.ContextWindow = meta.ContextWindow + } + if model.MaxOutputTokens == 0 { + model.MaxOutputTokens = meta.MaxOutputTokens + } + if !model.SupportsVision { + model.SupportsVision = meta.SupportsVision + } + if model.SupportsTools == nil { + v := meta.SupportsTools + model.SupportsTools = &v + } + } + if model.ContextMargin == 0 { + model.ContextMargin = DefaultContextMargin + } + if model.SupportsTools == nil { + v := true + model.SupportsTools = &v + } + return model +} + +func SupportsTools(model ModelConfig) bool { + model = ResolveModelConfig(model) + return model.SupportsTools == nil || *model.SupportsTools +} diff --git a/internal/daemon/background.go b/internal/daemon/background.go index af6a3c4..7d5e8fc 100644 --- a/internal/daemon/background.go +++ b/internal/daemon/background.go @@ -24,7 +24,7 @@ func ForkIntoBackground(opts BackgroundOpts) error { return fmt.Errorf("cannot create config directory: %w", err) } if pid, err := GetPID(paths.PIDFile); err == nil { - if IsProcessRunning(pid) { + if IsProcessRunning(pid) && IsAppProcess(pid, AppName) { return fmt.Errorf("server is already running (PID %d)", pid) } _ = os.Remove(paths.PIDFile) diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 8376c25..bde7a9b 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -69,6 +69,12 @@ func TestIsProcessRunning_CurrentProcess(t *testing.T) { } } +func TestIsAppProcess_CurrentTestProcessIsNotOCGoCC(t *testing.T) { + if IsAppProcess(os.Getpid(), AppName) { + t.Errorf("current test process should not be reported as %s", AppName) + } +} + func TestIsProcessRunning_NonexistentPID(t *testing.T) { // PID 1 is typically init — but on some systems it may not exist. // Use an almost-certainly-invalid PID instead. diff --git a/internal/daemon/process_unix.go b/internal/daemon/process_unix.go index 6f02400..29c48b9 100644 --- a/internal/daemon/process_unix.go +++ b/internal/daemon/process_unix.go @@ -5,6 +5,8 @@ package daemon import ( "fmt" "os" + "path/filepath" + "strings" "syscall" ) @@ -19,6 +21,16 @@ func IsProcessRunning(pid int) bool { return err == nil } +func IsAppProcess(pid int, appName string) bool { + exe, err := os.Readlink(filepath.Join("/proc", fmt.Sprintf("%d", pid), "exe")) + if err != nil { + return false + } + base := strings.ToLower(filepath.Base(exe)) + app := strings.ToLower(appName) + return base == app || strings.HasPrefix(base, app+"-") +} + // StopProcess sends SIGTERM to a process and waits for it to exit. func StopProcess(pid int) error { process, err := os.FindProcess(pid) diff --git a/internal/daemon/process_windows.go b/internal/daemon/process_windows.go index 8500a79..18e23b2 100644 --- a/internal/daemon/process_windows.go +++ b/internal/daemon/process_windows.go @@ -5,7 +5,11 @@ package daemon import ( "fmt" "os" + "path/filepath" + "strings" "syscall" + + "golang.org/x/sys/windows" ) const windowsSynchronize = 0x00100000 @@ -22,6 +26,24 @@ func IsProcessRunning(pid int) bool { return err == nil && event == syscall.WAIT_TIMEOUT } +func IsAppProcess(pid int, appName string) bool { + handle, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, uint32(pid)) + if err != nil { + return false + } + defer func() { _ = windows.CloseHandle(handle) }() + + buf := make([]uint16, windows.MAX_PATH) + size := uint32(len(buf)) + if err := windows.QueryFullProcessImageName(handle, 0, &buf[0], &size); err != nil { + return false + } + + base := strings.TrimSuffix(strings.ToLower(filepath.Base(windows.UTF16ToString(buf[:size]))), ".exe") + app := strings.TrimSuffix(strings.ToLower(appName), ".exe") + return base == app || strings.HasPrefix(base, app+"-") +} + // StopProcess terminates a process on Windows. // Unlike the Unix implementation which sends SIGTERM for graceful shutdown, // this uses process.Kill() (TerminateProcess) which immediately terminates the diff --git a/internal/handlers/health.go b/internal/handlers/health.go index b25cc65..6417e7c 100644 --- a/internal/handlers/health.go +++ b/internal/handlers/health.go @@ -4,8 +4,10 @@ import ( "encoding/json" "net/http" + "oc-go-cc/internal/buildinfo" "oc-go-cc/internal/metrics" "oc-go-cc/internal/router" + "oc-go-cc/internal/status" "oc-go-cc/internal/token" "oc-go-cc/pkg/types" ) @@ -15,14 +17,16 @@ type HealthHandler struct { tokenCounter *token.Counter fallbackHandler *router.FallbackHandler metrics *metrics.Metrics + statusStore *status.Store } // NewHealthHandler creates a new health handler. -func NewHealthHandler(tokenCounter *token.Counter, fallbackHandler *router.FallbackHandler, metrics *metrics.Metrics) *HealthHandler { +func NewHealthHandler(tokenCounter *token.Counter, fallbackHandler *router.FallbackHandler, metrics *metrics.Metrics, statusStore *status.Store) *HealthHandler { return &HealthHandler{ tokenCounter: tokenCounter, fallbackHandler: fallbackHandler, metrics: metrics, + statusStore: statusStore, } } @@ -38,8 +42,12 @@ func (h *HealthHandler) HandleHealth(w http.ResponseWriter, r *http.Request) { } response := map[string]interface{}{ - "status": "ok", - "service": "oc-go-cc", + "status": "ok", + "service": "oc-go-cc", + "version": buildinfo.Version, + "build_time": buildinfo.BuildTime, + "pid": buildinfo.PID(), + "binary": buildinfo.BinaryPath(), "metrics": map[string]interface{}{ "requests_received": snapshot.RequestsReceived, "requests_success": snapshot.RequestsSuccess, @@ -61,6 +69,21 @@ func (h *HealthHandler) HandleHealth(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(response) } +func (h *HealthHandler) HandleStatusline(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(http.StatusOK) + if h.statusStore == nil { + _ = json.NewEncoder(w).Encode(status.NewStore(0).Snapshot()) + return + } + _ = json.NewEncoder(w).Encode(h.statusStore.Snapshot()) +} + // HandleCountTokens handles POST /v1/messages/count_tokens. func (h *HealthHandler) HandleCountTokens(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { diff --git a/internal/handlers/health_test.go b/internal/handlers/health_test.go index b30dc51..24be364 100644 --- a/internal/handlers/health_test.go +++ b/internal/handlers/health_test.go @@ -39,6 +39,27 @@ func TestHandleCountTokensSupportsAnthropicContentBlocks(t *testing.T) { } } +func TestHandleHealthIncludesBuildInfo(t *testing.T) { + handler := newTestHealthHandler(t) + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/health", nil) + + handler.HandleHealth(recorder, req) + + if got, want := recorder.Code, http.StatusOK; got != want { + t.Fatalf("status = %d, want %d; body: %s", got, want, recorder.Body.String()) + } + var response map[string]any + if err := json.Unmarshal(recorder.Body.Bytes(), &response); err != nil { + t.Fatalf("response is invalid JSON: %v", err) + } + for _, key := range []string{"version", "build_time", "pid", "binary"} { + if _, ok := response[key]; !ok { + t.Fatalf("health response missing %s: %s", key, recorder.Body.String()) + } + } +} + func TestHandleCountTokensIncludesSystemToolsAndThinking(t *testing.T) { handler := newTestHealthHandler(t) @@ -87,5 +108,5 @@ func newTestHealthHandler(t *testing.T) *HealthHandler { if err != nil { t.Fatalf("NewCounter() error = %v", err) } - return NewHealthHandler(counter, nil, metrics.New()) + return NewHealthHandler(counter, nil, metrics.New(), nil) } diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index 39a7182..8d5137d 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -3,12 +3,15 @@ package handlers import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" "log/slog" "net/http" "strings" + "sync/atomic" "time" "oc-go-cc/internal/client" @@ -16,6 +19,7 @@ import ( "oc-go-cc/internal/metrics" "oc-go-cc/internal/middleware" "oc-go-cc/internal/router" + "oc-go-cc/internal/status" "oc-go-cc/internal/token" "oc-go-cc/internal/transformer" "oc-go-cc/pkg/types" @@ -35,6 +39,8 @@ type MessagesHandler struct { requestDedup *middleware.RequestDeduplicator requestIDGen *middleware.RequestIDGenerator metrics *metrics.Metrics + statusStore *status.Store + statusSeq atomic.Uint64 } // responseWriter wraps http.ResponseWriter to track if headers were written. @@ -71,6 +77,7 @@ func NewMessagesHandler( fallbackHandler *router.FallbackHandler, tokenCounter *token.Counter, metrics *metrics.Metrics, + statusStore *status.Store, ) *MessagesHandler { return &MessagesHandler{ client: openCodeClient, @@ -85,6 +92,7 @@ func NewMessagesHandler( requestDedup: middleware.NewRequestDeduplicator(500 * time.Millisecond), requestIDGen: middleware.NewRequestIDGenerator(), metrics: metrics, + statusStore: statusStore, } } @@ -153,21 +161,24 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) // Build message content for routing and token counting. var routerMessages []router.MessageContent var tokenMessages []token.MessageContent - systemText := anthropicReq.SystemText() + systemText, err := systemAndToolsTokenText(anthropicReq.SystemText(), anthropicReq.Tools) + if err != nil { + h.sendError(w, http.StatusBadRequest, "failed to process tools", err) + return + } for _, msg := range anthropicReq.Messages { blocks := msg.ContentBlocks() content := extractTextFromBlocks(blocks) mc := router.MessageContent{ - Role: msg.Role, - Content: content, + Role: msg.Role, + Content: content, + HasImage: blocksHaveImage(blocks), + ImageHashes: imageHashesFromBlocks(blocks), } routerMessages = append(routerMessages, mc) - tokenMessages = append(tokenMessages, token.MessageContent{ - Role: msg.Role, - Content: content, - }) } + tokenMessages = tokenMessagesFromAnthropic(anthropicReq.Messages) // Count tokens. tokenCount, err := h.tokenCounter.CountMessages(systemText, tokenMessages) @@ -198,14 +209,28 @@ func (h *MessagesHandler) HandleMessages(w http.ResponseWriter, r *http.Request) "scenario", routeResult.Scenario, "model", routeResult.Primary.ModelID, "tokens", tokenCount, + "latest_user_has_image", router.AnalyzeRequestFacts(routerMessages).LatestUserHasImage, + "any_historical_image", router.AnalyzeRequestFacts(routerMessages).AnyHistoricalImage, + "latest_text_visual_intent", router.AnalyzeRequestFacts(routerMessages).LatestTextVisualIntent, + "needs_vision", router.AnalyzeRequestFacts(routerMessages).NeedsVision, + "supports_vision", routeResult.Primary.SupportsVision, ) // Build fallback chain. + facts := router.AnalyzeRequestFacts(routerMessages) modelChain := routeResult.GetModelChain() + capacity, err := router.FilterByCapacity(modelChain, tokenCount, anthropicReq.MaxTokens, facts.NeedsVision, len(anthropicReq.Tools) > 0) + if err != nil { + h.updateStatus(requestID, isStreaming, routeResult, capacity) + h.sendError(w, http.StatusBadRequest, "no eligible model for request context", err) + return + } + modelChain = capacity.Models + h.updateStatus(requestID, isStreaming, routeResult, capacity) if isStreaming { // Streaming: use ProxyStream for real-time SSE transformation - h.handleStreaming(w, r, &anthropicReq, modelChain, rawBody) + h.handleStreaming(w, r, &anthropicReq, modelChain, rawBody, routeResult.Scenario) } else { // Non-streaming: execute with fallback and return full response h.handleNonStreaming(w, r, &anthropicReq, modelChain, rawBody) @@ -219,6 +244,7 @@ func (h *MessagesHandler) handleStreaming( anthropicReq *types.MessageRequest, modelChain []config.ModelConfig, rawBody json.RawMessage, + scenario router.Scenario, ) { // Each fallback attempt needs its own context with timeout. // Don't share r.Context() across fallbacks - when Claude Code retries, @@ -285,7 +311,7 @@ func (h *MessagesHandler) handleStreaming( if client.IsAnthropicModel(model.ModelID) { // For MiniMax models, send raw Anthropic request to Anthropic endpoint // But we need to replace the model name in the raw body - modelBody := replaceModelInRawBody(rawBody, model.ModelID) + modelBody := sanitizeAnthropicRawBody(rawBody, model) if err := h.handleAnthropicStreaming(ctx, rw, modelBody, model.ModelID); err != nil { cancel() // Check if this was a client disconnect @@ -311,6 +337,44 @@ func (h *MessagesHandler) handleStreaming( continue } + if isVisionScenario(scenario) { + streamFalse := false + openaiReq.Stream = &streamFalse + openaiReq.StreamOptions = nil + chatResp, err := h.client.ChatCompletionNonStreaming(ctx, model.ModelID, openaiReq) + if err != nil { + cancel() + h.logger.Warn("vision non-streaming request failed", "model", model.ModelID, "error", err) + continue + } + anthropicResp, err := h.responseTransformer.TransformResponse(chatResp, model.ModelID) + if err != nil { + cancel() + h.logger.Warn("vision response transform failed", "model", model.ModelID, "error", err) + continue + } + visible := visibleTextLength(anthropicResp) + if visible == 0 && !hasToolUseContent(anthropicResp) { + cancel() + h.logger.Warn("vision response had no visible output", "model", model.ModelID, "empty_visible_stream", true, "visible_text_deltas", 0) + continue + } + if err := h.streamHandler.EmitMessageResponse(rw, anthropicResp); err != nil { + cancel() + if err == transformer.ErrClientDisconnected { + h.logger.Info("client disconnected during synthesized vision stream") + return + } + h.logger.Warn("vision stream synthesis failed", "model", model.ModelID, "error", err) + continue + } + cancel() + latency := time.Since(streamStart) + h.metrics.RecordSuccess(model.ModelID, latency) + h.logger.Info("streaming completed", "model", model.ModelID, "latency", latency, "visible_text_deltas", visible) + return + } + // Get streaming body from upstream streamBody, err := h.client.GetStreamingBody(ctx, model.ModelID, openaiReq) if err != nil { @@ -387,6 +451,51 @@ func replaceModelInRawBody(rawBody json.RawMessage, modelID string) json.RawMess return rawBody } +func sanitizeAnthropicRawBody(rawBody json.RawMessage, model config.ModelConfig) json.RawMessage { + var req types.MessageRequest + if err := json.Unmarshal(rawBody, &req); err != nil { + return replaceModelInRawBody(rawBody, model.ModelID) + } + req.Model = model.ModelID + if model.MaxTokens > 0 { + req.MaxTokens = model.MaxTokens + } + if model.SupportsVision { + body, err := json.Marshal(req) + if err != nil { + return replaceModelInRawBody(rawBody, model.ModelID) + } + return body + } + for i := range req.Messages { + blocks := req.Messages[i].ContentBlocks() + if len(blocks) == 0 { + continue + } + sanitized := make([]types.ContentBlock, 0, len(blocks)) + changed := false + for _, block := range blocks { + if block.Type == "image" { + changed = true + sanitized = append(sanitized, types.ContentBlock{Type: "text", Text: "[Image omitted for text-only model]"}) + continue + } + sanitized = append(sanitized, block) + } + if changed { + content, err := json.Marshal(sanitized) + if err == nil { + req.Messages[i].Content = content + } + } + } + body, err := json.Marshal(req) + if err != nil { + return replaceModelInRawBody(rawBody, model.ModelID) + } + return body +} + // handleAnthropicStreaming sends a raw Anthropic request to the Anthropic endpoint. func (h *MessagesHandler) handleAnthropicStreaming( ctx context.Context, @@ -493,6 +602,7 @@ func (h *MessagesHandler) executeAnthropicRequest( rawBody json.RawMessage, model config.ModelConfig, ) ([]byte, error) { + rawBody = sanitizeAnthropicRawBody(rawBody, model) // Send raw Anthropic request to Anthropic endpoint resp, err := h.client.SendAnthropicRequest(ctx, rawBody, false) if err != nil { @@ -558,6 +668,92 @@ func extractTextFromBlocks(blocks []types.ContentBlock) string { return content } +func isVisionScenario(s router.Scenario) bool { + return s == router.ScenarioVision || s == router.ScenarioVisionComplex || s == router.ScenarioVisionLongContext +} + +func visibleTextLength(resp *types.MessageResponse) int { + if resp == nil { + return 0 + } + total := 0 + for _, block := range resp.Content { + if block.Type == "text" { + total += len(block.Text) + } + } + return total +} + +func hasToolUseContent(resp *types.MessageResponse) bool { + if resp == nil { + return false + } + for _, block := range resp.Content { + if block.Type == "tool_use" { + return true + } + } + return false +} + +func blocksHaveImage(blocks []types.ContentBlock) bool { + for _, block := range blocks { + if block.Type == "image" && block.Source != nil { + return true + } + } + return false +} + +func imageHashesFromBlocks(blocks []types.ContentBlock) []string { + var hashes []string + for _, block := range blocks { + if block.Type != "image" || block.Source == nil { + continue + } + source := block.Source.Type + "\x00" + block.Source.MediaType + "\x00" + block.Source.Data + "\x00" + block.Source.URL + sum := sha256.Sum256([]byte(source)) + hashes = append(hashes, hex.EncodeToString(sum[:])) + } + return hashes +} + +func (h *MessagesHandler) updateStatus(requestID string, streaming bool, routeResult router.RouteResult, capacity router.CapacityDecision) { + if h.statusStore == nil { + return + } + seq := h.statusSeq.Add(1) + modelID := routeResult.Primary.ModelID + contextWindow := capacity.ContextWindow + if len(capacity.Models) > 0 { + modelID = capacity.Models[0].ModelID + contextWindow = capacity.Models[0].ContextWindow + } + pct := 0 + if contextWindow > 0 { + pct = int((float64(capacity.InputTokens) / float64(contextWindow)) * 100) + } + h.statusStore.Update(seq, status.Snapshot{ + Request: status.RequestSnapshot{ + RequestID: requestID, + Streaming: streaming, + }, + Routing: status.RoutingSnapshot{ + Scenario: string(routeResult.Scenario), + ModelID: modelID, + }, + Context: status.ContextSnapshot{ + InputTokens: capacity.InputTokens, + MaxTokens: contextWindow, + Percent: pct, + }, + Models: status.ModelsSnapshot{ + SkippedFallbacks: capacity.Skipped, + }, + }) +} + // sendError sends an error response in Anthropic format. // Safe to call multiple times - subsequent calls are no-ops. func (h *MessagesHandler) sendError(w http.ResponseWriter, statusCode int, message string, err error) { diff --git a/internal/handlers/token_count.go b/internal/handlers/token_count.go index e70c5a8..4a80e7f 100644 --- a/internal/handlers/token_count.go +++ b/internal/handlers/token_count.go @@ -12,9 +12,11 @@ import ( func tokenMessagesFromAnthropic(messages []types.Message) []token.MessageContent { tokenMessages := make([]token.MessageContent, 0, len(messages)) for _, msg := range messages { + blocks := msg.ContentBlocks() tokenMessages = append(tokenMessages, token.MessageContent{ - Role: msg.Role, - Content: extractTokenTextFromBlocks(msg.ContentBlocks()), + Role: msg.Role, + Content: extractTokenTextFromBlocks(blocks), + ExtraTokens: imageTokenEstimate(blocks), }) } return tokenMessages @@ -72,3 +74,13 @@ func extractTokenTextFromBlocks(blocks []types.ContentBlock) string { } return content.String() } + +func imageTokenEstimate(blocks []types.ContentBlock) int { + total := 0 + for _, block := range blocks { + if block.Type == "image" { + total += 1500 + } + } + return total +} diff --git a/internal/router/capacity.go b/internal/router/capacity.go new file mode 100644 index 0000000..010f27d --- /dev/null +++ b/internal/router/capacity.go @@ -0,0 +1,83 @@ +package router + +import ( + "fmt" + + "oc-go-cc/internal/config" +) + +const minimumOutputTokens = 256 + +type SkippedModel struct { + ModelID string `json:"model_id"` + Reason string `json:"reason"` +} + +type CapacityDecision struct { + Models []config.ModelConfig + Skipped []SkippedModel + InputTokens int + RequestedMaxTokens int + SelectedMaxTokens int + ContextWindow int + ContextMargin int + NeedsVision bool + NeedsTools bool +} + +func FilterByCapacity(chain []config.ModelConfig, inputTokens int, requestedMaxTokens int, needsVision bool, needsTools bool) (CapacityDecision, error) { + decision := CapacityDecision{ + InputTokens: inputTokens, + RequestedMaxTokens: requestedMaxTokens, + NeedsVision: needsVision, + NeedsTools: needsTools, + } + + for _, raw := range chain { + model := config.ResolveModelConfig(raw) + if needsVision && !model.SupportsVision { + decision.Skipped = append(decision.Skipped, SkippedModel{ModelID: model.ModelID, Reason: "vision_not_supported"}) + continue + } + if needsTools && !config.SupportsTools(model) { + decision.Skipped = append(decision.Skipped, SkippedModel{ModelID: model.ModelID, Reason: "tools_not_supported"}) + continue + } + + sentMax := clampOutputTokens(model, inputTokens, requestedMaxTokens) + if sentMax < minimumOutputTokens { + decision.Skipped = append(decision.Skipped, SkippedModel{ModelID: model.ModelID, Reason: "context_window_exceeded"}) + continue + } + model.MaxTokens = sentMax + if len(decision.Models) == 0 { + decision.SelectedMaxTokens = sentMax + decision.ContextWindow = model.ContextWindow + decision.ContextMargin = model.ContextMargin + } + decision.Models = append(decision.Models, model) + } + + if len(decision.Models) == 0 { + return decision, fmt.Errorf("no eligible model for request capacity") + } + return decision, nil +} + +func clampOutputTokens(model config.ModelConfig, inputTokens int, requestedMaxTokens int) int { + limit := model.MaxTokens + if requestedMaxTokens > 0 && (limit == 0 || requestedMaxTokens < limit) { + limit = requestedMaxTokens + } + if model.MaxOutputTokens > 0 && (limit == 0 || model.MaxOutputTokens < limit) { + limit = model.MaxOutputTokens + } + if model.ContextWindow <= 0 { + return limit + } + remaining := model.ContextWindow - inputTokens - model.ContextMargin + if limit == 0 || remaining < limit { + limit = remaining + } + return limit +} diff --git a/internal/router/capacity_test.go b/internal/router/capacity_test.go new file mode 100644 index 0000000..6e6dc57 --- /dev/null +++ b/internal/router/capacity_test.go @@ -0,0 +1,56 @@ +package router + +import ( + "testing" + + "oc-go-cc/internal/config" +) + +func TestFilterByCapacitySkipsPrimaryAndUsesEligibleFallback(t *testing.T) { + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "glm-5.1", MaxTokens: 8192}, + {Provider: "opencode-go", ModelID: "deepseek-v4-pro", MaxTokens: 8192}, + } + + decision, err := FilterByCapacity(chain, 250000, 8192, false, false) + if err != nil { + t.Fatalf("FilterByCapacity() error = %v", err) + } + if got, want := decision.Models[0].ModelID, "deepseek-v4-pro"; got != want { + t.Fatalf("selected model = %s, want %s", got, want) + } + if len(decision.Skipped) != 1 || decision.Skipped[0].Reason != "context_window_exceeded" { + t.Fatalf("skipped = %+v, want context skip", decision.Skipped) + } +} + +func TestFilterByCapacityRejectsVisionFallbackToTextModel(t *testing.T) { + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "deepseek-v4-pro", MaxTokens: 8192}, + } + + decision, err := FilterByCapacity(chain, 1000, 8192, true, false) + if err == nil { + t.Fatal("FilterByCapacity() error = nil, want error") + } + if len(decision.Models) != 0 { + t.Fatalf("eligible models = %+v, want none", decision.Models) + } + if len(decision.Skipped) != 1 || decision.Skipped[0].Reason != "vision_not_supported" { + t.Fatalf("skipped = %+v, want vision skip", decision.Skipped) + } +} + +func TestFilterByCapacityClampsMaxTokens(t *testing.T) { + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "kimi-k2.6", MaxTokens: 16384}, + } + + decision, err := FilterByCapacity(chain, 240000, 16384, true, false) + if err != nil { + t.Fatalf("FilterByCapacity() error = %v", err) + } + if got, want := decision.Models[0].MaxTokens, 256000-240000-config.DefaultContextMargin; got != want { + t.Fatalf("max_tokens = %d, want %d", got, want) + } +} diff --git a/internal/router/model_router.go b/internal/router/model_router.go index 231f6e7..5d1e7dd 100644 --- a/internal/router/model_router.go +++ b/internal/router/model_router.go @@ -28,9 +28,9 @@ type RouteResult struct { // resolveRequestedModel checks if the user-specified model should override // scenario-based routing. Returns the route result and true if it matched, // or zero value and false if scenario routing should proceed normally. -func (r *ModelRouter) resolveRequestedModel(cfg *config.Config, requestedModel string) (RouteResult, bool) { +func (r *ModelRouter) resolveRequestedModel(cfg *config.Config, requestedModel string, needsVision bool) (RouteResult, bool, error) { if !cfg.RespectRequestedModel || requestedModel == "" { - return RouteResult{}, false + return RouteResult{}, false, nil } // Look up the requested model in config to inherit its settings @@ -46,22 +46,29 @@ func (r *ModelRouter) resolveRequestedModel(cfg *config.Config, requestedModel s primary.MaxTokens = def.MaxTokens } } + primary = config.ResolveModelConfig(primary) + if needsVision && !primary.SupportsVision { + return RouteResult{}, false, fmt.Errorf("requested model %s does not support vision", primary.ModelID) + } - fallbacks := cfg.Fallbacks["default"] + fallbacks := normalizeModels(cfg.Fallbacks["default"]) return RouteResult{ Primary: primary, Fallbacks: fallbacks, Scenario: ScenarioDefault, - }, true + }, true, nil } // Route determines which model to use for a request. // If respect_requested_model is enabled and requestedModel is provided, it overrides scenario-based routing. func (r *ModelRouter) Route(messages []MessageContent, tokenCount int, requestedModel string) (RouteResult, error) { cfg := r.atomic.Get() + facts := AnalyzeRequestFacts(messages) - if result, ok := r.resolveRequestedModel(cfg, requestedModel); ok { + if result, ok, err := r.resolveRequestedModel(cfg, requestedModel, facts.NeedsVision); err != nil { + return RouteResult{}, err + } else if ok { return result, nil } @@ -71,18 +78,35 @@ func (r *ModelRouter) Route(messages []MessageContent, tokenCount int, requested // Get primary model for scenario primary, ok := cfg.Models[string(result.Scenario)] if !ok { + if isVisionScenario(result.Scenario) { + return RouteResult{}, fmt.Errorf("vision scenario %s is not configured", result.Scenario) + } // Fall back to default if scenario model not configured primary, ok = cfg.Models["default"] if !ok { return RouteResult{}, fmt.Errorf("no default model configured") } } + primary = config.ResolveModelConfig(primary) + if isVisionScenario(result.Scenario) && !primary.SupportsVision { + return RouteResult{}, fmt.Errorf("vision scenario %s primary model %s does not support vision", result.Scenario, primary.ModelID) + } // Get fallbacks for scenario - fallbacks := cfg.Fallbacks[string(result.Scenario)] + fallbacks := normalizeModels(cfg.Fallbacks[string(result.Scenario)]) if len(fallbacks) == 0 { + if isVisionScenario(result.Scenario) { + return RouteResult{}, fmt.Errorf("vision scenario %s has no configured vision fallbacks", result.Scenario) + } // Fall back to default fallbacks - fallbacks = cfg.Fallbacks["default"] + fallbacks = normalizeModels(cfg.Fallbacks["default"]) + } + if isVisionScenario(result.Scenario) { + for _, fallback := range fallbacks { + if !fallback.SupportsVision { + return RouteResult{}, fmt.Errorf("vision scenario %s fallback model %s does not support vision", result.Scenario, fallback.ModelID) + } + } } return RouteResult{ @@ -110,8 +134,9 @@ func (rr *RouteResult) GetModelChain() []config.ModelConfig { // If respect_requested_model is enabled and requestedModel is provided, it overrides scenario-based routing. func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount int, requestedModel string) RouteResult { cfg := r.atomic.Get() + facts := AnalyzeRequestFacts(messages) - if result, ok := r.resolveRequestedModel(cfg, requestedModel); ok { + if result, ok, err := r.resolveRequestedModel(cfg, requestedModel, facts.NeedsVision); err == nil && ok { return result } @@ -121,6 +146,9 @@ func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount in // Get primary model for scenario primary, ok := cfg.Models[string(result.Scenario)] if !ok { + if isVisionScenario(result.Scenario) { + return RouteResult{Scenario: result.Scenario} + } // Fall back to fast scenario if not configured primary, ok = cfg.Models["fast"] if !ok { @@ -128,12 +156,17 @@ func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount in primary = cfg.Models["default"] } } + primary = config.ResolveModelConfig(primary) // Get fallbacks for scenario - fallbacks := cfg.Fallbacks[string(result.Scenario)] + fallbacks := normalizeModels(cfg.Fallbacks[string(result.Scenario)]) if len(fallbacks) == 0 { - // Fall back to fast fallbacks - fallbacks = cfg.Fallbacks["fast"] + if isVisionScenario(result.Scenario) { + fallbacks = nil + } else { + // Fall back to fast fallbacks + fallbacks = normalizeModels(cfg.Fallbacks["fast"]) + } } return RouteResult{ @@ -142,3 +175,18 @@ func (r *ModelRouter) RouteForStreaming(messages []MessageContent, tokenCount in Scenario: result.Scenario, } } + +func isVisionScenario(s Scenario) bool { + return s == ScenarioVision || s == ScenarioVisionComplex || s == ScenarioVisionLongContext +} + +func normalizeModels(models []config.ModelConfig) []config.ModelConfig { + if len(models) == 0 { + return nil + } + out := make([]config.ModelConfig, 0, len(models)) + for _, model := range models { + out = append(out, config.ResolveModelConfig(model)) + } + return out +} diff --git a/internal/router/model_router_test.go b/internal/router/model_router_test.go index af749aa..3177f7d 100644 --- a/internal/router/model_router_test.go +++ b/internal/router/model_router_test.go @@ -221,7 +221,10 @@ func TestResolveRequestedModel_UsesFallbacks(t *testing.T) { router := NewModelRouter(newTestAtomicConfig(cfg)) - result, ok := router.resolveRequestedModel(cfg, "kimi-k2.6") + result, ok, err := router.resolveRequestedModel(cfg, "kimi-k2.6", false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if !ok { t.Fatal("expected resolveRequestedModel to match") } @@ -232,3 +235,58 @@ func TestResolveRequestedModel_UsesFallbacks(t *testing.T) { t.Errorf("expected first fallback qwen3.5-plus, got %s", result.Fallbacks[0].ModelID) } } + +func TestRoute_VisionScenarioRequiresConfiguredVisionModel(t *testing.T) { + cfg := &config.Config{ + Models: map[string]config.ModelConfig{ + "default": {ModelID: "deepseek-v4-pro"}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "default": {{ModelID: "deepseek-v4-flash"}}, + }, + } + router := NewModelRouter(newTestAtomicConfig(cfg)) + _, err := router.Route([]MessageContent{ + {Role: "user", Content: "Cosa vedi?", HasImage: true, ImageHashes: []string{"img1"}}, + }, 100, "") + if err == nil { + t.Fatal("expected missing vision model to return an error") + } +} + +func TestRoute_VisionFallbacksMustSupportVision(t *testing.T) { + cfg := &config.Config{ + Models: map[string]config.ModelConfig{ + "default": {ModelID: "deepseek-v4-pro"}, + "vision": {ModelID: "qwen3.6-plus", SupportsVision: true}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "vision": {{ModelID: "deepseek-v4-pro", SupportsVision: false}}, + }, + } + router := NewModelRouter(newTestAtomicConfig(cfg)) + _, err := router.Route([]MessageContent{ + {Role: "user", Content: "Cosa vedi?", HasImage: true, ImageHashes: []string{"img1"}}, + }, 100, "") + if err == nil { + t.Fatal("expected text-only vision fallback to return an error") + } +} + +func TestRoute_RespectRequestedTextModelRejectsVisionRequest(t *testing.T) { + cfg := &config.Config{ + RespectRequestedModel: true, + Models: map[string]config.ModelConfig{ + "default": {ModelID: "deepseek-v4-pro"}, + "deepseek-v4-pro": {ModelID: "deepseek-v4-pro", SupportsVision: false}, + "vision": {ModelID: "qwen3.6-plus", SupportsVision: true}, + }, + } + router := NewModelRouter(newTestAtomicConfig(cfg)) + _, err := router.Route([]MessageContent{ + {Role: "user", Content: "Cosa vedi?", HasImage: true, ImageHashes: []string{"img1"}}, + }, 100, "deepseek-v4-pro") + if err == nil { + t.Fatal("expected requested text-only model to reject vision request") + } +} diff --git a/internal/router/scenarios.go b/internal/router/scenarios.go index ba3c5af..1117a42 100644 --- a/internal/router/scenarios.go +++ b/internal/router/scenarios.go @@ -11,12 +11,15 @@ import ( type Scenario string const ( - ScenarioDefault Scenario = "default" - ScenarioBackground Scenario = "background" - ScenarioThink Scenario = "think" - ScenarioComplex Scenario = "complex" - ScenarioLongContext Scenario = "long_context" - ScenarioFast Scenario = "fast" + ScenarioDefault Scenario = "default" + ScenarioBackground Scenario = "background" + ScenarioThink Scenario = "think" + ScenarioComplex Scenario = "complex" + ScenarioLongContext Scenario = "long_context" + ScenarioFast Scenario = "fast" + ScenarioVision Scenario = "vision" + ScenarioVisionComplex Scenario = "vision_complex" + ScenarioVisionLongContext Scenario = "vision_long_context" ) // ScenarioResult contains the detected scenario and token count. @@ -28,8 +31,19 @@ type ScenarioResult struct { // MessageContent represents a single message in a conversation. type MessageContent struct { - Role string - Content string + Role string + Content string + HasImage bool + ImageHashes []string +} + +type RequestFacts struct { + LatestUserText string + LatestUserHasImage bool + AnyHistoricalImage bool + LatestTextVisualIntent bool + LatestTextComplexIntent bool + NeedsVision bool } // DetectScenario analyzes a request to determine which model to use. @@ -42,9 +56,17 @@ type MessageContent struct { // // For streaming requests, consider using RouteForStreaming() to prefer faster models. func DetectScenario(messages []MessageContent, tokenCount int, cfg *config.Config) ScenarioResult { + facts := AnalyzeRequestFacts(messages) // 1. Check for long context first (most important) threshold := getLongContextThreshold(cfg) if tokenCount > threshold { + if facts.NeedsVision { + return ScenarioResult{ + Scenario: ScenarioVisionLongContext, + TokenCount: tokenCount, + Reason: fmt.Sprintf("image request token count %d exceeds threshold %d", tokenCount, threshold), + } + } return ScenarioResult{ Scenario: ScenarioLongContext, TokenCount: tokenCount, @@ -52,8 +74,24 @@ func DetectScenario(messages []MessageContent, tokenCount int, cfg *config.Confi } } + if facts.NeedsVision { + if facts.LatestTextComplexIntent { + return ScenarioResult{ + Scenario: ScenarioVisionComplex, + TokenCount: tokenCount, + Reason: "complex image request detected", + } + } + return ScenarioResult{ + Scenario: ScenarioVision, + TokenCount: tokenCount, + Reason: "simple image request detected", + } + } + // 2. Check for complex tasks (architectural OR tool-related) - if hasComplexPattern(messages) { + latestUser := latestUserMessages(messages) + if hasComplexPattern(latestUser) { return ScenarioResult{ Scenario: ScenarioComplex, TokenCount: tokenCount, @@ -62,7 +100,7 @@ func DetectScenario(messages []MessageContent, tokenCount int, cfg *config.Confi } // 3. Check for thinking/reasoning patterns - if hasThinkingPattern(messages) { + if hasThinkingPattern(latestUser) { return ScenarioResult{ Scenario: ScenarioThink, TokenCount: tokenCount, @@ -87,6 +125,106 @@ func DetectScenario(messages []MessageContent, tokenCount int, cfg *config.Confi } } +func hasImage(messages []MessageContent) bool { + for _, msg := range messages { + if msg.HasImage { + return true + } + } + return false +} + +func latestUserHasImage(messages []MessageContent) bool { + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + return messages[i].HasImage + } + } + return false +} + +func latestImageUserMessages(messages []MessageContent) []MessageContent { + for i := len(messages) - 1; i >= 0; i-- { + msg := messages[i] + if msg.Role == "user" && msg.HasImage { + return []MessageContent{msg} + } + } + return messages +} + +func AnalyzeRequestFacts(messages []MessageContent) RequestFacts { + facts := RequestFacts{} + latestIdx := -1 + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + latestIdx = i + break + } + } + if latestIdx == -1 { + return facts + } + + latest := messages[latestIdx] + facts.LatestUserText = latest.Content + facts.LatestUserHasImage = latest.HasImage && imageHashesAreNewForLatest(messages, latestIdx) + facts.LatestTextVisualIntent = hasVisualIntent(latest.Content) + facts.LatestTextComplexIntent = hasComplexPattern([]MessageContent{latest}) || hasThinkingPattern([]MessageContent{latest}) + + for i, msg := range messages { + if i != latestIdx && msg.Role == "user" && msg.HasImage { + facts.AnyHistoricalImage = true + break + } + } + + facts.NeedsVision = facts.LatestUserHasImage || (facts.AnyHistoricalImage && facts.LatestTextVisualIntent) + return facts +} + +func imageHashesAreNewForLatest(messages []MessageContent, latestIdx int) bool { + latest := messages[latestIdx] + if len(latest.ImageHashes) == 0 { + return latest.HasImage + } + seen := map[string]bool{} + for i := 0; i < latestIdx; i++ { + for _, hash := range messages[i].ImageHashes { + seen[hash] = true + } + } + for _, hash := range latest.ImageHashes { + if !seen[hash] { + return true + } + } + return false +} + +func latestUserMessages(messages []MessageContent) []MessageContent { + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + return []MessageContent{messages[i]} + } + } + return nil +} + +func hasVisualIntent(content string) bool { + visualKeywords := []string{ + "image", "screenshot", "screen", "schermata", "immagine", "foto", + "allegato", "[image", "vedi", "visual", "ui", "layout", + } + lower := strings.ToLower(content) + for _, kw := range visualKeywords { + if strings.Contains(lower, kw) { + return true + } + } + return false +} + // hasComplexPattern looks for complex operations that need more capable models. // This includes tool-based operations (executing functions, writing/editing files, etc.) func hasComplexPattern(messages []MessageContent) bool { @@ -96,6 +234,7 @@ func hasComplexPattern(messages []MessageContent) bool { "complex", "difficult", "challenging", "optimize", "performance", "efficiency", "design pattern", "best practice", + "bug", "debug", "error", "exception", "stack trace", // Tool-related keywords indicate complex operations "execute", "run command", "bash", "shell", "implement", "build", "create", "add feature", @@ -195,11 +334,19 @@ func getLongContextThreshold(cfg *config.Config) int { // For streaming, we prioritize fast TTFT (time-to-first-token) over capability. // This may return a less capable model but one that streams faster. func RouteForStreaming(messages []MessageContent, tokenCount int, cfg *config.Config) ScenarioResult { + facts := AnalyzeRequestFacts(messages) // For streaming, use simpler models that have better TTFT // Complex models (GLM, Kimi) are too slow for streaming with many tools threshold := getLongContextThreshold(cfg) if tokenCount > threshold { + if facts.NeedsVision { + return ScenarioResult{ + Scenario: ScenarioVisionLongContext, + TokenCount: tokenCount, + Reason: fmt.Sprintf("high token count image request (%d > %d)", tokenCount, threshold), + } + } model := "long_context" if cfg != nil { if lc, ok := cfg.Models["long_context"]; ok && lc.ModelID != "" { @@ -213,7 +360,23 @@ func RouteForStreaming(messages []MessageContent, tokenCount int, cfg *config.Co } } - if hasComplexPattern(messages) || hasThinkingPattern(messages) { + if facts.NeedsVision { + if facts.LatestTextComplexIntent { + return ScenarioResult{ + Scenario: ScenarioVisionComplex, + TokenCount: tokenCount, + Reason: "complex image request detected", + } + } + return ScenarioResult{ + Scenario: ScenarioVision, + TokenCount: tokenCount, + Reason: "simple image request detected", + } + } + + latestUser := latestUserMessages(messages) + if hasComplexPattern(latestUser) || hasThinkingPattern(latestUser) { // Complex request but streaming - downgrade to faster model // GLM-5 and Kimi are too slow for streaming with complex prompts return ScenarioResult{ diff --git a/internal/router/scenarios_test.go b/internal/router/scenarios_test.go index 998ef3a..e2758dc 100644 --- a/internal/router/scenarios_test.go +++ b/internal/router/scenarios_test.go @@ -113,6 +113,118 @@ func TestDetectScenario_LongContextTakesPriority(t *testing.T) { } } +func TestDetectScenario_VisionSimpleRequest(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Describe this screen", HasImage: true}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioVision { + t.Errorf("Expected ScenarioVision, got %s", result.Scenario) + } +} + +func TestDetectScenario_VisionComplexRequest(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Analyze this screenshot and find the bug", HasImage: true}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioVisionComplex { + t.Errorf("Expected ScenarioVisionComplex, got %s", result.Scenario) + } +} + +func TestDetectScenario_VisionUsesLatestImageRequestComplexity(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Analyze this previous architecture"}, + {Role: "assistant", Content: "Done"}, + {Role: "user", Content: "Cosa vedi?", HasImage: true}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioVision { + t.Errorf("Expected ScenarioVision, got %s", result.Scenario) + } +} + +func TestDetectScenario_ReturnsToTextRoutingAfterImageTurn(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Analyze this screenshot and find the bug", HasImage: true}, + {Role: "assistant", Content: "Done"}, + {Role: "user", Content: "Refactor this code"}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioComplex { + t.Errorf("Expected ScenarioComplex, got %s", result.Scenario) + } +} + +func TestDetectScenario_ReturnsToTextRoutingWhenLatestTurnHasHistoricalImageOnly(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Cosa vedi?", HasImage: true, ImageHashes: []string{"img1"}}, + {Role: "assistant", Content: "Vedo una schermata."}, + {Role: "user", Content: "ci sei?", HasImage: true, ImageHashes: []string{"img1"}}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioDefault { + t.Errorf("Expected ScenarioDefault, got %s", result.Scenario) + } +} + +func TestDetectScenario_UsesHistoricalImageWhenLatestTextHasVisualIntent(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Cosa vedi?", HasImage: true, ImageHashes: []string{"img1"}}, + {Role: "assistant", Content: "Vedo una schermata."}, + {Role: "user", Content: "cosa vedi nello screenshot?", HasImage: true, ImageHashes: []string{"img1"}}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioVision { + t.Errorf("Expected ScenarioVision, got %s", result.Scenario) + } +} + +func TestDetectScenario_DebugWithoutVisualIntentStaysTextComplex(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Cosa vedi?", HasImage: true, ImageHashes: []string{"img1"}}, + {Role: "assistant", Content: "Vedo una schermata."}, + {Role: "user", Content: "debug questo codice", HasImage: false}, + } + result := DetectScenario(messages, 100, mockConfig()) + if result.Scenario != ScenarioComplex { + t.Errorf("Expected ScenarioComplex, got %s", result.Scenario) + } +} + +func TestRouteForStreaming_ReturnsToFastAfterImageTurn(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Cosa vedi?", HasImage: true}, + {Role: "assistant", Content: "Done"}, + {Role: "user", Content: "Hello"}, + } + result := RouteForStreaming(messages, 100, mockConfig()) + if result.Scenario != ScenarioFast { + t.Errorf("Expected ScenarioFast, got %s", result.Scenario) + } +} + +func TestDetectScenario_VisionLongContextTakesPriorityOverVisionComplex(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Analyze this screenshot and refactor the code", HasImage: true}, + } + result := DetectScenario(messages, 70000, mockConfig()) + if result.Scenario != ScenarioVisionLongContext { + t.Errorf("Expected ScenarioVisionLongContext, got %s", result.Scenario) + } +} + +func TestRouteForStreaming_VisionComplexKeepsVisionComplexScenario(t *testing.T) { + messages := []MessageContent{ + {Role: "user", Content: "Find the bug in this screenshot", HasImage: true}, + } + result := RouteForStreaming(messages, 100, mockConfig()) + if result.Scenario != ScenarioVisionComplex { + t.Errorf("Expected ScenarioVisionComplex, got %s", result.Scenario) + } +} + func TestRouteForStreaming_RespectsConfiguredThreshold(t *testing.T) { messages := []MessageContent{ {Role: "user", Content: "Hello"}, diff --git a/internal/server/server.go b/internal/server/server.go index 0c96909..5e8bf57 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -16,6 +16,7 @@ import ( "oc-go-cc/internal/handlers" "oc-go-cc/internal/metrics" "oc-go-cc/internal/router" + "oc-go-cc/internal/status" "oc-go-cc/internal/token" ) @@ -50,6 +51,7 @@ func NewServer(atomic *config.AtomicConfig) (*Server, error) { openCodeClient := client.NewOpenCodeClient(atomic) modelRouter := router.NewModelRouter(atomic) fallbackHandler := router.NewFallbackHandler(logger, 3, 30*time.Second) + statusStore := status.NewStore(10 * time.Second) // Create handlers. messagesHandler := handlers.NewMessagesHandler( @@ -58,8 +60,9 @@ func NewServer(atomic *config.AtomicConfig) (*Server, error) { fallbackHandler, tokenCounter, metrics, + statusStore, ) - healthHandler := handlers.NewHealthHandler(tokenCounter, fallbackHandler, metrics) + healthHandler := handlers.NewHealthHandler(tokenCounter, fallbackHandler, metrics, statusStore) // Setup router. mux := http.NewServeMux() @@ -68,6 +71,7 @@ func NewServer(atomic *config.AtomicConfig) (*Server, error) { mux.HandleFunc("/v1/messages", messagesHandler.HandleMessages) mux.HandleFunc("/v1/messages/count_tokens", healthHandler.HandleCountTokens) mux.HandleFunc("/health", healthHandler.HandleHealth) + mux.HandleFunc("/statusline", healthHandler.HandleStatusline) // Create HTTP server. addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) diff --git a/internal/status/store.go b/internal/status/store.go new file mode 100644 index 0000000..44b8239 --- /dev/null +++ b/internal/status/store.go @@ -0,0 +1,105 @@ +package status + +import ( + "sync" + "time" + + "oc-go-cc/internal/buildinfo" + "oc-go-cc/internal/router" +) + +type Snapshot struct { + SchemaVersion int `json:"schema_version"` + UpdatedAt string `json:"updated_at"` + AgeMS int64 `json:"age_ms"` + Source string `json:"source"` + Stale bool `json:"stale"` + Proxy ProxySnapshot `json:"proxy"` + Request RequestSnapshot `json:"request"` + Routing RoutingSnapshot `json:"routing"` + Context ContextSnapshot `json:"context"` + Models ModelsSnapshot `json:"models"` +} + +type ProxySnapshot struct { + Version string `json:"version"` + PID int `json:"pid"` + Binary string `json:"binary"` +} + +type RequestSnapshot struct { + RequestID string `json:"request_id"` + Streaming bool `json:"streaming"` +} + +type RoutingSnapshot struct { + Scenario string `json:"scenario"` + ModelID string `json:"model_id"` +} + +type ContextSnapshot struct { + InputTokens int `json:"input_tokens"` + MaxTokens int `json:"max_tokens"` + Percent int `json:"pct"` +} + +type ModelsSnapshot struct { + SkippedFallbacks []router.SkippedModel `json:"skipped_fallbacks,omitempty"` +} + +type Store struct { + mu sync.RWMutex + seq uint64 + updated time.Time + snapshot Snapshot + ttl time.Duration +} + +func NewStore(ttl time.Duration) *Store { + if ttl <= 0 { + ttl = 10 * time.Second + } + return &Store{ttl: ttl} +} + +func (s *Store) Update(seq uint64, snap Snapshot) { + s.mu.Lock() + defer s.mu.Unlock() + if seq < s.seq { + return + } + now := time.Now().UTC() + s.seq = seq + s.updated = now + snap.SchemaVersion = 1 + snap.UpdatedAt = now.Format(time.RFC3339Nano) + snap.AgeMS = 0 + snap.Source = "proxy" + snap.Proxy = ProxySnapshot{ + Version: buildinfo.Version, + PID: buildinfo.PID(), + Binary: buildinfo.BinaryPath(), + } + s.snapshot = snap +} + +func (s *Store) Snapshot() Snapshot { + s.mu.RLock() + defer s.mu.RUnlock() + snap := s.snapshot + if s.updated.IsZero() { + snap.SchemaVersion = 1 + snap.Source = "empty" + snap.Stale = true + snap.Proxy = ProxySnapshot{ + Version: buildinfo.Version, + PID: buildinfo.PID(), + Binary: buildinfo.BinaryPath(), + } + return snap + } + age := time.Since(s.updated) + snap.AgeMS = age.Milliseconds() + snap.Stale = age > s.ttl + return snap +} diff --git a/internal/token/counter.go b/internal/token/counter.go index 323e0bb..15de0d1 100644 --- a/internal/token/counter.go +++ b/internal/token/counter.go @@ -52,8 +52,9 @@ func (c *Counter) CountTokens(text string) (int, error) { // MessageContent represents a single message in a conversation. type MessageContent struct { - Role string - Content string + Role string + Content string + ExtraTokens int } // CountMessages counts tokens in a message array. @@ -76,6 +77,7 @@ func (c *Counter) CountMessages(system string, messages []MessageContent) (int, return 0, err } total += msgTokens + 5 // Per-message overhead + total += msg.ExtraTokens } return total, nil diff --git a/internal/token/counter_test.go b/internal/token/counter_test.go index 151c0d0..11691a4 100644 --- a/internal/token/counter_test.go +++ b/internal/token/counter_test.go @@ -3,6 +3,7 @@ package token import ( "os" "path/filepath" + "runtime" "testing" ) @@ -52,6 +53,9 @@ func TestDefaultCacheDir(t *testing.T) { } func TestDefaultCacheDir_HomeDirFallback(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Windows UserHomeDir does not depend only on HOME") + } // When UserHomeDir fails (HOME unset), fall back to temp dir. t.Setenv("TIKTOKEN_CACHE_DIR", "") t.Setenv("DATA_GYM_CACHE_DIR", "") diff --git a/internal/transformer/request.go b/internal/transformer/request.go index 7684f6b..f39f865 100644 --- a/internal/transformer/request.go +++ b/internal/transformer/request.go @@ -47,7 +47,7 @@ func (t *RequestTransformer) TransformRequest( model config.ModelConfig, ) (*types.ChatCompletionRequest, error) { // Transform messages - messages, err := t.transformMessages(anthropicReq, model.ModelID) + messages, err := t.transformMessages(anthropicReq, model) if err != nil { return nil, fmt.Errorf("failed to transform messages: %w", err) } @@ -159,7 +159,7 @@ func HasThinkingBlocks(messages []types.Message) bool { } // transformMessages converts Anthropic messages to OpenAI format. -func (t *RequestTransformer) transformMessages(anthropicReq *types.MessageRequest, modelID string) ([]types.ChatMessage, error) { +func (t *RequestTransformer) transformMessages(anthropicReq *types.MessageRequest, model config.ModelConfig) ([]types.ChatMessage, error) { hasThinking := HasThinkingBlocks(anthropicReq.Messages) var result []types.ChatMessage @@ -171,24 +171,12 @@ func (t *RequestTransformer) transformMessages(anthropicReq *types.MessageReques Role: "system", Content: systemText, } - // Try to extract cache_control from system array blocks - if len(anthropicReq.System) > 0 { - var blocks []types.SystemContentBlock - if err := json.Unmarshal(anthropicReq.System, &blocks); err == nil { - for _, b := range blocks { - if b.Type == "text" && b.CacheControl != nil { - systemMsg.CacheControl = b.CacheControl - break - } - } - } - } result = append(result, systemMsg) } // Transform each message for _, msg := range anthropicReq.Messages { - openaiMsgs, err := t.transformMessage(msg, modelID, hasThinking) + openaiMsgs, err := t.transformMessage(msg, model, hasThinking) if err != nil { return nil, err } @@ -200,14 +188,14 @@ func (t *RequestTransformer) transformMessages(anthropicReq *types.MessageReques // transformMessage converts a single Anthropic message to one or more OpenAI messages. // Tool_use and tool_result require special handling to map to OpenAI's function calling format. -func (t *RequestTransformer) transformMessage(msg types.Message, modelID string, hasThinkingInHistory bool) ([]types.ChatMessage, error) { +func (t *RequestTransformer) transformMessage(msg types.Message, model config.ModelConfig, hasThinkingInHistory bool) ([]types.ChatMessage, error) { blocks := msg.ContentBlocks() switch msg.Role { case "user": - return t.transformUserMessage(blocks) + return t.transformUserMessage(blocks, model.SupportsVision) case "assistant": - return t.transformAssistantMessage(blocks, modelID, hasThinkingInHistory) + return t.transformAssistantMessage(blocks, model.ModelID, hasThinkingInHistory) default: // Fallback: concatenate all text var text string @@ -221,14 +209,19 @@ func (t *RequestTransformer) transformMessage(msg types.Message, modelID string, } // transformUserMessage converts a user message with potential tool_result blocks. -func (t *RequestTransformer) transformUserMessage(blocks []types.ContentBlock) ([]types.ChatMessage, error) { +func (t *RequestTransformer) transformUserMessage(blocks []types.ContentBlock, supportsVision bool) ([]types.ChatMessage, error) { var result []types.ChatMessage var textParts []string + var contentParts []types.ContentPart for _, block := range blocks { switch block.Type { case "text": textParts = append(textParts, block.Text) + contentParts = append(contentParts, types.ContentPart{ + Type: "text", + Text: block.Text, + }) case "tool_result": // In OpenAI, tool results are separate messages with role "tool" toolContent := block.TextContent() @@ -238,13 +231,34 @@ func (t *RequestTransformer) transformUserMessage(blocks []types.ContentBlock) ( ToolCallID: block.GetToolID(), }) case "image": - // Images not supported in text-only models, skip - textParts = append(textParts, "[Image]") + if !supportsVision { + textParts = append(textParts, "[Image omitted for text-only model]") + contentParts = append(contentParts, types.ContentPart{ + Type: "text", + Text: "[Image omitted for text-only model]", + }) + continue + } + if block.Source != nil && block.Source.Type == "base64" && block.Source.MediaType != "" && block.Source.Data != "" { + contentParts = append(contentParts, types.ContentPart{ + Type: "image_url", + ImageURL: &types.ImageURL{ + URL: "data:" + block.Source.MediaType + ";base64," + block.Source.Data, + }, + }) + } else if block.Source != nil && block.Source.URL != "" { + contentParts = append(contentParts, types.ContentPart{ + Type: "image_url", + ImageURL: &types.ImageURL{ + URL: block.Source.URL, + }, + }) + } } } // If there's text content, add it as a user message - if len(textParts) > 0 { + if len(contentParts) > 0 { text := "" for _, p := range textParts { text += p @@ -253,7 +267,20 @@ func (t *RequestTransformer) transformUserMessage(blocks []types.ContentBlock) ( // immediately after the assistant message that emitted tool_calls. // If the Anthropic user turn also includes free-form text, emit it as // a subsequent user message after all tool results. - userMsg := types.ChatMessage{Role: "user", Content: text} + content := interface{}(text) + if len(contentParts) > 0 { + hasImage := false + for _, p := range contentParts { + if p.Type == "image_url" { + hasImage = true + break + } + } + if hasImage { + content = contentParts + } + } + userMsg := types.ChatMessage{Role: "user", Content: content} result = append(result, userMsg) } diff --git a/internal/transformer/request_test.go b/internal/transformer/request_test.go index 486644d..9af505e 100644 --- a/internal/transformer/request_test.go +++ b/internal/transformer/request_test.go @@ -3,6 +3,7 @@ package transformer import ( "bytes" "encoding/json" + "strings" "testing" "oc-go-cc/internal/config" @@ -350,7 +351,7 @@ func TestTransformRequestStripsReasoningEffortWhenNoThinkingHistory(t *testing.T } } -func TestTransformRequestPreservesSystemCacheControl(t *testing.T) { +func TestTransformRequestOmitsAnthropicSystemCacheControlForOpenAIModels(t *testing.T) { transformer := NewRequestTransformer() req := &types.MessageRequest{ @@ -380,11 +381,8 @@ func TestTransformRequestPreservesSystemCacheControl(t *testing.T) { if got, want := systemMsg.Content, "You are helpful"; got != want { t.Fatalf("Messages[0].Content = %q, want %q", got, want) } - if systemMsg.CacheControl == nil { - t.Fatal("Messages[0].CacheControl = nil, want non-nil") - } - if got, want := systemMsg.CacheControl.Type, "ephemeral"; got != want { - t.Fatalf("Messages[0].CacheControl.Type = %q, want %q", got, want) + if systemMsg.CacheControl != nil { + t.Fatalf("Messages[0].CacheControl = %v, want nil", systemMsg.CacheControl) } } @@ -796,6 +794,100 @@ func TestTransformRequestExtractsThinkingFromToolUseBlock(t *testing.T) { } } +func TestTransformRequestPreservesUserImageBlocksAsOpenAIContentParts(t *testing.T) { + req := &types.MessageRequest{ + Model: "claude-3-5-sonnet", + MaxTokens: 1024, + Messages: []types.Message{ + { + Role: "user", + Content: json.RawMessage(`[ + {"type":"text","text":"Describe this screen"}, + {"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBORw0KGgo="}} + ]`), + }, + }, + } + + transformer := NewRequestTransformer() + openaiReq, err := transformer.TransformRequest(req, config.ModelConfig{ + Provider: "opencode-go", + ModelID: "qwen3.6-plus", + SupportsVision: true, + }) + if err != nil { + t.Fatalf("TransformRequest: %v", err) + } + + if len(openaiReq.Messages) != 1 { + t.Fatalf("Messages = %d, want 1", len(openaiReq.Messages)) + } + + contentJSON, err := json.Marshal(openaiReq.Messages[0].Content) + if err != nil { + t.Fatalf("marshal content: %v", err) + } + + var parts []map[string]any + if err := json.Unmarshal(contentJSON, &parts); err != nil { + t.Fatalf("content = %s, want OpenAI content parts: %v", contentJSON, err) + } + if len(parts) != 2 { + t.Fatalf("content parts = %d, want 2: %s", len(parts), contentJSON) + } + if got, want := parts[0]["type"], "text"; got != want { + t.Fatalf("first part type = %v, want %s", got, want) + } + if got, want := parts[0]["text"], "Describe this screen"; got != want { + t.Fatalf("first part text = %v, want %s", got, want) + } + + imageURL, ok := parts[1]["image_url"].(map[string]any) + if !ok { + t.Fatalf("second part image_url missing: %s", contentJSON) + } + if got, want := imageURL["url"], "data:image/png;base64,iVBORw0KGgo="; got != want { + t.Fatalf("image url = %v, want %s", got, want) + } +} + +func TestTransformRequest_TextOnlyModelOmitsImageURL(t *testing.T) { + req := &types.MessageRequest{ + Model: "claude-3-5-sonnet", + MaxTokens: 1024, + Messages: []types.Message{ + { + Role: "user", + Content: json.RawMessage(`[ + {"type":"text","text":"ci sei?"}, + {"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBORw0KGgo="}} + ]`), + }, + }, + } + + transformer := NewRequestTransformer() + openaiReq, err := transformer.TransformRequest(req, config.ModelConfig{ + Provider: "opencode-go", + ModelID: "deepseek-v4-pro", + SupportsVision: false, + }) + if err != nil { + t.Fatalf("TransformRequest: %v", err) + } + + body, err := json.Marshal(openaiReq) + if err != nil { + t.Fatalf("marshal request: %v", err) + } + if strings.Contains(string(body), "image_url") || strings.Contains(string(body), "data:image") { + t.Fatalf("text-only request leaked image content: %s", body) + } + if !strings.Contains(string(body), "[Image omitted for text-only model]") { + t.Fatalf("text-only request missing image placeholder: %s", body) + } +} + func mustJSONBytes(t *testing.T, v any) json.RawMessage { t.Helper() b, err := json.Marshal(v) diff --git a/internal/transformer/response.go b/internal/transformer/response.go index f9a0bdb..d17a34b 100644 --- a/internal/transformer/response.go +++ b/internal/transformer/response.go @@ -26,6 +26,40 @@ func NewResponseTransformer() *ResponseTransformer { return &ResponseTransformer{} } +func contentAsString(content interface{}) string { + switch v := content.(type) { + case string: + return v + case []types.ContentPart: + var text string + for _, part := range v { + if part.Type == "text" { + text += part.Text + } + } + return text + case []interface{}: + var text string + for _, part := range v { + if m, ok := part.(map[string]interface{}); ok { + if partText, ok := m["text"].(string); ok { + text += partText + } + } + } + return text + case map[string]interface{}: + if text, ok := v["text"].(string); ok { + return text + } + return "" + case nil: + return "" + default: + return "" + } +} + // TransformResponse converts an OpenAI ChatCompletionResponse to Anthropic MessageResponse. func (t *ResponseTransformer) TransformResponse( openaiResp *types.ChatCompletionResponse, @@ -104,10 +138,11 @@ func (t *ResponseTransformer) transformContent(msg types.ChatMessage) ([]types.C } // Handle text content. - if msg.Content != "" { + textContent := contentAsString(msg.Content) + if textContent != "" { blocks = append(blocks, types.ContentBlock{ Type: "text", - Text: msg.Content, + Text: textContent, }) } diff --git a/internal/transformer/response_test.go b/internal/transformer/response_test.go index dfffebf..32afe25 100644 --- a/internal/transformer/response_test.go +++ b/internal/transformer/response_test.go @@ -202,6 +202,45 @@ func TestTransformResponseNoReasoningContent(t *testing.T) { } } +func TestTransformResponseExtractsTextFromContentParts(t *testing.T) { + transformer := NewResponseTransformer() + + resp := &types.ChatCompletionResponse{ + ID: "chatcmpl_parts", + Object: "chat.completion", + Created: 1234567890, + Model: "qwen3.6-plus", + Choices: []types.Choice{ + { + Index: 0, + Message: types.ChatMessage{ + Role: "assistant", + Content: []interface{}{ + map[string]interface{}{"type": "text", "text": "Vedo uno screenshot."}, + }, + }, + FinishReason: "stop", + }, + }, + Usage: types.UsageInfo{ + PromptTokens: 10, + CompletionTokens: 5, + TotalTokens: 15, + }, + } + + anthropicResp, err := transformer.TransformResponse(resp, "qwen3.6-plus") + if err != nil { + t.Fatalf("TransformResponse() error = %v", err) + } + if got, want := len(anthropicResp.Content), 1; got != want { + t.Fatalf("len(Content) = %d, want %d", got, want) + } + if got, want := anthropicResp.Content[0].Text, "Vedo uno screenshot."; got != want { + t.Fatalf("Content[0].Text = %q, want %q", got, want) + } +} + func TestTransformResponseWithCacheTokens(t *testing.T) { transformer := NewResponseTransformer() diff --git a/internal/transformer/stream.go b/internal/transformer/stream.go index 47e15ea..b714583 100644 --- a/internal/transformer/stream.go +++ b/internal/transformer/stream.go @@ -30,6 +30,102 @@ func NewStreamHandler() *StreamHandler { } } +func (h *StreamHandler) EmitMessageResponse(w http.ResponseWriter, resp *types.MessageResponse) error { + flusher, ok := w.(http.Flusher) + if !ok { + return fmt.Errorf("streaming not supported by response writer") + } + if resp == nil { + return fmt.Errorf("nil message response") + } + msgStart := types.MessageEvent{ + Type: "message_start", + Message: resp, + } + if err := writeSSEEvent(w, msgStart); err != nil { + return ErrClientDisconnected + } + flusher.Flush() + + for i, block := range resp.Content { + idx := i + startBlock := block + switch block.Type { + case "text": + startBlock.Text = "" + case "thinking": + startBlock.Thinking = "" + case "tool_use": + startBlock.Input = json.RawMessage(`{}`) + } + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &startBlock, + }); err != nil { + return ErrClientDisconnected + } + switch block.Type { + case "text": + if block.Text != "" { + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "content_block_delta", + Index: &idx, + Delta: &types.Delta{Type: "text_delta", Text: block.Text}, + }); err != nil { + return ErrClientDisconnected + } + } + case "thinking": + if block.Thinking != "" { + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "content_block_delta", + Index: &idx, + Delta: &types.Delta{Type: "thinking_delta", Thinking: block.Thinking}, + }); err != nil { + return ErrClientDisconnected + } + } + case "tool_use": + if len(block.Input) > 0 { + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "content_block_delta", + Index: &idx, + Delta: &types.Delta{Type: "input_json_delta", PartialJSON: string(block.Input)}, + }); err != nil { + return ErrClientDisconnected + } + } + } + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "content_block_stop", + Index: &idx, + }); err != nil { + return ErrClientDisconnected + } + flusher.Flush() + } + + stopReason := resp.StopReason + if stopReason == "" { + stopReason = "end_turn" + } + if err := writeSSEEvent(w, types.MessageEvent{ + Type: "message_delta", + Delta: &types.Delta{ + StopReason: stopReason, + }, + Usage: &resp.Usage, + }); err != nil { + return ErrClientDisconnected + } + if err := writeSSEEvent(w, types.MessageEvent{Type: "message_stop"}); err != nil { + return ErrClientDisconnected + } + flusher.Flush() + return nil +} + // ProxyStream takes an OpenAI streaming response and writes Anthropic-format SSE to the writer. // It reads OpenAI ChatCompletionChunk SSE events and transforms them into Anthropic MessageEvent SSE events. // The clientCtx is used to detect client disconnection and abort early. @@ -389,7 +485,8 @@ func (h *StreamHandler) processSSELine( } // Handle text content deltas - if choice.Delta.Content != "" { + textDelta := contentAsString(choice.Delta.Content) + if textDelta != "" { if !*contentStarted { // If reasoning was already started, close it first if *reasoningStarted { @@ -416,7 +513,7 @@ func (h *StreamHandler) processSSELine( delta := types.Delta{ Type: "text_delta", - Text: choice.Delta.Content, + Text: textDelta, } event := types.MessageEvent{ Type: "content_block_delta", diff --git a/internal/transformer/stream_test.go b/internal/transformer/stream_test.go index a8526b9..1eabf6b 100644 --- a/internal/transformer/stream_test.go +++ b/internal/transformer/stream_test.go @@ -63,6 +63,42 @@ func parseSSEEvents(t *testing.T, raw string) []types.MessageEvent { return events } +func TestEmitMessageResponse_SynthesizesAnthropicSSE(t *testing.T) { + handler := NewStreamHandler() + w := newMockResponseWriter() + resp := &types.MessageResponse{ + ID: "msg_test", + Type: "message", + Role: "assistant", + Model: "qwen3.6-plus", + StopReason: "end_turn", + Content: []types.ContentBlock{ + {Type: "text", Text: "Vedo uno screenshot."}, + }, + Usage: types.Usage{InputTokens: 10, OutputTokens: 4}, + } + + if err := handler.EmitMessageResponse(w, resp); err != nil { + t.Fatalf("EmitMessageResponse error: %v", err) + } + events := parseSSEEvents(t, w.buf.String()) + if len(events) != 6 { + t.Fatalf("events = %d, want 6: %+v", len(events), events) + } + if events[0].Type != "message_start" { + t.Fatalf("event[0] = %s, want message_start", events[0].Type) + } + if events[2].Type != "content_block_delta" || events[2].Delta.Type != "text_delta" { + t.Fatalf("event[2] = %+v, want text_delta", events[2]) + } + if got, want := events[2].Delta.Text, "Vedo uno screenshot."; got != want { + t.Fatalf("text delta = %q, want %q", got, want) + } + if events[4].Type != "message_delta" || events[5].Type != "message_stop" { + t.Fatalf("tail events = %+v %+v, want message_delta/message_stop", events[4], events[5]) + } +} + func TestProxyStream_ReasoningContentFastPath(t *testing.T) { handler := NewStreamHandler() w := newMockResponseWriter() @@ -209,6 +245,36 @@ func TestProxyStream_TextOnlyStillWorks(t *testing.T) { } } +func TestProxyStream_ContentArrayTextDelta(t *testing.T) { + handler := NewStreamHandler() + w := newMockResponseWriter() + body := sseLines( + `{"choices":[{"delta":{"content":[{"type":"text","text":"Vedo uno screenshot."}]}}]}`, + `{"choices":[{"delta":{},"finish_reason":"stop"}]}`, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := handler.ProxyStream(w, body, "qwen3.6-plus", ctx); err != nil { + t.Fatalf("ProxyStream error: %v", err) + } + + events := parseSSEEvents(t, w.buf.String()) + if len(events) != 6 { + t.Fatalf("expected 6 events, got %d: %+v", len(events), events) + } + if events[1].Type != "content_block_start" || events[1].ContentBlock == nil || events[1].ContentBlock.Type != "text" { + t.Errorf("event[1] = %+v, want content_block_start(text)", events[1]) + } + if events[2].Type != "content_block_delta" || events[2].Delta.Type != "text_delta" { + t.Errorf("event[2] = %+v, want content_block_delta(text_delta)", events[2]) + } + if got, want := events[2].Delta.Text, "Vedo uno screenshot."; got != want { + t.Errorf("event[2].Delta.Text = %q, want %q", got, want) + } +} + func TestProxyStream_UsageOnlyChunk(t *testing.T) { handler := NewStreamHandler() w := newMockResponseWriter() diff --git a/pkg/types/anthropic.go b/pkg/types/anthropic.go index cfbe446..5758ba2 100644 --- a/pkg/types/anthropic.go +++ b/pkg/types/anthropic.go @@ -166,6 +166,7 @@ type ImageSource struct { Type string `json:"type"` MediaType string `json:"media_type"` Data string `json:"data"` + URL string `json:"url,omitempty"` } // Tool represents a tool definition for function calling. diff --git a/pkg/types/openai.go b/pkg/types/openai.go index ea23569..8463b4d 100644 --- a/pkg/types/openai.go +++ b/pkg/types/openai.go @@ -30,7 +30,7 @@ type StreamOptions struct { // ChatMessage represents a single message in the conversation. type ChatMessage struct { Role string `json:"role"` - Content string `json:"content"` + Content interface{} `json:"content"` ReasoningContent *string `json:"reasoning_content,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` Name string `json:"name,omitempty"` @@ -38,6 +38,16 @@ type ChatMessage struct { CacheControl *CacheControl `json:"cache_control,omitempty"` } +type ContentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +type ImageURL struct { + URL string `json:"url"` +} + // ToolCall represents a function call made by the model. // Index is only present in streaming deltas — it identifies which tool call // position this delta belongs to within the tool_calls array.