From d148109224d539b0e1c57c41fcb52150b956866e Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Tue, 23 Jun 2026 09:41:44 +0800 Subject: [PATCH] Add custom LLM extra headers --- .../app/src-tauri/src/commands/credentials.rs | 10 ++ .../app/src-tauri/src/commands/mod.rs | 52 ++++++++++ .../app/src-tauri/src/commands/providers.rs | 19 +++- openless-all/app/src-tauri/src/coordinator.rs | 3 +- .../src-tauri/src/persistence/credentials.rs | 98 ++++++++++++++++++- openless-all/app/src-tauri/src/polish.rs | 5 + openless-all/app/src/i18n/en.ts | 2 + openless-all/app/src/i18n/ja.ts | 2 + openless-all/app/src/i18n/ko.ts | 2 + openless-all/app/src/i18n/zh-CN.ts | 2 + openless-all/app/src/i18n/zh-TW.ts | 2 + .../src/pages/settings/ProvidersSection.tsx | 10 ++ 12 files changed, 203 insertions(+), 4 deletions(-) diff --git a/openless-all/app/src-tauri/src/commands/credentials.rs b/openless-all/app/src-tauri/src/commands/credentials.rs index 164b6ac2..d5b2980b 100644 --- a/openless-all/app/src-tauri/src/commands/credentials.rs +++ b/openless-all/app/src-tauri/src/commands/credentials.rs @@ -1,5 +1,7 @@ use super::*; +const LLM_EXTRA_HEADERS_ACCOUNT: &str = "ark.extra_headers"; + #[tauri::command] pub fn get_credentials() -> CredentialsStatus { let snap = CredentialsVault::snapshot(); @@ -153,6 +155,11 @@ pub(crate) async fn release_sherpa_runtime_if_inactive( #[tauri::command] pub fn set_credential(window: Window, account: String, value: String) -> Result<(), String> { ensure_main_window(&window)?; + if account == LLM_EXTRA_HEADERS_ACCOUNT { + CredentialsVault::set_active_llm_extra_headers_json(&value).map_err(|e| e.to_string())?; + let _ = window.emit("credentials:changed", ()); + return Ok(()); + } let acc = parse_account(&account)?; if value.is_empty() { CredentialsVault::remove(acc).map_err(|e| e.to_string())?; @@ -236,6 +243,9 @@ pub fn set_active_llm_provider(provider: String) -> Result<(), String> { #[tauri::command] pub fn read_credential(window: Window, account: String) -> Result, String> { ensure_main_window(&window)?; + if account == LLM_EXTRA_HEADERS_ACCOUNT { + return CredentialsVault::get_active_llm_extra_headers_json().map_err(|e| e.to_string()); + } let acc = parse_account(&account)?; CredentialsVault::get(acc).map_err(|e| e.to_string()) } diff --git a/openless-all/app/src-tauri/src/commands/mod.rs b/openless-all/app/src-tauri/src/commands/mod.rs index 2c128ba5..47105f7c 100644 --- a/openless-all/app/src-tauri/src/commands/mod.rs +++ b/openless-all/app/src-tauri/src/commands/mod.rs @@ -1324,6 +1324,7 @@ mod tests { let models = fetch_provider_models(&ProviderConfig { base_url: format!("http://{}", addr), api_key: String::new(), + extra_headers: Default::default(), }) .await .unwrap(); @@ -1332,6 +1333,57 @@ mod tests { server.join().unwrap(); } + #[tokio::test] + async fn fetch_provider_models_sends_extra_headers() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let server = thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + let mut buf = [0u8; 8192]; + let mut request = Vec::new(); + loop { + let n = stream.read(&mut buf).unwrap(); + if n == 0 { + break; + } + request.extend_from_slice(&buf[..n]); + if request.windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let request_text = String::from_utf8_lossy(&request); + assert!(request_text + .to_ascii_lowercase() + .contains("ocp-apim-subscription-key: secret")); + assert!(!request_text.contains("Authorization: Bearer")); + + let body = r#"{"data":[{"id":"m1"}]}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes()).unwrap(); + }); + + let models = fetch_provider_models(&ProviderConfig { + base_url: format!("http://{}", addr), + api_key: String::new(), + extra_headers: [( + "custom-head".to_string(), + "secret".to_string(), + )] + .into_iter() + .collect(), + }) + .await + .unwrap(); + + assert_eq!(models, vec!["m1".to_string()]); + server.join().unwrap(); + } + #[test] fn is_valid_session_id_accepts_canonical_uuid_v4() { // canonical UUID-v4 字面:8-4-4-4-12,全小写、全大写、混合都接受。 diff --git a/openless-all/app/src-tauri/src/commands/providers.rs b/openless-all/app/src-tauri/src/commands/providers.rs index faa830dd..d6e29f37 100644 --- a/openless-all/app/src-tauri/src/commands/providers.rs +++ b/openless-all/app/src-tauri/src/commands/providers.rs @@ -1,4 +1,5 @@ use super::*; +use std::collections::HashMap; #[derive(Serialize)] #[serde(rename_all = "camelCase")] @@ -55,6 +56,7 @@ pub async fn list_provider_models(kind: String) -> Result, } fn read_openai_provider_config(kind: &str) -> Result { @@ -77,6 +79,11 @@ fn read_openai_provider_config(kind: &str) -> Result { let base_url = CredentialsVault::get(endpoint_account) .map_err(|e| e.to_string())? .unwrap_or_default(); + let extra_headers = if kind == "llm" { + CredentialsVault::get_active_llm_extra_headers() + } else { + HashMap::new() + }; if api_key_required && api_key.trim().is_empty() { return Err("API Key 为空".to_string()); } @@ -90,7 +97,11 @@ fn read_openai_provider_config(kind: &str) -> Result { // (asr/llm) 连通性测试与 list_provider_models 模型列表两条 HTTP 路径。 crate::coordinator::validate_llm_endpoint(&base_url) .map_err(|_| "endpointInvalid".to_string())?; - Ok(ProviderConfig { base_url, api_key }) + Ok(ProviderConfig { + base_url, + api_key, + extra_headers, + }) } async fn validate_llm_provider() -> Result<(), String> { @@ -142,7 +153,8 @@ async fn validate_llm_provider() -> Result<(), String> { config.api_key, model, ) - .with_thinking_enabled(llm_thinking_enabled), + .with_thinking_enabled(llm_thinking_enabled) + .with_extra_headers(config.extra_headers), ); provider .polish( @@ -419,6 +431,9 @@ pub(crate) async fn fetch_provider_models(config: &ProviderConfig) -> Result anyhow::Result HashMap { + root.providers + .llm + .get(&root.active.llm) + .and_then(|entry| entry.extraHeaders.clone()) + .unwrap_or_default() +} + +fn active_llm_extra_headers_json(root: &CredsRoot) -> Result> { + let headers = active_llm_extra_headers(root); + if headers.is_empty() { + return Ok(None); + } + let ordered = headers.into_iter().collect::>(); + serde_json::to_string(&ordered) + .map(Some) + .context("encode LLM extra headers") +} + +fn parse_extra_headers_json(value: &str) -> Result> { + let trimmed = value.trim(); + if trimmed.is_empty() { + return Ok(HashMap::new()); + } + + let raw: HashMap = + serde_json::from_str(trimmed).context("extra headers must be a JSON object")?; + let mut headers = HashMap::new(); + for (key, value) in raw { + let key = key.trim(); + if key.is_empty() { + anyhow::bail!("extra header name cannot be empty"); + } + if !is_valid_header_name(key) { + anyhow::bail!("invalid extra header name: {key}"); + } + let Some(value) = value.as_str() else { + anyhow::bail!("extra header value for {key} must be a string"); + }; + if value.contains('\r') || value.contains('\n') { + anyhow::bail!("extra header value for {key} cannot contain line breaks"); + } + headers.insert(key.to_string(), value.to_string()); + } + Ok(headers) +} + +fn is_valid_header_name(name: &str) -> bool { + !name.is_empty() + && name.bytes().all(|b| { + matches!( + b, + b'!' | b'#' + | b'$' + | b'%' + | b'&' + | b'\'' + | b'*' + | b'+' + | b'-' + | b'.' + | b'^' + | b'_' + | b'`' + | b'|' + | b'~' + | b'0'..=b'9' + | b'a'..=b'z' + | b'A'..=b'Z' + ) + }) +} + fn credentials_path() -> Result { // macOS / Linux: ~/.openless/credentials.json (与 Swift 同源) // Windows: %APPDATA%\OpenLess\credentials.json (Windows 没有标准 HOME 环境变量) @@ -841,6 +914,29 @@ impl CredentialsVault { load_credentials().active.llm } + pub fn get_active_llm_extra_headers() -> HashMap { + let _guard = credentials_lock().lock(); + active_llm_extra_headers(&load_credentials()) + } + + pub fn get_active_llm_extra_headers_json() -> Result> { + let _guard = credentials_lock().lock(); + active_llm_extra_headers_json(&load_credentials()) + } + + pub fn set_active_llm_extra_headers_json(value: &str) -> Result<()> { + let _guard = credentials_lock().lock(); + let headers = parse_extra_headers_json(value)?; + let mut root = load_credentials_for_update()?; + let entry = root.providers.llm.entry(root.active.llm.clone()).or_default(); + entry.extraHeaders = if headers.is_empty() { + None + } else { + Some(headers) + }; + save_credentials(&root) + } + pub fn snapshot() -> CredentialsSnapshot { let _guard = credentials_lock().lock(); let root = load_credentials(); diff --git a/openless-all/app/src-tauri/src/polish.rs b/openless-all/app/src-tauri/src/polish.rs index 2f7c9a66..0b5056fe 100644 --- a/openless-all/app/src-tauri/src/polish.rs +++ b/openless-all/app/src-tauri/src/polish.rs @@ -69,6 +69,11 @@ impl OpenAICompatibleConfig { self.thinking_enabled = enabled; self } + + pub fn with_extra_headers(mut self, extra_headers: HashMap) -> Self { + self.extra_headers = extra_headers; + self + } } #[derive(Debug, Error)] diff --git a/openless-all/app/src/i18n/en.ts b/openless-all/app/src/i18n/en.ts index 6a17fecf..7ec93bf1 100644 --- a/openless-all/app/src/i18n/en.ts +++ b/openless-all/app/src/i18n/en.ts @@ -767,6 +767,8 @@ export const en: typeof zhCN = { apiKeyLabel: 'API Key', baseUrlLabel: 'Base URL', modelLabel: 'Model', + extraHeadersLabel: 'Extra headers', + extraHeadersPlaceholder: '{"custom-head":"..."}', thinkingModeLabel: 'Thinking', thinkingModeOn: 'On', thinkingModeOff: 'Off', diff --git a/openless-all/app/src/i18n/ja.ts b/openless-all/app/src/i18n/ja.ts index f4c25e04..f5e6c1f5 100644 --- a/openless-all/app/src/i18n/ja.ts +++ b/openless-all/app/src/i18n/ja.ts @@ -769,6 +769,8 @@ export const ja: typeof zhCN = { apiKeyLabel: 'API キー', baseUrlLabel: 'エンドポイント', modelLabel: 'モデル', + extraHeadersLabel: '追加 Headers', + extraHeadersPlaceholder: '{"custom-head":"..."}', thinkingModeLabel: '思考', thinkingModeOn: 'オン', thinkingModeOff: 'オフ', diff --git a/openless-all/app/src/i18n/ko.ts b/openless-all/app/src/i18n/ko.ts index 95b74594..86f588b4 100644 --- a/openless-all/app/src/i18n/ko.ts +++ b/openless-all/app/src/i18n/ko.ts @@ -769,6 +769,8 @@ export const ko: typeof zhCN = { apiKeyLabel: 'API 키', baseUrlLabel: '엔드포인트', modelLabel: '모델', + extraHeadersLabel: '추가 Headers', + extraHeadersPlaceholder: '{"custom-head":"..."}', thinkingModeLabel: '사고', thinkingModeOn: '켜짐', thinkingModeOff: '꺼짐', diff --git a/openless-all/app/src/i18n/zh-CN.ts b/openless-all/app/src/i18n/zh-CN.ts index 2d87e8ab..fbd3753e 100644 --- a/openless-all/app/src/i18n/zh-CN.ts +++ b/openless-all/app/src/i18n/zh-CN.ts @@ -765,6 +765,8 @@ export const zhCN = { apiKeyLabel: 'API 密钥', baseUrlLabel: '接口地址', modelLabel: '模型', + extraHeadersLabel: '额外 Headers', + extraHeadersPlaceholder: '{"custom-head":"..."}', thinkingModeLabel: '思考', thinkingModeOn: '开启', thinkingModeOff: '关闭', diff --git a/openless-all/app/src/i18n/zh-TW.ts b/openless-all/app/src/i18n/zh-TW.ts index 4390c7d4..c9a7362e 100644 --- a/openless-all/app/src/i18n/zh-TW.ts +++ b/openless-all/app/src/i18n/zh-TW.ts @@ -767,6 +767,8 @@ export const zhTW: typeof zhCN = { apiKeyLabel: 'API 密鑰', baseUrlLabel: '接口地址', modelLabel: '模型', + extraHeadersLabel: '額外 Headers', + extraHeadersPlaceholder: '{"custom-head":"..."}', thinkingModeLabel: '思考', thinkingModeOn: '開啟', thinkingModeOff: '關閉', diff --git a/openless-all/app/src/pages/settings/ProvidersSection.tsx b/openless-all/app/src/pages/settings/ProvidersSection.tsx index cf5d9a1d..1a4f878d 100644 --- a/openless-all/app/src/pages/settings/ProvidersSection.tsx +++ b/openless-all/app/src/pages/settings/ProvidersSection.tsx @@ -376,6 +376,16 @@ export function ProvidersSection({ kind = 'all' }: ProvidersSectionProps = {}) { + {committedLlmProvider === 'custom' && ( + + )} )}