diff --git a/README.md b/README.md
index 2dccfa7..94b7f00 100644
--- a/README.md
+++ b/README.md
@@ -329,6 +329,7 @@ Config file: `~/.opencodereview/config.json`
| `llm.auth_token` | string | `sk-xxxxxxx` |
| `llm.model` | string | `claude-opus-4-6` |
| `llm.use_anthropic` | boolean | `true` \| `false` |
+| `llm.use_max_completion_tokens` | boolean | `true` = use `max_completion_tokens`, `false` = use `max_tokens` (default) |
| `language` | string | `English` \| `Chinese` (default: Chinese) |
| `telemetry.enabled` | boolean | `true` \| `false` |
| `telemetry.exporter` | string | `console` \| `otlp` |
@@ -345,6 +346,7 @@ Environment variables take precedence over the config file.
| `OCR_LLM_TOKEN` | API key / auth token |
| `OCR_LLM_MODEL` | Model name |
| `OCR_USE_ANTHROPIC` | `true` = Anthropic, `false` = OpenAI |
+| `OCR_USE_MAX_COMPLETION_TOKENS` | `true` = use `max_completion_tokens` param instead of `max_tokens` |
## Telemetry
diff --git a/README.zh-CN.md b/README.zh-CN.md
index 04b2b3a..b4d5a42 100644
--- a/README.zh-CN.md
+++ b/README.zh-CN.md
@@ -319,6 +319,7 @@ OCR 通过四层优先级链解析评审规则。每层采用首次匹配原则
| `llm.auth_token` | string | `sk-xxxxxxx` |
| `llm.model` | string | `claude-opus-4-6` |
| `llm.use_anthropic` | boolean | `true` \| `false` |
+| `llm.use_max_completion_tokens` | boolean | `true` = 使用 `max_completion_tokens`,`false` = 使用 `max_tokens`(默认) |
| `language` | string | `English` \| `Chinese`(默认:Chinese) |
| `telemetry.enabled` | boolean | `true` \| `false` |
| `telemetry.exporter` | string | `console` \| `otlp` |
@@ -335,6 +336,7 @@ OCR 通过四层优先级链解析评审规则。每层采用首次匹配原则
| `OCR_LLM_TOKEN` | API 密钥 / 认证令牌 |
| `OCR_LLM_MODEL` | 模型名称 |
| `OCR_USE_ANTHROPIC` | `true` = Anthropic,`false` = OpenAI |
+| `OCR_USE_MAX_COMPLETION_TOKENS` | `true` = 使用 `max_completion_tokens` 参数代替 `max_tokens` |
## 遥测
diff --git a/cmd/opencodereview/config_cmd.go b/cmd/opencodereview/config_cmd.go
index 14bd4ac..ea96731 100644
--- a/cmd/opencodereview/config_cmd.go
+++ b/cmd/opencodereview/config_cmd.go
@@ -77,11 +77,12 @@ type Config struct {
}
type LlmConfig struct {
- URL string `json:"url,omitempty"`
- AuthToken string `json:"auth_token,omitempty"`
- Model string `json:"model,omitempty"`
- UseAnthropic *bool `json:"use_anthropic,omitempty"` // nil = default true; false = OpenAI protocol
- ExtraBody map[string]any `json:"extra_body,omitempty"`
+ URL string `json:"url,omitempty"`
+ AuthToken string `json:"auth_token,omitempty"`
+ Model string `json:"model,omitempty"`
+ UseAnthropic *bool `json:"use_anthropic,omitempty"` // nil = default true; false = OpenAI protocol
+ UseMaxCompletionTokens *bool `json:"use_max_completion_tokens,omitempty"` // nil = default false; true = use max_completion_tokens
+ ExtraBody map[string]any `json:"extra_body,omitempty"`
}
// TelemetryConfig holds telemetry-specific settings.
@@ -137,6 +138,12 @@ func setConfigValue(cfg *Config, key, value string) error {
return fmt.Errorf("invalid boolean for llm.use_anthropic: %w", err)
}
cfg.Llm.UseAnthropic = &b
+ case "llm.use_max_completion_tokens", "llm.UseMaxCompletionTokens":
+ b, err := strconv.ParseBool(value)
+ if err != nil {
+ return fmt.Errorf("invalid boolean for llm.use_max_completion_tokens: %w", err)
+ }
+ cfg.Llm.UseMaxCompletionTokens = &b
case "language", "Language":
cfg.Language = value
case "telemetry.enabled", "telemetry.Enabled":
@@ -166,7 +173,7 @@ func setConfigValue(cfg *Config, key, value string) error {
}
cfg.Llm.ExtraBody = m
default:
- return fmt.Errorf("unknown config key: %s\nSupported keys: llm.url, llm.auth_token, llm.model, llm.use_anthropic, llm.extra_body, language, telemetry.enabled, telemetry.exporter, telemetry.otlp_endpoint, telemetry.content_logging", key)
+ return fmt.Errorf("unknown config key: %s\nSupported keys: llm.url, llm.auth_token, llm.model, llm.use_anthropic, llm.use_max_completion_tokens, llm.extra_body, language, telemetry.enabled, telemetry.exporter, telemetry.otlp_endpoint, telemetry.content_logging", key)
}
return nil
}
diff --git a/cmd/opencodereview/config_cmd_test.go b/cmd/opencodereview/config_cmd_test.go
new file mode 100644
index 0000000..3d15795
--- /dev/null
+++ b/cmd/opencodereview/config_cmd_test.go
@@ -0,0 +1,146 @@
+package main
+
+import (
+ "testing"
+)
+
+func boolVal(b bool) *bool { return &b }
+
+func TestSetConfigValue_StringKeys(t *testing.T) {
+ tests := []struct {
+ key string
+ value string
+ checkFn func(*Config) string
+ }{
+ {"llm.url", "https://api.example.com/v1", func(c *Config) string { return c.Llm.URL }},
+ {"llm.URL", "https://api.example.com/v2", func(c *Config) string { return c.Llm.URL }},
+ {"llm.auth_token", "sk-token-123", func(c *Config) string { return c.Llm.AuthToken }},
+ {"llm.AuthToken", "sk-token-456", func(c *Config) string { return c.Llm.AuthToken }},
+ {"llm.model", "gpt-4", func(c *Config) string { return c.Llm.Model }},
+ {"llm.Model", "claude-opus-4-6", func(c *Config) string { return c.Llm.Model }},
+ {"language", "English", func(c *Config) string { return c.Language }},
+ {"Language", "Chinese", func(c *Config) string { return c.Language }},
+ {"telemetry.exporter", "otlp", func(c *Config) string { return c.Telemetry.Exporter }},
+ {"telemetry.Exporter", "console", func(c *Config) string { return c.Telemetry.Exporter }},
+ {"telemetry.otlp_endpoint", "localhost:4317", func(c *Config) string { return c.Telemetry.OTLPEndpoint }},
+ {"telemetry.OTLPEndpoint", "collector:4317", func(c *Config) string { return c.Telemetry.OTLPEndpoint }},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.key+"="+tt.value, func(t *testing.T) {
+ cfg := &Config{}
+ if err := setConfigValue(cfg, tt.key, tt.value); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ got := tt.checkFn(cfg)
+ if got != tt.value {
+ t.Errorf("got %q, want %q", got, tt.value)
+ }
+ })
+ }
+}
+
+func TestSetConfigValue_BoolKeys(t *testing.T) {
+ tests := []struct {
+ key string
+ value string
+ want bool
+ checkFn func(*Config) bool
+ }{
+ {"llm.use_anthropic", "true", true, func(c *Config) bool { return *c.Llm.UseAnthropic }},
+ {"llm.UseAnthropic", "false", false, func(c *Config) bool { return *c.Llm.UseAnthropic }},
+ {"llm.use_max_completion_tokens", "true", true, func(c *Config) bool { return *c.Llm.UseMaxCompletionTokens }},
+ {"llm.UseMaxCompletionTokens", "false", false, func(c *Config) bool { return *c.Llm.UseMaxCompletionTokens }},
+ {"telemetry.enabled", "true", true, func(c *Config) bool { return c.Telemetry.Enabled }},
+ {"telemetry.Enabled", "false", false, func(c *Config) bool { return c.Telemetry.Enabled }},
+ {"telemetry.content_logging", "true", true, func(c *Config) bool { return c.Telemetry.ContentLog }},
+ {"telemetry.ContentLog", "false", false, func(c *Config) bool { return c.Telemetry.ContentLog }},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.key+"="+tt.value, func(t *testing.T) {
+ cfg := &Config{}
+ if err := setConfigValue(cfg, tt.key, tt.value); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ got := tt.checkFn(cfg)
+ if got != tt.want {
+ t.Errorf("got %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestSetConfigValue_BoolKeys_InvalidValue(t *testing.T) {
+ keys := []string{
+ "llm.use_anthropic",
+ "llm.use_max_completion_tokens",
+ "telemetry.enabled",
+ "telemetry.content_logging",
+ }
+
+ for _, key := range keys {
+ t.Run(key, func(t *testing.T) {
+ cfg := &Config{}
+ err := setConfigValue(cfg, key, "not-a-bool")
+ if err == nil {
+ t.Fatal("expected error for invalid boolean value")
+ }
+ })
+ }
+}
+
+func TestSetConfigValue_ExtraBody(t *testing.T) {
+ tests := []struct {
+ name string
+ key string
+ value string
+ wantErr bool
+ }{
+ {"valid JSON", "llm.extra_body", `{"thinking":{"type":"disabled"}}`, false},
+ {"alias key", "llm.ExtraBody", `{"key":"value"}`, false},
+ {"invalid JSON", "llm.extra_body", `not json`, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &Config{}
+ err := setConfigValue(cfg, tt.key, tt.value)
+ if tt.wantErr {
+ if err == nil {
+ t.Fatal("expected error for invalid JSON")
+ }
+ } else {
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if cfg.Llm.ExtraBody == nil {
+ t.Fatal("expected non-nil ExtraBody")
+ }
+ }
+ })
+ }
+}
+
+func TestSetConfigValue_UnknownKey(t *testing.T) {
+ cfg := &Config{}
+ err := setConfigValue(cfg, "unknown.key", "value")
+ if err == nil {
+ t.Fatal("expected error for unknown key")
+ }
+}
+
+func TestSetConfigValue_TelemetryInitialized(t *testing.T) {
+ cfg := &Config{}
+ if cfg.Telemetry != nil {
+ t.Fatal("telemetry should be nil initially")
+ }
+
+ if err := setConfigValue(cfg, "telemetry.enabled", "true"); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if cfg.Telemetry == nil {
+ t.Fatal("telemetry should be initialized after setting telemetry key")
+ }
+}
diff --git a/cmd/opencodereview/flags.go b/cmd/opencodereview/flags.go
index 8a73d6d..42b5a87 100644
--- a/cmd/opencodereview/flags.go
+++ b/cmd/opencodereview/flags.go
@@ -254,5 +254,5 @@ Examples:
ocr config set language English
ocr config set telemetry.enabled true
-Supported keys: llm.url, llm.auth_token, llm.model, llm.use_anthropic, llm.extra_body, language, telemetry.enabled, telemetry.exporter, telemetry.otlp_endpoint, telemetry.content_logging`)
+Supported keys: llm.url, llm.auth_token, llm.model, llm.use_anthropic, llm.use_max_completion_tokens, llm.extra_body, language, telemetry.enabled, telemetry.exporter, telemetry.otlp_endpoint, telemetry.content_logging`)
}
diff --git a/go.mod b/go.mod
index 16e5cf7..77bec62 100644
--- a/go.mod
+++ b/go.mod
@@ -3,7 +3,9 @@ module github.com/open-code-review/open-code-review
go 1.25.0
require (
+ github.com/anthropics/anthropic-sdk-go v1.46.0
github.com/bmatcuk/doublestar/v4 v4.10.0
+ github.com/openai/openai-go/v3 v3.38.0
github.com/pkoukk/tiktoken-go v0.1.8
go.opentelemetry.io/otel v1.43.0
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0
@@ -17,6 +19,8 @@ require (
)
require (
+ github.com/bahlo/generic-list-go v0.2.0 // indirect
+ github.com/buger/jsonparser v1.1.2 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
@@ -24,14 +28,24 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
+ github.com/invopop/jsonschema v0.13.0 // indirect
+ github.com/mailru/easyjson v0.7.7 // indirect
+ github.com/standard-webhooks/standard-webhooks/libraries v0.0.1 // indirect
+ github.com/tidwall/gjson v1.18.0 // indirect
+ github.com/tidwall/match v1.1.1 // indirect
+ github.com/tidwall/pretty v1.2.1 // indirect
+ github.com/tidwall/sjson v1.2.5 // indirect
+ github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect
go.opentelemetry.io/proto/otlp v1.10.0 // indirect
golang.org/x/net v0.52.0 // indirect
+ golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
google.golang.org/grpc v1.80.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index 1f8f50e..ead072a 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,11 @@
+github.com/anthropics/anthropic-sdk-go v1.46.0 h1:yl3n+el5ZfNgiCtQ7zQ7s/NXxB11YbrKXdc3uLPNWlU=
+github.com/anthropics/anthropic-sdk-go v1.46.0/go.mod h1:bx5vWuHFuGPkELH8Z4KUiNSohFnUwScdpTyr+50myPo=
+github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
+github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/bmatcuk/doublestar/v4 v4.10.0 h1:zU9WiOla1YA122oLM6i4EXvGW62DvKZVxIe6TYWexEs=
github.com/bmatcuk/doublestar/v4 v4.10.0/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
+github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk=
+github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
@@ -8,6 +14,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
+github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
+github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -21,12 +29,39 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
+github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
+github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
+github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
+github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
+github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
+github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
+github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
+github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
+github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
+github.com/openai/openai-go/v3 v3.38.0 h1:Kre0Fz9mPUxtWjRB/CoNBHflp9W7FkztOm/XMDr/A3E=
+github.com/openai/openai-go/v3 v3.38.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
+github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
+github.com/standard-webhooks/standard-webhooks/libraries v0.0.1 h1:uOfcYT+3QungH6tIGSVCR/Y3KJmgJiHcojJbMTPDZAI=
+github.com/standard-webhooks/standard-webhooks/libraries v0.0.1/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
+github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
+github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
+github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
+github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
+github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
+github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
+github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
+github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
+github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
+github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
+github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
+github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
@@ -55,6 +90,8 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
+golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
+golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
@@ -69,5 +106,10 @@ google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
+gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
+gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/internal/agent/agent.go b/internal/agent/agent.go
index ca8d676..f0b45b6 100644
--- a/internal/agent/agent.go
+++ b/internal/agent/agent.go
@@ -661,9 +661,11 @@ func (a *Agent) executePlanPhase(ctx context.Context, newPath, rawDiff, changeFi
}
rec.SetResponse(resp, time.Since(startTime))
if resp.Usage != nil {
- atomic.AddInt64(&a.totalTokensUsed, int64(resp.Usage.TotalTokens))
- atomic.AddInt64(&a.totalInputTokens, int64(resp.Usage.PromptTokens+resp.Usage.CacheReadTokens))
- atomic.AddInt64(&a.totalOutputTokens, int64(resp.Usage.CompletionTokens+resp.Usage.CacheWriteTokens))
+ input := int64(resp.Usage.PromptTokens + resp.Usage.CacheReadTokens)
+ output := int64(resp.Usage.CompletionTokens + resp.Usage.CacheWriteTokens)
+ atomic.AddInt64(&a.totalTokensUsed, input+output)
+ atomic.AddInt64(&a.totalInputTokens, input)
+ atomic.AddInt64(&a.totalOutputTokens, output)
}
fmt.Fprintf(stdout.Writer(), "[ocr] Plan completed for %s\n", newPath)
return resp.Content(), nil
@@ -740,11 +742,13 @@ func (a *Agent) performLlmCodeReview(ctx context.Context, messages []llm.Message
}
rec.SetResponse(resp, duration)
// Record LLM metrics with token info from API response usage field.
- totalTokens := int64(0)
+ var totalTokens int64
if resp.Usage != nil {
- totalTokens = resp.Usage.TotalTokens
- atomic.AddInt64(&a.totalInputTokens, int64(resp.Usage.PromptTokens+resp.Usage.CacheReadTokens))
- atomic.AddInt64(&a.totalOutputTokens, int64(resp.Usage.CompletionTokens+resp.Usage.CacheWriteTokens))
+ input := int64(resp.Usage.PromptTokens + resp.Usage.CacheReadTokens)
+ output := int64(resp.Usage.CompletionTokens + resp.Usage.CacheWriteTokens)
+ totalTokens = input + output
+ atomic.AddInt64(&a.totalInputTokens, input)
+ atomic.AddInt64(&a.totalOutputTokens, output)
}
telemetry.RecordLLMRequest(ctx, a.args.Model, duration, totalTokens, "ok")
atomic.AddInt64(&a.totalTokensUsed, totalTokens)
diff --git a/internal/diff/relocation_test.go b/internal/diff/relocation_test.go
index 8b3a81b..8a2d706 100644
--- a/internal/diff/relocation_test.go
+++ b/internal/diff/relocation_test.go
@@ -24,7 +24,7 @@ func (m *mockLLMClient) CompletionsWithCtx(_ context.Context, req llm.ChatReques
return m.response, m.err
}
-func (m *mockLLMClient) StreamCompletion(req llm.ChatRequest, cb func(chunk []byte) error) error {
+func (m *mockLLMClient) StreamCompletion(_ context.Context, req llm.ChatRequest, cb func(chunk []byte) error) error {
return m.err
}
diff --git a/internal/llm/client.go b/internal/llm/client.go
index 4948ffa..f8fc34b 100644
--- a/internal/llm/client.go
+++ b/internal/llm/client.go
@@ -1,20 +1,21 @@
// Package llm provides LLM client interfaces supporting multiple protocols.
// Supported protocols: Anthropic Messages API, OpenAI Chat Completions API.
+// Implementations use the official SDKs: github.com/openai/openai-go and github.com/anthropics/anthropic-sdk-go.
package llm
import (
- "bufio"
- "bytes"
"context"
"encoding/json"
"fmt"
- "io"
- "math/rand"
"net/http"
"strings"
"sync"
"time"
+ "github.com/anthropics/anthropic-sdk-go"
+ anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
+ "github.com/openai/openai-go/v3"
+ openaioption "github.com/openai/openai-go/v3/option"
tiktoken "github.com/pkoukk/tiktoken-go"
"github.com/open-code-review/open-code-review/internal/stdout"
@@ -36,7 +37,7 @@ func userAgent(provider string) string {
type LLMClient interface {
Completions(req ChatRequest) (*ChatResponse, error)
CompletionsWithCtx(ctx context.Context, req ChatRequest) (*ChatResponse, error)
- StreamCompletion(req ChatRequest, cb func(chunk []byte) error) error
+ StreamCompletion(ctx context.Context, req ChatRequest, cb func(chunk []byte) error) error
}
// --- Shared data types ---
@@ -186,11 +187,12 @@ type FunctionDef struct {
// ClientConfig holds configuration for connecting to an LLM service.
type ClientConfig struct {
- URL string // Full API endpoint URL
- APIKey string // Bearer token / API key
- Model string // Default model override
- Timeout time.Duration // Request timeout
- ExtraBody map[string]any // Vendor-specific fields merged into every request body
+ URL string // Full API endpoint URL
+ APIKey string // Bearer token / API key
+ Model string // Default model override
+ Timeout time.Duration // Request timeout
+ ExtraBody map[string]any // Vendor-specific fields merged into every request body
+ UseMaxCompletionTokens bool // use max_completion_tokens instead of max_tokens
}
// --- Factory ---
@@ -199,10 +201,11 @@ type ClientConfig struct {
// protocol: "anthropic" -> AnthropicClient, anything else -> OpenAIClient.
func NewLLMClient(ep ResolvedEndpoint) LLMClient {
cfg := ClientConfig{
- URL: ep.URL,
- APIKey: ep.Token,
- Model: ep.Model,
- ExtraBody: ep.ExtraBody,
+ URL: ep.URL,
+ APIKey: ep.Token,
+ Model: ep.Model,
+ ExtraBody: ep.ExtraBody,
+ UseMaxCompletionTokens: ep.UseMaxCompletionTokens,
}
if ep.Protocol == "anthropic" {
return NewAnthropicClient(cfg)
@@ -276,29 +279,43 @@ func encodingForModel(modelName string) string {
}
}
+// ChatRequest represents the payload for a chat completion call.
+type ChatRequest struct {
+ Model string `json:"model"`
+ Messages []Message `json:"messages"`
+ Tools []ToolDef `json:"tools,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+}
+
// --- OpenAIClient ---
-// OpenAIClient sends requests to an OpenAI-compatible chat completion API.
+// OpenAIClient sends requests to an OpenAI-compatible chat completion API via the official SDK.
type OpenAIClient struct {
cfg ClientConfig
- client *http.Client
+ client openai.Client
}
-// NewOpenAIClient creates a new OpenAI-compatible LLM client.
+// NewOpenAIClient creates a new OpenAI-compatible LLM client using the official SDK.
func NewOpenAIClient(cfg ClientConfig) *OpenAIClient {
if cfg.Timeout <= 0 {
cfg.Timeout = 5 * time.Minute
}
- baseURL := strings.TrimRight(cfg.URL, "/")
- if !strings.HasSuffix(baseURL, "/chat/completions") {
- cfg.URL = baseURL + "/chat/completions"
- }
- return &OpenAIClient{
- cfg: cfg,
- client: &http.Client{
- Timeout: cfg.Timeout,
- },
+
+ // Normalize URL: strip /chat/completions suffix since SDK appends it automatically.
+ baseURL := normalizeOpenAIBaseURL(cfg.URL)
+
+ opts := []openaioption.RequestOption{
+ openaioption.WithAPIKey(cfg.APIKey),
+ openaioption.WithBaseURL(baseURL),
+ openaioption.WithHeader("User-Agent", userAgent("")),
+ openaioption.WithMaxRetries(maxRetries),
+ openaioption.WithRequestTimeout(cfg.Timeout),
}
+
+ client := openai.NewClient(opts...)
+ return &OpenAIClient{cfg: cfg, client: client}
}
// NewClient is kept as an alias for backward compatibility during transition.
@@ -306,16 +323,6 @@ func NewClient(cfg ClientConfig) *OpenAIClient {
return NewOpenAIClient(cfg)
}
-// ChatRequest represents the payload for a chat completion call.
-type ChatRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- Tools []ToolDef `json:"tools,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
-}
-
// Completions sends a chat completion request and returns the parsed response.
func (c *OpenAIClient) Completions(req ChatRequest) (*ChatResponse, error) {
return c.CompletionsWithCtx(context.Background(), req)
@@ -328,16 +335,19 @@ func (c *OpenAIClient) CompletionsWithCtx(ctx context.Context, req ChatRequest)
model = c.cfg.Model
}
- var result *ChatResponse
- err := c.withRetryCtx(ctx, func() error {
- resp, err := c.doRequestCtx(ctx, model, req)
- if err != nil {
- return err
- }
- result = resp
- return nil
- })
- return result, err
+ params := c.buildParams(model, req)
+ reqOpts := c.buildExtraBodyOptions()
+
+ // Capture raw HTTP response for headers
+ var httpResp *http.Response
+ reqOpts = append(reqOpts, openaioption.WithResponseInto(&httpResp))
+
+ completion, err := c.client.Chat.Completions.New(ctx, params, reqOpts...)
+ if err != nil {
+ return nil, fmt.Errorf("API error: %w", err)
+ }
+
+ return c.convertResponse(completion, httpResp), nil
}
// GeneralRequest sends a simple chat request without or with optional tool calls.
@@ -355,150 +365,229 @@ func (c *OpenAIClient) GeneralRequestWithCtx(ctx context.Context, messages []Mes
}
// StreamCompletion initiates a streaming chat completion. The callback is invoked per chunk.
-func (c *OpenAIClient) StreamCompletion(req ChatRequest, cb func(chunk []byte) error) error {
- req.Stream = true
-
+func (c *OpenAIClient) StreamCompletion(ctx context.Context, req ChatRequest, cb func(chunk []byte) error) error {
model := req.Model
if model == "" {
model = c.cfg.Model
}
- return c.withRetry(func() error {
- body := make(map[string]any)
- b, _ := json.Marshal(req)
- json.Unmarshal(b, &body)
- body["model"] = model
- for k, v := range c.cfg.ExtraBody {
- body[k] = v
- }
+ params := c.buildParams(model, req)
+ // Enable usage reporting in stream so the final chunk contains token stats.
+ params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
+ IncludeUsage: openai.Bool(true),
+ }
+ reqOpts := c.buildExtraBodyOptions()
- payload, _ := json.Marshal(body)
- httpReq, err := http.NewRequest(http.MethodPost, c.cfg.URL, bytes.NewReader(payload))
- if err != nil {
- return fmt.Errorf("create request: %w", err)
- }
- httpReq.Header.Set("Content-Type", "application/json")
- httpReq.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
- httpReq.Header.Set("Accept", "text/event-stream")
- httpReq.Header.Set("User-Agent", userAgent(""))
-
- resp, err := c.client.Do(httpReq)
- if err != nil {
- return fmt.Errorf("request failed: %w", err)
- }
- defer resp.Body.Close()
+ stream := c.client.Chat.Completions.NewStreaming(ctx, params, reqOpts...)
+ defer stream.Close()
- if isRetryableStatus(resp.StatusCode) {
- bodyBytes, _ := io.ReadAll(resp.Body)
- return fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes))
+ for stream.Next() {
+ chunk := stream.Current()
+ // Use RawJSON to preserve the original API response format for downstream consumers.
+ raw := chunk.RawJSON()
+ if raw == "" {
+ continue
}
- if resp.StatusCode >= 400 {
- bodyBytes, _ := io.ReadAll(resp.Body)
- return fmt.Errorf("API error %d: %s (non-retryable)", resp.StatusCode, string(bodyBytes))
+ if err := cb([]byte(raw)); err != nil {
+ return err
}
+ }
+ return stream.Err()
+}
- scanner := bufio.NewScanner(resp.Body)
- for scanner.Scan() {
- line := scanner.Text()
- if !strings.HasPrefix(line, "data: ") {
- continue
- }
- data := strings.TrimPrefix(line, "data: ")
- if data == "[DONE]" {
- break
- }
- if err := cb([]byte(data)); err != nil {
- return err
- }
+// buildParams converts internal ChatRequest to OpenAI SDK params.
+func (c *OpenAIClient) buildParams(model string, req ChatRequest) openai.ChatCompletionNewParams {
+ params := openai.ChatCompletionNewParams{
+ Model: model,
+ Messages: convertMessagesToOpenAI(req.Messages),
+ }
+
+ if len(req.Tools) > 0 {
+ params.Tools = convertToolsToOpenAI(req.Tools)
+ }
+
+ if req.MaxTokens > 0 {
+ if c.cfg.UseMaxCompletionTokens {
+ params.MaxCompletionTokens = openai.Int(int64(req.MaxTokens))
+ } else {
+ params.MaxTokens = openai.Int(int64(req.MaxTokens))
}
- return scanner.Err()
- })
-}
+ }
-// doRequest builds and sends a non-streaming completion request, returning the parsed response.
-func (c *OpenAIClient) doRequest(model string, req ChatRequest) (*ChatResponse, error) {
- return c.doRequestCtx(context.Background(), model, req)
+ if req.Temperature != nil {
+ params.Temperature = openai.Float(*req.Temperature)
+ }
+
+ return params
}
-// doRequestCtx builds and sends a non-streaming completion request with context support.
-func (c *OpenAIClient) doRequestCtx(ctx context.Context, model string, req ChatRequest) (*ChatResponse, error) {
- if model == "" {
- model = c.cfg.Model
+// buildExtraBodyOptions creates request options for ExtraBody fields.
+func (c *OpenAIClient) buildExtraBodyOptions() []openaioption.RequestOption {
+ if len(c.cfg.ExtraBody) == 0 {
+ return nil
}
- req.Model = model
- payload, err := mergeExtraBody(req, c.cfg.ExtraBody)
- if err != nil {
- return nil, fmt.Errorf("marshal request body: %w", err)
+ opts := make([]openaioption.RequestOption, 0, len(c.cfg.ExtraBody))
+ for k, v := range c.cfg.ExtraBody {
+ opts = append(opts, openaioption.WithJSONSet(k, v))
}
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.URL, bytes.NewReader(payload))
- if err != nil {
- return nil, fmt.Errorf("create request: %w", err)
+ return opts
+}
+
+// convertResponse maps the SDK ChatCompletion to our internal ChatResponse.
+func (c *OpenAIClient) convertResponse(comp *openai.ChatCompletion, httpResp *http.Response) *ChatResponse {
+ if comp == nil {
+ return &ChatResponse{}
}
- httpReq.Header.Set("Content-Type", "application/json")
- httpReq.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
- httpReq.Header.Set("User-Agent", userAgent(""))
+ choices := make([]Choice, 0, len(comp.Choices))
+ for _, ch := range comp.Choices {
+ var content *string
+ if ch.Message.Content != "" {
+ s := ch.Message.Content
+ content = &s
+ }
- resp, err := c.client.Do(httpReq)
- if err != nil {
- return nil, fmt.Errorf("request failed: %w", err)
+ toolCalls := make([]ToolCall, 0, len(ch.Message.ToolCalls))
+ for _, tc := range ch.Message.ToolCalls {
+ toolCalls = append(toolCalls, ToolCall{
+ ID: tc.ID,
+ Type: "function",
+ Function: FunctionCall{
+ Name: tc.Function.Name,
+ Arguments: tc.Function.Arguments,
+ },
+ })
+ }
+
+ choices = append(choices, Choice{
+ Message: ResponseMessage{
+ Role: string(ch.Message.Role),
+ Content: content,
+ ToolCalls: toolCalls,
+ },
+ FinishReason: string(ch.FinishReason),
+ })
}
- defer resp.Body.Close()
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("read response body: %w", err)
+ var usage *UsageInfo
+ if comp.Usage.TotalTokens > 0 || comp.Usage.PromptTokens > 0 || comp.Usage.CompletionTokens > 0 {
+ cachedTokens := comp.Usage.PromptTokensDetails.CachedTokens
+ usage = &UsageInfo{
+ TotalTokens: comp.Usage.TotalTokens,
+ PromptTokens: comp.Usage.PromptTokens - cachedTokens,
+ CompletionTokens: comp.Usage.CompletionTokens,
+ CacheReadTokens: cachedTokens,
+ }
+ if usage.TotalTokens == 0 {
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage.CacheReadTokens
+ }
}
- if resp.StatusCode >= 400 {
- detail := extractErrorMessage(bodyBytes)
- return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, detail)
+ var headers http.Header
+ if httpResp != nil {
+ headers = httpResp.Header
}
- var apiResp struct {
- ID string `json:"id"`
- Model string `json:"model"`
- Choices []Choice `json:"choices"`
+ return &ChatResponse{
+ ID: comp.ID,
+ Model: comp.Model,
+ Choices: choices,
+ Headers: headers,
+ Usage: usage,
}
- if err := json.Unmarshal(bodyBytes, &apiResp); err != nil {
- return nil, fmt.Errorf("decode response: %w", err)
+}
+
+// convertMessagesToOpenAI maps internal Message slice to OpenAI SDK message params.
+func convertMessagesToOpenAI(messages []Message) []openai.ChatCompletionMessageParamUnion {
+ result := make([]openai.ChatCompletionMessageParamUnion, 0, len(messages))
+ for _, msg := range messages {
+ switch msg.Role {
+ case "system":
+ result = append(result, openai.SystemMessage(msg.ExtractText()))
+ case "user":
+ result = append(result, openai.UserMessage(msg.ExtractText()))
+ case "assistant":
+ if len(msg.ToolCalls) > 0 {
+ toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, 0, len(msg.ToolCalls))
+ for _, tc := range msg.ToolCalls {
+ toolCalls = append(toolCalls, openai.ChatCompletionMessageToolCallUnionParam{
+ OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
+ ID: tc.ID,
+ Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
+ Name: tc.Function.Name,
+ Arguments: tc.Function.Arguments,
+ },
+ },
+ })
+ }
+ text := msg.ExtractText()
+ param := openai.ChatCompletionAssistantMessageParam{
+ ToolCalls: toolCalls,
+ }
+ if text != "" {
+ param.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
+ OfString: openai.String(text),
+ }
+ }
+ result = append(result, openai.ChatCompletionMessageParamUnion{
+ OfAssistant: ¶m,
+ })
+ } else {
+ result = append(result, openai.AssistantMessage(msg.ExtractText()))
+ }
+ case "tool":
+ result = append(result, openai.ToolMessage(msg.ExtractText(), msg.ToolCallID))
+ }
}
+ return result
+}
- return &ChatResponse{
- ID: apiResp.ID,
- Model: apiResp.Model,
- Choices: apiResp.Choices,
- Headers: resp.Header,
- Usage: resolveUsage(bodyBytes),
- }, nil
+// convertToolsToOpenAI maps internal ToolDef slice to OpenAI SDK tool params.
+func convertToolsToOpenAI(tools []ToolDef) []openai.ChatCompletionToolUnionParam {
+ result := make([]openai.ChatCompletionToolUnionParam, 0, len(tools))
+ for _, t := range tools {
+ result = append(result, openai.ChatCompletionFunctionTool(openai.FunctionDefinitionParam{
+ Name: t.Function.Name,
+ Description: openai.String(t.Function.Description),
+ Parameters: openai.FunctionParameters(t.Function.Parameters),
+ }))
+ }
+ return result
}
-// --- AnthropicClient ---
+// normalizeOpenAIBaseURL strips the /chat/completions suffix since the SDK appends it.
+func normalizeOpenAIBaseURL(rawURL string) string {
+ u := strings.TrimRight(rawURL, "/")
+ u = strings.TrimSuffix(u, "/chat/completions")
+ return u
+}
-const anthropicVersion = "2023-06-01"
+// --- AnthropicClient ---
-// AnthropicClient implements the Anthropic Messages API.
+// AnthropicClient implements the Anthropic Messages API via the official SDK.
type AnthropicClient struct {
cfg ClientConfig
- client *http.Client
+ client anthropic.Client
}
-// NewAnthropicClient creates a new Anthropic Messages API client.
+// NewAnthropicClient creates a new Anthropic Messages API client using the official SDK.
func NewAnthropicClient(cfg ClientConfig) *AnthropicClient {
if cfg.Timeout <= 0 {
cfg.Timeout = 5 * time.Minute
}
- if !strings.HasSuffix(cfg.URL, "/v1/messages") && !strings.HasSuffix(cfg.URL, "/v1/messages/") {
- baseURL := strings.TrimRight(cfg.URL, "/")
- if !strings.HasSuffix(baseURL, "/v1/messages") {
- cfg.URL = baseURL + "/v1/messages"
- }
- }
- return &AnthropicClient{
- cfg: cfg,
- client: &http.Client{
- Timeout: cfg.Timeout,
- },
+
+ // Normalize URL: strip /v1/messages suffix since SDK appends it automatically.
+ baseURL := normalizeAnthropicBaseURL(cfg.URL)
+
+ opts := []anthropicoption.RequestOption{
+ anthropicoption.WithAPIKey(cfg.APIKey),
+ anthropicoption.WithBaseURL(baseURL),
+ anthropicoption.WithHeader("User-Agent", userAgent("claude")),
+ anthropicoption.WithMaxRetries(maxRetries),
+ anthropicoption.WithRequestTimeout(cfg.Timeout),
}
+
+ client := anthropic.NewClient(opts...)
+ return &AnthropicClient{cfg: cfg, client: client}
}
// Completions sends a chat completion request and returns the parsed response.
@@ -513,343 +602,108 @@ func (c *AnthropicClient) CompletionsWithCtx(ctx context.Context, req ChatReques
model = c.cfg.Model
}
- var result *ChatResponse
- err := c.withRetryCtx(ctx, func() error {
- resp, err := c.doRequestCtx(ctx, model, req)
- if err != nil {
- return err
- }
- result = resp
- return nil
- })
- return result, err
-}
+ params := c.buildParams(model, req)
+ reqOpts := c.buildExtraBodyOptions()
-// StreamCompletion initiates a streaming chat completion using SSE. The callback
-// is invoked per chunk with raw JSON data stripped of the "data: " prefix.
-func (c *AnthropicClient) StreamCompletion(req ChatRequest, cb func(chunk []byte) error) error {
- req.Stream = true
+ // Capture raw HTTP response for headers
+ var httpResp *http.Response
+ reqOpts = append(reqOpts, anthropicoption.WithResponseInto(&httpResp))
- model := req.Model
- if model == "" {
- model = c.cfg.Model
+ message, err := c.client.Messages.New(ctx, params, reqOpts...)
+ if err != nil {
+ return nil, fmt.Errorf("API error: %w", err)
}
- return c.withRetry(func() error {
- body := c.buildRequestBody(model, req)
- body.Stream = true
-
- payload, err := mergeExtraBody(body, c.cfg.ExtraBody)
- if err != nil {
- return fmt.Errorf("marshal request body: %w", err)
- }
-
- httpReq, err := http.NewRequest(http.MethodPost, c.cfg.URL, bytes.NewReader(payload))
- if err != nil {
- return fmt.Errorf("create request: %w", err)
- }
- httpReq.Header.Set("Content-Type", "application/json")
- httpReq.Header.Set("x-api-key", c.cfg.APIKey)
- httpReq.Header.Set("anthropic-version", anthropicVersion)
- httpReq.Header.Set("User-Agent", userAgent("claude"))
-
- resp, err := c.client.Do(httpReq)
- if err != nil {
- return fmt.Errorf("request failed: %w", err)
- }
- defer resp.Body.Close()
-
- if isRetryableStatus(resp.StatusCode) {
- bodyBytes, _ := io.ReadAll(resp.Body)
- return fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes))
- }
- if resp.StatusCode >= 400 {
- bodyBytes, _ := io.ReadAll(resp.Body)
- return fmt.Errorf("API error %d: %s (non-retryable)", resp.StatusCode, string(bodyBytes))
- }
-
- scanner := bufio.NewScanner(resp.Body)
- var eventType string
-
- for scanner.Scan() {
- line := scanner.Text()
-
- // Capture event type line: "event: message_delta"
- if strings.HasPrefix(line, "event: ") {
- eventType = strings.TrimPrefix(line, "event: ")
- continue
- }
-
- // Skip empty lines and non-data lines
- if !strings.HasPrefix(line, "data: ") {
- continue
- }
-
- data := strings.TrimPrefix(line, "data: ")
- if data == "" {
- continue
- }
-
- // message_stop signals end of stream
- if eventType == "message_stop" {
- break
- }
-
- if err := cb([]byte(data)); err != nil {
- return err
- }
- }
- return scanner.Err()
- })
-}
-
-// anthropicRequest is the request body for Anthropic Messages API.
-type anthropicRequest struct {
- Model string `json:"model"`
- MaxTokens int `json:"max_tokens"`
- System string `json:"system,omitempty"`
- Messages []anthroMessage `json:"messages"`
- Tools []anthroTool `json:"tools,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
-}
-
-type anthroMessage struct {
- Role string `json:"role"`
- Content any `json:"content"` // string or []interface{}
-}
-
-// anthropicToolUseBlock represents a tool_use content block in Anthropic's Messages API.
-type anthropicToolUseBlock struct {
- Type string `json:"type"` // "tool_use"
- ID string `json:"id"` // tool use ID
- Name string `json:"name"` // function name
- Input map[string]any `json:"input"` // function arguments (parsed as object)
-}
-
-type anthroTool struct {
- Name string `json:"name"`
- Description string `json:"description"`
- InputSchema map[string]any `json:"input_schema"`
+ return c.convertResponse(message, httpResp), nil
}
-// doRequestCtx builds and sends an Anthropic Messages API request.
-func (c *AnthropicClient) doRequestCtx(ctx context.Context, model string, req ChatRequest) (*ChatResponse, error) {
+// StreamCompletion initiates a streaming chat completion using SSE.
+func (c *AnthropicClient) StreamCompletion(ctx context.Context, req ChatRequest, cb func(chunk []byte) error) error {
+ model := req.Model
if model == "" {
model = c.cfg.Model
}
- body := c.buildRequestBody(model, req)
- payload, err := mergeExtraBody(body, c.cfg.ExtraBody)
- if err != nil {
- return nil, fmt.Errorf("marshal request body: %w", err)
- }
+ params := c.buildParams(model, req)
+ reqOpts := c.buildExtraBodyOptions()
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.URL, bytes.NewReader(payload))
- if err != nil {
- return nil, fmt.Errorf("create request: %w", err)
- }
- httpReq.Header.Set("Content-Type", "application/json")
- httpReq.Header.Set("x-api-key", c.cfg.APIKey)
- httpReq.Header.Set("anthropic-version", anthropicVersion)
- httpReq.Header.Set("User-Agent", userAgent("claude"))
+ stream := c.client.Messages.NewStreaming(ctx, params, reqOpts...)
+ defer stream.Close()
- resp, err := c.client.Do(httpReq)
- if err != nil {
- return nil, fmt.Errorf("request failed: %w", err)
- }
- defer resp.Body.Close()
-
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("read response body: %w", err)
- }
-
- if resp.StatusCode >= 400 {
- detail := extractErrorMessage(bodyBytes)
- return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, detail)
- }
-
- chatResp, err := c.parseResponse(bodyBytes, resp.Header)
- if err != nil {
- return nil, fmt.Errorf("decode response: %w", err)
+ for stream.Next() {
+ event := stream.Current()
+ // Use RawJSON to preserve the original API response format for downstream consumers.
+ raw := event.RawJSON()
+ if raw == "" {
+ continue
+ }
+ if err := cb([]byte(raw)); err != nil {
+ return err
+ }
}
- return chatResp, nil
+ return stream.Err()
}
-// buildRequestBody converts the shared ChatRequest into Anthropic format.
-func (c *AnthropicClient) buildRequestBody(model string, req ChatRequest) anthropicRequest {
- messages := make([]anthroMessage, 0, len(req.Messages))
- var systemMsg string
+// buildParams converts internal ChatRequest to Anthropic SDK params.
+func (c *AnthropicClient) buildParams(model string, req ChatRequest) anthropic.MessageNewParams {
+ system, messages := convertMessagesToAnthropic(req.Messages)
- var pendingToolResults []Message // collect consecutive tool messages
+ maxTokens := int64(req.MaxTokens)
+ if maxTokens <= 0 {
+ maxTokens = 8192
+ }
- flushToolResults := func() {
- if len(pendingToolResults) == 0 {
- return
- }
- // Merge all pending tool results into a single user message
- var blocks []interface{}
- for _, tr := range pendingToolResults {
- blocks = append(blocks, ContentBlock{
- Type: "tool_result",
- ToolUseID: tr.ToolCallID,
- Content: []ContentBlock{{
- Type: "text",
- Text: fmt.Sprintf("%v", tr.Content),
- }},
- })
- }
- messages = append(messages, anthroMessage{Role: "user", Content: blocks})
- pendingToolResults = nil
+ params := anthropic.MessageNewParams{
+ Model: model,
+ MaxTokens: maxTokens,
+ Messages: messages,
}
- for _, msg := range req.Messages {
- switch msg.Role {
- case "system":
- if s, ok := msg.Content.(string); ok {
- systemMsg = s
- }
- flushToolResults()
- case "tool":
- pendingToolResults = append(pendingToolResults, msg)
- case "assistant":
- flushToolResults()
- // Build Anthropic content blocks from text + tool calls
- var blocks []interface{}
- if s, ok := msg.Content.(string); ok && s != "" {
- blocks = append(blocks, ContentBlock{Type: "text", Text: s})
- }
- for _, tc := range msg.ToolCalls {
- argsMap := map[string]any{}
- if tc.Function.Arguments != "" {
- if err := json.Unmarshal([]byte(tc.Function.Arguments), &argsMap); err != nil {
- fmt.Fprintf(stdout.Writer(), "[llm] WARNING: failed to parse tool call arguments JSON for %q: %v\n", tc.ID, err)
- }
- }
- blocks = append(blocks, anthropicToolUseBlock{
- Type: "tool_use",
- ID: tc.ID,
- Name: tc.Function.Name,
- Input: argsMap,
- })
- }
- if len(blocks) > 0 {
- messages = append(messages, anthroMessage{Role: "assistant", Content: blocks})
- } else {
- s, _ := msg.Content.(string)
- messages = append(messages, anthroMessage{Role: "assistant", Content: s})
- }
- default:
- // user or other roles: flush tool results first
- flushToolResults()
- content := msg.Content
- if blkArr, ok := content.([]ContentBlock); ok {
- converted := make([]ContentBlock, len(blkArr))
- for i, b := range blkArr {
- converted[i] = ContentBlock{
- Type: b.Type,
- Text: b.Text,
- ToolUseID: b.ToolUseID,
- Content: b.Content,
- }
- }
- content = converted
- }
- messages = append(messages, anthroMessage{Role: msg.Role, Content: content})
+ if system != "" {
+ params.System = []anthropic.TextBlockParam{
+ {Text: system},
}
}
- flushToolResults() // flush any remaining tool results at the end
- tools := make([]anthroTool, 0, len(req.Tools))
- for _, t := range req.Tools {
- tools = append(tools, anthroTool{
- Name: t.Function.Name,
- Description: t.Function.Description,
- InputSchema: t.Function.Parameters,
- })
+ if len(req.Tools) > 0 {
+ params.Tools = convertToolsToAnthropic(req.Tools)
}
- maxTokens := req.MaxTokens
- if maxTokens <= 0 {
- maxTokens = 8192 // Anthropic default
+ if req.Temperature != nil {
+ params.Temperature = anthropic.Float(*req.Temperature)
}
- return anthropicRequest{
- Model: model,
- MaxTokens: maxTokens,
- System: systemMsg,
- Messages: messages,
- Tools: tools,
- Stream: false,
- Temperature: req.Temperature,
- }
+ return params
}
-func mergeExtraBody(base any, extraBody map[string]any) ([]byte, error) {
- if len(extraBody) == 0 {
- return json.Marshal(base)
- }
- b, err := json.Marshal(base)
- if err != nil {
- return nil, err
- }
- var m map[string]any
- if err := json.Unmarshal(b, &m); err != nil {
- return nil, err
+// buildExtraBodyOptions creates request options for ExtraBody fields.
+func (c *AnthropicClient) buildExtraBodyOptions() []anthropicoption.RequestOption {
+ if len(c.cfg.ExtraBody) == 0 {
+ return nil
}
- for k, v := range extraBody {
- m[k] = v
+ opts := make([]anthropicoption.RequestOption, 0, len(c.cfg.ExtraBody))
+ for k, v := range c.cfg.ExtraBody {
+ opts = append(opts, anthropicoption.WithJSONSet(k, v))
}
- return json.Marshal(m)
+ return opts
}
-// parseResponse converts Anthropic JSON response into ChatResponse.
-func (c *AnthropicClient) parseResponse(body []byte, headers http.Header) (*ChatResponse, error) {
- type contentBlockResp struct {
- Type string `json:"type"`
- Text string `json:"text,omitempty"`
- ID string `json:"id,omitempty"`
- Name string `json:"name,omitempty"`
- Input any `json:"input,omitempty"`
- }
-
- type anthropicUsageRaw struct {
- InputTokens int64 `json:"input_tokens"`
- OutputTokens int64 `json:"output_tokens"`
- CacheReadInputTokens int64 `json:"cache_read_input_tokens"`
- CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"`
- }
-
- var resp struct {
- ID string `json:"id"`
- Model string `json:"model"`
- Type string `json:"type"`
- Role string `json:"role"`
- Content []contentBlockResp `json:"content"`
- Usage anthropicUsageRaw `json:"usage"`
- StopReason string `json:"stop_reason,omitempty"`
- }
-
- if err := json.Unmarshal(body, &resp); err != nil {
- return nil, err
- }
-
- // Build the response message from content blocks.
+// convertResponse maps Anthropic SDK Message to our internal ChatResponse.
+func (c *AnthropicClient) convertResponse(msg *anthropic.Message, httpResp *http.Response) *ChatResponse {
var textParts []string
var toolCalls []ToolCall
- for _, block := range resp.Content {
- switch block.Type {
- case "text":
- textParts = append(textParts, block.Text)
- case "tool_use":
- argsJSON, _ := json.Marshal(block.Input)
+ for _, block := range msg.Content {
+ switch variant := block.AsAny().(type) {
+ case anthropic.TextBlock:
+ textParts = append(textParts, variant.Text)
+ case anthropic.ToolUseBlock:
+ argsJSON, _ := json.Marshal(variant.Input)
toolCalls = append(toolCalls, ToolCall{
- ID: block.ID,
+ ID: variant.ID,
Type: "function",
Function: FunctionCall{
- Name: block.Name,
+ Name: variant.Name,
Arguments: string(argsJSON),
},
})
@@ -862,28 +716,33 @@ func (c *AnthropicClient) parseResponse(body []byte, headers http.Header) (*Chat
contentStr = &s
}
- finishReason := resp.StopReason
+ finishReason := string(msg.StopReason)
if finishReason == "" {
finishReason = "stop"
}
var usage *UsageInfo
- if u := resp.Usage; u.InputTokens > 0 || u.OutputTokens > 0 {
+ if msg.Usage.InputTokens > 0 || msg.Usage.OutputTokens > 0 || msg.Usage.CacheReadInputTokens > 0 || msg.Usage.CacheCreationInputTokens > 0 {
usage = &UsageInfo{
- PromptTokens: u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens,
- CompletionTokens: u.OutputTokens,
- CacheReadTokens: u.CacheReadInputTokens,
- CacheWriteTokens: u.CacheCreationInputTokens,
+ PromptTokens: msg.Usage.InputTokens,
+ CompletionTokens: msg.Usage.OutputTokens,
+ CacheReadTokens: msg.Usage.CacheReadInputTokens,
+ CacheWriteTokens: msg.Usage.CacheCreationInputTokens,
}
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage.CacheReadTokens + usage.CacheWriteTokens
+ }
+
+ var headers http.Header
+ if httpResp != nil {
+ headers = httpResp.Header
}
return &ChatResponse{
- ID: resp.ID,
- Model: resp.Model,
+ ID: msg.ID,
+ Model: string(msg.Model),
Choices: []Choice{{
Message: ResponseMessage{
- Role: resp.Role,
+ Role: string(msg.Role),
Content: contentStr,
ToolCalls: toolCalls,
},
@@ -891,104 +750,119 @@ func (c *AnthropicClient) parseResponse(body []byte, headers http.Header) (*Chat
}},
Headers: headers,
Usage: usage,
- }, nil
+ }
}
-// --- Retry logic ---
-
-func retryWithCtx(ctx context.Context, fn func() error) error {
- var lastErr error
- for attempt := 0; attempt <= maxRetries; attempt++ {
- select {
- case <-ctx.Done():
- return fmt.Errorf("context cancelled: %w", ctx.Err())
- default:
- }
+// convertMessagesToAnthropic separates system message and converts remaining messages.
+func convertMessagesToAnthropic(messages []Message) (string, []anthropic.MessageParam) {
+ var systemMsg string
+ var result []anthropic.MessageParam
+ var pendingToolResults []Message
- lastErr = fn()
- if lastErr == nil {
- return nil
+ flushToolResults := func() {
+ if len(pendingToolResults) == 0 {
+ return
}
-
- if !isRetryable(lastErr) {
- return lastErr
+ var blocks []anthropic.ContentBlockParamUnion
+ for _, tr := range pendingToolResults {
+ blocks = append(blocks, anthropic.NewToolResultBlock(
+ tr.ToolCallID,
+ tr.ExtractText(),
+ false,
+ ))
}
+ result = append(result, anthropic.NewUserMessage(blocks...))
+ pendingToolResults = nil
+ }
- if attempt < maxRetries {
- sleepWithBackoff(attempt)
+ for _, msg := range messages {
+ switch msg.Role {
+ case "system":
+ if s, ok := msg.Content.(string); ok {
+ systemMsg = s
+ }
+ flushToolResults()
+ case "tool":
+ pendingToolResults = append(pendingToolResults, msg)
+ case "assistant":
+ flushToolResults()
+ var blocks []anthropic.ContentBlockParamUnion
+ text := msg.ExtractText()
+ if text != "" {
+ blocks = append(blocks, anthropic.NewTextBlock(text))
+ }
+ for _, tc := range msg.ToolCalls {
+ argsMap := map[string]any{}
+ if tc.Function.Arguments != "" {
+ if err := json.Unmarshal([]byte(tc.Function.Arguments), &argsMap); err != nil {
+ fmt.Fprintf(stdout.Writer(), "[llm] WARNING: failed to parse tool call arguments JSON for %q: %v\n", tc.ID, err)
+ }
+ }
+ inputJSON, _ := json.Marshal(argsMap)
+ blocks = append(blocks, anthropic.ContentBlockParamUnion{
+ OfToolUse: &anthropic.ToolUseBlockParam{
+ ID: tc.ID,
+ Name: tc.Function.Name,
+ Input: json.RawMessage(inputJSON),
+ },
+ })
+ }
+ if len(blocks) > 0 {
+ result = append(result, anthropic.NewAssistantMessage(blocks...))
+ }
+ default:
+ // user or other roles
+ flushToolResults()
+ result = append(result, anthropic.NewUserMessage(anthropic.NewTextBlock(msg.ExtractText())))
}
}
- return fmt.Errorf("request failed after %d retries: %w", maxRetries, lastErr)
-}
-
-func (c *OpenAIClient) withRetry(fn func() error) error {
- return retryWithCtx(context.Background(), fn)
-}
+ flushToolResults()
-func (c *OpenAIClient) withRetryCtx(ctx context.Context, fn func() error) error {
- return retryWithCtx(ctx, fn)
+ return systemMsg, result
}
-func (c *AnthropicClient) withRetry(fn func() error) error {
- return retryWithCtx(context.Background(), fn)
-}
-
-func (c *AnthropicClient) withRetryCtx(ctx context.Context, fn func() error) error {
- return retryWithCtx(ctx, fn)
-}
-
-// isRetryable determines whether an error is transient and worth retrying.
-func isRetryable(err error) bool {
- msg := err.Error()
- // 429 (rate limit) and 5xx server errors are retryable.
- if strings.Contains(msg, "API error 429:") {
- return true
- }
- for code := 500; code <= 599; code++ {
- if strings.Contains(msg, fmt.Sprintf("API error %d:", code)) {
- return true
+// convertToolsToAnthropic maps internal ToolDef slice to Anthropic SDK tool params.
+func convertToolsToAnthropic(tools []ToolDef) []anthropic.ToolUnionParam {
+ result := make([]anthropic.ToolUnionParam, 0, len(tools))
+ for _, t := range tools {
+ schema := anthropic.ToolInputSchemaParam{
+ Type: "object",
+ Properties: t.Function.Parameters["properties"],
}
+ // Preserve required field constraints from the original schema.
+ if req, ok := t.Function.Parameters["required"]; ok {
+ if reqSlice, ok := req.([]any); ok {
+ required := make([]string, 0, len(reqSlice))
+ for _, r := range reqSlice {
+ if s, ok := r.(string); ok {
+ required = append(required, s)
+ }
+ }
+ schema.Required = required
+ }
+ }
+ result = append(result, anthropic.ToolUnionParam{
+ OfTool: &anthropic.ToolParam{
+ Name: t.Function.Name,
+ Description: anthropic.String(t.Function.Description),
+ InputSchema: schema,
+ },
+ })
}
- // Network-level errors (timeout, connection refused, DNS failure, etc.) are retryable.
- if strings.Contains(msg, "request failed:") ||
- strings.Contains(msg, "connection refused") ||
- strings.Contains(msg, "no such host") ||
- strings.Contains(msg, "i/o timeout") ||
- strings.Contains(msg, "EOF") {
- return true
- }
- return false
+ return result
}
-// isRetryableStatus returns true for HTTP status codes that should trigger a retry.
-func isRetryableStatus(status int) bool {
- return status == 429 || (status >= 500 && status <= 599)
+// normalizeAnthropicBaseURL strips the /v1/messages suffix since the SDK appends it.
+func normalizeAnthropicBaseURL(rawURL string) string {
+ u := strings.TrimRight(rawURL, "/")
+ u = strings.TrimSuffix(u, "/v1/messages")
+ return u
}
-// sleepWithBackoff sleeps for baseDelay * 2^attempt + jitter, capped at 60s.
-// Jitter spreads retries randomly within ±50% of the computed delay.
-func sleepWithBackoff(attempt int) {
- const (
- baseDelay = 1 * time.Second
- maxDelay = 60 * time.Second
- )
-
- delay := baseDelay << uint(min(attempt, 6)) // 1s, 2s, 4s, 8s, 16s, 32s, 64s→capped
- if delay > maxDelay {
- delay = maxDelay
- }
-
- // Add random jitter: [delay*0.5, delay*1.5]
- jitter := time.Duration(rand.Int63n(int64(delay))) - delay/2
- delay += jitter
-
- fmt.Fprintf(stdout.Writer(), "[llm] Retrying in %v (attempt info)... \n", delay)
- time.Sleep(delay)
-}
+// --- Utility functions ---
// stripThinkTags removes reasoning wrapper tags from content.
func stripThinkTags(s string) string {
- // Construct tag strings from individual bytes.
openBytes := []byte{0x3c, 't', 'h', 'i', 'n', 'k', 0x3e}
closeBytes := []byte{0x3c, 0x2f, 't', 'h', 'i', 'n', 'k', 0x3e}
s = strings.ReplaceAll(s, string(openBytes), "")
@@ -1025,7 +899,6 @@ func extractErrorMessage(body []byte) string {
return ae.Error.Message
}
- // Truncate raw body to avoid excessively noisy errors.
bodyText := string(body)
if len(bodyText) > 512 {
bodyText = bodyText[:512] + "... (truncated)"
diff --git a/internal/llm/client_test.go b/internal/llm/client_test.go
index e3adc81..7ea9363 100644
--- a/internal/llm/client_test.go
+++ b/internal/llm/client_test.go
@@ -1,10 +1,14 @@
package llm
import (
+ "encoding/json"
+ "strings"
"testing"
)
-func TestNewOpenAIClient_URLNormalization(t *testing.T) {
+// TestNormalizeOpenAIBaseURL verifies URL normalization for the OpenAI SDK.
+// The SDK appends /chat/completions automatically, so we strip that suffix.
+func TestNormalizeOpenAIBaseURL(t *testing.T) {
tests := []struct {
name string
inputURL string
@@ -13,41 +17,43 @@ func TestNewOpenAIClient_URLNormalization(t *testing.T) {
{
name: "base URL without trailing slash",
inputURL: "https://api.example.com/v1",
- wantURL: "https://api.example.com/v1/chat/completions",
+ wantURL: "https://api.example.com/v1",
},
{
name: "base URL with trailing slash",
inputURL: "https://api.example.com/v1/",
- wantURL: "https://api.example.com/v1/chat/completions",
+ wantURL: "https://api.example.com/v1",
},
{
- name: "full URL already has chat/completions",
+ name: "full URL with chat/completions suffix stripped",
inputURL: "https://api.example.com/v1/chat/completions",
- wantURL: "https://api.example.com/v1/chat/completions",
+ wantURL: "https://api.example.com/v1",
},
{
- name: "full URL with trailing slash",
+ name: "full URL with trailing slash stripped",
inputURL: "https://api.example.com/v1/chat/completions/",
- wantURL: "https://api.example.com/v1/chat/completions/",
+ wantURL: "https://api.example.com/v1",
},
{
- name: "bare host",
+ name: "bare host unchanged",
inputURL: "https://api.example.com",
- wantURL: "https://api.example.com/chat/completions",
+ wantURL: "https://api.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- client := NewOpenAIClient(ClientConfig{URL: tt.inputURL})
- if client.cfg.URL != tt.wantURL {
- t.Errorf("got URL %q, want %q", client.cfg.URL, tt.wantURL)
+ got := normalizeOpenAIBaseURL(tt.inputURL)
+ if got != tt.wantURL {
+ t.Errorf("normalizeOpenAIBaseURL(%q) = %q, want %q", tt.inputURL, got, tt.wantURL)
}
})
}
}
-func TestNewAnthropicClient_URLNormalization(t *testing.T) {
+// TestNormalizeAnthropicBaseURL verifies URL normalization for the Anthropic SDK.
+// The SDK appends /v1/messages automatically, so we strip that suffix.
+func TestNormalizeAnthropicBaseURL(t *testing.T) {
tests := []struct {
name string
inputURL string
@@ -56,36 +62,508 @@ func TestNewAnthropicClient_URLNormalization(t *testing.T) {
{
name: "bare host",
inputURL: "https://api.anthropic.com",
- wantURL: "https://api.anthropic.com/v1/messages",
+ wantURL: "https://api.anthropic.com",
},
{
name: "bare host with trailing slash",
inputURL: "https://api.anthropic.com/",
- wantURL: "https://api.anthropic.com/v1/messages",
+ wantURL: "https://api.anthropic.com",
},
{
- name: "full URL already has /v1/messages",
+ name: "full URL with /v1/messages suffix stripped",
inputURL: "https://api.anthropic.com/v1/messages",
- wantURL: "https://api.anthropic.com/v1/messages",
+ wantURL: "https://api.anthropic.com",
},
{
- name: "full URL with trailing slash",
+ name: "full URL with trailing slash stripped",
inputURL: "https://api.anthropic.com/v1/messages/",
- wantURL: "https://api.anthropic.com/v1/messages/",
+ wantURL: "https://api.anthropic.com",
},
{
- name: "custom proxy base URL",
+ name: "custom proxy base URL unchanged",
inputURL: "https://proxy.example.com/anthropic",
- wantURL: "https://proxy.example.com/anthropic/v1/messages",
+ wantURL: "https://proxy.example.com/anthropic",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := normalizeAnthropicBaseURL(tt.inputURL)
+ if got != tt.wantURL {
+ t.Errorf("normalizeAnthropicBaseURL(%q) = %q, want %q", tt.inputURL, got, tt.wantURL)
+ }
+ })
+ }
+}
+
+// TestNewOpenAIClient_CreatesSuccessfully verifies that the client is created without panics.
+func TestNewOpenAIClient_CreatesSuccessfully(t *testing.T) {
+ client := NewOpenAIClient(ClientConfig{
+ URL: "https://api.example.com/v1",
+ APIKey: "test-key",
+ Model: "gpt-4",
+ })
+ if client == nil {
+ t.Fatal("expected non-nil client")
+ }
+}
+
+// TestNewAnthropicClient_CreatesSuccessfully verifies that the client is created without panics.
+func TestNewAnthropicClient_CreatesSuccessfully(t *testing.T) {
+ client := NewAnthropicClient(ClientConfig{
+ URL: "https://api.anthropic.com",
+ APIKey: "test-key",
+ Model: "claude-sonnet-4-20250514",
+ })
+ if client == nil {
+ t.Fatal("expected non-nil client")
+ }
+}
+
+// TestBuildParams_MaxTokensBranching verifies that buildParams uses MaxTokens or
+// MaxCompletionTokens based on the UseMaxCompletionTokens config flag.
+func TestBuildParams_MaxTokensBranching(t *testing.T) {
+ tests := []struct {
+ name string
+ useMaxCompletionTokens bool
+ maxTokens int
+ wantMaxTokensSet bool
+ wantMaxCompletionSet bool
+ }{
+ {
+ name: "default uses max_tokens",
+ useMaxCompletionTokens: false,
+ maxTokens: 4096,
+ wantMaxTokensSet: true,
+ wantMaxCompletionSet: false,
+ },
+ {
+ name: "enabled uses max_completion_tokens",
+ useMaxCompletionTokens: true,
+ maxTokens: 4096,
+ wantMaxTokensSet: false,
+ wantMaxCompletionSet: true,
+ },
+ {
+ name: "zero max_tokens sets neither",
+ useMaxCompletionTokens: true,
+ maxTokens: 0,
+ wantMaxTokensSet: false,
+ wantMaxCompletionSet: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ client := NewOpenAIClient(ClientConfig{
+ URL: "https://api.example.com/v1",
+ APIKey: "test-key",
+ Model: "gpt-4",
+ UseMaxCompletionTokens: tt.useMaxCompletionTokens,
+ })
+
+ params := client.buildParams("gpt-4", ChatRequest{
+ MaxTokens: tt.maxTokens,
+ })
+
+ maxTokensSet := params.MaxTokens.Valid()
+ maxCompletionSet := params.MaxCompletionTokens.Valid()
+
+ if maxTokensSet != tt.wantMaxTokensSet {
+ t.Errorf("MaxTokens present = %v, want %v", maxTokensSet, tt.wantMaxTokensSet)
+ }
+ if maxCompletionSet != tt.wantMaxCompletionSet {
+ t.Errorf("MaxCompletionTokens present = %v, want %v", maxCompletionSet, tt.wantMaxCompletionSet)
+ }
+ })
+ }
+}
+
+// --- Tests for convertMessagesToOpenAI ---
+
+func TestConvertMessagesToOpenAI_BasicRoles(t *testing.T) {
+ messages := []Message{
+ NewTextMessage("system", "You are a helpful assistant"),
+ NewTextMessage("user", "Hello"),
+ NewTextMessage("assistant", "Hi there"),
+ }
+
+ result := convertMessagesToOpenAI(messages)
+ if len(result) != 3 {
+ t.Fatalf("expected 3 messages, got %d", len(result))
+ }
+
+ // Verify roles by checking which union variant is set
+ if result[0].OfSystem == nil {
+ t.Error("expected system message at index 0")
+ }
+ if result[1].OfUser == nil {
+ t.Error("expected user message at index 1")
+ }
+ if result[2].OfAssistant == nil {
+ t.Error("expected assistant message at index 2")
+ }
+}
+
+func TestConvertMessagesToOpenAI_ToolMessage(t *testing.T) {
+ messages := []Message{
+ NewToolResultMessage("call-123", "tool result content"),
+ }
+
+ result := convertMessagesToOpenAI(messages)
+ if len(result) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(result))
+ }
+ if result[0].OfTool == nil {
+ t.Error("expected tool message")
+ }
+ if result[0].OfTool.ToolCallID != "call-123" {
+ t.Errorf("expected tool_call_id %q, got %q", "call-123", result[0].OfTool.ToolCallID)
+ }
+}
+
+func TestConvertMessagesToOpenAI_AssistantWithToolCalls(t *testing.T) {
+ messages := []Message{
+ NewToolCallMessage("thinking...", []ToolCall{
+ {
+ ID: "call-1",
+ Type: "function",
+ Function: FunctionCall{
+ Name: "get_weather",
+ Arguments: `{"city":"Tokyo"}`,
+ },
+ },
+ }),
+ }
+
+ result := convertMessagesToOpenAI(messages)
+ if len(result) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(result))
+ }
+ if result[0].OfAssistant == nil {
+ t.Fatal("expected assistant message")
+ }
+ if len(result[0].OfAssistant.ToolCalls) != 1 {
+ t.Fatalf("expected 1 tool call, got %d", len(result[0].OfAssistant.ToolCalls))
+ }
+ tc := result[0].OfAssistant.ToolCalls[0]
+ if tc.OfFunction == nil {
+ t.Fatal("expected function tool call")
+ }
+ if tc.OfFunction.ID != "call-1" {
+ t.Errorf("expected ID %q, got %q", "call-1", tc.OfFunction.ID)
+ }
+ if tc.OfFunction.Function.Name != "get_weather" {
+ t.Errorf("expected function name %q, got %q", "get_weather", tc.OfFunction.Function.Name)
+ }
+}
+
+func TestConvertMessagesToOpenAI_EmptyMessages(t *testing.T) {
+ result := convertMessagesToOpenAI(nil)
+ if len(result) != 0 {
+ t.Fatalf("expected 0 messages, got %d", len(result))
+ }
+}
+
+// --- Tests for convertMessagesToAnthropic ---
+
+func TestConvertMessagesToAnthropic_SystemExtraction(t *testing.T) {
+ messages := []Message{
+ NewTextMessage("system", "Be concise"),
+ NewTextMessage("user", "Hello"),
+ }
+
+ system, result := convertMessagesToAnthropic(messages)
+ if system != "Be concise" {
+ t.Errorf("expected system %q, got %q", "Be concise", system)
+ }
+ if len(result) != 1 {
+ t.Fatalf("expected 1 message (user only), got %d", len(result))
+ }
+}
+
+func TestConvertMessagesToAnthropic_NoSystem(t *testing.T) {
+ messages := []Message{
+ NewTextMessage("user", "Hello"),
+ NewTextMessage("assistant", "Hi"),
+ }
+
+ system, result := convertMessagesToAnthropic(messages)
+ if system != "" {
+ t.Errorf("expected empty system, got %q", system)
+ }
+ if len(result) != 2 {
+ t.Fatalf("expected 2 messages, got %d", len(result))
+ }
+}
+
+func TestConvertMessagesToAnthropic_ToolResultBatching(t *testing.T) {
+ // Multiple consecutive tool results should be batched into a single user message
+ messages := []Message{
+ NewTextMessage("user", "Do tasks"),
+ NewToolCallMessage("", []ToolCall{
+ {ID: "call-1", Type: "function", Function: FunctionCall{Name: "fn1", Arguments: "{}"}},
+ {ID: "call-2", Type: "function", Function: FunctionCall{Name: "fn2", Arguments: "{}"}},
+ }),
+ NewToolResultMessage("call-1", "result 1"),
+ NewToolResultMessage("call-2", "result 2"),
+ NewTextMessage("assistant", "Done"),
+ }
+
+ system, result := convertMessagesToAnthropic(messages)
+ if system != "" {
+ t.Errorf("unexpected system: %q", system)
+ }
+ // Expected: user, assistant(tool_calls), user(batched tool results), assistant
+ if len(result) != 4 {
+ t.Fatalf("expected 4 messages, got %d", len(result))
+ }
+}
+
+func TestConvertMessagesToAnthropic_AssistantWithToolCalls(t *testing.T) {
+ messages := []Message{
+ NewToolCallMessage("thinking", []ToolCall{
+ {ID: "call-1", Type: "function", Function: FunctionCall{Name: "search", Arguments: `{"q":"test"}`}},
+ }),
+ }
+
+ _, result := convertMessagesToAnthropic(messages)
+ if len(result) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(result))
+ }
+ // The assistant message should have both text block and tool_use block
+ msg := result[0]
+ if msg.Role != "assistant" {
+ t.Errorf("expected role assistant, got %q", msg.Role)
+ }
+}
+
+// --- Tests for utility functions ---
+
+func TestStripThinkTags(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {"no tags", "hello world", "hello world"},
+ {"with think tags", "reasoninganswer", "reasoninganswer"},
+ {"only open tag", "partial", "partial"},
+ {"only close tag", "partial", "partial"},
+ {"empty string", "", ""},
+ {"nested content", "step1resultstep2final", "step1resultstep2final"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := stripThinkTags(tt.input)
+ if got != tt.want {
+ t.Errorf("stripThinkTags(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestExtractErrorMessage(t *testing.T) {
+ tests := []struct {
+ name string
+ body []byte
+ want string
+ }{
+ {
+ name: "OpenAI error format",
+ body: []byte(`{"error":{"message":"Rate limit exceeded","type":"rate_limit_error"}}`),
+ want: "Rate limit exceeded",
+ },
+ {
+ name: "Anthropic error format",
+ body: []byte(`{"type":"error","error":{"message":"Invalid API key","type":"authentication_error"}}`),
+ want: "Invalid API key",
+ },
+ {
+ name: "empty body",
+ body: []byte{},
+ want: "(empty body)",
+ },
+ {
+ name: "nil body",
+ body: nil,
+ want: "(empty body)",
+ },
+ {
+ name: "unrecognized JSON",
+ body: []byte(`{"status":"error","detail":"something went wrong"}`),
+ want: `{"status":"error","detail":"something went wrong"}`,
+ },
+ {
+ name: "truncates long body",
+ body: []byte(strings.Repeat("x", 600)),
+ want: strings.Repeat("x", 512) + "... (truncated)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- client := NewAnthropicClient(ClientConfig{URL: tt.inputURL})
- if client.cfg.URL != tt.wantURL {
- t.Errorf("got URL %q, want %q", client.cfg.URL, tt.wantURL)
+ got := extractErrorMessage(tt.body)
+ if got != tt.want {
+ t.Errorf("extractErrorMessage() = %q, want %q", got, tt.want)
}
})
}
}
+
+func TestExtractText(t *testing.T) {
+ tests := []struct {
+ name string
+ message Message
+ want string
+ }{
+ {
+ name: "string content",
+ message: Message{Role: "user", Content: "hello"},
+ want: "hello",
+ },
+ {
+ name: "content block array",
+ message: Message{Role: "user", Content: []ContentBlock{{Type: "text", Text: "part1"}, {Type: "text", Text: "part2"}}},
+ want: "part1part2",
+ },
+ {
+ name: "nil content",
+ message: Message{Role: "user", Content: nil},
+ want: "",
+ },
+ {
+ name: "empty string",
+ message: Message{Role: "user", Content: ""},
+ want: "",
+ },
+ {
+ name: "nested content blocks",
+ message: Message{Role: "user", Content: []ContentBlock{
+ {Type: "tool_result", ToolUseID: "id1", Content: []ContentBlock{{Type: "text", Text: "nested"}}},
+ }},
+ want: "nested",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.message.ExtractText()
+ if got != tt.want {
+ t.Errorf("ExtractText() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestEncodingForModel(t *testing.T) {
+ tests := []struct {
+ model string
+ want string
+ }{
+ {"o1-preview", "o200k_base"},
+ {"o3-mini", "o200k_base"},
+ {"o4-mini", "o200k_base"},
+ {"gpt-4", "cl100k_base"},
+ {"gpt-4o", "cl100k_base"},
+ {"claude-opus-4-6", "cl100k_base"},
+ {"", "cl100k_base"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.model, func(t *testing.T) {
+ got := encodingForModel(tt.model)
+ if got != tt.want {
+ t.Errorf("encodingForModel(%q) = %q, want %q", tt.model, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestChatResponseContent(t *testing.T) {
+ strPtr := func(s string) *string { return &s }
+
+ tests := []struct {
+ name string
+ resp ChatResponse
+ want string
+ }{
+ {
+ name: "normal content",
+ resp: ChatResponse{Choices: []Choice{{Message: ResponseMessage{Content: strPtr("hello")}}}},
+ want: "hello",
+ },
+ {
+ name: "content with think tags",
+ resp: ChatResponse{Choices: []Choice{{Message: ResponseMessage{Content: strPtr("reasoninganswer")}}}},
+ want: "reasoninganswer",
+ },
+ {
+ name: "fallback to reasoning content",
+ resp: ChatResponse{Choices: []Choice{{Message: ResponseMessage{Content: strPtr(""), ReasoningContent: "reasoning"}}}},
+ want: "reasoning",
+ },
+ {
+ name: "nil content fallback to reasoning",
+ resp: ChatResponse{Choices: []Choice{{Message: ResponseMessage{Content: nil, ReasoningContent: "fallback"}}}},
+ want: "fallback",
+ },
+ {
+ name: "empty choices",
+ resp: ChatResponse{Choices: []Choice{}},
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.resp.Content()
+ if got != tt.want {
+ t.Errorf("Content() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestNewTextMessage(t *testing.T) {
+ msg := NewTextMessage("user", "hello")
+ if msg.Role != "user" {
+ t.Errorf("Role = %q, want %q", msg.Role, "user")
+ }
+ if msg.Content != "hello" {
+ t.Errorf("Content = %v, want %q", msg.Content, "hello")
+ }
+}
+
+func TestNewToolCallMessage(t *testing.T) {
+ calls := []ToolCall{{ID: "1", Type: "function", Function: FunctionCall{Name: "fn", Arguments: "{}"}}}
+ msg := NewToolCallMessage("text", calls)
+ if msg.Role != "assistant" {
+ t.Errorf("Role = %q, want %q", msg.Role, "assistant")
+ }
+ if len(msg.ToolCalls) != 1 {
+ t.Fatalf("expected 1 tool call, got %d", len(msg.ToolCalls))
+ }
+ // Verify it's a copy, not the same slice
+ calls[0].ID = "modified"
+ if msg.ToolCalls[0].ID == "modified" {
+ t.Error("ToolCalls should be a copy, not a reference")
+ }
+}
+
+func TestNewToolResultMessage(t *testing.T) {
+ msg := NewToolResultMessage("call-1", "result")
+ if msg.Role != "tool" {
+ t.Errorf("Role = %q, want %q", msg.Role, "tool")
+ }
+ if msg.ToolCallID != "call-1" {
+ t.Errorf("ToolCallID = %q, want %q", msg.ToolCallID, "call-1")
+ }
+ if msg.Content != "result" {
+ t.Errorf("Content = %v, want %q", msg.Content, "result")
+ }
+}
+
+// Ensure json import is used (for potential future tests using json assertions)
+var _ = json.Marshal
diff --git a/internal/llm/resolver.go b/internal/llm/resolver.go
index 442ff81..edbaf20 100644
--- a/internal/llm/resolver.go
+++ b/internal/llm/resolver.go
@@ -11,20 +11,22 @@ import (
// ResolvedEndpoint holds the resolved LLM endpoint configuration.
type ResolvedEndpoint struct {
- URL string
- Token string
- Model string
- Protocol string // "anthropic" or "openai"
- Source string // human-readable config source label
- ExtraBody map[string]any // vendor-specific request body fields
+ URL string
+ Token string
+ Model string
+ Protocol string // "anthropic" or "openai"
+ Source string // human-readable config source label
+ ExtraBody map[string]any // vendor-specific request body fields
+ UseMaxCompletionTokens bool // use max_completion_tokens instead of max_tokens
}
// Environment variable names for OCR-specific configuration.
const (
- envOCRLLMURL = "OCR_LLM_URL"
- envOCRLLMToken = "OCR_LLM_TOKEN"
- envOCRLLMModel = "OCR_LLM_MODEL"
- envOCRUseAnthropic = "OCR_USE_ANTHROPIC"
+ envOCRLLMURL = "OCR_LLM_URL"
+ envOCRLLMToken = "OCR_LLM_TOKEN"
+ envOCRLLMModel = "OCR_LLM_MODEL"
+ envOCRUseAnthropic = "OCR_USE_ANTHROPIC"
+ envOCRUseMaxCompletionTokens = "OCR_USE_MAX_COMPLETION_TOKENS"
)
// Environment variable names from Claude Code configuration.
@@ -83,16 +85,23 @@ func tryOCREnv() (ResolvedEndpoint, bool, error) {
protocol = "openai"
}
- return ResolvedEndpoint{URL: url, Token: token, Model: model, Protocol: protocol, Source: "OCR environment"}, true, nil
+ useMaxCompletionTokens := false
+ if v := os.Getenv(envOCRUseMaxCompletionTokens); v != "" {
+ lower := strings.ToLower(v)
+ useMaxCompletionTokens = lower == "true" || lower == "1" || lower == "yes"
+ }
+
+ return ResolvedEndpoint{URL: url, Token: token, Model: model, Protocol: protocol, Source: "OCR environment", UseMaxCompletionTokens: useMaxCompletionTokens}, true, nil
}
// llmFileConfig represents the llm section in config.json.
type llmFileConfig struct {
- URL string `json:"url,omitempty"`
- AuthToken string `json:"auth_token,omitempty"`
- Model string `json:"model,omitempty"`
- UseAnthropic *bool `json:"use_anthropic,omitempty"` // pointer to distinguish unset from false
- ExtraBody map[string]any `json:"extra_body,omitempty"`
+ URL string `json:"url,omitempty"`
+ AuthToken string `json:"auth_token,omitempty"`
+ Model string `json:"model,omitempty"`
+ UseAnthropic *bool `json:"use_anthropic,omitempty"` // pointer to distinguish unset from false
+ UseMaxCompletionTokens *bool `json:"use_max_completion_tokens,omitempty"` // pointer to distinguish unset from false
+ ExtraBody map[string]any `json:"extra_body,omitempty"`
}
type configFile struct {
@@ -128,7 +137,12 @@ func tryOCRConfig(path string) (ResolvedEndpoint, bool, error) {
protocol = "openai"
}
- return ResolvedEndpoint{URL: cfg.Llm.URL, Token: cfg.Llm.AuthToken, Model: cfg.Llm.Model, Protocol: protocol, Source: "OCR config file", ExtraBody: cfg.Llm.ExtraBody}, true, nil
+ useMaxCompletionTokens := false
+ if cfg.Llm.UseMaxCompletionTokens != nil {
+ useMaxCompletionTokens = *cfg.Llm.UseMaxCompletionTokens
+ }
+
+ return ResolvedEndpoint{URL: cfg.Llm.URL, Token: cfg.Llm.AuthToken, Model: cfg.Llm.Model, Protocol: protocol, Source: "OCR config file", ExtraBody: cfg.Llm.ExtraBody, UseMaxCompletionTokens: useMaxCompletionTokens}, true, nil
}
// tryCCEnv reads Claude Code environment variables.
diff --git a/internal/llm/resolver_test.go b/internal/llm/resolver_test.go
index 61ca101..d99365a 100644
--- a/internal/llm/resolver_test.go
+++ b/internal/llm/resolver_test.go
@@ -115,3 +115,88 @@ func TestResolveEndpoint_ConfigFileStripsModelSuffix(t *testing.T) {
t.Errorf("expected source %q, got %q", "OCR config file", ep.Source)
}
}
+
+func TestTryOCREnv_UseMaxCompletionTokens(t *testing.T) {
+ tests := []struct {
+ name string
+ envValue string
+ want bool
+ }{
+ {"true", "true", true},
+ {"True_uppercase", "True", true},
+ {"1", "1", true},
+ {"yes", "yes", true},
+ {"false", "false", false},
+ {"0", "0", false},
+ {"empty_string", "", false},
+ {"invalid", "invalid", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Setenv("OCR_LLM_URL", "https://api.example.com/v1/chat/completions")
+ t.Setenv("OCR_LLM_TOKEN", "test-token")
+ t.Setenv("OCR_LLM_MODEL", "gpt-4")
+ t.Setenv("OCR_USE_ANTHROPIC", "false")
+ if tt.envValue != "" {
+ t.Setenv("OCR_USE_MAX_COMPLETION_TOKENS", tt.envValue)
+ } else {
+ t.Setenv("OCR_USE_MAX_COMPLETION_TOKENS", "")
+ }
+
+ ep, ok, err := tryOCREnv()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !ok {
+ t.Fatal("expected ok=true")
+ }
+ if ep.UseMaxCompletionTokens != tt.want {
+ t.Errorf("UseMaxCompletionTokens = %v, want %v", ep.UseMaxCompletionTokens, tt.want)
+ }
+ })
+ }
+}
+
+func TestTryOCRConfig_UseMaxCompletionTokens(t *testing.T) {
+ tests := []struct {
+ name string
+ val *bool
+ want bool
+ }{
+ {"explicit_true", boolPtr(true), true},
+ {"explicit_false", boolPtr(false), false},
+ {"unset_nil", nil, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := configFile{
+ Llm: llmFileConfig{
+ URL: "https://api.example.com/v1/chat/completions",
+ AuthToken: "test-token",
+ Model: "gpt-4",
+ UseMaxCompletionTokens: tt.val,
+ },
+ }
+ data, _ := json.Marshal(cfg)
+ cfgPath := filepath.Join(t.TempDir(), "config.json")
+ os.WriteFile(cfgPath, data, 0644)
+
+ ep, ok, err := tryOCRConfig(cfgPath)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !ok {
+ t.Fatal("expected ok=true")
+ }
+ if ep.UseMaxCompletionTokens != tt.want {
+ t.Errorf("UseMaxCompletionTokens = %v, want %v", ep.UseMaxCompletionTokens, tt.want)
+ }
+ })
+ }
+}
+
+func boolPtr(b bool) *bool {
+ return &b
+}
diff --git a/internal/llm/usage_resolver.go b/internal/llm/usage_resolver.go
index 1345420..fd033b1 100644
--- a/internal/llm/usage_resolver.go
+++ b/internal/llm/usage_resolver.go
@@ -1,10 +1,5 @@
package llm
-import (
- "encoding/json"
- "strings"
-)
-
// UsageInfo holds token usage extracted from an LLM API response.
type UsageInfo struct {
TotalTokens int64 `json:"total_tokens"`
@@ -13,101 +8,3 @@ type UsageInfo struct {
CacheReadTokens int64 `json:"cache_read_tokens,omitempty"`
CacheWriteTokens int64 `json:"cache_write_tokens,omitempty"`
}
-
-var promptTokensPaths = []string{
- "usage.prompt_tokens", // OpenAI standard
- "prompt_tokens", // flat at root
- "data.usage.prompt_tokens", // wrapped in data layer
-}
-
-var completionTokensPaths = []string{
- "usage.completion_tokens", // OpenAI standard
- "completion_tokens", // flat at root
- "data.usage.completion_tokens", // wrapped in data layer
-}
-
-var cacheReadTokensPaths = []string{
- "usage.cache_read_input_tokens", // Anthropic
- "cache_read_input_tokens", // flat at root
- "usage.prompt_tokens_details.cache_tokens_hit", // some providers
- "usage.prompt_tokens_details.cache_tokens", // some providers
-}
-
-var cacheWriteTokensPaths = []string{
- "usage.cache_creation_input_tokens", // Anthropic / proxy
- "cache_creation_input_tokens", // flat at root
-}
-
-// totalTokensPaths is an ordered list of JSON paths to try when extracting
-// total token count from a response body. Paths are dot-separated keys that
-// navigate through nested map[string]any objects. The first match wins.
-var totalTokensPaths = []string{
- "usage.total_tokens", // OpenAI standard
- "total_tokens", // flat at root
- "data.usage.total_tokens", // wrapped in data layer (some proxy APIs)
-}
-
-// resolveUsage parses raw JSON bytes into a map and extracts token usage
-// by probing configured paths sequentially. Returns nil if no total_tokens found.
-func resolveUsage(raw []byte) *UsageInfo {
- var rawBody map[string]any
- if err := json.Unmarshal(raw, &rawBody); err != nil {
- return nil
- }
-
- total, hasAny := probePath(rawBody, totalTokensPaths)
- prompt, _ := probePath(rawBody, promptTokensPaths)
- completion, _ := probePath(rawBody, completionTokensPaths)
- cacheRead, _ := probePath(rawBody, cacheReadTokensPaths)
- cacheWrite, _ := probePath(rawBody, cacheWriteTokensPaths)
-
- if !hasAny && prompt == 0 && completion == 0 {
- return nil
- }
-
- ui := &UsageInfo{
- TotalTokens: total,
- PromptTokens: prompt,
- CompletionTokens: completion,
- CacheReadTokens: cacheRead,
- CacheWriteTokens: cacheWrite,
- }
-
- // If TotalTokens wasn't explicitly available but we have prompt+completion, compute it.
- if total == 0 && (prompt > 0 || completion > 0) {
- ui.TotalTokens = prompt + completion + cacheRead + cacheWrite
- }
-
- return ui
-}
-
-// probePath walks through each candidate path in order, returning the first
-// int64 value found along with true. Returns (0, false) if none match.
-func probePath(root map[string]any, paths []string) (int64, bool) {
- for _, p := range paths {
- parts := strings.Split(p, ".")
-
- var current any = root
- for _, part := range parts {
- obj, ok := current.(map[string]any)
- if !ok {
- goto next
- }
- current, ok = obj[part]
- if !ok {
- goto next
- }
- }
-
- switch v := current.(type) {
- case float64:
- return int64(v), true
- case int64:
- return v, true
- case int:
- return int64(v), true
- }
- next:
- }
- return 0, false
-}