diff --git a/README.md b/README.md index 2dccfa7..94b7f00 100644 --- a/README.md +++ b/README.md @@ -329,6 +329,7 @@ Config file: `~/.opencodereview/config.json` | `llm.auth_token` | string | `sk-xxxxxxx` | | `llm.model` | string | `claude-opus-4-6` | | `llm.use_anthropic` | boolean | `true` \| `false` | +| `llm.use_max_completion_tokens` | boolean | `true` = use `max_completion_tokens`, `false` = use `max_tokens` (default) | | `language` | string | `English` \| `Chinese` (default: Chinese) | | `telemetry.enabled` | boolean | `true` \| `false` | | `telemetry.exporter` | string | `console` \| `otlp` | @@ -345,6 +346,7 @@ Environment variables take precedence over the config file. | `OCR_LLM_TOKEN` | API key / auth token | | `OCR_LLM_MODEL` | Model name | | `OCR_USE_ANTHROPIC` | `true` = Anthropic, `false` = OpenAI | +| `OCR_USE_MAX_COMPLETION_TOKENS` | `true` = use `max_completion_tokens` param instead of `max_tokens` | ## Telemetry diff --git a/README.zh-CN.md b/README.zh-CN.md index 04b2b3a..b4d5a42 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -319,6 +319,7 @@ OCR 通过四层优先级链解析评审规则。每层采用首次匹配原则 | `llm.auth_token` | string | `sk-xxxxxxx` | | `llm.model` | string | `claude-opus-4-6` | | `llm.use_anthropic` | boolean | `true` \| `false` | +| `llm.use_max_completion_tokens` | boolean | `true` = 使用 `max_completion_tokens`,`false` = 使用 `max_tokens`(默认) | | `language` | string | `English` \| `Chinese`(默认:Chinese) | | `telemetry.enabled` | boolean | `true` \| `false` | | `telemetry.exporter` | string | `console` \| `otlp` | @@ -335,6 +336,7 @@ OCR 通过四层优先级链解析评审规则。每层采用首次匹配原则 | `OCR_LLM_TOKEN` | API 密钥 / 认证令牌 | | `OCR_LLM_MODEL` | 模型名称 | | `OCR_USE_ANTHROPIC` | `true` = Anthropic,`false` = OpenAI | +| `OCR_USE_MAX_COMPLETION_TOKENS` | `true` = 使用 `max_completion_tokens` 参数代替 `max_tokens` | ## 遥测 diff --git a/cmd/opencodereview/config_cmd.go b/cmd/opencodereview/config_cmd.go index 14bd4ac..ea96731 100644 --- a/cmd/opencodereview/config_cmd.go +++ b/cmd/opencodereview/config_cmd.go @@ -77,11 +77,12 @@ type Config struct { } type LlmConfig struct { - URL string `json:"url,omitempty"` - AuthToken string `json:"auth_token,omitempty"` - Model string `json:"model,omitempty"` - UseAnthropic *bool `json:"use_anthropic,omitempty"` // nil = default true; false = OpenAI protocol - ExtraBody map[string]any `json:"extra_body,omitempty"` + URL string `json:"url,omitempty"` + AuthToken string `json:"auth_token,omitempty"` + Model string `json:"model,omitempty"` + UseAnthropic *bool `json:"use_anthropic,omitempty"` // nil = default true; false = OpenAI protocol + UseMaxCompletionTokens *bool `json:"use_max_completion_tokens,omitempty"` // nil = default false; true = use max_completion_tokens + ExtraBody map[string]any `json:"extra_body,omitempty"` } // TelemetryConfig holds telemetry-specific settings. @@ -137,6 +138,12 @@ func setConfigValue(cfg *Config, key, value string) error { return fmt.Errorf("invalid boolean for llm.use_anthropic: %w", err) } cfg.Llm.UseAnthropic = &b + case "llm.use_max_completion_tokens", "llm.UseMaxCompletionTokens": + b, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid boolean for llm.use_max_completion_tokens: %w", err) + } + cfg.Llm.UseMaxCompletionTokens = &b case "language", "Language": cfg.Language = value case "telemetry.enabled", "telemetry.Enabled": @@ -166,7 +173,7 @@ func setConfigValue(cfg *Config, key, value string) error { } cfg.Llm.ExtraBody = m default: - return fmt.Errorf("unknown config key: %s\nSupported keys: llm.url, llm.auth_token, llm.model, llm.use_anthropic, llm.extra_body, language, telemetry.enabled, telemetry.exporter, telemetry.otlp_endpoint, telemetry.content_logging", key) + return fmt.Errorf("unknown config key: %s\nSupported keys: llm.url, llm.auth_token, llm.model, llm.use_anthropic, llm.use_max_completion_tokens, llm.extra_body, language, telemetry.enabled, telemetry.exporter, telemetry.otlp_endpoint, telemetry.content_logging", key) } return nil } diff --git a/cmd/opencodereview/config_cmd_test.go b/cmd/opencodereview/config_cmd_test.go new file mode 100644 index 0000000..3d15795 --- /dev/null +++ b/cmd/opencodereview/config_cmd_test.go @@ -0,0 +1,146 @@ +package main + +import ( + "testing" +) + +func boolVal(b bool) *bool { return &b } + +func TestSetConfigValue_StringKeys(t *testing.T) { + tests := []struct { + key string + value string + checkFn func(*Config) string + }{ + {"llm.url", "https://api.example.com/v1", func(c *Config) string { return c.Llm.URL }}, + {"llm.URL", "https://api.example.com/v2", func(c *Config) string { return c.Llm.URL }}, + {"llm.auth_token", "sk-token-123", func(c *Config) string { return c.Llm.AuthToken }}, + {"llm.AuthToken", "sk-token-456", func(c *Config) string { return c.Llm.AuthToken }}, + {"llm.model", "gpt-4", func(c *Config) string { return c.Llm.Model }}, + {"llm.Model", "claude-opus-4-6", func(c *Config) string { return c.Llm.Model }}, + {"language", "English", func(c *Config) string { return c.Language }}, + {"Language", "Chinese", func(c *Config) string { return c.Language }}, + {"telemetry.exporter", "otlp", func(c *Config) string { return c.Telemetry.Exporter }}, + {"telemetry.Exporter", "console", func(c *Config) string { return c.Telemetry.Exporter }}, + {"telemetry.otlp_endpoint", "localhost:4317", func(c *Config) string { return c.Telemetry.OTLPEndpoint }}, + {"telemetry.OTLPEndpoint", "collector:4317", func(c *Config) string { return c.Telemetry.OTLPEndpoint }}, + } + + for _, tt := range tests { + t.Run(tt.key+"="+tt.value, func(t *testing.T) { + cfg := &Config{} + if err := setConfigValue(cfg, tt.key, tt.value); err != nil { + t.Fatalf("unexpected error: %v", err) + } + got := tt.checkFn(cfg) + if got != tt.value { + t.Errorf("got %q, want %q", got, tt.value) + } + }) + } +} + +func TestSetConfigValue_BoolKeys(t *testing.T) { + tests := []struct { + key string + value string + want bool + checkFn func(*Config) bool + }{ + {"llm.use_anthropic", "true", true, func(c *Config) bool { return *c.Llm.UseAnthropic }}, + {"llm.UseAnthropic", "false", false, func(c *Config) bool { return *c.Llm.UseAnthropic }}, + {"llm.use_max_completion_tokens", "true", true, func(c *Config) bool { return *c.Llm.UseMaxCompletionTokens }}, + {"llm.UseMaxCompletionTokens", "false", false, func(c *Config) bool { return *c.Llm.UseMaxCompletionTokens }}, + {"telemetry.enabled", "true", true, func(c *Config) bool { return c.Telemetry.Enabled }}, + {"telemetry.Enabled", "false", false, func(c *Config) bool { return c.Telemetry.Enabled }}, + {"telemetry.content_logging", "true", true, func(c *Config) bool { return c.Telemetry.ContentLog }}, + {"telemetry.ContentLog", "false", false, func(c *Config) bool { return c.Telemetry.ContentLog }}, + } + + for _, tt := range tests { + t.Run(tt.key+"="+tt.value, func(t *testing.T) { + cfg := &Config{} + if err := setConfigValue(cfg, tt.key, tt.value); err != nil { + t.Fatalf("unexpected error: %v", err) + } + got := tt.checkFn(cfg) + if got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} + +func TestSetConfigValue_BoolKeys_InvalidValue(t *testing.T) { + keys := []string{ + "llm.use_anthropic", + "llm.use_max_completion_tokens", + "telemetry.enabled", + "telemetry.content_logging", + } + + for _, key := range keys { + t.Run(key, func(t *testing.T) { + cfg := &Config{} + err := setConfigValue(cfg, key, "not-a-bool") + if err == nil { + t.Fatal("expected error for invalid boolean value") + } + }) + } +} + +func TestSetConfigValue_ExtraBody(t *testing.T) { + tests := []struct { + name string + key string + value string + wantErr bool + }{ + {"valid JSON", "llm.extra_body", `{"thinking":{"type":"disabled"}}`, false}, + {"alias key", "llm.ExtraBody", `{"key":"value"}`, false}, + {"invalid JSON", "llm.extra_body", `not json`, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &Config{} + err := setConfigValue(cfg, tt.key, tt.value) + if tt.wantErr { + if err == nil { + t.Fatal("expected error for invalid JSON") + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Llm.ExtraBody == nil { + t.Fatal("expected non-nil ExtraBody") + } + } + }) + } +} + +func TestSetConfigValue_UnknownKey(t *testing.T) { + cfg := &Config{} + err := setConfigValue(cfg, "unknown.key", "value") + if err == nil { + t.Fatal("expected error for unknown key") + } +} + +func TestSetConfigValue_TelemetryInitialized(t *testing.T) { + cfg := &Config{} + if cfg.Telemetry != nil { + t.Fatal("telemetry should be nil initially") + } + + if err := setConfigValue(cfg, "telemetry.enabled", "true"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.Telemetry == nil { + t.Fatal("telemetry should be initialized after setting telemetry key") + } +} diff --git a/cmd/opencodereview/flags.go b/cmd/opencodereview/flags.go index 8a73d6d..42b5a87 100644 --- a/cmd/opencodereview/flags.go +++ b/cmd/opencodereview/flags.go @@ -254,5 +254,5 @@ Examples: ocr config set language English ocr config set telemetry.enabled true -Supported keys: llm.url, llm.auth_token, llm.model, llm.use_anthropic, llm.extra_body, language, telemetry.enabled, telemetry.exporter, telemetry.otlp_endpoint, telemetry.content_logging`) +Supported keys: llm.url, llm.auth_token, llm.model, llm.use_anthropic, llm.use_max_completion_tokens, llm.extra_body, language, telemetry.enabled, telemetry.exporter, telemetry.otlp_endpoint, telemetry.content_logging`) } diff --git a/go.mod b/go.mod index 16e5cf7..77bec62 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module github.com/open-code-review/open-code-review go 1.25.0 require ( + github.com/anthropics/anthropic-sdk-go v1.46.0 github.com/bmatcuk/doublestar/v4 v4.10.0 + github.com/openai/openai-go/v3 v3.38.0 github.com/pkoukk/tiktoken-go v0.1.8 go.opentelemetry.io/otel v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 @@ -17,6 +19,8 @@ require ( ) require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.2 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect @@ -24,14 +28,24 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/standard-webhooks/standard-webhooks/libraries v0.0.1 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect go.opentelemetry.io/proto/otlp v1.10.0 // indirect golang.org/x/net v0.52.0 // indirect + golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect google.golang.org/grpc v1.80.0 // indirect google.golang.org/protobuf v1.36.11 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 1f8f50e..ead072a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,11 @@ +github.com/anthropics/anthropic-sdk-go v1.46.0 h1:yl3n+el5ZfNgiCtQ7zQ7s/NXxB11YbrKXdc3uLPNWlU= +github.com/anthropics/anthropic-sdk-go v1.46.0/go.mod h1:bx5vWuHFuGPkELH8Z4KUiNSohFnUwScdpTyr+50myPo= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/bmatcuk/doublestar/v4 v4.10.0 h1:zU9WiOla1YA122oLM6i4EXvGW62DvKZVxIe6TYWexEs= github.com/bmatcuk/doublestar/v4 v4.10.0/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= +github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk= +github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -8,6 +14,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= +github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -21,12 +29,39 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/openai/openai-go/v3 v3.38.0 h1:Kre0Fz9mPUxtWjRB/CoNBHflp9W7FkztOm/XMDr/A3E= +github.com/openai/openai-go/v3 v3.38.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo= github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.1 h1:uOfcYT+3QungH6tIGSVCR/Y3KJmgJiHcojJbMTPDZAI= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.1/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= @@ -55,6 +90,8 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= @@ -69,5 +106,10 @@ google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index ca8d676..f0b45b6 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -661,9 +661,11 @@ func (a *Agent) executePlanPhase(ctx context.Context, newPath, rawDiff, changeFi } rec.SetResponse(resp, time.Since(startTime)) if resp.Usage != nil { - atomic.AddInt64(&a.totalTokensUsed, int64(resp.Usage.TotalTokens)) - atomic.AddInt64(&a.totalInputTokens, int64(resp.Usage.PromptTokens+resp.Usage.CacheReadTokens)) - atomic.AddInt64(&a.totalOutputTokens, int64(resp.Usage.CompletionTokens+resp.Usage.CacheWriteTokens)) + input := int64(resp.Usage.PromptTokens + resp.Usage.CacheReadTokens) + output := int64(resp.Usage.CompletionTokens + resp.Usage.CacheWriteTokens) + atomic.AddInt64(&a.totalTokensUsed, input+output) + atomic.AddInt64(&a.totalInputTokens, input) + atomic.AddInt64(&a.totalOutputTokens, output) } fmt.Fprintf(stdout.Writer(), "[ocr] Plan completed for %s\n", newPath) return resp.Content(), nil @@ -740,11 +742,13 @@ func (a *Agent) performLlmCodeReview(ctx context.Context, messages []llm.Message } rec.SetResponse(resp, duration) // Record LLM metrics with token info from API response usage field. - totalTokens := int64(0) + var totalTokens int64 if resp.Usage != nil { - totalTokens = resp.Usage.TotalTokens - atomic.AddInt64(&a.totalInputTokens, int64(resp.Usage.PromptTokens+resp.Usage.CacheReadTokens)) - atomic.AddInt64(&a.totalOutputTokens, int64(resp.Usage.CompletionTokens+resp.Usage.CacheWriteTokens)) + input := int64(resp.Usage.PromptTokens + resp.Usage.CacheReadTokens) + output := int64(resp.Usage.CompletionTokens + resp.Usage.CacheWriteTokens) + totalTokens = input + output + atomic.AddInt64(&a.totalInputTokens, input) + atomic.AddInt64(&a.totalOutputTokens, output) } telemetry.RecordLLMRequest(ctx, a.args.Model, duration, totalTokens, "ok") atomic.AddInt64(&a.totalTokensUsed, totalTokens) diff --git a/internal/diff/relocation_test.go b/internal/diff/relocation_test.go index 8b3a81b..8a2d706 100644 --- a/internal/diff/relocation_test.go +++ b/internal/diff/relocation_test.go @@ -24,7 +24,7 @@ func (m *mockLLMClient) CompletionsWithCtx(_ context.Context, req llm.ChatReques return m.response, m.err } -func (m *mockLLMClient) StreamCompletion(req llm.ChatRequest, cb func(chunk []byte) error) error { +func (m *mockLLMClient) StreamCompletion(_ context.Context, req llm.ChatRequest, cb func(chunk []byte) error) error { return m.err } diff --git a/internal/llm/client.go b/internal/llm/client.go index 4948ffa..f8fc34b 100644 --- a/internal/llm/client.go +++ b/internal/llm/client.go @@ -1,20 +1,21 @@ // Package llm provides LLM client interfaces supporting multiple protocols. // Supported protocols: Anthropic Messages API, OpenAI Chat Completions API. +// Implementations use the official SDKs: github.com/openai/openai-go and github.com/anthropics/anthropic-sdk-go. package llm import ( - "bufio" - "bytes" "context" "encoding/json" "fmt" - "io" - "math/rand" "net/http" "strings" "sync" "time" + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" + "github.com/openai/openai-go/v3" + openaioption "github.com/openai/openai-go/v3/option" tiktoken "github.com/pkoukk/tiktoken-go" "github.com/open-code-review/open-code-review/internal/stdout" @@ -36,7 +37,7 @@ func userAgent(provider string) string { type LLMClient interface { Completions(req ChatRequest) (*ChatResponse, error) CompletionsWithCtx(ctx context.Context, req ChatRequest) (*ChatResponse, error) - StreamCompletion(req ChatRequest, cb func(chunk []byte) error) error + StreamCompletion(ctx context.Context, req ChatRequest, cb func(chunk []byte) error) error } // --- Shared data types --- @@ -186,11 +187,12 @@ type FunctionDef struct { // ClientConfig holds configuration for connecting to an LLM service. type ClientConfig struct { - URL string // Full API endpoint URL - APIKey string // Bearer token / API key - Model string // Default model override - Timeout time.Duration // Request timeout - ExtraBody map[string]any // Vendor-specific fields merged into every request body + URL string // Full API endpoint URL + APIKey string // Bearer token / API key + Model string // Default model override + Timeout time.Duration // Request timeout + ExtraBody map[string]any // Vendor-specific fields merged into every request body + UseMaxCompletionTokens bool // use max_completion_tokens instead of max_tokens } // --- Factory --- @@ -199,10 +201,11 @@ type ClientConfig struct { // protocol: "anthropic" -> AnthropicClient, anything else -> OpenAIClient. func NewLLMClient(ep ResolvedEndpoint) LLMClient { cfg := ClientConfig{ - URL: ep.URL, - APIKey: ep.Token, - Model: ep.Model, - ExtraBody: ep.ExtraBody, + URL: ep.URL, + APIKey: ep.Token, + Model: ep.Model, + ExtraBody: ep.ExtraBody, + UseMaxCompletionTokens: ep.UseMaxCompletionTokens, } if ep.Protocol == "anthropic" { return NewAnthropicClient(cfg) @@ -276,29 +279,43 @@ func encodingForModel(modelName string) string { } } +// ChatRequest represents the payload for a chat completion call. +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Tools []ToolDef `json:"tools,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` +} + // --- OpenAIClient --- -// OpenAIClient sends requests to an OpenAI-compatible chat completion API. +// OpenAIClient sends requests to an OpenAI-compatible chat completion API via the official SDK. type OpenAIClient struct { cfg ClientConfig - client *http.Client + client openai.Client } -// NewOpenAIClient creates a new OpenAI-compatible LLM client. +// NewOpenAIClient creates a new OpenAI-compatible LLM client using the official SDK. func NewOpenAIClient(cfg ClientConfig) *OpenAIClient { if cfg.Timeout <= 0 { cfg.Timeout = 5 * time.Minute } - baseURL := strings.TrimRight(cfg.URL, "/") - if !strings.HasSuffix(baseURL, "/chat/completions") { - cfg.URL = baseURL + "/chat/completions" - } - return &OpenAIClient{ - cfg: cfg, - client: &http.Client{ - Timeout: cfg.Timeout, - }, + + // Normalize URL: strip /chat/completions suffix since SDK appends it automatically. + baseURL := normalizeOpenAIBaseURL(cfg.URL) + + opts := []openaioption.RequestOption{ + openaioption.WithAPIKey(cfg.APIKey), + openaioption.WithBaseURL(baseURL), + openaioption.WithHeader("User-Agent", userAgent("")), + openaioption.WithMaxRetries(maxRetries), + openaioption.WithRequestTimeout(cfg.Timeout), } + + client := openai.NewClient(opts...) + return &OpenAIClient{cfg: cfg, client: client} } // NewClient is kept as an alias for backward compatibility during transition. @@ -306,16 +323,6 @@ func NewClient(cfg ClientConfig) *OpenAIClient { return NewOpenAIClient(cfg) } -// ChatRequest represents the payload for a chat completion call. -type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Tools []ToolDef `json:"tools,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` -} - // Completions sends a chat completion request and returns the parsed response. func (c *OpenAIClient) Completions(req ChatRequest) (*ChatResponse, error) { return c.CompletionsWithCtx(context.Background(), req) @@ -328,16 +335,19 @@ func (c *OpenAIClient) CompletionsWithCtx(ctx context.Context, req ChatRequest) model = c.cfg.Model } - var result *ChatResponse - err := c.withRetryCtx(ctx, func() error { - resp, err := c.doRequestCtx(ctx, model, req) - if err != nil { - return err - } - result = resp - return nil - }) - return result, err + params := c.buildParams(model, req) + reqOpts := c.buildExtraBodyOptions() + + // Capture raw HTTP response for headers + var httpResp *http.Response + reqOpts = append(reqOpts, openaioption.WithResponseInto(&httpResp)) + + completion, err := c.client.Chat.Completions.New(ctx, params, reqOpts...) + if err != nil { + return nil, fmt.Errorf("API error: %w", err) + } + + return c.convertResponse(completion, httpResp), nil } // GeneralRequest sends a simple chat request without or with optional tool calls. @@ -355,150 +365,229 @@ func (c *OpenAIClient) GeneralRequestWithCtx(ctx context.Context, messages []Mes } // StreamCompletion initiates a streaming chat completion. The callback is invoked per chunk. -func (c *OpenAIClient) StreamCompletion(req ChatRequest, cb func(chunk []byte) error) error { - req.Stream = true - +func (c *OpenAIClient) StreamCompletion(ctx context.Context, req ChatRequest, cb func(chunk []byte) error) error { model := req.Model if model == "" { model = c.cfg.Model } - return c.withRetry(func() error { - body := make(map[string]any) - b, _ := json.Marshal(req) - json.Unmarshal(b, &body) - body["model"] = model - for k, v := range c.cfg.ExtraBody { - body[k] = v - } + params := c.buildParams(model, req) + // Enable usage reporting in stream so the final chunk contains token stats. + params.StreamOptions = openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), + } + reqOpts := c.buildExtraBodyOptions() - payload, _ := json.Marshal(body) - httpReq, err := http.NewRequest(http.MethodPost, c.cfg.URL, bytes.NewReader(payload)) - if err != nil { - return fmt.Errorf("create request: %w", err) - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+c.cfg.APIKey) - httpReq.Header.Set("Accept", "text/event-stream") - httpReq.Header.Set("User-Agent", userAgent("")) - - resp, err := c.client.Do(httpReq) - if err != nil { - return fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() + stream := c.client.Chat.Completions.NewStreaming(ctx, params, reqOpts...) + defer stream.Close() - if isRetryableStatus(resp.StatusCode) { - bodyBytes, _ := io.ReadAll(resp.Body) - return fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + for stream.Next() { + chunk := stream.Current() + // Use RawJSON to preserve the original API response format for downstream consumers. + raw := chunk.RawJSON() + if raw == "" { + continue } - if resp.StatusCode >= 400 { - bodyBytes, _ := io.ReadAll(resp.Body) - return fmt.Errorf("API error %d: %s (non-retryable)", resp.StatusCode, string(bodyBytes)) + if err := cb([]byte(raw)); err != nil { + return err } + } + return stream.Err() +} - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - line := scanner.Text() - if !strings.HasPrefix(line, "data: ") { - continue - } - data := strings.TrimPrefix(line, "data: ") - if data == "[DONE]" { - break - } - if err := cb([]byte(data)); err != nil { - return err - } +// buildParams converts internal ChatRequest to OpenAI SDK params. +func (c *OpenAIClient) buildParams(model string, req ChatRequest) openai.ChatCompletionNewParams { + params := openai.ChatCompletionNewParams{ + Model: model, + Messages: convertMessagesToOpenAI(req.Messages), + } + + if len(req.Tools) > 0 { + params.Tools = convertToolsToOpenAI(req.Tools) + } + + if req.MaxTokens > 0 { + if c.cfg.UseMaxCompletionTokens { + params.MaxCompletionTokens = openai.Int(int64(req.MaxTokens)) + } else { + params.MaxTokens = openai.Int(int64(req.MaxTokens)) } - return scanner.Err() - }) -} + } -// doRequest builds and sends a non-streaming completion request, returning the parsed response. -func (c *OpenAIClient) doRequest(model string, req ChatRequest) (*ChatResponse, error) { - return c.doRequestCtx(context.Background(), model, req) + if req.Temperature != nil { + params.Temperature = openai.Float(*req.Temperature) + } + + return params } -// doRequestCtx builds and sends a non-streaming completion request with context support. -func (c *OpenAIClient) doRequestCtx(ctx context.Context, model string, req ChatRequest) (*ChatResponse, error) { - if model == "" { - model = c.cfg.Model +// buildExtraBodyOptions creates request options for ExtraBody fields. +func (c *OpenAIClient) buildExtraBodyOptions() []openaioption.RequestOption { + if len(c.cfg.ExtraBody) == 0 { + return nil } - req.Model = model - payload, err := mergeExtraBody(req, c.cfg.ExtraBody) - if err != nil { - return nil, fmt.Errorf("marshal request body: %w", err) + opts := make([]openaioption.RequestOption, 0, len(c.cfg.ExtraBody)) + for k, v := range c.cfg.ExtraBody { + opts = append(opts, openaioption.WithJSONSet(k, v)) } - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.URL, bytes.NewReader(payload)) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) + return opts +} + +// convertResponse maps the SDK ChatCompletion to our internal ChatResponse. +func (c *OpenAIClient) convertResponse(comp *openai.ChatCompletion, httpResp *http.Response) *ChatResponse { + if comp == nil { + return &ChatResponse{} } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+c.cfg.APIKey) - httpReq.Header.Set("User-Agent", userAgent("")) + choices := make([]Choice, 0, len(comp.Choices)) + for _, ch := range comp.Choices { + var content *string + if ch.Message.Content != "" { + s := ch.Message.Content + content = &s + } - resp, err := c.client.Do(httpReq) - if err != nil { - return nil, fmt.Errorf("request failed: %w", err) + toolCalls := make([]ToolCall, 0, len(ch.Message.ToolCalls)) + for _, tc := range ch.Message.ToolCalls { + toolCalls = append(toolCalls, ToolCall{ + ID: tc.ID, + Type: "function", + Function: FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } + + choices = append(choices, Choice{ + Message: ResponseMessage{ + Role: string(ch.Message.Role), + Content: content, + ToolCalls: toolCalls, + }, + FinishReason: string(ch.FinishReason), + }) } - defer resp.Body.Close() - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response body: %w", err) + var usage *UsageInfo + if comp.Usage.TotalTokens > 0 || comp.Usage.PromptTokens > 0 || comp.Usage.CompletionTokens > 0 { + cachedTokens := comp.Usage.PromptTokensDetails.CachedTokens + usage = &UsageInfo{ + TotalTokens: comp.Usage.TotalTokens, + PromptTokens: comp.Usage.PromptTokens - cachedTokens, + CompletionTokens: comp.Usage.CompletionTokens, + CacheReadTokens: cachedTokens, + } + if usage.TotalTokens == 0 { + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage.CacheReadTokens + } } - if resp.StatusCode >= 400 { - detail := extractErrorMessage(bodyBytes) - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, detail) + var headers http.Header + if httpResp != nil { + headers = httpResp.Header } - var apiResp struct { - ID string `json:"id"` - Model string `json:"model"` - Choices []Choice `json:"choices"` + return &ChatResponse{ + ID: comp.ID, + Model: comp.Model, + Choices: choices, + Headers: headers, + Usage: usage, } - if err := json.Unmarshal(bodyBytes, &apiResp); err != nil { - return nil, fmt.Errorf("decode response: %w", err) +} + +// convertMessagesToOpenAI maps internal Message slice to OpenAI SDK message params. +func convertMessagesToOpenAI(messages []Message) []openai.ChatCompletionMessageParamUnion { + result := make([]openai.ChatCompletionMessageParamUnion, 0, len(messages)) + for _, msg := range messages { + switch msg.Role { + case "system": + result = append(result, openai.SystemMessage(msg.ExtractText())) + case "user": + result = append(result, openai.UserMessage(msg.ExtractText())) + case "assistant": + if len(msg.ToolCalls) > 0 { + toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, 0, len(msg.ToolCalls)) + for _, tc := range msg.ToolCalls { + toolCalls = append(toolCalls, openai.ChatCompletionMessageToolCallUnionParam{ + OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ + ID: tc.ID, + Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }, + }) + } + text := msg.ExtractText() + param := openai.ChatCompletionAssistantMessageParam{ + ToolCalls: toolCalls, + } + if text != "" { + param.Content = openai.ChatCompletionAssistantMessageParamContentUnion{ + OfString: openai.String(text), + } + } + result = append(result, openai.ChatCompletionMessageParamUnion{ + OfAssistant: ¶m, + }) + } else { + result = append(result, openai.AssistantMessage(msg.ExtractText())) + } + case "tool": + result = append(result, openai.ToolMessage(msg.ExtractText(), msg.ToolCallID)) + } } + return result +} - return &ChatResponse{ - ID: apiResp.ID, - Model: apiResp.Model, - Choices: apiResp.Choices, - Headers: resp.Header, - Usage: resolveUsage(bodyBytes), - }, nil +// convertToolsToOpenAI maps internal ToolDef slice to OpenAI SDK tool params. +func convertToolsToOpenAI(tools []ToolDef) []openai.ChatCompletionToolUnionParam { + result := make([]openai.ChatCompletionToolUnionParam, 0, len(tools)) + for _, t := range tools { + result = append(result, openai.ChatCompletionFunctionTool(openai.FunctionDefinitionParam{ + Name: t.Function.Name, + Description: openai.String(t.Function.Description), + Parameters: openai.FunctionParameters(t.Function.Parameters), + })) + } + return result } -// --- AnthropicClient --- +// normalizeOpenAIBaseURL strips the /chat/completions suffix since the SDK appends it. +func normalizeOpenAIBaseURL(rawURL string) string { + u := strings.TrimRight(rawURL, "/") + u = strings.TrimSuffix(u, "/chat/completions") + return u +} -const anthropicVersion = "2023-06-01" +// --- AnthropicClient --- -// AnthropicClient implements the Anthropic Messages API. +// AnthropicClient implements the Anthropic Messages API via the official SDK. type AnthropicClient struct { cfg ClientConfig - client *http.Client + client anthropic.Client } -// NewAnthropicClient creates a new Anthropic Messages API client. +// NewAnthropicClient creates a new Anthropic Messages API client using the official SDK. func NewAnthropicClient(cfg ClientConfig) *AnthropicClient { if cfg.Timeout <= 0 { cfg.Timeout = 5 * time.Minute } - if !strings.HasSuffix(cfg.URL, "/v1/messages") && !strings.HasSuffix(cfg.URL, "/v1/messages/") { - baseURL := strings.TrimRight(cfg.URL, "/") - if !strings.HasSuffix(baseURL, "/v1/messages") { - cfg.URL = baseURL + "/v1/messages" - } - } - return &AnthropicClient{ - cfg: cfg, - client: &http.Client{ - Timeout: cfg.Timeout, - }, + + // Normalize URL: strip /v1/messages suffix since SDK appends it automatically. + baseURL := normalizeAnthropicBaseURL(cfg.URL) + + opts := []anthropicoption.RequestOption{ + anthropicoption.WithAPIKey(cfg.APIKey), + anthropicoption.WithBaseURL(baseURL), + anthropicoption.WithHeader("User-Agent", userAgent("claude")), + anthropicoption.WithMaxRetries(maxRetries), + anthropicoption.WithRequestTimeout(cfg.Timeout), } + + client := anthropic.NewClient(opts...) + return &AnthropicClient{cfg: cfg, client: client} } // Completions sends a chat completion request and returns the parsed response. @@ -513,343 +602,108 @@ func (c *AnthropicClient) CompletionsWithCtx(ctx context.Context, req ChatReques model = c.cfg.Model } - var result *ChatResponse - err := c.withRetryCtx(ctx, func() error { - resp, err := c.doRequestCtx(ctx, model, req) - if err != nil { - return err - } - result = resp - return nil - }) - return result, err -} + params := c.buildParams(model, req) + reqOpts := c.buildExtraBodyOptions() -// StreamCompletion initiates a streaming chat completion using SSE. The callback -// is invoked per chunk with raw JSON data stripped of the "data: " prefix. -func (c *AnthropicClient) StreamCompletion(req ChatRequest, cb func(chunk []byte) error) error { - req.Stream = true + // Capture raw HTTP response for headers + var httpResp *http.Response + reqOpts = append(reqOpts, anthropicoption.WithResponseInto(&httpResp)) - model := req.Model - if model == "" { - model = c.cfg.Model + message, err := c.client.Messages.New(ctx, params, reqOpts...) + if err != nil { + return nil, fmt.Errorf("API error: %w", err) } - return c.withRetry(func() error { - body := c.buildRequestBody(model, req) - body.Stream = true - - payload, err := mergeExtraBody(body, c.cfg.ExtraBody) - if err != nil { - return fmt.Errorf("marshal request body: %w", err) - } - - httpReq, err := http.NewRequest(http.MethodPost, c.cfg.URL, bytes.NewReader(payload)) - if err != nil { - return fmt.Errorf("create request: %w", err) - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("x-api-key", c.cfg.APIKey) - httpReq.Header.Set("anthropic-version", anthropicVersion) - httpReq.Header.Set("User-Agent", userAgent("claude")) - - resp, err := c.client.Do(httpReq) - if err != nil { - return fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - if isRetryableStatus(resp.StatusCode) { - bodyBytes, _ := io.ReadAll(resp.Body) - return fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) - } - if resp.StatusCode >= 400 { - bodyBytes, _ := io.ReadAll(resp.Body) - return fmt.Errorf("API error %d: %s (non-retryable)", resp.StatusCode, string(bodyBytes)) - } - - scanner := bufio.NewScanner(resp.Body) - var eventType string - - for scanner.Scan() { - line := scanner.Text() - - // Capture event type line: "event: message_delta" - if strings.HasPrefix(line, "event: ") { - eventType = strings.TrimPrefix(line, "event: ") - continue - } - - // Skip empty lines and non-data lines - if !strings.HasPrefix(line, "data: ") { - continue - } - - data := strings.TrimPrefix(line, "data: ") - if data == "" { - continue - } - - // message_stop signals end of stream - if eventType == "message_stop" { - break - } - - if err := cb([]byte(data)); err != nil { - return err - } - } - return scanner.Err() - }) -} - -// anthropicRequest is the request body for Anthropic Messages API. -type anthropicRequest struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - System string `json:"system,omitempty"` - Messages []anthroMessage `json:"messages"` - Tools []anthroTool `json:"tools,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` -} - -type anthroMessage struct { - Role string `json:"role"` - Content any `json:"content"` // string or []interface{} -} - -// anthropicToolUseBlock represents a tool_use content block in Anthropic's Messages API. -type anthropicToolUseBlock struct { - Type string `json:"type"` // "tool_use" - ID string `json:"id"` // tool use ID - Name string `json:"name"` // function name - Input map[string]any `json:"input"` // function arguments (parsed as object) -} - -type anthroTool struct { - Name string `json:"name"` - Description string `json:"description"` - InputSchema map[string]any `json:"input_schema"` + return c.convertResponse(message, httpResp), nil } -// doRequestCtx builds and sends an Anthropic Messages API request. -func (c *AnthropicClient) doRequestCtx(ctx context.Context, model string, req ChatRequest) (*ChatResponse, error) { +// StreamCompletion initiates a streaming chat completion using SSE. +func (c *AnthropicClient) StreamCompletion(ctx context.Context, req ChatRequest, cb func(chunk []byte) error) error { + model := req.Model if model == "" { model = c.cfg.Model } - body := c.buildRequestBody(model, req) - payload, err := mergeExtraBody(body, c.cfg.ExtraBody) - if err != nil { - return nil, fmt.Errorf("marshal request body: %w", err) - } + params := c.buildParams(model, req) + reqOpts := c.buildExtraBodyOptions() - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.URL, bytes.NewReader(payload)) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("x-api-key", c.cfg.APIKey) - httpReq.Header.Set("anthropic-version", anthropicVersion) - httpReq.Header.Set("User-Agent", userAgent("claude")) + stream := c.client.Messages.NewStreaming(ctx, params, reqOpts...) + defer stream.Close() - resp, err := c.client.Do(httpReq) - if err != nil { - return nil, fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response body: %w", err) - } - - if resp.StatusCode >= 400 { - detail := extractErrorMessage(bodyBytes) - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, detail) - } - - chatResp, err := c.parseResponse(bodyBytes, resp.Header) - if err != nil { - return nil, fmt.Errorf("decode response: %w", err) + for stream.Next() { + event := stream.Current() + // Use RawJSON to preserve the original API response format for downstream consumers. + raw := event.RawJSON() + if raw == "" { + continue + } + if err := cb([]byte(raw)); err != nil { + return err + } } - return chatResp, nil + return stream.Err() } -// buildRequestBody converts the shared ChatRequest into Anthropic format. -func (c *AnthropicClient) buildRequestBody(model string, req ChatRequest) anthropicRequest { - messages := make([]anthroMessage, 0, len(req.Messages)) - var systemMsg string +// buildParams converts internal ChatRequest to Anthropic SDK params. +func (c *AnthropicClient) buildParams(model string, req ChatRequest) anthropic.MessageNewParams { + system, messages := convertMessagesToAnthropic(req.Messages) - var pendingToolResults []Message // collect consecutive tool messages + maxTokens := int64(req.MaxTokens) + if maxTokens <= 0 { + maxTokens = 8192 + } - flushToolResults := func() { - if len(pendingToolResults) == 0 { - return - } - // Merge all pending tool results into a single user message - var blocks []interface{} - for _, tr := range pendingToolResults { - blocks = append(blocks, ContentBlock{ - Type: "tool_result", - ToolUseID: tr.ToolCallID, - Content: []ContentBlock{{ - Type: "text", - Text: fmt.Sprintf("%v", tr.Content), - }}, - }) - } - messages = append(messages, anthroMessage{Role: "user", Content: blocks}) - pendingToolResults = nil + params := anthropic.MessageNewParams{ + Model: model, + MaxTokens: maxTokens, + Messages: messages, } - for _, msg := range req.Messages { - switch msg.Role { - case "system": - if s, ok := msg.Content.(string); ok { - systemMsg = s - } - flushToolResults() - case "tool": - pendingToolResults = append(pendingToolResults, msg) - case "assistant": - flushToolResults() - // Build Anthropic content blocks from text + tool calls - var blocks []interface{} - if s, ok := msg.Content.(string); ok && s != "" { - blocks = append(blocks, ContentBlock{Type: "text", Text: s}) - } - for _, tc := range msg.ToolCalls { - argsMap := map[string]any{} - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &argsMap); err != nil { - fmt.Fprintf(stdout.Writer(), "[llm] WARNING: failed to parse tool call arguments JSON for %q: %v\n", tc.ID, err) - } - } - blocks = append(blocks, anthropicToolUseBlock{ - Type: "tool_use", - ID: tc.ID, - Name: tc.Function.Name, - Input: argsMap, - }) - } - if len(blocks) > 0 { - messages = append(messages, anthroMessage{Role: "assistant", Content: blocks}) - } else { - s, _ := msg.Content.(string) - messages = append(messages, anthroMessage{Role: "assistant", Content: s}) - } - default: - // user or other roles: flush tool results first - flushToolResults() - content := msg.Content - if blkArr, ok := content.([]ContentBlock); ok { - converted := make([]ContentBlock, len(blkArr)) - for i, b := range blkArr { - converted[i] = ContentBlock{ - Type: b.Type, - Text: b.Text, - ToolUseID: b.ToolUseID, - Content: b.Content, - } - } - content = converted - } - messages = append(messages, anthroMessage{Role: msg.Role, Content: content}) + if system != "" { + params.System = []anthropic.TextBlockParam{ + {Text: system}, } } - flushToolResults() // flush any remaining tool results at the end - tools := make([]anthroTool, 0, len(req.Tools)) - for _, t := range req.Tools { - tools = append(tools, anthroTool{ - Name: t.Function.Name, - Description: t.Function.Description, - InputSchema: t.Function.Parameters, - }) + if len(req.Tools) > 0 { + params.Tools = convertToolsToAnthropic(req.Tools) } - maxTokens := req.MaxTokens - if maxTokens <= 0 { - maxTokens = 8192 // Anthropic default + if req.Temperature != nil { + params.Temperature = anthropic.Float(*req.Temperature) } - return anthropicRequest{ - Model: model, - MaxTokens: maxTokens, - System: systemMsg, - Messages: messages, - Tools: tools, - Stream: false, - Temperature: req.Temperature, - } + return params } -func mergeExtraBody(base any, extraBody map[string]any) ([]byte, error) { - if len(extraBody) == 0 { - return json.Marshal(base) - } - b, err := json.Marshal(base) - if err != nil { - return nil, err - } - var m map[string]any - if err := json.Unmarshal(b, &m); err != nil { - return nil, err +// buildExtraBodyOptions creates request options for ExtraBody fields. +func (c *AnthropicClient) buildExtraBodyOptions() []anthropicoption.RequestOption { + if len(c.cfg.ExtraBody) == 0 { + return nil } - for k, v := range extraBody { - m[k] = v + opts := make([]anthropicoption.RequestOption, 0, len(c.cfg.ExtraBody)) + for k, v := range c.cfg.ExtraBody { + opts = append(opts, anthropicoption.WithJSONSet(k, v)) } - return json.Marshal(m) + return opts } -// parseResponse converts Anthropic JSON response into ChatResponse. -func (c *AnthropicClient) parseResponse(body []byte, headers http.Header) (*ChatResponse, error) { - type contentBlockResp struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input any `json:"input,omitempty"` - } - - type anthropicUsageRaw struct { - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - CacheReadInputTokens int64 `json:"cache_read_input_tokens"` - CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"` - } - - var resp struct { - ID string `json:"id"` - Model string `json:"model"` - Type string `json:"type"` - Role string `json:"role"` - Content []contentBlockResp `json:"content"` - Usage anthropicUsageRaw `json:"usage"` - StopReason string `json:"stop_reason,omitempty"` - } - - if err := json.Unmarshal(body, &resp); err != nil { - return nil, err - } - - // Build the response message from content blocks. +// convertResponse maps Anthropic SDK Message to our internal ChatResponse. +func (c *AnthropicClient) convertResponse(msg *anthropic.Message, httpResp *http.Response) *ChatResponse { var textParts []string var toolCalls []ToolCall - for _, block := range resp.Content { - switch block.Type { - case "text": - textParts = append(textParts, block.Text) - case "tool_use": - argsJSON, _ := json.Marshal(block.Input) + for _, block := range msg.Content { + switch variant := block.AsAny().(type) { + case anthropic.TextBlock: + textParts = append(textParts, variant.Text) + case anthropic.ToolUseBlock: + argsJSON, _ := json.Marshal(variant.Input) toolCalls = append(toolCalls, ToolCall{ - ID: block.ID, + ID: variant.ID, Type: "function", Function: FunctionCall{ - Name: block.Name, + Name: variant.Name, Arguments: string(argsJSON), }, }) @@ -862,28 +716,33 @@ func (c *AnthropicClient) parseResponse(body []byte, headers http.Header) (*Chat contentStr = &s } - finishReason := resp.StopReason + finishReason := string(msg.StopReason) if finishReason == "" { finishReason = "stop" } var usage *UsageInfo - if u := resp.Usage; u.InputTokens > 0 || u.OutputTokens > 0 { + if msg.Usage.InputTokens > 0 || msg.Usage.OutputTokens > 0 || msg.Usage.CacheReadInputTokens > 0 || msg.Usage.CacheCreationInputTokens > 0 { usage = &UsageInfo{ - PromptTokens: u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens, - CompletionTokens: u.OutputTokens, - CacheReadTokens: u.CacheReadInputTokens, - CacheWriteTokens: u.CacheCreationInputTokens, + PromptTokens: msg.Usage.InputTokens, + CompletionTokens: msg.Usage.OutputTokens, + CacheReadTokens: msg.Usage.CacheReadInputTokens, + CacheWriteTokens: msg.Usage.CacheCreationInputTokens, } - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage.CacheReadTokens + usage.CacheWriteTokens + } + + var headers http.Header + if httpResp != nil { + headers = httpResp.Header } return &ChatResponse{ - ID: resp.ID, - Model: resp.Model, + ID: msg.ID, + Model: string(msg.Model), Choices: []Choice{{ Message: ResponseMessage{ - Role: resp.Role, + Role: string(msg.Role), Content: contentStr, ToolCalls: toolCalls, }, @@ -891,104 +750,119 @@ func (c *AnthropicClient) parseResponse(body []byte, headers http.Header) (*Chat }}, Headers: headers, Usage: usage, - }, nil + } } -// --- Retry logic --- - -func retryWithCtx(ctx context.Context, fn func() error) error { - var lastErr error - for attempt := 0; attempt <= maxRetries; attempt++ { - select { - case <-ctx.Done(): - return fmt.Errorf("context cancelled: %w", ctx.Err()) - default: - } +// convertMessagesToAnthropic separates system message and converts remaining messages. +func convertMessagesToAnthropic(messages []Message) (string, []anthropic.MessageParam) { + var systemMsg string + var result []anthropic.MessageParam + var pendingToolResults []Message - lastErr = fn() - if lastErr == nil { - return nil + flushToolResults := func() { + if len(pendingToolResults) == 0 { + return } - - if !isRetryable(lastErr) { - return lastErr + var blocks []anthropic.ContentBlockParamUnion + for _, tr := range pendingToolResults { + blocks = append(blocks, anthropic.NewToolResultBlock( + tr.ToolCallID, + tr.ExtractText(), + false, + )) } + result = append(result, anthropic.NewUserMessage(blocks...)) + pendingToolResults = nil + } - if attempt < maxRetries { - sleepWithBackoff(attempt) + for _, msg := range messages { + switch msg.Role { + case "system": + if s, ok := msg.Content.(string); ok { + systemMsg = s + } + flushToolResults() + case "tool": + pendingToolResults = append(pendingToolResults, msg) + case "assistant": + flushToolResults() + var blocks []anthropic.ContentBlockParamUnion + text := msg.ExtractText() + if text != "" { + blocks = append(blocks, anthropic.NewTextBlock(text)) + } + for _, tc := range msg.ToolCalls { + argsMap := map[string]any{} + if tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &argsMap); err != nil { + fmt.Fprintf(stdout.Writer(), "[llm] WARNING: failed to parse tool call arguments JSON for %q: %v\n", tc.ID, err) + } + } + inputJSON, _ := json.Marshal(argsMap) + blocks = append(blocks, anthropic.ContentBlockParamUnion{ + OfToolUse: &anthropic.ToolUseBlockParam{ + ID: tc.ID, + Name: tc.Function.Name, + Input: json.RawMessage(inputJSON), + }, + }) + } + if len(blocks) > 0 { + result = append(result, anthropic.NewAssistantMessage(blocks...)) + } + default: + // user or other roles + flushToolResults() + result = append(result, anthropic.NewUserMessage(anthropic.NewTextBlock(msg.ExtractText()))) } } - return fmt.Errorf("request failed after %d retries: %w", maxRetries, lastErr) -} - -func (c *OpenAIClient) withRetry(fn func() error) error { - return retryWithCtx(context.Background(), fn) -} + flushToolResults() -func (c *OpenAIClient) withRetryCtx(ctx context.Context, fn func() error) error { - return retryWithCtx(ctx, fn) + return systemMsg, result } -func (c *AnthropicClient) withRetry(fn func() error) error { - return retryWithCtx(context.Background(), fn) -} - -func (c *AnthropicClient) withRetryCtx(ctx context.Context, fn func() error) error { - return retryWithCtx(ctx, fn) -} - -// isRetryable determines whether an error is transient and worth retrying. -func isRetryable(err error) bool { - msg := err.Error() - // 429 (rate limit) and 5xx server errors are retryable. - if strings.Contains(msg, "API error 429:") { - return true - } - for code := 500; code <= 599; code++ { - if strings.Contains(msg, fmt.Sprintf("API error %d:", code)) { - return true +// convertToolsToAnthropic maps internal ToolDef slice to Anthropic SDK tool params. +func convertToolsToAnthropic(tools []ToolDef) []anthropic.ToolUnionParam { + result := make([]anthropic.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + schema := anthropic.ToolInputSchemaParam{ + Type: "object", + Properties: t.Function.Parameters["properties"], } + // Preserve required field constraints from the original schema. + if req, ok := t.Function.Parameters["required"]; ok { + if reqSlice, ok := req.([]any); ok { + required := make([]string, 0, len(reqSlice)) + for _, r := range reqSlice { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + schema.Required = required + } + } + result = append(result, anthropic.ToolUnionParam{ + OfTool: &anthropic.ToolParam{ + Name: t.Function.Name, + Description: anthropic.String(t.Function.Description), + InputSchema: schema, + }, + }) } - // Network-level errors (timeout, connection refused, DNS failure, etc.) are retryable. - if strings.Contains(msg, "request failed:") || - strings.Contains(msg, "connection refused") || - strings.Contains(msg, "no such host") || - strings.Contains(msg, "i/o timeout") || - strings.Contains(msg, "EOF") { - return true - } - return false + return result } -// isRetryableStatus returns true for HTTP status codes that should trigger a retry. -func isRetryableStatus(status int) bool { - return status == 429 || (status >= 500 && status <= 599) +// normalizeAnthropicBaseURL strips the /v1/messages suffix since the SDK appends it. +func normalizeAnthropicBaseURL(rawURL string) string { + u := strings.TrimRight(rawURL, "/") + u = strings.TrimSuffix(u, "/v1/messages") + return u } -// sleepWithBackoff sleeps for baseDelay * 2^attempt + jitter, capped at 60s. -// Jitter spreads retries randomly within ±50% of the computed delay. -func sleepWithBackoff(attempt int) { - const ( - baseDelay = 1 * time.Second - maxDelay = 60 * time.Second - ) - - delay := baseDelay << uint(min(attempt, 6)) // 1s, 2s, 4s, 8s, 16s, 32s, 64s→capped - if delay > maxDelay { - delay = maxDelay - } - - // Add random jitter: [delay*0.5, delay*1.5] - jitter := time.Duration(rand.Int63n(int64(delay))) - delay/2 - delay += jitter - - fmt.Fprintf(stdout.Writer(), "[llm] Retrying in %v (attempt info)... \n", delay) - time.Sleep(delay) -} +// --- Utility functions --- // stripThinkTags removes reasoning wrapper tags from content. func stripThinkTags(s string) string { - // Construct tag strings from individual bytes. openBytes := []byte{0x3c, 't', 'h', 'i', 'n', 'k', 0x3e} closeBytes := []byte{0x3c, 0x2f, 't', 'h', 'i', 'n', 'k', 0x3e} s = strings.ReplaceAll(s, string(openBytes), "") @@ -1025,7 +899,6 @@ func extractErrorMessage(body []byte) string { return ae.Error.Message } - // Truncate raw body to avoid excessively noisy errors. bodyText := string(body) if len(bodyText) > 512 { bodyText = bodyText[:512] + "... (truncated)" diff --git a/internal/llm/client_test.go b/internal/llm/client_test.go index e3adc81..7ea9363 100644 --- a/internal/llm/client_test.go +++ b/internal/llm/client_test.go @@ -1,10 +1,14 @@ package llm import ( + "encoding/json" + "strings" "testing" ) -func TestNewOpenAIClient_URLNormalization(t *testing.T) { +// TestNormalizeOpenAIBaseURL verifies URL normalization for the OpenAI SDK. +// The SDK appends /chat/completions automatically, so we strip that suffix. +func TestNormalizeOpenAIBaseURL(t *testing.T) { tests := []struct { name string inputURL string @@ -13,41 +17,43 @@ func TestNewOpenAIClient_URLNormalization(t *testing.T) { { name: "base URL without trailing slash", inputURL: "https://api.example.com/v1", - wantURL: "https://api.example.com/v1/chat/completions", + wantURL: "https://api.example.com/v1", }, { name: "base URL with trailing slash", inputURL: "https://api.example.com/v1/", - wantURL: "https://api.example.com/v1/chat/completions", + wantURL: "https://api.example.com/v1", }, { - name: "full URL already has chat/completions", + name: "full URL with chat/completions suffix stripped", inputURL: "https://api.example.com/v1/chat/completions", - wantURL: "https://api.example.com/v1/chat/completions", + wantURL: "https://api.example.com/v1", }, { - name: "full URL with trailing slash", + name: "full URL with trailing slash stripped", inputURL: "https://api.example.com/v1/chat/completions/", - wantURL: "https://api.example.com/v1/chat/completions/", + wantURL: "https://api.example.com/v1", }, { - name: "bare host", + name: "bare host unchanged", inputURL: "https://api.example.com", - wantURL: "https://api.example.com/chat/completions", + wantURL: "https://api.example.com", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := NewOpenAIClient(ClientConfig{URL: tt.inputURL}) - if client.cfg.URL != tt.wantURL { - t.Errorf("got URL %q, want %q", client.cfg.URL, tt.wantURL) + got := normalizeOpenAIBaseURL(tt.inputURL) + if got != tt.wantURL { + t.Errorf("normalizeOpenAIBaseURL(%q) = %q, want %q", tt.inputURL, got, tt.wantURL) } }) } } -func TestNewAnthropicClient_URLNormalization(t *testing.T) { +// TestNormalizeAnthropicBaseURL verifies URL normalization for the Anthropic SDK. +// The SDK appends /v1/messages automatically, so we strip that suffix. +func TestNormalizeAnthropicBaseURL(t *testing.T) { tests := []struct { name string inputURL string @@ -56,36 +62,508 @@ func TestNewAnthropicClient_URLNormalization(t *testing.T) { { name: "bare host", inputURL: "https://api.anthropic.com", - wantURL: "https://api.anthropic.com/v1/messages", + wantURL: "https://api.anthropic.com", }, { name: "bare host with trailing slash", inputURL: "https://api.anthropic.com/", - wantURL: "https://api.anthropic.com/v1/messages", + wantURL: "https://api.anthropic.com", }, { - name: "full URL already has /v1/messages", + name: "full URL with /v1/messages suffix stripped", inputURL: "https://api.anthropic.com/v1/messages", - wantURL: "https://api.anthropic.com/v1/messages", + wantURL: "https://api.anthropic.com", }, { - name: "full URL with trailing slash", + name: "full URL with trailing slash stripped", inputURL: "https://api.anthropic.com/v1/messages/", - wantURL: "https://api.anthropic.com/v1/messages/", + wantURL: "https://api.anthropic.com", }, { - name: "custom proxy base URL", + name: "custom proxy base URL unchanged", inputURL: "https://proxy.example.com/anthropic", - wantURL: "https://proxy.example.com/anthropic/v1/messages", + wantURL: "https://proxy.example.com/anthropic", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeAnthropicBaseURL(tt.inputURL) + if got != tt.wantURL { + t.Errorf("normalizeAnthropicBaseURL(%q) = %q, want %q", tt.inputURL, got, tt.wantURL) + } + }) + } +} + +// TestNewOpenAIClient_CreatesSuccessfully verifies that the client is created without panics. +func TestNewOpenAIClient_CreatesSuccessfully(t *testing.T) { + client := NewOpenAIClient(ClientConfig{ + URL: "https://api.example.com/v1", + APIKey: "test-key", + Model: "gpt-4", + }) + if client == nil { + t.Fatal("expected non-nil client") + } +} + +// TestNewAnthropicClient_CreatesSuccessfully verifies that the client is created without panics. +func TestNewAnthropicClient_CreatesSuccessfully(t *testing.T) { + client := NewAnthropicClient(ClientConfig{ + URL: "https://api.anthropic.com", + APIKey: "test-key", + Model: "claude-sonnet-4-20250514", + }) + if client == nil { + t.Fatal("expected non-nil client") + } +} + +// TestBuildParams_MaxTokensBranching verifies that buildParams uses MaxTokens or +// MaxCompletionTokens based on the UseMaxCompletionTokens config flag. +func TestBuildParams_MaxTokensBranching(t *testing.T) { + tests := []struct { + name string + useMaxCompletionTokens bool + maxTokens int + wantMaxTokensSet bool + wantMaxCompletionSet bool + }{ + { + name: "default uses max_tokens", + useMaxCompletionTokens: false, + maxTokens: 4096, + wantMaxTokensSet: true, + wantMaxCompletionSet: false, + }, + { + name: "enabled uses max_completion_tokens", + useMaxCompletionTokens: true, + maxTokens: 4096, + wantMaxTokensSet: false, + wantMaxCompletionSet: true, + }, + { + name: "zero max_tokens sets neither", + useMaxCompletionTokens: true, + maxTokens: 0, + wantMaxTokensSet: false, + wantMaxCompletionSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewOpenAIClient(ClientConfig{ + URL: "https://api.example.com/v1", + APIKey: "test-key", + Model: "gpt-4", + UseMaxCompletionTokens: tt.useMaxCompletionTokens, + }) + + params := client.buildParams("gpt-4", ChatRequest{ + MaxTokens: tt.maxTokens, + }) + + maxTokensSet := params.MaxTokens.Valid() + maxCompletionSet := params.MaxCompletionTokens.Valid() + + if maxTokensSet != tt.wantMaxTokensSet { + t.Errorf("MaxTokens present = %v, want %v", maxTokensSet, tt.wantMaxTokensSet) + } + if maxCompletionSet != tt.wantMaxCompletionSet { + t.Errorf("MaxCompletionTokens present = %v, want %v", maxCompletionSet, tt.wantMaxCompletionSet) + } + }) + } +} + +// --- Tests for convertMessagesToOpenAI --- + +func TestConvertMessagesToOpenAI_BasicRoles(t *testing.T) { + messages := []Message{ + NewTextMessage("system", "You are a helpful assistant"), + NewTextMessage("user", "Hello"), + NewTextMessage("assistant", "Hi there"), + } + + result := convertMessagesToOpenAI(messages) + if len(result) != 3 { + t.Fatalf("expected 3 messages, got %d", len(result)) + } + + // Verify roles by checking which union variant is set + if result[0].OfSystem == nil { + t.Error("expected system message at index 0") + } + if result[1].OfUser == nil { + t.Error("expected user message at index 1") + } + if result[2].OfAssistant == nil { + t.Error("expected assistant message at index 2") + } +} + +func TestConvertMessagesToOpenAI_ToolMessage(t *testing.T) { + messages := []Message{ + NewToolResultMessage("call-123", "tool result content"), + } + + result := convertMessagesToOpenAI(messages) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d", len(result)) + } + if result[0].OfTool == nil { + t.Error("expected tool message") + } + if result[0].OfTool.ToolCallID != "call-123" { + t.Errorf("expected tool_call_id %q, got %q", "call-123", result[0].OfTool.ToolCallID) + } +} + +func TestConvertMessagesToOpenAI_AssistantWithToolCalls(t *testing.T) { + messages := []Message{ + NewToolCallMessage("thinking...", []ToolCall{ + { + ID: "call-1", + Type: "function", + Function: FunctionCall{ + Name: "get_weather", + Arguments: `{"city":"Tokyo"}`, + }, + }, + }), + } + + result := convertMessagesToOpenAI(messages) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d", len(result)) + } + if result[0].OfAssistant == nil { + t.Fatal("expected assistant message") + } + if len(result[0].OfAssistant.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(result[0].OfAssistant.ToolCalls)) + } + tc := result[0].OfAssistant.ToolCalls[0] + if tc.OfFunction == nil { + t.Fatal("expected function tool call") + } + if tc.OfFunction.ID != "call-1" { + t.Errorf("expected ID %q, got %q", "call-1", tc.OfFunction.ID) + } + if tc.OfFunction.Function.Name != "get_weather" { + t.Errorf("expected function name %q, got %q", "get_weather", tc.OfFunction.Function.Name) + } +} + +func TestConvertMessagesToOpenAI_EmptyMessages(t *testing.T) { + result := convertMessagesToOpenAI(nil) + if len(result) != 0 { + t.Fatalf("expected 0 messages, got %d", len(result)) + } +} + +// --- Tests for convertMessagesToAnthropic --- + +func TestConvertMessagesToAnthropic_SystemExtraction(t *testing.T) { + messages := []Message{ + NewTextMessage("system", "Be concise"), + NewTextMessage("user", "Hello"), + } + + system, result := convertMessagesToAnthropic(messages) + if system != "Be concise" { + t.Errorf("expected system %q, got %q", "Be concise", system) + } + if len(result) != 1 { + t.Fatalf("expected 1 message (user only), got %d", len(result)) + } +} + +func TestConvertMessagesToAnthropic_NoSystem(t *testing.T) { + messages := []Message{ + NewTextMessage("user", "Hello"), + NewTextMessage("assistant", "Hi"), + } + + system, result := convertMessagesToAnthropic(messages) + if system != "" { + t.Errorf("expected empty system, got %q", system) + } + if len(result) != 2 { + t.Fatalf("expected 2 messages, got %d", len(result)) + } +} + +func TestConvertMessagesToAnthropic_ToolResultBatching(t *testing.T) { + // Multiple consecutive tool results should be batched into a single user message + messages := []Message{ + NewTextMessage("user", "Do tasks"), + NewToolCallMessage("", []ToolCall{ + {ID: "call-1", Type: "function", Function: FunctionCall{Name: "fn1", Arguments: "{}"}}, + {ID: "call-2", Type: "function", Function: FunctionCall{Name: "fn2", Arguments: "{}"}}, + }), + NewToolResultMessage("call-1", "result 1"), + NewToolResultMessage("call-2", "result 2"), + NewTextMessage("assistant", "Done"), + } + + system, result := convertMessagesToAnthropic(messages) + if system != "" { + t.Errorf("unexpected system: %q", system) + } + // Expected: user, assistant(tool_calls), user(batched tool results), assistant + if len(result) != 4 { + t.Fatalf("expected 4 messages, got %d", len(result)) + } +} + +func TestConvertMessagesToAnthropic_AssistantWithToolCalls(t *testing.T) { + messages := []Message{ + NewToolCallMessage("thinking", []ToolCall{ + {ID: "call-1", Type: "function", Function: FunctionCall{Name: "search", Arguments: `{"q":"test"}`}}, + }), + } + + _, result := convertMessagesToAnthropic(messages) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d", len(result)) + } + // The assistant message should have both text block and tool_use block + msg := result[0] + if msg.Role != "assistant" { + t.Errorf("expected role assistant, got %q", msg.Role) + } +} + +// --- Tests for utility functions --- + +func TestStripThinkTags(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"no tags", "hello world", "hello world"}, + {"with think tags", "reasoninganswer", "reasoninganswer"}, + {"only open tag", "partial", "partial"}, + {"only close tag", "partial", "partial"}, + {"empty string", "", ""}, + {"nested content", "step1resultstep2final", "step1resultstep2final"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripThinkTags(tt.input) + if got != tt.want { + t.Errorf("stripThinkTags(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestExtractErrorMessage(t *testing.T) { + tests := []struct { + name string + body []byte + want string + }{ + { + name: "OpenAI error format", + body: []byte(`{"error":{"message":"Rate limit exceeded","type":"rate_limit_error"}}`), + want: "Rate limit exceeded", + }, + { + name: "Anthropic error format", + body: []byte(`{"type":"error","error":{"message":"Invalid API key","type":"authentication_error"}}`), + want: "Invalid API key", + }, + { + name: "empty body", + body: []byte{}, + want: "(empty body)", + }, + { + name: "nil body", + body: nil, + want: "(empty body)", + }, + { + name: "unrecognized JSON", + body: []byte(`{"status":"error","detail":"something went wrong"}`), + want: `{"status":"error","detail":"something went wrong"}`, + }, + { + name: "truncates long body", + body: []byte(strings.Repeat("x", 600)), + want: strings.Repeat("x", 512) + "... (truncated)", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := NewAnthropicClient(ClientConfig{URL: tt.inputURL}) - if client.cfg.URL != tt.wantURL { - t.Errorf("got URL %q, want %q", client.cfg.URL, tt.wantURL) + got := extractErrorMessage(tt.body) + if got != tt.want { + t.Errorf("extractErrorMessage() = %q, want %q", got, tt.want) } }) } } + +func TestExtractText(t *testing.T) { + tests := []struct { + name string + message Message + want string + }{ + { + name: "string content", + message: Message{Role: "user", Content: "hello"}, + want: "hello", + }, + { + name: "content block array", + message: Message{Role: "user", Content: []ContentBlock{{Type: "text", Text: "part1"}, {Type: "text", Text: "part2"}}}, + want: "part1part2", + }, + { + name: "nil content", + message: Message{Role: "user", Content: nil}, + want: "", + }, + { + name: "empty string", + message: Message{Role: "user", Content: ""}, + want: "", + }, + { + name: "nested content blocks", + message: Message{Role: "user", Content: []ContentBlock{ + {Type: "tool_result", ToolUseID: "id1", Content: []ContentBlock{{Type: "text", Text: "nested"}}}, + }}, + want: "nested", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.message.ExtractText() + if got != tt.want { + t.Errorf("ExtractText() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestEncodingForModel(t *testing.T) { + tests := []struct { + model string + want string + }{ + {"o1-preview", "o200k_base"}, + {"o3-mini", "o200k_base"}, + {"o4-mini", "o200k_base"}, + {"gpt-4", "cl100k_base"}, + {"gpt-4o", "cl100k_base"}, + {"claude-opus-4-6", "cl100k_base"}, + {"", "cl100k_base"}, + } + + for _, tt := range tests { + t.Run(tt.model, func(t *testing.T) { + got := encodingForModel(tt.model) + if got != tt.want { + t.Errorf("encodingForModel(%q) = %q, want %q", tt.model, got, tt.want) + } + }) + } +} + +func TestChatResponseContent(t *testing.T) { + strPtr := func(s string) *string { return &s } + + tests := []struct { + name string + resp ChatResponse + want string + }{ + { + name: "normal content", + resp: ChatResponse{Choices: []Choice{{Message: ResponseMessage{Content: strPtr("hello")}}}}, + want: "hello", + }, + { + name: "content with think tags", + resp: ChatResponse{Choices: []Choice{{Message: ResponseMessage{Content: strPtr("reasoninganswer")}}}}, + want: "reasoninganswer", + }, + { + name: "fallback to reasoning content", + resp: ChatResponse{Choices: []Choice{{Message: ResponseMessage{Content: strPtr(""), ReasoningContent: "reasoning"}}}}, + want: "reasoning", + }, + { + name: "nil content fallback to reasoning", + resp: ChatResponse{Choices: []Choice{{Message: ResponseMessage{Content: nil, ReasoningContent: "fallback"}}}}, + want: "fallback", + }, + { + name: "empty choices", + resp: ChatResponse{Choices: []Choice{}}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.resp.Content() + if got != tt.want { + t.Errorf("Content() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestNewTextMessage(t *testing.T) { + msg := NewTextMessage("user", "hello") + if msg.Role != "user" { + t.Errorf("Role = %q, want %q", msg.Role, "user") + } + if msg.Content != "hello" { + t.Errorf("Content = %v, want %q", msg.Content, "hello") + } +} + +func TestNewToolCallMessage(t *testing.T) { + calls := []ToolCall{{ID: "1", Type: "function", Function: FunctionCall{Name: "fn", Arguments: "{}"}}} + msg := NewToolCallMessage("text", calls) + if msg.Role != "assistant" { + t.Errorf("Role = %q, want %q", msg.Role, "assistant") + } + if len(msg.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(msg.ToolCalls)) + } + // Verify it's a copy, not the same slice + calls[0].ID = "modified" + if msg.ToolCalls[0].ID == "modified" { + t.Error("ToolCalls should be a copy, not a reference") + } +} + +func TestNewToolResultMessage(t *testing.T) { + msg := NewToolResultMessage("call-1", "result") + if msg.Role != "tool" { + t.Errorf("Role = %q, want %q", msg.Role, "tool") + } + if msg.ToolCallID != "call-1" { + t.Errorf("ToolCallID = %q, want %q", msg.ToolCallID, "call-1") + } + if msg.Content != "result" { + t.Errorf("Content = %v, want %q", msg.Content, "result") + } +} + +// Ensure json import is used (for potential future tests using json assertions) +var _ = json.Marshal diff --git a/internal/llm/resolver.go b/internal/llm/resolver.go index 442ff81..edbaf20 100644 --- a/internal/llm/resolver.go +++ b/internal/llm/resolver.go @@ -11,20 +11,22 @@ import ( // ResolvedEndpoint holds the resolved LLM endpoint configuration. type ResolvedEndpoint struct { - URL string - Token string - Model string - Protocol string // "anthropic" or "openai" - Source string // human-readable config source label - ExtraBody map[string]any // vendor-specific request body fields + URL string + Token string + Model string + Protocol string // "anthropic" or "openai" + Source string // human-readable config source label + ExtraBody map[string]any // vendor-specific request body fields + UseMaxCompletionTokens bool // use max_completion_tokens instead of max_tokens } // Environment variable names for OCR-specific configuration. const ( - envOCRLLMURL = "OCR_LLM_URL" - envOCRLLMToken = "OCR_LLM_TOKEN" - envOCRLLMModel = "OCR_LLM_MODEL" - envOCRUseAnthropic = "OCR_USE_ANTHROPIC" + envOCRLLMURL = "OCR_LLM_URL" + envOCRLLMToken = "OCR_LLM_TOKEN" + envOCRLLMModel = "OCR_LLM_MODEL" + envOCRUseAnthropic = "OCR_USE_ANTHROPIC" + envOCRUseMaxCompletionTokens = "OCR_USE_MAX_COMPLETION_TOKENS" ) // Environment variable names from Claude Code configuration. @@ -83,16 +85,23 @@ func tryOCREnv() (ResolvedEndpoint, bool, error) { protocol = "openai" } - return ResolvedEndpoint{URL: url, Token: token, Model: model, Protocol: protocol, Source: "OCR environment"}, true, nil + useMaxCompletionTokens := false + if v := os.Getenv(envOCRUseMaxCompletionTokens); v != "" { + lower := strings.ToLower(v) + useMaxCompletionTokens = lower == "true" || lower == "1" || lower == "yes" + } + + return ResolvedEndpoint{URL: url, Token: token, Model: model, Protocol: protocol, Source: "OCR environment", UseMaxCompletionTokens: useMaxCompletionTokens}, true, nil } // llmFileConfig represents the llm section in config.json. type llmFileConfig struct { - URL string `json:"url,omitempty"` - AuthToken string `json:"auth_token,omitempty"` - Model string `json:"model,omitempty"` - UseAnthropic *bool `json:"use_anthropic,omitempty"` // pointer to distinguish unset from false - ExtraBody map[string]any `json:"extra_body,omitempty"` + URL string `json:"url,omitempty"` + AuthToken string `json:"auth_token,omitempty"` + Model string `json:"model,omitempty"` + UseAnthropic *bool `json:"use_anthropic,omitempty"` // pointer to distinguish unset from false + UseMaxCompletionTokens *bool `json:"use_max_completion_tokens,omitempty"` // pointer to distinguish unset from false + ExtraBody map[string]any `json:"extra_body,omitempty"` } type configFile struct { @@ -128,7 +137,12 @@ func tryOCRConfig(path string) (ResolvedEndpoint, bool, error) { protocol = "openai" } - return ResolvedEndpoint{URL: cfg.Llm.URL, Token: cfg.Llm.AuthToken, Model: cfg.Llm.Model, Protocol: protocol, Source: "OCR config file", ExtraBody: cfg.Llm.ExtraBody}, true, nil + useMaxCompletionTokens := false + if cfg.Llm.UseMaxCompletionTokens != nil { + useMaxCompletionTokens = *cfg.Llm.UseMaxCompletionTokens + } + + return ResolvedEndpoint{URL: cfg.Llm.URL, Token: cfg.Llm.AuthToken, Model: cfg.Llm.Model, Protocol: protocol, Source: "OCR config file", ExtraBody: cfg.Llm.ExtraBody, UseMaxCompletionTokens: useMaxCompletionTokens}, true, nil } // tryCCEnv reads Claude Code environment variables. diff --git a/internal/llm/resolver_test.go b/internal/llm/resolver_test.go index 61ca101..d99365a 100644 --- a/internal/llm/resolver_test.go +++ b/internal/llm/resolver_test.go @@ -115,3 +115,88 @@ func TestResolveEndpoint_ConfigFileStripsModelSuffix(t *testing.T) { t.Errorf("expected source %q, got %q", "OCR config file", ep.Source) } } + +func TestTryOCREnv_UseMaxCompletionTokens(t *testing.T) { + tests := []struct { + name string + envValue string + want bool + }{ + {"true", "true", true}, + {"True_uppercase", "True", true}, + {"1", "1", true}, + {"yes", "yes", true}, + {"false", "false", false}, + {"0", "0", false}, + {"empty_string", "", false}, + {"invalid", "invalid", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("OCR_LLM_URL", "https://api.example.com/v1/chat/completions") + t.Setenv("OCR_LLM_TOKEN", "test-token") + t.Setenv("OCR_LLM_MODEL", "gpt-4") + t.Setenv("OCR_USE_ANTHROPIC", "false") + if tt.envValue != "" { + t.Setenv("OCR_USE_MAX_COMPLETION_TOKENS", tt.envValue) + } else { + t.Setenv("OCR_USE_MAX_COMPLETION_TOKENS", "") + } + + ep, ok, err := tryOCREnv() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Fatal("expected ok=true") + } + if ep.UseMaxCompletionTokens != tt.want { + t.Errorf("UseMaxCompletionTokens = %v, want %v", ep.UseMaxCompletionTokens, tt.want) + } + }) + } +} + +func TestTryOCRConfig_UseMaxCompletionTokens(t *testing.T) { + tests := []struct { + name string + val *bool + want bool + }{ + {"explicit_true", boolPtr(true), true}, + {"explicit_false", boolPtr(false), false}, + {"unset_nil", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := configFile{ + Llm: llmFileConfig{ + URL: "https://api.example.com/v1/chat/completions", + AuthToken: "test-token", + Model: "gpt-4", + UseMaxCompletionTokens: tt.val, + }, + } + data, _ := json.Marshal(cfg) + cfgPath := filepath.Join(t.TempDir(), "config.json") + os.WriteFile(cfgPath, data, 0644) + + ep, ok, err := tryOCRConfig(cfgPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Fatal("expected ok=true") + } + if ep.UseMaxCompletionTokens != tt.want { + t.Errorf("UseMaxCompletionTokens = %v, want %v", ep.UseMaxCompletionTokens, tt.want) + } + }) + } +} + +func boolPtr(b bool) *bool { + return &b +} diff --git a/internal/llm/usage_resolver.go b/internal/llm/usage_resolver.go index 1345420..fd033b1 100644 --- a/internal/llm/usage_resolver.go +++ b/internal/llm/usage_resolver.go @@ -1,10 +1,5 @@ package llm -import ( - "encoding/json" - "strings" -) - // UsageInfo holds token usage extracted from an LLM API response. type UsageInfo struct { TotalTokens int64 `json:"total_tokens"` @@ -13,101 +8,3 @@ type UsageInfo struct { CacheReadTokens int64 `json:"cache_read_tokens,omitempty"` CacheWriteTokens int64 `json:"cache_write_tokens,omitempty"` } - -var promptTokensPaths = []string{ - "usage.prompt_tokens", // OpenAI standard - "prompt_tokens", // flat at root - "data.usage.prompt_tokens", // wrapped in data layer -} - -var completionTokensPaths = []string{ - "usage.completion_tokens", // OpenAI standard - "completion_tokens", // flat at root - "data.usage.completion_tokens", // wrapped in data layer -} - -var cacheReadTokensPaths = []string{ - "usage.cache_read_input_tokens", // Anthropic - "cache_read_input_tokens", // flat at root - "usage.prompt_tokens_details.cache_tokens_hit", // some providers - "usage.prompt_tokens_details.cache_tokens", // some providers -} - -var cacheWriteTokensPaths = []string{ - "usage.cache_creation_input_tokens", // Anthropic / proxy - "cache_creation_input_tokens", // flat at root -} - -// totalTokensPaths is an ordered list of JSON paths to try when extracting -// total token count from a response body. Paths are dot-separated keys that -// navigate through nested map[string]any objects. The first match wins. -var totalTokensPaths = []string{ - "usage.total_tokens", // OpenAI standard - "total_tokens", // flat at root - "data.usage.total_tokens", // wrapped in data layer (some proxy APIs) -} - -// resolveUsage parses raw JSON bytes into a map and extracts token usage -// by probing configured paths sequentially. Returns nil if no total_tokens found. -func resolveUsage(raw []byte) *UsageInfo { - var rawBody map[string]any - if err := json.Unmarshal(raw, &rawBody); err != nil { - return nil - } - - total, hasAny := probePath(rawBody, totalTokensPaths) - prompt, _ := probePath(rawBody, promptTokensPaths) - completion, _ := probePath(rawBody, completionTokensPaths) - cacheRead, _ := probePath(rawBody, cacheReadTokensPaths) - cacheWrite, _ := probePath(rawBody, cacheWriteTokensPaths) - - if !hasAny && prompt == 0 && completion == 0 { - return nil - } - - ui := &UsageInfo{ - TotalTokens: total, - PromptTokens: prompt, - CompletionTokens: completion, - CacheReadTokens: cacheRead, - CacheWriteTokens: cacheWrite, - } - - // If TotalTokens wasn't explicitly available but we have prompt+completion, compute it. - if total == 0 && (prompt > 0 || completion > 0) { - ui.TotalTokens = prompt + completion + cacheRead + cacheWrite - } - - return ui -} - -// probePath walks through each candidate path in order, returning the first -// int64 value found along with true. Returns (0, false) if none match. -func probePath(root map[string]any, paths []string) (int64, bool) { - for _, p := range paths { - parts := strings.Split(p, ".") - - var current any = root - for _, part := range parts { - obj, ok := current.(map[string]any) - if !ok { - goto next - } - current, ok = obj[part] - if !ok { - goto next - } - } - - switch v := current.(type) { - case float64: - return int64(v), true - case int64: - return v, true - case int: - return int64(v), true - } - next: - } - return 0, false -}