Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions openless-all/app/src-tauri/src/commands/credentials.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use super::*;

const LLM_EXTRA_HEADERS_ACCOUNT: &str = "ark.extra_headers";

#[tauri::command]
pub fn get_credentials() -> CredentialsStatus {
let snap = CredentialsVault::snapshot();
Expand Down Expand Up @@ -153,6 +155,11 @@ pub(crate) async fn release_sherpa_runtime_if_inactive(
#[tauri::command]
pub fn set_credential(window: Window, account: String, value: String) -> Result<(), String> {
ensure_main_window(&window)?;
if account == LLM_EXTRA_HEADERS_ACCOUNT {
CredentialsVault::set_active_llm_extra_headers_json(&value).map_err(|e| e.to_string())?;
let _ = window.emit("credentials:changed", ());
return Ok(());
}
let acc = parse_account(&account)?;
if value.is_empty() {
CredentialsVault::remove(acc).map_err(|e| e.to_string())?;
Expand Down Expand Up @@ -236,6 +243,9 @@ pub fn set_active_llm_provider(provider: String) -> Result<(), String> {
#[tauri::command]
pub fn read_credential(window: Window, account: String) -> Result<Option<String>, String> {
ensure_main_window(&window)?;
if account == LLM_EXTRA_HEADERS_ACCOUNT {
return CredentialsVault::get_active_llm_extra_headers_json().map_err(|e| e.to_string());
}
let acc = parse_account(&account)?;
CredentialsVault::get(acc).map_err(|e| e.to_string())
}
Expand Down
52 changes: 52 additions & 0 deletions openless-all/app/src-tauri/src/commands/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,7 @@ mod tests {
let models = fetch_provider_models(&ProviderConfig {
base_url: format!("http://{}", addr),
api_key: String::new(),
extra_headers: Default::default(),
})
.await
.unwrap();
Expand All @@ -1332,6 +1333,57 @@ mod tests {
server.join().unwrap();
}

#[tokio::test]
async fn fetch_provider_models_sends_extra_headers() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();

let server = thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let mut buf = [0u8; 8192];
let mut request = Vec::new();
loop {
let n = stream.read(&mut buf).unwrap();
if n == 0 {
break;
}
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
let request_text = String::from_utf8_lossy(&request);
assert!(request_text
.to_ascii_lowercase()
.contains("ocp-apim-subscription-key: secret"));
assert!(!request_text.contains("Authorization: Bearer"));

let body = r#"{"data":[{"id":"m1"}]}"#;
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
);
stream.write_all(response.as_bytes()).unwrap();
});

let models = fetch_provider_models(&ProviderConfig {
base_url: format!("http://{}", addr),
api_key: String::new(),
extra_headers: [(
"custom-head".to_string(),
"secret".to_string(),
)]
.into_iter()
.collect(),
})
.await
.unwrap();

assert_eq!(models, vec!["m1".to_string()]);
server.join().unwrap();
}

#[test]
fn is_valid_session_id_accepts_canonical_uuid_v4() {
// canonical UUID-v4 字面:8-4-4-4-12,全小写、全大写、混合都接受。
Expand Down
19 changes: 17 additions & 2 deletions openless-all/app/src-tauri/src/commands/providers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use std::collections::HashMap;

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -55,6 +56,7 @@ pub async fn list_provider_models(kind: String) -> Result<ProviderModelsResult,
pub(crate) struct ProviderConfig {
pub(crate) base_url: String,
pub(crate) api_key: String,
pub(crate) extra_headers: HashMap<String, String>,
}

fn read_openai_provider_config(kind: &str) -> Result<ProviderConfig, String> {
Expand All @@ -77,6 +79,11 @@ fn read_openai_provider_config(kind: &str) -> Result<ProviderConfig, String> {
let base_url = CredentialsVault::get(endpoint_account)
.map_err(|e| e.to_string())?
.unwrap_or_default();
let extra_headers = if kind == "llm" {
CredentialsVault::get_active_llm_extra_headers()
} else {
HashMap::new()
};
if api_key_required && api_key.trim().is_empty() {
return Err("API Key 为空".to_string());
}
Expand All @@ -90,7 +97,11 @@ fn read_openai_provider_config(kind: &str) -> Result<ProviderConfig, String> {
// (asr/llm) 连通性测试与 list_provider_models 模型列表两条 HTTP 路径。
crate::coordinator::validate_llm_endpoint(&base_url)
.map_err(|_| "endpointInvalid".to_string())?;
Ok(ProviderConfig { base_url, api_key })
Ok(ProviderConfig {
base_url,
api_key,
extra_headers,
})
}

async fn validate_llm_provider() -> Result<(), String> {
Expand Down Expand Up @@ -142,7 +153,8 @@ async fn validate_llm_provider() -> Result<(), String> {
config.api_key,
model,
)
.with_thinking_enabled(llm_thinking_enabled),
.with_thinking_enabled(llm_thinking_enabled)
.with_extra_headers(config.extra_headers),
);
provider
.polish(
Expand Down Expand Up @@ -419,6 +431,9 @@ pub(crate) async fn fetch_provider_models(config: &ProviderConfig) -> Result<Vec
request = request.header("Authorization", format!("Bearer {}", config.api_key));
}
}
for (k, v) in &config.extra_headers {
request = request.header(k.as_str(), v.as_str());
}
let response = request.send().await.map_err(|e| {
if e.is_timeout() {
"请求超时".to_string()
Expand Down
3 changes: 2 additions & 1 deletion openless-all/app/src-tauri/src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2003,7 +2003,8 @@ fn build_active_llm_provider(llm_thinking_enabled: bool) -> anyhow::Result<Activ
.trim_end_matches('/')
.to_string();
let config = OpenAICompatibleConfig::new(active, "OpenLess LLM", base_url, api_key, model)
.with_thinking_enabled(llm_thinking_enabled);
.with_thinking_enabled(llm_thinking_enabled)
.with_extra_headers(CredentialsVault::get_active_llm_extra_headers());
Ok(ActiveLLMProvider::OpenAI(OpenAICompatibleLLMProvider::new(
config,
)))
Expand Down
98 changes: 97 additions & 1 deletion openless-all/app/src-tauri/src/persistence/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
//!
//! "ark.api_key"/"volcengine.app_key" 等账户名按 Swift 语义路由到 active provider。

use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};
use std::path::{Path, PathBuf};
use std::sync::OnceLock;

Expand Down Expand Up @@ -202,6 +202,79 @@ impl CredsLlmEntry {
}
}

fn active_llm_extra_headers(root: &CredsRoot) -> HashMap<String, String> {
root.providers
.llm
.get(&root.active.llm)
.and_then(|entry| entry.extraHeaders.clone())
.unwrap_or_default()
}

fn active_llm_extra_headers_json(root: &CredsRoot) -> Result<Option<String>> {
let headers = active_llm_extra_headers(root);
if headers.is_empty() {
return Ok(None);
}
let ordered = headers.into_iter().collect::<BTreeMap<_, _>>();
serde_json::to_string(&ordered)
.map(Some)
.context("encode LLM extra headers")
}

fn parse_extra_headers_json(value: &str) -> Result<HashMap<String, String>> {
let trimmed = value.trim();
if trimmed.is_empty() {
return Ok(HashMap::new());
}

let raw: HashMap<String, serde_json::Value> =
serde_json::from_str(trimmed).context("extra headers must be a JSON object")?;
let mut headers = HashMap::new();
for (key, value) in raw {
let key = key.trim();
if key.is_empty() {
anyhow::bail!("extra header name cannot be empty");
}
if !is_valid_header_name(key) {
anyhow::bail!("invalid extra header name: {key}");
}
let Some(value) = value.as_str() else {
anyhow::bail!("extra header value for {key} must be a string");
};
if value.contains('\r') || value.contains('\n') {
anyhow::bail!("extra header value for {key} cannot contain line breaks");
}
headers.insert(key.to_string(), value.to_string());
}
Ok(headers)
}

fn is_valid_header_name(name: &str) -> bool {
!name.is_empty()
&& name.bytes().all(|b| {
matches!(
b,
b'!' | b'#'
| b'$'
| b'%'
| b'&'
| b'\''
| b'*'
| b'+'
| b'-'
| b'.'
| b'^'
| b'_'
| b'`'
| b'|'
| b'~'
| b'0'..=b'9'
| b'a'..=b'z'
| b'A'..=b'Z'
)
})
}

fn credentials_path() -> Result<PathBuf> {
// macOS / Linux: ~/.openless/credentials.json (与 Swift 同源)
// Windows: %APPDATA%\OpenLess\credentials.json (Windows 没有标准 HOME 环境变量)
Expand Down Expand Up @@ -841,6 +914,29 @@ impl CredentialsVault {
load_credentials().active.llm
}

pub fn get_active_llm_extra_headers() -> HashMap<String, String> {
let _guard = credentials_lock().lock();
active_llm_extra_headers(&load_credentials())
}

pub fn get_active_llm_extra_headers_json() -> Result<Option<String>> {
let _guard = credentials_lock().lock();
active_llm_extra_headers_json(&load_credentials())
}

pub fn set_active_llm_extra_headers_json(value: &str) -> Result<()> {
let _guard = credentials_lock().lock();
let headers = parse_extra_headers_json(value)?;
let mut root = load_credentials_for_update()?;
let entry = root.providers.llm.entry(root.active.llm.clone()).or_default();
entry.extraHeaders = if headers.is_empty() {
None
} else {
Some(headers)
};
save_credentials(&root)
}

pub fn snapshot() -> CredentialsSnapshot {
let _guard = credentials_lock().lock();
let root = load_credentials();
Expand Down
5 changes: 5 additions & 0 deletions openless-all/app/src-tauri/src/polish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ impl OpenAICompatibleConfig {
self.thinking_enabled = enabled;
self
}

pub fn with_extra_headers(mut self, extra_headers: HashMap<String, String>) -> Self {
self.extra_headers = extra_headers;
self
}
}

#[derive(Debug, Error)]
Expand Down
2 changes: 2 additions & 0 deletions openless-all/app/src/i18n/en.ts
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,8 @@ export const en: typeof zhCN = {
apiKeyLabel: 'API Key',
baseUrlLabel: 'Base URL',
modelLabel: 'Model',
extraHeadersLabel: 'Extra headers',
extraHeadersPlaceholder: '{"custom-head":"..."}',
thinkingModeLabel: 'Thinking',
thinkingModeOn: 'On',
thinkingModeOff: 'Off',
Expand Down
2 changes: 2 additions & 0 deletions openless-all/app/src/i18n/ja.ts
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,8 @@ export const ja: typeof zhCN = {
apiKeyLabel: 'API キー',
baseUrlLabel: 'エンドポイント',
modelLabel: 'モデル',
extraHeadersLabel: '追加 Headers',
extraHeadersPlaceholder: '{"custom-head":"..."}',
thinkingModeLabel: '思考',
thinkingModeOn: 'オン',
thinkingModeOff: 'オフ',
Expand Down
2 changes: 2 additions & 0 deletions openless-all/app/src/i18n/ko.ts
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,8 @@ export const ko: typeof zhCN = {
apiKeyLabel: 'API 키',
baseUrlLabel: '엔드포인트',
modelLabel: '모델',
extraHeadersLabel: '추가 Headers',
extraHeadersPlaceholder: '{"custom-head":"..."}',
thinkingModeLabel: '사고',
thinkingModeOn: '켜짐',
thinkingModeOff: '꺼짐',
Expand Down
2 changes: 2 additions & 0 deletions openless-all/app/src/i18n/zh-CN.ts
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,8 @@ export const zhCN = {
apiKeyLabel: 'API 密钥',
baseUrlLabel: '接口地址',
modelLabel: '模型',
extraHeadersLabel: '额外 Headers',
extraHeadersPlaceholder: '{"custom-head":"..."}',
thinkingModeLabel: '思考',
thinkingModeOn: '开启',
thinkingModeOff: '关闭',
Expand Down
2 changes: 2 additions & 0 deletions openless-all/app/src/i18n/zh-TW.ts
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,8 @@ export const zhTW: typeof zhCN = {
apiKeyLabel: 'API 密鑰',
baseUrlLabel: '接口地址',
modelLabel: '模型',
extraHeadersLabel: '額外 Headers',
extraHeadersPlaceholder: '{"custom-head":"..."}',
thinkingModeLabel: '思考',
thinkingModeOn: '開啟',
thinkingModeOff: '關閉',
Expand Down
10 changes: 10 additions & 0 deletions openless-all/app/src/pages/settings/ProvidersSection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ export function ProvidersSection({ kind = 'all' }: ProvidersSectionProps = {}) {
<CredentialField key={`${committedLlmProvider}:api_key`} label={t('settings.providers.apiKeyLabel')} account="ark.api_key" mono mask />
<CredentialField key={`${committedLlmProvider}:endpoint`} label={t('settings.providers.baseUrlLabel')} account="ark.endpoint"
placeholder={preset.baseUrl || 'https://your-endpoint/v1'} />
{committedLlmProvider === 'custom' && (
<CredentialField
key={`${committedLlmProvider}:extra_headers`}
label={t('settings.providers.extraHeadersLabel')}
account="ark.extra_headers"
placeholder={t('settings.providers.extraHeadersPlaceholder')}
mono
mask
/>
)}
</>
)}
<CredentialField key={`${committedLlmProvider}:model:${llmModelRevision}`} label={t('settings.providers.modelLabel')} account="ark.model_id"
Expand Down