From a03ff6285e9705f145d6bb472d32cc74394112c0 Mon Sep 17 00:00:00 2001 From: Louis <8515500@gmail.com> Date: Wed, 13 May 2026 18:45:15 +0800 Subject: [PATCH 1/2] feat: add tool calling settings switch --- API.en.md | 2 + API.md | 2 + README.en.md | 1 + README.md | 1 + config.example.json | 4 + docs/README.md | 2 + docs/feature-switch-roadmap.md | 40 +++++ docs/prompt-compatibility.md | 10 ++ docs/toolcall-semantics.md | 6 + internal/assistantturn/turn.go | 11 +- internal/completionruntime/nonstream.go | 1 + internal/config/codec.go | 11 ++ internal/config/config.go | 8 + internal/config/store_accessors.go | 15 ++ internal/config/validation.go | 27 ++++ internal/config/validation_test.go | 5 + .../httpapi/admin/handler_settings_test.go | 73 +++++++++ .../admin/settings/handler_settings_parse.go | 50 +++--- .../admin/settings/handler_settings_read.go | 4 + .../admin/settings/handler_settings_write.go | 12 +- internal/httpapi/admin/shared/deps.go | 2 + internal/httpapi/claude/handler_messages.go | 3 + internal/httpapi/claude/standard_request.go | 38 +++-- .../httpapi/claude/stream_runtime_core.go | 13 +- .../httpapi/claude/stream_runtime_finalize.go | 1 + internal/httpapi/gemini/convert_request.go | 34 +++-- .../httpapi/gemini/handler_stream_runtime.go | 8 +- .../openai/chat/chat_stream_runtime.go | 20 ++- .../openai/chat/chat_stream_runtime_test.go | 2 + .../openai/chat/empty_retry_runtime.go | 6 +- .../openai/chat/empty_retry_runtime_test.go | 1 + internal/httpapi/openai/chat/handler_chat.go | 1 + .../openai/responses/empty_retry_runtime.go | 6 +- .../responses/empty_retry_runtime_test.go | 1 + .../openai/responses/responses_handler.go | 1 + .../responses_stream_runtime_core.go | 22 +-- internal/promptcompat/request_normalize.go | 143 +++++++++++++----- internal/promptcompat/standard_request.go | 2 + .../tool_calling_settings_test.go | 94 ++++++++++++ .../src/features/settings/BehaviorSection.jsx | 2 +- .../features/settings/FeatureFlagsSection.jsx | 40 +++++ .../src/features/settings/useSettingsForm.js | 9 ++ webui/src/locales/en.json | 6 + webui/src/locales/zh.json | 6 + 44 files changed, 628 insertions(+), 118 deletions(-) create mode 100644 docs/feature-switch-roadmap.md create mode 100644 internal/promptcompat/tool_calling_settings_test.go diff --git a/API.en.md b/API.en.md index 78802a7bf..d49660f81 100644 --- a/API.en.md +++ b/API.en.md @@ -772,6 +772,7 @@ Reads runtime settings and status, including: - `auto_delete` (`mode`: `none` / `single` / `all`; legacy `sessions=true` is still treated as `all`) - `current_input_file` (`enabled` defaults to `true`, plus `min_chars`; `inline_max_tokens` defaults to `30000`; `filename_policy` defaults to `neutral_random`) - `thinking_injection` (`enabled` defaults to `false`, `prompt`, and `default_prompt`) +- `tool_calling` (`enabled` defaults to `true`; `disabled_behavior` supports `reject` / `ignore_tools`, defaulting to `reject`) - `model_aliases` - `env_backed`, `needs_vercel_sync` - `context_engine` (`mode`: `off` / `shadow` / `enforce`, defaults to `enforce`; `strategy`: `raw_transcript` / `natural_context` / `context_capsule` / `hybrid_recent` / `auto`, defaults to `hybrid_recent`; `env_override`: whether an env var is active) @@ -789,6 +790,7 @@ Hot-updates runtime settings. Supported fields: - `auto_delete.mode` - `current_input_file.enabled` / `current_input_file.min_chars` / `current_input_file.inline_max_tokens` / `current_input_file.filename_policy` - `thinking_injection.enabled` / `thinking_injection.prompt` +- `tool_calling.enabled` / `tool_calling.disabled_behavior` - `context_engine.mode` / `context_engine.strategy` - `model_aliases` - `toolcall` policy is fixed and is no longer writable through settings diff --git a/API.md b/API.md index f05a01d5a..ce125d382 100644 --- a/API.md +++ b/API.md @@ -778,6 +778,7 @@ data: {"type":"message_stop"} - `auto_delete`(`mode`:`none` / `single` / `all`;旧配置 `sessions=true` 仍按 `all` 处理) - `current_input_file`(`enabled` 默认返回 `true`、`min_chars`、`inline_max_tokens` 默认返回 `30000`、`filename_policy` 默认返回 `neutral_random`) - `thinking_injection`(`enabled` 默认返回 `false`、`prompt`、`default_prompt`) +- `tool_calling`(`enabled` 默认返回 `true`;`disabled_behavior` 支持 `reject` / `ignore_tools`,默认 `reject`) - `model_aliases` - `env_backed`、`needs_vercel_sync` - `context_engine`(`mode`:`off` / `shadow` / `enforce`,默认 `enforce`;`strategy`:`raw_transcript` / `natural_context` / `context_capsule` / `hybrid_recent` / `auto`,默认 `hybrid_recent`;`env_override`:是否被环境变量覆盖) @@ -795,6 +796,7 @@ data: {"type":"message_stop"} - `auto_delete.mode` - `current_input_file.enabled` / `current_input_file.min_chars` / `current_input_file.inline_max_tokens` / `current_input_file.filename_policy` - `thinking_injection.enabled` / `thinking_injection.prompt` +- `tool_calling.enabled` / `tool_calling.disabled_behavior` - `context_engine.mode` / `context_engine.strategy` - `model_aliases` - `toolcall` 策略已固定,不再作为可写入字段 diff --git a/README.en.md b/README.en.md index 54bb5be1b..695511638 100644 --- a/README.en.md +++ b/README.en.md @@ -336,6 +336,7 @@ Common fields: - When the full context exceeds the inline threshold and the latest user turn satisfies `min_chars`, DS2API generates conversation context / tool reference files. The default `filename_policy` is `neutral_random`; legacy implementation filenames are used only when `filename_policy=legacy`. - If you turn off `current_input_file`, requests pass through directly without uploading any split context file. - `thinking_injection`: disabled by default. It appends the enhanced reasoning prompt to the latest user message only when `thinking_injection.enabled=true`. +- `tool_calling`: enabled by default. When disabled, requests with `tools` or forced `tool_choice` are rejected by default; `disabled_behavior=ignore_tools` strips tools and continues as plain chat. - `parser_v2.mode`: Tool Parser v2 gradual switch, supporting `off` / `shadow` / `enforce`; it is safely off by default and can be overridden with `DS2API_PARSER_V2`. - `context_engine.mode`: Context Engine gradual switch, supporting `off` / `shadow` / `enforce`; it defaults to `enforce`, with `context_engine.strategy=hybrid_recent` by default. diff --git a/README.md b/README.md index 65863c2dd..ad03ddb02 100644 --- a/README.md +++ b/README.md @@ -360,6 +360,7 @@ go run ./cmd/ds2api - 当整体上下文超过 inline 阈值且最新 user turn 满足 `min_chars` 时,才会生成 conversation context / tool reference 文件;默认 `filename_policy=neutral_random`,只在显式设为 `legacy` 时使用旧文件名。 - 如果关闭 `current_input_file`,请求会直接透传,不上传拆分上下文文件。 - `thinking_injection`:默认关闭;只有显式设置 `thinking_injection.enabled=true` 时,才会在最新 user 消息末尾追加增强提示词;`prompt` 留空时使用内置默认提示词。 +- `tool_calling`:默认开启。关闭后默认拒绝携带 `tools` / 强制 `tool_choice` 的请求;如设置 `disabled_behavior=ignore_tools`,会剥离 tools 并按普通对话继续。 - `parser_v2.mode`:Tool Parser v2 渐进开关,支持 `off` / `shadow` / `enforce`,默认安全关闭;可用环境变量 `DS2API_PARSER_V2` 覆盖。 - `context_engine.mode`:Context Engine 渐进开关,支持 `off` / `shadow` / `enforce`,默认 `enforce`;`context_engine.strategy` 默认 `hybrid_recent`,可用环境变量 `DS2API_CONTEXT_ENGINE` / `DS2API_CONTEXT_ENGINE_STRATEGY` 覆盖。 - `log`:日志文件输出与轮转配置;`file` 为空时使用默认 `logs/ds2api.log`,`max_size_mb` 有上限校验。 diff --git a/config.example.json b/config.example.json index fb703138f..3ed45ebe5 100644 --- a/config.example.json +++ b/config.example.json @@ -56,6 +56,10 @@ "enabled": false, "prompt": "" }, + "tool_calling": { + "enabled": true, + "disabled_behavior": "reject" + }, "embeddings": { "provider": "deterministic" }, diff --git a/docs/README.md b/docs/README.md index 378c7a463..099003c9c 100644 --- a/docs/README.md +++ b/docs/README.md @@ -20,6 +20,7 @@ - [API -> 网页对话纯文本兼容主链路说明](./prompt-compatibility.md) - [Context Engine 策略验收说明](./context-engine-strategies.md) - [Tool Calling 统一语义](./toolcall-semantics.md) +- [Feature Switch Roadmap](./feature-switch-roadmap.md) - [DeepSeek SSE 行为结构说明(逆向观察)](./DeepSeekSSE行为结构说明-2026-04-05.md) ### v2 开发规划 @@ -67,6 +68,7 @@ Recommended reading order: - [API -> pure-text web-chat compatibility pipeline](./prompt-compatibility.md) - [Context Engine strategy acceptance guide](./context-engine-strategies.md) - [Tool-calling unified semantics](./toolcall-semantics.md) +- [Feature Switch Roadmap](./feature-switch-roadmap.md) - [DeepSeek SSE behavior notes (reverse-engineered)](./DeepSeekSSE行为结构说明-2026-04-05.md) ### v2 development planning (Chinese only) diff --git a/docs/feature-switch-roadmap.md b/docs/feature-switch-roadmap.md new file mode 100644 index 000000000..874eb56f5 --- /dev/null +++ b/docs/feature-switch-roadmap.md @@ -0,0 +1,40 @@ +# Feature Switch Roadmap + +文档导航:[文档索引](./README.md) / [M4/M5 执行规划](./m4-development-plan.md) / [Tool Calling 统一语义](./toolcall-semantics.md) + +本文把近期讨论的“后续需要补充开关”拆成可堆叠 PR 的开发清单。原则是先补高风险、可回滚、影响面清晰的开关,再推进能力矩阵和诊断产品化。 + +## 当前 PR:Tool Calling 全局开关 + +目标:管理员可在配置和 WebUI 中关闭 tools 主链路。 + +范围: + +- 新增 `tool_calling.enabled`,默认 `true`。 +- 新增 `tool_calling.disabled_behavior`,支持 `reject` / `ignore_tools`,默认 `reject`。 +- 覆盖 OpenAI Chat、OpenAI Responses、Claude、Gemini 的请求标准化。 +- 关闭后同步禁止非流式和流式收尾的工具调用解析。 +- WebUI Settings 增加可编辑配置。 + +验收: + +- 普通对话在关闭 tools 后仍可执行。 +- 携带 tools 的请求默认被拒绝。 +- `ignore_tools` 模式会剥离 tools,不向模型暴露工具说明,也不解析输出工具调用。 +- API / README / prompt compatibility / toolcall 语义文档同步。 + +## 后续 PR 队列 + +| 优先级 | 开关组 | 建议字段 | 默认值 | 目的 | +|---|---|---|---|---| +| P1 | 能力开关 | `capabilities.search.enabled`、`capabilities.thinking.enabled` | `true` | 允许部署方按账号/环境关闭高风险或高成本能力 | +| P1 | 输入能力 | `capabilities.file_upload.enabled`、`capabilities.vision.enabled` | `true` | 在无真实账号或文件链路不稳定时一键回退 | +| P2 | Runtime 安全 | `runtime.empty_output_retry.enabled`、`runtime.account_switch_retry.enabled` | `true` | 将已存在的重试行为显式化,便于排障 | +| P2 | 观测采样 | `observability.raw_sample_capture.enabled`、`observability.response_history.enabled` | 当前行为 | 让隐私/存储敏感部署可集中控制数据留存 | +| P3 | M4/M5 新能力 | `auto_continue.mode`、`capability_router.mode`、`agent_memory.mode` | `off` | 新主链路能力先 shadow,再 enforce | + +执行约束: + +- 每个开关组单独分支和 PR,不与无关重构混合。 +- 默认值必须在 config、API、README、WebUI、专题文档中保持一致。 +- 涉及主请求链路的开关必须有单元测试覆盖开启、关闭、非法配置和局部更新。 diff --git a/docs/prompt-compatibility.md b/docs/prompt-compatibility.md index 2a552ad19..8a3119322 100644 --- a/docs/prompt-compatibility.md +++ b/docs/prompt-compatibility.md @@ -80,6 +80,16 @@ DS2API 当前的核心思路,不是把客户端传来的 `messages`、`tools` - Go completion runtime: [internal/completionruntime/nonstream.go](../internal/completionruntime/nonstream.go) +### Tool Calling 全局开关 + +`tool_calling.enabled` 默认开启,保持现有 API tools 兼容行为:接收上游 `tools` / `tool_choice`,把工具声明注入 prompt,并在输出侧解析模型生成的工具调用。 + +当 `tool_calling.enabled=false` 时: + +- 默认 `disabled_behavior=reject`,携带 `tools` 或强制 `tool_choice` 的请求会在标准化阶段返回错误。 +- 如果配置 `disabled_behavior=ignore_tools`,标准化阶段会移除工具声明,把 `tool_choice` 收敛为 `none`,后续 prompt 不注入工具说明,输出侧也不解析工具调用。 +- 没有请求工具能力的普通对话仍可继续执行。 + ## 4. 下游真正收到的东西 在“完成标准化后”,下游 completion payload 的核心形态是: diff --git a/docs/toolcall-semantics.md b/docs/toolcall-semantics.md index 23e7430b5..e120f2fec 100644 --- a/docs/toolcall-semantics.md +++ b/docs/toolcall-semantics.md @@ -4,6 +4,12 @@ 文档导航:[总览](../README.md) / [架构说明](./ARCHITECTURE.md) / [测试指南](./TESTING.md) +## 0) 全局启停开关 + +`tool_calling.enabled` 是 tools 主链路的全局开关,默认 `true`。开启时,协议层会接受上游工具声明、注入 prompt-visible tool 指令,并在 assistant 输出侧解析工具调用。 + +当它关闭时,默认 `tool_calling.disabled_behavior=reject`:携带 `tools`、`tool_choice=required`、forced tool choice 或 allowed-tools 约束的请求会被拒绝。若设置为 `ignore_tools`,请求中的工具声明会被剥离并按普通对话继续;此时非流式和流式收尾都不会把 assistant 输出解析为工具调用。 + ## 1) 当前可执行格式 当前版本推荐模型输出半角管道符 DSML 外壳: diff --git a/internal/assistantturn/turn.go b/internal/assistantturn/turn.go index 194b54342..4eefab0da 100644 --- a/internal/assistantturn/turn.go +++ b/internal/assistantturn/turn.go @@ -76,6 +76,7 @@ type BuildOptions struct { ToolNames []string ToolsRaw any ToolChoice promptcompat.ToolChoicePolicy + DisableToolCalling bool ParserV2Mode string // Ctx is the request context used for observability. May be nil. Ctx context.Context @@ -102,7 +103,10 @@ func BuildTurnFromCollected(result sse.CollectResult, opts BuildOptions) Turn { text = shared.ReplaceCitationMarkersWithLinks(text, result.CitationLinks) } - parsed := shared.DetectAssistantToolCalls(result.Text, text, result.Thinking, result.ToolDetectionThinking, opts.ToolNames) + parsed := toolcall.ToolCallParseResult{} + if !opts.DisableToolCalling { + parsed = shared.DetectAssistantToolCalls(result.Text, text, result.Thinking, result.ToolDetectionThinking, opts.ToolNames) + } parsedBeforeNorm := parsed calls := toolcall.NormalizeParsedToolCallsForSchemas(parsed.Calls, opts.ToolsRaw) parsed.Calls = calls @@ -154,7 +158,10 @@ func BuildTurnFromStreamSnapshot(snapshot StreamSnapshot, opts BuildOptions) Tur text = shared.ReplaceCitationMarkersWithLinks(text, snapshot.CitationLinks) } - parsed := shared.DetectAssistantToolCalls(snapshot.RawText, text, snapshot.RawThinking, snapshot.DetectionThinking, opts.ToolNames) + parsed := toolcall.ToolCallParseResult{} + if !opts.DisableToolCalling { + parsed = shared.DetectAssistantToolCalls(snapshot.RawText, text, snapshot.RawThinking, snapshot.DetectionThinking, opts.ToolNames) + } parsedBeforeNorm := parsed calls := parsed.Calls if len(calls) == 0 && len(snapshot.AdditionalToolCalls) > 0 { diff --git a/internal/completionruntime/nonstream.go b/internal/completionruntime/nonstream.go index 1051c22fe..7892129ea 100644 --- a/internal/completionruntime/nonstream.go +++ b/internal/completionruntime/nonstream.go @@ -280,6 +280,7 @@ func buildOptions(ctx context.Context, stdReq promptcompat.StandardRequest, prom ToolNames: stdReq.ToolNames, ToolsRaw: stdReq.ToolsRaw, ToolChoice: stdReq.ToolChoice, + DisableToolCalling: stdReq.ToolCallingDisabled, ParserV2Mode: opts.ParserV2Mode, Ctx: ctx, } diff --git a/internal/config/codec.go b/internal/config/codec.go index ccb5c7147..ea7b4e44e 100644 --- a/internal/config/codec.go +++ b/internal/config/codec.go @@ -48,6 +48,9 @@ func (c Config) MarshalJSON() ([]byte, error) { if c.ThinkingInjection.Enabled != nil || strings.TrimSpace(c.ThinkingInjection.Prompt) != "" { m["thinking_injection"] = c.ThinkingInjection } + if c.ToolCalling.Enabled != nil || strings.TrimSpace(c.ToolCalling.DisabledBehavior) != "" { + m["tool_calling"] = c.ToolCalling + } if strings.TrimSpace(c.Vercel.Token) != "" || strings.TrimSpace(c.Vercel.ProjectID) != "" || strings.TrimSpace(c.Vercel.TeamID) != "" { m["vercel"] = NormalizeVercelConfig(c.Vercel) } @@ -143,6 +146,10 @@ func (c *Config) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(v, &c.ThinkingInjection); err != nil { return fmt.Errorf("invalid field %q: %w", k, err) } + case "tool_calling": + if err := json.Unmarshal(v, &c.ToolCalling); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } case "cors": if err := json.Unmarshal(v, &c.CORS); err != nil { return fmt.Errorf("invalid field %q: %w", k, err) @@ -213,6 +220,10 @@ func (c Config) Clone() Config { Enabled: cloneBoolPtr(c.ThinkingInjection.Enabled), Prompt: c.ThinkingInjection.Prompt, }, + ToolCalling: ToolCallingConfig{ + Enabled: cloneBoolPtr(c.ToolCalling.Enabled), + DisabledBehavior: c.ToolCalling.DisabledBehavior, + }, Vercel: c.Vercel, CORS: CORSConfig{AllowOrigins: slices.Clone(c.CORS.AllowOrigins)}, Auth: AuthConfig{AllowGeminiQueryKey: cloneBoolPtr(c.Auth.AllowGeminiQueryKey)}, diff --git a/internal/config/config.go b/internal/config/config.go index a5223cc44..523da6d97 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,7 @@ type Config struct { AutoDelete AutoDeleteConfig `json:"auto_delete"` CurrentInputFile CurrentInputFileConfig `json:"current_input_file,omitempty"` ThinkingInjection ThinkingInjectionConfig `json:"thinking_injection,omitempty"` + ToolCalling ToolCallingConfig `json:"tool_calling,omitempty"` Vercel VercelConfig `json:"vercel,omitempty"` CORS CORSConfig `json:"cors,omitempty"` Auth AuthConfig `json:"auth,omitempty"` @@ -190,6 +191,13 @@ type ThinkingInjectionConfig struct { Prompt string `json:"prompt,omitempty"` } +// ToolCallingConfig controls whether prompt-visible tool schemas and parsed +// tool-call outputs are enabled. +type ToolCallingConfig struct { + Enabled *bool `json:"enabled,omitempty"` + DisabledBehavior string `json:"disabled_behavior,omitempty"` +} + type VercelConfig struct { Token string `json:"token,omitempty"` ProjectID string `json:"project_id,omitempty"` diff --git a/internal/config/store_accessors.go b/internal/config/store_accessors.go index a2339778a..e551b76b6 100644 --- a/internal/config/store_accessors.go +++ b/internal/config/store_accessors.go @@ -190,6 +190,21 @@ func (s *Store) ThinkingInjectionPrompt() string { return strings.TrimSpace(s.cfg.ThinkingInjection.Prompt) } +func (s *Store) ToolCallingEnabled() bool { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.ToolCalling.Enabled == nil { + return true + } + return *s.cfg.ToolCalling.Enabled +} + +func (s *Store) ToolCallingDisabledBehavior() string { + s.mu.RLock() + defer s.mu.RUnlock() + return NormalizeToolCallingDisabledBehavior(s.cfg.ToolCalling.DisabledBehavior) +} + // ContextEngineMode returns the context engine feature flag value. // Valid values: "off" | "shadow" | "enforce" (default). // The DS2API_CONTEXT_ENGINE environment variable takes precedence over the diff --git a/internal/config/validation.go b/internal/config/validation.go index 511795f0a..ff108def0 100644 --- a/internal/config/validation.go +++ b/internal/config/validation.go @@ -27,6 +27,9 @@ func ValidateConfig(c Config) error { if err := ValidateCurrentInputFileConfig(c.CurrentInputFile); err != nil { return err } + if err := ValidateToolCallingConfig(c.ToolCalling); err != nil { + return err + } if err := ValidateContextEngineConfig(c.ContextEngine); err != nil { return err } @@ -134,6 +137,30 @@ func ValidateCurrentInputFileConfig(currentInputFile CurrentInputFileConfig) err return nil } +func ValidateToolCallingConfig(toolCalling ToolCallingConfig) error { + return ValidateToolCallingDisabledBehavior(toolCalling.DisabledBehavior) +} + +func ValidateToolCallingDisabledBehavior(behavior string) error { + switch NormalizeToolCallingDisabledBehavior(behavior) { + case "reject", "ignore_tools": + return nil + default: + return fmt.Errorf("tool_calling.disabled_behavior must be one of reject, ignore_tools") + } +} + +func NormalizeToolCallingDisabledBehavior(behavior string) string { + switch strings.ToLower(strings.TrimSpace(behavior)) { + case "", "reject": + return "reject" + case "ignore", "ignore_tools": + return "ignore_tools" + default: + return strings.ToLower(strings.TrimSpace(behavior)) + } +} + func NormalizeCurrentInputFileFilenamePolicy(policy string) string { switch strings.ToLower(strings.TrimSpace(policy)) { case "legacy", "neutral", "neutral_random": diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go index 0d9580fe0..37ef0f273 100644 --- a/internal/config/validation_test.go +++ b/internal/config/validation_test.go @@ -54,6 +54,11 @@ func TestValidateConfigRejectsInvalidValues(t *testing.T) { cfg: Config{CurrentInputFile: CurrentInputFileConfig{FilenamePolicy: "random-ish"}}, want: "current_input_file.filename_policy", }, + { + name: "tool calling disabled behavior", + cfg: Config{ToolCalling: ToolCallingConfig{DisabledBehavior: "drop"}}, + want: "tool_calling.disabled_behavior", + }, { name: "context engine mode", cfg: Config{ContextEngine: ContextEngineConfig{Mode: "observe"}}, diff --git a/internal/httpapi/admin/handler_settings_test.go b/internal/httpapi/admin/handler_settings_test.go index 1e31c8684..150b73a64 100644 --- a/internal/httpapi/admin/handler_settings_test.go +++ b/internal/httpapi/admin/handler_settings_test.go @@ -80,6 +80,13 @@ func TestGetSettingsIncludesCurrentInputFileDefaults(t *testing.T) { if got, _ := thinkingInjection["default_prompt"].(string); got == "" { t.Fatalf("expected default thinking prompt, body=%v", body) } + toolCalling, _ := body["tool_calling"].(map[string]any) + if got := boolFrom(toolCalling["enabled"]); !got { + t.Fatalf("expected tool_calling.enabled=true, body=%v", body) + } + if got, _ := toolCalling["disabled_behavior"].(string); got != "reject" { + t.Fatalf("expected tool_calling.disabled_behavior=reject, got %q body=%v", got, body) + } contextEngine, _ := body["context_engine"].(map[string]any) if got, _ := contextEngine["mode"].(string); got != "enforce" { t.Fatalf("expected context_engine.mode=enforce, got %q body=%v", got, body) @@ -430,6 +437,72 @@ func TestUpdateSettingsThinkingInjectionPartialEnabledPreservesPrompt(t *testing } } +func TestUpdateSettingsToolCalling(t *testing.T) { + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + payload := map[string]any{ + "tool_calling": map[string]any{ + "enabled": false, + "disabled_behavior": "ignore_tools", + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + snap := h.Store.Snapshot() + if snap.ToolCalling.Enabled == nil || *snap.ToolCalling.Enabled { + t.Fatalf("expected tool_calling.enabled=false, got %#v", snap.ToolCalling.Enabled) + } + if got := h.Store.ToolCallingDisabledBehavior(); got != "ignore_tools" { + t.Fatalf("expected disabled behavior ignore_tools, got %q", got) + } +} + +func TestUpdateSettingsToolCallingPartialUpdatePreservesEnabled(t *testing.T) { + h := newAdminTestHandler(t, `{"keys":["k1"],"tool_calling":{"enabled":false,"disabled_behavior":"reject"}}`) + payload := map[string]any{ + "tool_calling": map[string]any{ + "disabled_behavior": "ignore_tools", + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + snap := h.Store.Snapshot() + if snap.ToolCalling.Enabled == nil || *snap.ToolCalling.Enabled { + t.Fatalf("expected tool_calling.enabled to remain false, got %#v", snap.ToolCalling.Enabled) + } + if got := snap.ToolCalling.DisabledBehavior; got != "ignore_tools" { + t.Fatalf("expected disabled behavior to update, got %#v", snap.ToolCalling) + } +} + +func TestUpdateSettingsToolCallingRejectsInvalidBehavior(t *testing.T) { + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + payload := map[string]any{ + "tool_calling": map[string]any{ + "disabled_behavior": "drop", + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("tool_calling.disabled_behavior")) { + t.Fatalf("expected tool_calling validation detail, got %s", rec.Body.String()) + } +} + func TestUpdateSettingsContextEngine(t *testing.T) { h := newAdminTestHandler(t, `{"keys":["k1"]}`) payload := map[string]any{ diff --git a/internal/httpapi/admin/settings/handler_settings_parse.go b/internal/httpapi/admin/settings/handler_settings_parse.go index 63457be45..781d5322f 100644 --- a/internal/httpapi/admin/settings/handler_settings_parse.go +++ b/internal/httpapi/admin/settings/handler_settings_parse.go @@ -33,7 +33,7 @@ func stringFrom(v any) string { } } -func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, *config.AutoDeleteConfig, *config.CurrentInputFileConfig, *config.ThinkingInjectionConfig, *config.ContextEngineConfig, map[string]string, *config.LogConfig, error) { +func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, *config.AutoDeleteConfig, *config.CurrentInputFileConfig, *config.ThinkingInjectionConfig, *config.ToolCallingConfig, *config.ContextEngineConfig, map[string]string, *config.LogConfig, error) { var ( adminCfg *config.AdminConfig runtimeCfg *config.RuntimeConfig @@ -42,6 +42,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi autoDeleteCfg *config.AutoDeleteConfig currentInputCfg *config.CurrentInputFileConfig thinkingInjCfg *config.ThinkingInjectionConfig + toolCallingCfg *config.ToolCallingConfig contextEngineCfg *config.ContextEngineConfig aliasMap map[string]string logCfg *config.LogConfig @@ -52,7 +53,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi if v, exists := raw["jwt_expire_hours"]; exists { n := intFrom(v) if err := config.ValidateIntRange("admin.jwt_expire_hours", n, 1, 720, true); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.JWTExpireHours = n } @@ -64,33 +65,33 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi if v, exists := raw["account_max_inflight"]; exists { n := intFrom(v) if err := config.ValidateIntRange("runtime.account_max_inflight", n, 1, 256, true); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.AccountMaxInflight = n } if v, exists := raw["account_max_queue"]; exists { n := intFrom(v) if err := config.ValidateIntRange("runtime.account_max_queue", n, 1, 200000, true); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.AccountMaxQueue = n } if v, exists := raw["global_max_inflight"]; exists { n := intFrom(v) if err := config.ValidateIntRange("runtime.global_max_inflight", n, 1, 200000, true); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.GlobalMaxInflight = n } if v, exists := raw["token_refresh_interval_hours"]; exists { n := intFrom(v) if err := config.ValidateIntRange("runtime.token_refresh_interval_hours", n, 1, 720, true); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.TokenRefreshIntervalHours = n } if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") } runtimeCfg = cfg } @@ -100,7 +101,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi if v, exists := raw["store_ttl_seconds"]; exists { n := intFrom(v) if err := config.ValidateIntRange("responses.store_ttl_seconds", n, 30, 86400, true); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.StoreTTLSeconds = n } @@ -112,7 +113,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi if v, exists := raw["provider"]; exists { p := strings.TrimSpace(fmt.Sprintf("%v", v)) if err := config.ValidateTrimmedString("embeddings.provider", p, false); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.Provider = p } @@ -138,7 +139,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi if v, exists := raw["mode"]; exists { mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) if err := config.ValidateAutoDeleteMode(mode); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } if mode == "" { mode = "none" @@ -160,14 +161,14 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi if v, exists := raw["min_chars"]; exists { n := intFrom(v) if err := config.ValidateIntRange("current_input_file.min_chars", n, 0, 100000000, true); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.MinChars = n } if v, exists := raw["inline_max_tokens"]; exists { n := intFrom(v) if err := config.ValidateIntRange("current_input_file.inline_max_tokens", n, 0, 100000000, true); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.InlineMaxTokens = n } @@ -175,7 +176,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi cfg.FilenamePolicy = strings.TrimSpace(stringFrom(v)) } if err := config.ValidateCurrentInputFileConfig(*cfg); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } currentInputCfg = cfg } @@ -192,6 +193,21 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi thinkingInjCfg = cfg } + if raw, ok := req["tool_calling"].(map[string]any); ok { + cfg := &config.ToolCallingConfig{} + if v, exists := raw["enabled"]; exists { + b := boolFrom(v) + cfg.Enabled = &b + } + if v, exists := raw["disabled_behavior"]; exists { + cfg.DisabledBehavior = config.NormalizeToolCallingDisabledBehavior(stringFrom(v)) + } + if err := config.ValidateToolCallingConfig(*cfg); err != nil { + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + } + toolCallingCfg = cfg + } + if raw, ok := req["context_engine"].(map[string]any); ok { cfg := &config.ContextEngineConfig{} if v, exists := raw["mode"]; exists { @@ -201,7 +217,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi cfg.Strategy = config.NormalizeContextEngineStrategy(stringFrom(v)) } if err := config.ValidateContextEngineConfig(*cfg); err != nil { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err } contextEngineCfg = cfg } @@ -216,7 +232,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi case "debug", "info", "warn", "error": cfg.Level = level default: - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("log.level must be one of: debug, info, warn, error") + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("log.level must be one of: debug, info, warn, error") } } if v, exists := raw["file"]; exists { @@ -228,7 +244,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi if v, exists := raw["max_size_mb"]; exists { n := intFrom(v) if n > 1024 { - return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("log.max_size_mb must be <= 1024") + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("log.max_size_mb must be <= 1024") } if n > 0 { cfg.MaxSizeMB = n @@ -243,5 +259,5 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi logCfg = cfg } - return adminCfg, runtimeCfg, respCfg, embCfg, autoDeleteCfg, currentInputCfg, thinkingInjCfg, contextEngineCfg, aliasMap, logCfg, nil + return adminCfg, runtimeCfg, respCfg, embCfg, autoDeleteCfg, currentInputCfg, thinkingInjCfg, toolCallingCfg, contextEngineCfg, aliasMap, logCfg, nil } diff --git a/internal/httpapi/admin/settings/handler_settings_read.go b/internal/httpapi/admin/settings/handler_settings_read.go index 258bcc966..d2ed03a59 100644 --- a/internal/httpapi/admin/settings/handler_settings_read.go +++ b/internal/httpapi/admin/settings/handler_settings_read.go @@ -42,6 +42,10 @@ func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) { "prompt": h.Store.ThinkingInjectionPrompt(), "default_prompt": promptcompat.DefaultThinkingInjectionPrompt, }, + "tool_calling": map[string]any{ + "enabled": h.Store.ToolCallingEnabled(), + "disabled_behavior": h.Store.ToolCallingDisabledBehavior(), + }, "model_aliases": snap.ModelAliases, "env_backed": h.Store.IsEnvBacked(), "needs_vercel_sync": needsSync, diff --git a/internal/httpapi/admin/settings/handler_settings_write.go b/internal/httpapi/admin/settings/handler_settings_write.go index 499d49a0c..bcaa9ffcd 100644 --- a/internal/httpapi/admin/settings/handler_settings_write.go +++ b/internal/httpapi/admin/settings/handler_settings_write.go @@ -17,7 +17,7 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) { return } - adminCfg, runtimeCfg, responsesCfg, embeddingsCfg, autoDeleteCfg, currentInputCfg, thinkingInjCfg, contextEngineCfg, aliasMap, logCfg, err := parseSettingsUpdateRequest(req) + adminCfg, runtimeCfg, responsesCfg, embeddingsCfg, autoDeleteCfg, currentInputCfg, thinkingInjCfg, toolCallingCfg, contextEngineCfg, aliasMap, logCfg, err := parseSettingsUpdateRequest(req) if err != nil { writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return @@ -34,6 +34,8 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) { currentInputFilenamePolicySet := hasNestedSettingsKey(req, "current_input_file", "filename_policy") thinkingInjectionEnabledSet := hasNestedSettingsKey(req, "thinking_injection", "enabled") thinkingInjectionPromptSet := hasNestedSettingsKey(req, "thinking_injection", "prompt") + toolCallingEnabledSet := hasNestedSettingsKey(req, "tool_calling", "enabled") + toolCallingDisabledBehaviorSet := hasNestedSettingsKey(req, "tool_calling", "disabled_behavior") contextEngineModeSet := hasNestedSettingsKey(req, "context_engine", "mode") contextEngineStrategySet := hasNestedSettingsKey(req, "context_engine", "strategy") logLevelSet := hasNestedSettingsKey(req, "log", "level") @@ -95,6 +97,14 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) { c.ThinkingInjection.Prompt = thinkingInjCfg.Prompt } } + if toolCallingCfg != nil { + if toolCallingEnabledSet { + c.ToolCalling.Enabled = toolCallingCfg.Enabled + } + if toolCallingDisabledBehaviorSet { + c.ToolCalling.DisabledBehavior = toolCallingCfg.DisabledBehavior + } + } if contextEngineCfg != nil { if contextEngineModeSet { c.ContextEngine.Mode = contextEngineCfg.Mode diff --git a/internal/httpapi/admin/shared/deps.go b/internal/httpapi/admin/shared/deps.go index 6deb34fd0..cbf806148 100644 --- a/internal/httpapi/admin/shared/deps.go +++ b/internal/httpapi/admin/shared/deps.go @@ -39,6 +39,8 @@ type ConfigStore interface { CurrentInputFileFilenamePolicy() string ThinkingInjectionEnabled() bool ThinkingInjectionPrompt() string + ToolCallingEnabled() bool + ToolCallingDisabledBehavior() string ContextEngineMode() string ContextEngineStrategy() string AutoDeleteSessions() bool diff --git a/internal/httpapi/claude/handler_messages.go b/internal/httpapi/claude/handler_messages.go index 66d6177ce..59e66b438 100644 --- a/internal/httpapi/claude/handler_messages.go +++ b/internal/httpapi/claude/handler_messages.go @@ -342,6 +342,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ h.parserV2Mode(), toolNames, toolsRaw, + false, buildClaudePromptTokenText(messages, thinkingEnabled), historySession, ) @@ -396,6 +397,7 @@ func (h *Handler) handleClaudeStreamRealtimeWithRetry(w http.ResponseWriter, r * searchEnabled := stdReq.Search toolNames := stdReq.ToolNames toolsRaw := stdReq.ToolsRaw + disableToolCalling := stdReq.ToolCallingDisabled promptTokenText := stdReq.PromptTokenText streamRuntime := newClaudeStreamRuntime( w, @@ -410,6 +412,7 @@ func (h *Handler) handleClaudeStreamRealtimeWithRetry(w http.ResponseWriter, r * h.parserV2Mode(), toolNames, toolsRaw, + disableToolCalling, promptTokenText, historySession, ) diff --git a/internal/httpapi/claude/standard_request.go b/internal/httpapi/claude/standard_request.go index 65fb3b50f..57aaccdb6 100644 --- a/internal/httpapi/claude/standard_request.go +++ b/internal/httpapi/claude/standard_request.go @@ -31,8 +31,12 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma payload := cloneMap(req) payload["messages"] = normalizedMessages - toolsRequested, _ := req["tools"].([]any) - payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, toolsRequested) + toolsRequested, _, toolCallingEnabled, err := promptcompat.ResolveToolCallingRequest(store, req["tools"], req["tool_choice"]) + if err != nil { + return claudeNormalizedRequest{}, err + } + tools, _ := toolsRequested.([]any) + payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, tools) dsPayload := convertClaudeToDeepSeek(payload, store) dsModel, _ := dsPayload["model"].(string) @@ -46,25 +50,27 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma } promptMessages, _ := dsPayload["messages"].([]any) finalPrompt := promptcompat.BuildOpenAIMessagesOnlyPrompt(promptMessages, "", thinkingEnabled, store.ContextEngineMode()) - toolNames := extractClaudeToolNames(toolsRequested) - if len(toolNames) == 0 && len(toolsRequested) > 0 { + toolNames := extractClaudeToolNames(tools) + if len(toolNames) == 0 && len(tools) > 0 { toolNames = []string{"__any_tool__"} } return claudeNormalizedRequest{ Standard: promptcompat.StandardRequest{ - Surface: "anthropic_messages", - RequestedModel: strings.TrimSpace(model), - ResolvedModel: dsModel, - ResponseModel: strings.TrimSpace(model), - Messages: contextMessages, - PromptTokenText: finalPrompt, - ToolsRaw: toolsRequested, - FinalPrompt: finalPrompt, - ToolNames: toolNames, - Stream: util.ToBool(req["stream"]), - Thinking: thinkingEnabled, - Search: searchEnabled, + Surface: "anthropic_messages", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: dsModel, + ResponseModel: strings.TrimSpace(model), + Messages: contextMessages, + PromptTokenText: finalPrompt, + ToolsRaw: tools, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + ToolCallingEnabled: toolCallingEnabled, + ToolCallingDisabled: !toolCallingEnabled, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, }, NormalizedMessages: normalizedMessages, }, nil diff --git a/internal/httpapi/claude/stream_runtime_core.go b/internal/httpapi/claude/stream_runtime_core.go index 2341e177f..951d1b04f 100644 --- a/internal/httpapi/claude/stream_runtime_core.go +++ b/internal/httpapi/claude/stream_runtime_core.go @@ -19,11 +19,12 @@ type claudeStreamRuntime struct { rc *http.ResponseController canFlush bool - model string - toolNames []string - messages []any - toolsRaw any - promptTokenText string + model string + toolNames []string + messages []any + toolsRaw any + disableToolCalling bool + promptTokenText string thinkingEnabled bool searchEnabled bool @@ -68,6 +69,7 @@ func newClaudeStreamRuntime( parserV2Mode string, toolNames []string, toolsRaw any, + disableToolCalling bool, promptTokenText string, history *responsehistory.Session, ) *claudeStreamRuntime { @@ -85,6 +87,7 @@ func newClaudeStreamRuntime( parserV2Mode: parserV2Mode, toolNames: toolNames, toolsRaw: toolsRaw, + disableToolCalling: disableToolCalling, promptTokenText: promptTokenText, history: history, messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), diff --git a/internal/httpapi/claude/stream_runtime_finalize.go b/internal/httpapi/claude/stream_runtime_finalize.go index 86239d4ae..6b0874677 100644 --- a/internal/httpapi/claude/stream_runtime_finalize.go +++ b/internal/httpapi/claude/stream_runtime_finalize.go @@ -131,6 +131,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string, deferEmptyOutput bool) StripReferenceMarkers: s.stripReferenceMarkers, ToolNames: s.toolNames, ToolsRaw: s.toolsRaw, + DisableToolCalling: s.disableToolCalling, ParserV2Mode: s.parserV2Mode, Ctx: s.ctx, }) diff --git a/internal/httpapi/gemini/convert_request.go b/internal/httpapi/gemini/convert_request.go index acaa6c887..60b783cb3 100644 --- a/internal/httpapi/gemini/convert_request.go +++ b/internal/httpapi/gemini/convert_request.go @@ -31,7 +31,11 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin return promptcompat.StandardRequest{}, fmt.Errorf("request must include non-empty contents") } - toolsRaw := convertGeminiTools(req["tools"]) + geminiToolsRaw, _, toolCallingEnabled, err := promptcompat.ResolveToolCallingRequest(store, req["tools"], nil) + if err != nil { + return promptcompat.StandardRequest{}, err + } + toolsRaw := convertGeminiTools(geminiToolsRaw) finalPrompt, toolNames := promptcompat.BuildOpenAIPromptForAdapter(messagesRaw, toolsRaw, "", thinkingEnabled, store.ContextEngineMode()) if len(toolNames) == 0 && len(toolsRaw) > 0 { toolNames = []string{"__any_tool__"} @@ -39,18 +43,20 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin passThrough := collectGeminiPassThrough(req) return promptcompat.StandardRequest{ - Surface: "google_gemini", - RequestedModel: requestedModel, - ResolvedModel: resolvedModel, - ResponseModel: requestedModel, - Messages: messagesRaw, - PromptTokenText: finalPrompt, - ToolsRaw: toolsRaw, - FinalPrompt: finalPrompt, - ToolNames: toolNames, - Stream: stream, - Thinking: thinkingEnabled, - Search: searchEnabled, - PassThrough: passThrough, + Surface: "google_gemini", + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + ResponseModel: requestedModel, + Messages: messagesRaw, + PromptTokenText: finalPrompt, + ToolsRaw: toolsRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + ToolCallingEnabled: toolCallingEnabled, + ToolCallingDisabled: !toolCallingEnabled, + Stream: stream, + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, }, nil } diff --git a/internal/httpapi/gemini/handler_stream_runtime.go b/internal/httpapi/gemini/handler_stream_runtime.go index ec404f64a..3f7a9cce4 100644 --- a/internal/httpapi/gemini/handler_stream_runtime.go +++ b/internal/httpapi/gemini/handler_stream_runtime.go @@ -42,7 +42,7 @@ func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Req rc := http.NewResponseController(w) _, canFlush := w.(http.Flusher) - runtime := newGeminiStreamRuntime(w, rc, canFlush, r.Context(), model, finalPrompt, thinkingEnabled, searchEnabled, stripReferenceMarkersEnabled(), h.parserV2Mode(), toolNames, toolsRaw, historySession) + runtime := newGeminiStreamRuntime(w, rc, canFlush, r.Context(), model, finalPrompt, thinkingEnabled, searchEnabled, stripReferenceMarkersEnabled(), h.parserV2Mode(), toolNames, toolsRaw, false, historySession) initialType := "text" if thinkingEnabled { @@ -80,6 +80,7 @@ type geminiStreamRuntime struct { parserV2Mode string toolNames []string toolsRaw any + disableToolCalling bool ctx context.Context accumulator *assistantturn.Accumulator @@ -116,7 +117,7 @@ func (h *Handler) handleStreamGenerateContentWithRetry(w http.ResponseWriter, r searchEnabled := stdReq.Search toolNames := stdReq.ToolNames toolsRaw := stdReq.ToolsRaw - runtime := newGeminiStreamRuntime(w, rc, canFlush, r.Context(), model, finalPrompt, thinkingEnabled, searchEnabled, stripReferenceMarkersEnabled(), h.parserV2Mode(), toolNames, toolsRaw, historySession) + runtime := newGeminiStreamRuntime(w, rc, canFlush, r.Context(), model, finalPrompt, thinkingEnabled, searchEnabled, stripReferenceMarkersEnabled(), h.parserV2Mode(), toolNames, toolsRaw, stdReq.ToolCallingDisabled, historySession) runtime.onFirstByte = func() { observe.SetFirstByteAt(r.Context(), time.Now()) } completionruntime.ExecuteStreamWithRetry(r.Context(), h.DS, a, resp, payload, pow, completionruntime.StreamRetryOptions{ @@ -186,6 +187,7 @@ func newGeminiStreamRuntime( parserV2Mode string, toolNames []string, toolsRaw any, + disableToolCalling bool, history *responsehistory.Session, ) *geminiStreamRuntime { return &geminiStreamRuntime{ @@ -202,6 +204,7 @@ func newGeminiStreamRuntime( parserV2Mode: parserV2Mode, toolNames: toolNames, toolsRaw: toolsRaw, + disableToolCalling: disableToolCalling, history: history, accumulator: assistantturn.NewAccumulator(assistantturn.AccumulatorOptions{ ThinkingEnabled: thinkingEnabled, @@ -337,6 +340,7 @@ func (s *geminiStreamRuntime) finalize(deferEmptyOutput bool) bool { StripReferenceMarkers: s.stripReferenceMarkers, ToolNames: s.toolNames, ToolsRaw: s.toolsRaw, + DisableToolCalling: s.disableToolCalling, ParserV2Mode: s.parserV2Mode, Ctx: s.ctx, }) diff --git a/internal/httpapi/openai/chat/chat_stream_runtime.go b/internal/httpapi/openai/chat/chat_stream_runtime.go index a3a8b09a4..410d65cb2 100644 --- a/internal/httpapi/openai/chat/chat_stream_runtime.go +++ b/internal/httpapi/openai/chat/chat_stream_runtime.go @@ -22,14 +22,15 @@ type chatStreamRuntime struct { canFlush bool ctx context.Context - completionID string - created int64 - model string - finalPrompt string - refFileTokens int - toolNames []string - toolsRaw any - toolChoice promptcompat.ToolChoicePolicy + completionID string + created int64 + model string + finalPrompt string + refFileTokens int + toolNames []string + toolsRaw any + toolChoice promptcompat.ToolChoicePolicy + disableToolCalling bool thinkingEnabled bool searchEnabled bool @@ -100,6 +101,7 @@ func newChatStreamRuntime( toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, + disableToolCalling bool, bufferToolContent bool, emitEarlyToolDeltas bool, ) *chatStreamRuntime { @@ -116,6 +118,7 @@ func newChatStreamRuntime( toolNames: toolNames, toolsRaw: toolsRaw, toolChoice: toolChoice, + disableToolCalling: disableToolCalling, thinkingEnabled: thinkingEnabled, searchEnabled: searchEnabled, stripReferenceMarkers: stripReferenceMarkers, @@ -251,6 +254,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool) ToolNames: s.toolNames, ToolsRaw: s.toolsRaw, ToolChoice: s.toolChoice, + DisableToolCalling: s.disableToolCalling, ParserV2Mode: s.parserV2Mode, Ctx: s.ctx, }) diff --git a/internal/httpapi/openai/chat/chat_stream_runtime_test.go b/internal/httpapi/openai/chat/chat_stream_runtime_test.go index 23e9559ff..527db8eef 100644 --- a/internal/httpapi/openai/chat/chat_stream_runtime_test.go +++ b/internal/httpapi/openai/chat/chat_stream_runtime_test.go @@ -30,6 +30,7 @@ func TestChatStreamKeepAliveUsesCommentOnly(t *testing.T) { promptcompat.DefaultToolChoicePolicy(), false, false, + false, ) runtime.sendKeepAlive() @@ -65,6 +66,7 @@ func TestChatStreamFinalizeEnforcesRequiredToolChoice(t *testing.T) { []string{"Write"}, nil, promptcompat.ToolChoicePolicy{Mode: promptcompat.ToolChoiceRequired}, + false, true, false, ) diff --git a/internal/httpapi/openai/chat/empty_retry_runtime.go b/internal/httpapi/openai/chat/empty_retry_runtime.go index fbcbcde2b..d00c274d6 100644 --- a/internal/httpapi/openai/chat/empty_retry_runtime.go +++ b/internal/httpapi/openai/chat/empty_retry_runtime.go @@ -75,7 +75,7 @@ func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, toolNames := stdReq.ToolNames toolsRaw := stdReq.ToolsRaw toolChoice := stdReq.ToolChoice - streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, r.Context(), resp, completionID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, historySession) + streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, r.Context(), resp, completionID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, stdReq.ToolCallingDisabled, historySession) if !ok { return } @@ -118,7 +118,7 @@ func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, }) } -func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) { +func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, disableToolCalling bool, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) { if resp.StatusCode != http.StatusOK { defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) @@ -144,7 +144,7 @@ func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, ctx context.Co streamRuntime := newChatStreamRuntime( w, rc, canFlush, ctx, h.parserV2Mode(), completionID, time.Now().Unix(), model, finalPrompt, thinkingEnabled, searchEnabled, stripReferenceMarkersEnabled(), toolNames, toolsRaw, - toolChoice, + toolChoice, disableToolCalling, len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(), ) streamRuntime.refFileTokens = refFileTokens diff --git a/internal/httpapi/openai/chat/empty_retry_runtime_test.go b/internal/httpapi/openai/chat/empty_retry_runtime_test.go index 0db92baa7..aedd80e59 100644 --- a/internal/httpapi/openai/chat/empty_retry_runtime_test.go +++ b/internal/httpapi/openai/chat/empty_retry_runtime_test.go @@ -54,6 +54,7 @@ func TestConsumeChatStreamAttemptMarksContextCancelledState(t *testing.T) { promptcompat.DefaultToolChoicePolicy(), false, false, + false, ) resp := makeOpenAISSEHTTPResponse( `data: {"p":"response/content","v":"hello"}`, diff --git a/internal/httpapi/openai/chat/handler_chat.go b/internal/httpapi/openai/chat/handler_chat.go index 7752cdf84..7bfb02f23 100644 --- a/internal/httpapi/openai/chat/handler_chat.go +++ b/internal/httpapi/openai/chat/handler_chat.go @@ -244,6 +244,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt toolNames, toolsRaw, promptcompat.DefaultToolChoicePolicy(), + false, bufferToolContent, emitEarlyToolDeltas, ) diff --git a/internal/httpapi/openai/responses/empty_retry_runtime.go b/internal/httpapi/openai/responses/empty_retry_runtime.go index 07e7062a8..5e13a1c87 100644 --- a/internal/httpapi/openai/responses/empty_retry_runtime.go +++ b/internal/httpapi/openai/responses/empty_retry_runtime.go @@ -26,7 +26,7 @@ func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http. toolsRaw := stdReq.ToolsRaw toolChoice := stdReq.ToolChoice onFirstByte := func() { observe.SetFirstByteAt(r.Context(), time.Now()) } - streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, r.Context(), resp, owner, responseID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, traceID, historySession, onFirstByte) + streamRuntime, initialType, ok := h.prepareResponsesStreamRuntime(w, r.Context(), resp, owner, responseID, model, finalPrompt, refFileTokens, thinkingEnabled, searchEnabled, toolNames, toolsRaw, toolChoice, stdReq.ToolCallingDisabled, traceID, historySession, onFirstByte) if !ok { return } @@ -62,7 +62,7 @@ func (h *Handler) handleResponsesStreamWithRetry(w http.ResponseWriter, r *http. }) } -func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, ctx context.Context, resp *http.Response, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string, historySession *responsehistory.Session, onFirstByte func()) (*responsesStreamRuntime, string, bool) { +func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, ctx context.Context, resp *http.Response, owner, responseID, model, finalPrompt string, refFileTokens int, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, disableToolCalling bool, traceID string, historySession *responsehistory.Session, onFirstByte func()) (*responsesStreamRuntime, string, bool) { if resp.StatusCode != http.StatusOK { defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) @@ -86,7 +86,7 @@ func (h *Handler) prepareResponsesStreamRuntime(w http.ResponseWriter, ctx conte w, rc, canFlush, ctx, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, stripReferenceMarkersEnabled(), h.parserV2Mode(), toolNames, toolsRaw, len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(), - toolChoice, traceID, func(obj map[string]any) { + toolChoice, disableToolCalling, traceID, func(obj map[string]any) { h.getResponseStore().put(owner, responseID, obj) }, historySession, ) diff --git a/internal/httpapi/openai/responses/empty_retry_runtime_test.go b/internal/httpapi/openai/responses/empty_retry_runtime_test.go index aa7a1b1db..959c873ff 100644 --- a/internal/httpapi/openai/responses/empty_retry_runtime_test.go +++ b/internal/httpapi/openai/responses/empty_retry_runtime_test.go @@ -47,6 +47,7 @@ func TestConsumeResponsesStreamAttemptMarksContextCancelledState(t *testing.T) { false, false, promptcompat.DefaultToolChoicePolicy(), + false, "", nil, nil, diff --git a/internal/httpapi/openai/responses/responses_handler.go b/internal/httpapi/openai/responses/responses_handler.go index 5ebeed077..16b4d0b9c 100644 --- a/internal/httpapi/openai/responses/responses_handler.go +++ b/internal/httpapi/openai/responses/responses_handler.go @@ -217,6 +217,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, bufferToolContent, emitEarlyToolDeltas, toolChoice, + false, traceID, func(obj map[string]any) { h.getResponseStore().put(owner, responseID, obj) diff --git a/internal/httpapi/openai/responses/responses_stream_runtime_core.go b/internal/httpapi/openai/responses/responses_stream_runtime_core.go index 1cbb7ebc0..9cee4efd0 100644 --- a/internal/httpapi/openai/responses/responses_stream_runtime_core.go +++ b/internal/httpapi/openai/responses/responses_stream_runtime_core.go @@ -24,14 +24,15 @@ type responsesStreamRuntime struct { rc *http.ResponseController canFlush bool - responseID string - model string - finalPrompt string - refFileTokens int - toolNames []string - toolsRaw any - traceID string - toolChoice promptcompat.ToolChoicePolicy + responseID string + model string + finalPrompt string + refFileTokens int + toolNames []string + toolsRaw any + traceID string + toolChoice promptcompat.ToolChoicePolicy + disableToolCalling bool thinkingEnabled bool searchEnabled bool @@ -88,6 +89,7 @@ func newResponsesStreamRuntime( bufferToolContent bool, emitEarlyToolDeltas bool, toolChoice promptcompat.ToolChoicePolicy, + disableToolCalling bool, traceID string, persistResponse func(obj map[string]any), history *responsehistory.Session, @@ -108,6 +110,8 @@ func newResponsesStreamRuntime( toolsRaw: toolsRaw, bufferToolContent: bufferToolContent, emitEarlyToolDeltas: emitEarlyToolDeltas, + toolChoice: toolChoice, + disableToolCalling: disableToolCalling, streamToolCallIDs: map[int]string{}, functionItemIDs: map[int]string{}, functionOutputIDs: map[int]int{}, @@ -116,7 +120,6 @@ func newResponsesStreamRuntime( functionAdded: map[int]bool{}, functionNames: map[int]string{}, messageOutputID: -1, - toolChoice: toolChoice, traceID: traceID, persistResponse: persistResponse, history: history, @@ -197,6 +200,7 @@ func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput ToolNames: s.toolNames, ToolsRaw: s.toolsRaw, ToolChoice: s.toolChoice, + DisableToolCalling: s.disableToolCalling, ParserV2Mode: s.parserV2Mode, Ctx: s.ctx, }) diff --git a/internal/promptcompat/request_normalize.go b/internal/promptcompat/request_normalize.go index 88bfb66b8..d83ebccb4 100644 --- a/internal/promptcompat/request_normalize.go +++ b/internal/promptcompat/request_normalize.go @@ -32,29 +32,36 @@ func NormalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID if responseModel == "" { responseModel = resolvedModel } + toolsRaw, toolChoiceRaw, toolCallingEnabled, err := ResolveToolCallingRequest(store, req["tools"], req["tool_choice"]) + if err != nil { + return StandardRequest{}, err + } toolPolicy := DefaultToolChoicePolicy() - finalPrompt, toolNames := BuildOpenAIPrompt(messagesRaw, req["tools"], traceID, toolPolicy, thinkingEnabled, store.ContextEngineMode()) - toolNames = ensureToolDetectionEnabled(toolNames, req["tools"]) + finalPrompt, toolNames := BuildOpenAIPrompt(messagesRaw, toolsRaw, traceID, toolPolicy, thinkingEnabled, store.ContextEngineMode()) + toolNames = ensureToolDetectionEnabled(toolNames, toolsRaw) passThrough := collectOpenAIChatPassThrough(req) refFileIDs := CollectOpenAIRefFileIDs(req) + _ = toolChoiceRaw return StandardRequest{ - Surface: "openai_chat", - RequestedModel: strings.TrimSpace(model), - ResolvedModel: resolvedModel, - ResponseModel: responseModel, - Messages: messagesRaw, - PromptTokenText: finalPrompt, - ToolsRaw: req["tools"], - FinalPrompt: finalPrompt, - ToolNames: toolNames, - ToolChoice: toolPolicy, - Stream: util.ToBool(req["stream"]), - Thinking: thinkingEnabled, - Search: searchEnabled, - RefFileIDs: refFileIDs, - RefFileTokens: estimateInlineFileTokens(req), - PassThrough: passThrough, + Surface: "openai_chat", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: resolvedModel, + ResponseModel: responseModel, + Messages: messagesRaw, + PromptTokenText: finalPrompt, + ToolsRaw: toolsRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + ToolChoice: toolPolicy, + ToolCallingEnabled: toolCallingEnabled, + ToolCallingDisabled: !toolCallingEnabled, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + RefFileIDs: refFileIDs, + RefFileTokens: estimateInlineFileTokens(req), + PassThrough: passThrough, }, nil } @@ -78,12 +85,16 @@ func NormalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra if len(messagesRaw) == 0 { return StandardRequest{}, fmt.Errorf("request must include 'input' or 'messages'") } - toolPolicy, err := parseToolChoicePolicy(req["tool_choice"], req["tools"]) + toolsRaw, toolChoiceRaw, toolCallingEnabled, err := ResolveToolCallingRequest(store, req["tools"], req["tool_choice"]) if err != nil { return StandardRequest{}, err } - finalPrompt, toolNames := BuildOpenAIPrompt(messagesRaw, req["tools"], traceID, toolPolicy, thinkingEnabled, store.ContextEngineMode()) - toolNames = ensureToolDetectionEnabled(toolNames, req["tools"]) + toolPolicy, err := parseToolChoicePolicy(toolChoiceRaw, toolsRaw) + if err != nil { + return StandardRequest{}, err + } + finalPrompt, toolNames := BuildOpenAIPrompt(messagesRaw, toolsRaw, traceID, toolPolicy, thinkingEnabled, store.ContextEngineMode()) + toolNames = ensureToolDetectionEnabled(toolNames, toolsRaw) if !toolPolicy.IsNone() { toolPolicy.Allowed = namesToSet(toolNames) } @@ -91,25 +102,83 @@ func NormalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra refFileIDs := CollectOpenAIRefFileIDs(req) return StandardRequest{ - Surface: "openai_responses", - RequestedModel: model, - ResolvedModel: resolvedModel, - ResponseModel: model, - Messages: messagesRaw, - PromptTokenText: finalPrompt, - ToolsRaw: req["tools"], - FinalPrompt: finalPrompt, - ToolNames: toolNames, - ToolChoice: toolPolicy, - Stream: util.ToBool(req["stream"]), - Thinking: thinkingEnabled, - Search: searchEnabled, - RefFileIDs: refFileIDs, - RefFileTokens: estimateInlineFileTokens(req), - PassThrough: passThrough, + Surface: "openai_responses", + RequestedModel: model, + ResolvedModel: resolvedModel, + ResponseModel: model, + Messages: messagesRaw, + PromptTokenText: finalPrompt, + ToolsRaw: toolsRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + ToolChoice: toolPolicy, + ToolCallingEnabled: toolCallingEnabled, + ToolCallingDisabled: !toolCallingEnabled, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + RefFileIDs: refFileIDs, + RefFileTokens: estimateInlineFileTokens(req), + PassThrough: passThrough, }, nil } +type toolCallingReader interface { + ToolCallingEnabled() bool + ToolCallingDisabledBehavior() string +} + +func ResolveToolCallingRequest(store any, toolsRaw any, toolChoiceRaw any) (any, any, bool, error) { + if toolCallingEnabled(store) { + return toolsRaw, toolChoiceRaw, true, nil + } + if !requestUsesToolCalling(toolsRaw, toolChoiceRaw) { + return nil, toolChoiceRaw, false, nil + } + if toolCallingDisabledBehavior(store) == "ignore_tools" { + return nil, "none", false, nil + } + return nil, nil, false, fmt.Errorf("tool calling is disabled") +} + +func toolCallingEnabled(store any) bool { + if r, ok := store.(toolCallingReader); ok { + return r.ToolCallingEnabled() + } + return true +} + +func toolCallingDisabledBehavior(store any) string { + if r, ok := store.(toolCallingReader); ok { + behavior := strings.ToLower(strings.TrimSpace(r.ToolCallingDisabledBehavior())) + if behavior == "ignore_tools" { + return behavior + } + } + return "reject" +} + +func requestUsesToolCalling(toolsRaw any, toolChoiceRaw any) bool { + if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 { + return true + } + switch v := toolChoiceRaw.(type) { + case nil: + return false + case string: + mode := strings.ToLower(strings.TrimSpace(v)) + return mode != "" && mode != "none" && mode != "auto" + case map[string]any: + typ := strings.ToLower(strings.TrimSpace(asString(v["type"]))) + if typ == "" || typ == "none" { + return hasFunctionSelector(v) || v["allowed_tools"] != nil + } + return true + default: + return true + } +} + func ensureToolDetectionEnabled(toolNames []string, toolsRaw any) []string { if len(toolNames) > 0 { return toolNames diff --git a/internal/promptcompat/standard_request.go b/internal/promptcompat/standard_request.go index 7a69b3af2..3760421e3 100644 --- a/internal/promptcompat/standard_request.go +++ b/internal/promptcompat/standard_request.go @@ -21,6 +21,8 @@ type StandardRequest struct { FinalPrompt string ToolNames []string ToolChoice ToolChoicePolicy + ToolCallingEnabled bool + ToolCallingDisabled bool Stream bool Thinking bool Search bool diff --git a/internal/promptcompat/tool_calling_settings_test.go b/internal/promptcompat/tool_calling_settings_test.go new file mode 100644 index 000000000..efd219f34 --- /dev/null +++ b/internal/promptcompat/tool_calling_settings_test.go @@ -0,0 +1,94 @@ +package promptcompat + +import ( + "strings" + "testing" +) + +type toolCallingTestConfig struct { + enabled bool + behavior string +} + +func (toolCallingTestConfig) ModelAliases() map[string]string { return nil } +func (toolCallingTestConfig) ContextEngineMode() string { return "off" } +func (c toolCallingTestConfig) ToolCallingEnabled() bool { return c.enabled } +func (c toolCallingTestConfig) ToolCallingDisabledBehavior() string { + return c.behavior +} + +func TestNormalizeOpenAIChatRejectsToolsWhenToolCallingDisabled(t *testing.T) { + req := map[string]any{ + "model": "deepseek-v4-flash", + "messages": []any{map[string]any{"role": "user", "content": "hello"}}, + "tools": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "lookup", + "description": "lookup", + }, + }, + }, + } + _, err := NormalizeOpenAIChatRequest(toolCallingTestConfig{enabled: false, behavior: "reject"}, req, "") + if err == nil || !strings.Contains(err.Error(), "tool calling is disabled") { + t.Fatalf("expected disabled tool calling error, got %v", err) + } +} + +func TestNormalizeOpenAIChatIgnoresToolsWhenConfigured(t *testing.T) { + req := map[string]any{ + "model": "deepseek-v4-flash", + "messages": []any{map[string]any{"role": "user", "content": "hello"}}, + "tools": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "lookup", + "description": "lookup", + }, + }, + }, + } + out, err := NormalizeOpenAIChatRequest(toolCallingTestConfig{enabled: false, behavior: "ignore_tools"}, req, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out.ToolCallingEnabled || !out.ToolCallingDisabled { + t.Fatalf("expected tool calling disabled flags, got enabled=%v disabled=%v", out.ToolCallingEnabled, out.ToolCallingDisabled) + } + if out.ToolsRaw != nil || len(out.ToolNames) != 0 { + t.Fatalf("expected tools to be stripped, got raw=%#v names=%v", out.ToolsRaw, out.ToolNames) + } + if strings.Contains(out.FinalPrompt, "lookup") { + t.Fatalf("expected tool prompt to be omitted, got %q", out.FinalPrompt) + } +} + +func TestNormalizeOpenAIResponsesIgnoresRequiredToolChoiceWhenConfigured(t *testing.T) { + req := map[string]any{ + "model": "deepseek-v4-flash", + "input": "hello", + "tools": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "lookup", + "description": "lookup", + }, + }, + }, + "tool_choice": "required", + } + out, err := NormalizeOpenAIResponsesRequest(toolCallingTestConfig{enabled: false, behavior: "ignore_tools"}, req, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out.ToolChoice.Mode != ToolChoiceNone { + t.Fatalf("expected tool_choice none after stripping tools, got %#v", out.ToolChoice) + } + if out.ToolsRaw != nil || len(out.ToolNames) != 0 { + t.Fatalf("expected tools to be stripped, got raw=%#v names=%v", out.ToolsRaw, out.ToolNames) + } +} diff --git a/webui/src/features/settings/BehaviorSection.jsx b/webui/src/features/settings/BehaviorSection.jsx index 6b907afb5..9ddd921d2 100644 --- a/webui/src/features/settings/BehaviorSection.jsx +++ b/webui/src/features/settings/BehaviorSection.jsx @@ -31,7 +31,7 @@ export default function BehaviorSection({ t, form, setForm }) {