From 4235a249c8cb8249e5fa8dd86775b380c1c2170e Mon Sep 17 00:00:00 2001 From: justcodebruh Date: Tue, 17 Mar 2026 16:19:25 -0400 Subject: [PATCH] Add custom CA bundle support across bt HTTP flows --- Cargo.lock | 163 +++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 4 +- src/args.rs | 9 +++ src/auth.rs | 169 +++++++++++++++++++++++++++++++++++++------- src/eval.rs | 9 +-- src/http.rs | 44 ++++++++++-- src/main.rs | 3 +- src/projects/mod.rs | 7 +- src/self_update.rs | 29 ++++---- src/setup/docs.rs | 13 ++-- src/setup/mod.rs | 7 +- src/switch.rs | 1 + src/sync.rs | 21 ++++-- src/traces.rs | 1 + 14 files changed, 406 insertions(+), 74 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c7e002c..a80bfc9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -286,6 +286,15 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "arc-swap" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9f3647c145568cec02c42054e07bdf9a5a698e15b466fb2341bfc393cd24aa5" +dependencies = [ + "rustversion", +] + [[package]] name = "assert-json-diff" version = "2.0.2" @@ -393,16 +402,48 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bon" +version = "3.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f47dbe92550676ee653353c310dfb9cf6ba17ee70396e1f7cf0a2020ad49b2fe" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c" +dependencies = [ + "darling 0.21.3", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "braintrust-sdk-rust" version = "0.1.0-alpha.2" -source = "git+https://github.com/braintrustdata/braintrust-sdk-rust?rev=c8c7c7a8d9189164584adef71449a26d4ae8be2b#c8c7c7a8d9189164584adef71449a26d4ae8be2b" +source = "git+https://github.com/braintrustdata/braintrust-sdk-rust?rev=6330e78aca76f39870e5bdc947fb04d52e0ae7ca#6330e78aca76f39870e5bdc947fb04d52e0ae7ca" dependencies = [ "anyhow", + "arc-swap", "async-trait", "backoff", + "base64 0.22.1", + "bon", + "bytes", "chrono", + "crossbeam", "futures", + "indexmap 2.13.0", + "regex", "reqwest 0.12.28", "serde", "serde_json 1.0.149", @@ -677,6 +718,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -701,6 +752,62 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crossterm" version = "0.28.1" @@ -1358,6 +1465,7 @@ dependencies = [ "hyper 1.8.1", "hyper-util", "rustls 0.23.37", + "rustls-native-certs", "rustls-pki-types", "tokio", "tokio-rustls 0.26.4", @@ -1927,6 +2035,12 @@ dependencies = [ "pathdiff", ] +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + [[package]] name = "option-ext" version = "0.2.0" @@ -2429,6 +2543,7 @@ dependencies = [ "pin-project-lite", "quinn", "rustls 0.23.37", + "rustls-native-certs", "rustls-pki-types", "serde", "serde_json 1.0.149", @@ -2529,6 +2644,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -2581,6 +2708,15 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "schemars" version = "0.9.0" @@ -2621,6 +2757,29 @@ dependencies = [ "untrusted", ] +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags 2.11.0", + "core-foundation 0.10.1", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.27" @@ -2973,7 +3132,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] diff --git a/Cargo.toml b/Cargo.toml index 64cc407..1028660 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,13 +12,13 @@ repository = "https://github.com/braintrustdata/bt" actix-web = "4.11.0" anyhow = "1.0.89" backoff = { version = "0.4.0", features = ["tokio"] } -braintrust-sdk-rust = { git = "https://github.com/braintrustdata/braintrust-sdk-rust", rev = "c8c7c7a8d9189164584adef71449a26d4ae8be2b" } +braintrust-sdk-rust = { git = "https://github.com/braintrustdata/braintrust-sdk-rust", rev = "6330e78aca76f39870e5bdc947fb04d52e0ae7ca" } clap = { version = "4.5.20", features = ["derive", "env"] } crossterm = "0.28.1" futures-util = "0.3.31" indicatif = "0.17.8" ratatui = "0.29.0" -reqwest = { version = "0.12.7", default-features = false, features = ["json", "rustls-tls"] } +reqwest = { version = "0.12.7", default-features = false, features = ["json", "rustls-tls-native-roots"] } serde = { version = "1.0.210", features = ["derive"] } serde_json = "1.0.128" sha2 = "0.10.8" diff --git a/src/args.rs b/src/args.rs index 3748e8b..51d415a 100644 --- a/src/args.rs +++ b/src/args.rs @@ -66,6 +66,15 @@ pub struct BaseArgs { )] pub app_url: Option, + /// Path to a PEM-encoded CA bundle used for HTTPS requests. + #[arg( + long, + env = "BRAINTRUST_CA_BUNDLE", + hide_env_values = true, + global = true + )] + pub ca_bundle: Option, + /// Path to a .env file to load before running commands. #[arg(long, env = "BRAINTRUST_ENV_FILE", hide_env_values = true)] pub env_file: Option, diff --git a/src/auth.rs b/src/auth.rs index 29a70be..eba3c15 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -13,19 +13,18 @@ use clap::{Args, Subcommand}; use crossterm::event::{self, Event, KeyCode, KeyEventKind}; use dialoguer::{Confirm, Input, Password}; use oauth2::basic::{BasicClient, BasicTokenType}; -use oauth2::reqwest::async_http_client; use oauth2::{ AuthUrl, AuthorizationCode, ClientId, CsrfToken, EmptyExtraTokenFields, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, StandardTokenResponse, TokenResponse, TokenUrl, }; -use reqwest::Client; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use crate::{ args::{BaseArgs, DEFAULT_API_URL, DEFAULT_APP_URL}, + http::{build_http_client, build_http_client_from_builder}, ui, }; @@ -40,6 +39,7 @@ pub struct LoginContext { pub login: LoginState, pub api_url: String, pub app_url: String, + pub ca_bundle: Option, } #[derive(Debug, Clone)] @@ -318,6 +318,9 @@ pub async fn login(base: &BaseArgs) -> Result { if let Some(project) = &project { builder = builder.default_project(project); } + if let Some(ca_bundle) = &base.ca_bundle { + builder = builder.ca_bundle(ca_bundle.clone()); + } let login = match builder.build().await { Ok(client) => client.wait_for_login().await?, @@ -326,19 +329,25 @@ pub async fn login(base: &BaseArgs) -> Result { .org_name .clone() .ok_or_else(|| anyhow::anyhow!("oauth profile is missing org_name: {err}"))?; - LoginState { - api_key: api_key.clone(), - org_id: String::new(), + let login = LoginState::new(); + login.set( + api_key.clone(), + String::new(), org_name, - api_url: auth.api_url.clone(), - } + auth.api_url + .clone() + .unwrap_or_else(|| DEFAULT_API_URL.to_string()), + auth.app_url + .clone() + .unwrap_or_else(|| DEFAULT_APP_URL.to_string()), + ); + login } Err(err) => return Err(err.into()), }; let api_url = login - .api_url - .clone() + .api_url() .or(auth.api_url.clone()) .unwrap_or_else(|| DEFAULT_API_URL.to_string()); @@ -351,6 +360,7 @@ pub async fn login(base: &BaseArgs) -> Result { login, api_url, app_url, + ca_bundle: base.ca_bundle.clone(), }) } @@ -436,7 +446,13 @@ pub async fn resolve_auth(base: &BaseArgs) -> Result { "oauth refresh token missing for profile '{profile_name}'; re-run `bt auth login --oauth --profile {profile_name}`" ) })?; - let refreshed = refresh_oauth_access_token(&api_url, &refresh_token, client_id).await?; + let refreshed = refresh_oauth_access_token( + &api_url, + &refresh_token, + client_id, + base.ca_bundle.as_deref(), + ) + .await?; save_profile_oauth_access_token(&profile_name, &refreshed.access_token)?; if let Some(next_refresh_token) = refreshed.refresh_token.as_ref() { if next_refresh_token != &refresh_token { @@ -604,7 +620,7 @@ async fn run_login_set(base: &BaseArgs, args: AuthLoginArgs) -> Result<()> { .app_url .clone() .unwrap_or_else(|| DEFAULT_APP_URL.to_string()); - let login_orgs = fetch_login_orgs(&api_key, &login_app_url).await?; + let login_orgs = fetch_login_orgs(&api_key, &login_app_url, base.ca_bundle.as_deref()).await?; let selected_org = select_login_org( login_orgs.clone(), base.org_name.as_deref(), @@ -709,9 +725,15 @@ async fn run_login_oauth(base: &BaseArgs, args: AuthLoginArgs) -> Result<()> { &redirect_uri, &auth_code, pkce_verifier, + base.ca_bundle.as_deref(), + ) + .await?; + let login_orgs = fetch_login_orgs( + &oauth_tokens.access_token, + &app_url, + base.ca_bundle.as_deref(), ) .await?; - let login_orgs = fetch_login_orgs(&oauth_tokens.access_token, &app_url).await?; let selected_org = select_login_org( login_orgs.clone(), base.org_name.as_deref(), @@ -762,7 +784,7 @@ async fn login_interactive_api_key(base: &mut BaseArgs) -> Result { .app_url .clone() .unwrap_or_else(|| DEFAULT_APP_URL.to_string()); - let login_orgs = fetch_login_orgs(&api_key, &login_app_url).await?; + let login_orgs = fetch_login_orgs(&api_key, &login_app_url, base.ca_bundle.as_deref()).await?; let selected_org = select_login_org( login_orgs.clone(), base.org_name.as_deref(), @@ -850,10 +872,16 @@ async fn login_interactive_oauth(base: &mut BaseArgs) -> Result { &redirect_uri, &auth_code, pkce_verifier, + base.ca_bundle.as_deref(), ) .await?; - let login_orgs = fetch_login_orgs(&oauth_tokens.access_token, &app_url).await?; + let login_orgs = fetch_login_orgs( + &oauth_tokens.access_token, + &app_url, + base.ca_bundle.as_deref(), + ) + .await?; let selected_org = select_login_org( login_orgs.clone(), base.org_name.as_deref(), @@ -1004,7 +1032,13 @@ async fn run_login_refresh(base: &BaseArgs) -> Result<()> { println!("Cached access token expiry before refresh: unknown"); } - let refreshed = refresh_oauth_access_token(&api_url, &refresh_token, &client_id).await?; + let refreshed = refresh_oauth_access_token( + &api_url, + &refresh_token, + &client_id, + base.ca_bundle.as_deref(), + ) + .await?; save_profile_oauth_access_token(profile_name.as_str(), &refreshed.access_token)?; let mut refresh_rotated = false; if let Some(next_refresh_token) = refreshed.refresh_token.as_ref() { @@ -1134,7 +1168,7 @@ async fn run_profiles(base: &BaseArgs, args: AuthProfilesArgs) -> Result<()> { return Ok(()); } - let verifications = verify_all_profiles_from_store(&store).await; + let verifications = verify_all_profiles_from_store(&store, base.ca_bundle.clone()).await; let all_network_errors = verifications .iter() .all(|v| v.status == "error" && !v.error.as_deref().unwrap_or("").contains("invalid")); @@ -1315,7 +1349,11 @@ fn build_verification( } } -async fn verify_profile_full(name: &str, profile: &AuthProfile) -> ProfileVerification { +async fn verify_profile_full( + name: &str, + profile: &AuthProfile, + ca_bundle: Option, +) -> ProfileVerification { let app_url = profile.app_url.as_deref().unwrap_or(DEFAULT_APP_URL); let auth_kind = match profile.auth_kind { AuthKind::ApiKey => "api_key", @@ -1344,7 +1382,7 @@ async fn verify_profile_full(name: &str, profile: &AuthProfile) -> ProfileVerifi AuthKind::ApiKey => (None, profile.api_key_hint.clone()), }; - match fetch_login_orgs(&credential, app_url).await { + match fetch_login_orgs(&credential, app_url, ca_bundle.as_deref()).await { Ok(_) => mk(ProfileStatus::Ok, jwt_id, hint), Err(e) => { let msg = e.to_string(); @@ -1362,12 +1400,16 @@ async fn verify_profile_full(name: &str, profile: &AuthProfile) -> ProfileVerifi } } -async fn verify_all_profiles_from_store(store: &AuthStore) -> Vec { +async fn verify_all_profiles_from_store( + store: &AuthStore, + ca_bundle: Option, +) -> Vec { let mut set = tokio::task::JoinSet::new(); for (name, profile) in store.profiles.iter() { let name = name.clone(); let profile = profile.clone(); - set.spawn(async move { verify_profile_full(&name, &profile).await }); + let ca_bundle = ca_bundle.clone(); + set.spawn(async move { verify_profile_full(&name, &profile, ca_bundle).await }); } let mut results = Vec::new(); @@ -1451,11 +1493,13 @@ fn print_saved_profiles(store: &AuthStore, json: bool) -> Result<()> { Ok(()) } -async fn fetch_login_orgs(api_key: &str, app_url: &str) -> Result> { +async fn fetch_login_orgs( + api_key: &str, + app_url: &str, + ca_bundle: Option<&Path>, +) -> Result> { let login_url = format!("{}/api/apikey/login", app_url.trim_end_matches('/')); - let client = Client::builder() - .timeout(crate::http::DEFAULT_HTTP_TIMEOUT) - .build() + let client = build_http_client(crate::http::DEFAULT_HTTP_TIMEOUT, ca_bundle) .context("failed to initialize HTTP client")?; let response = client .post(&login_url) @@ -1849,12 +1893,20 @@ async fn exchange_oauth_authorization_code( redirect_uri: &str, code: &str, code_verifier: PkceCodeVerifier, + ca_bundle: Option<&Path>, ) -> Result { let oauth_client = build_oauth_client(api_url, client_id, Some(redirect_uri))?; + let http_client = build_http_client_from_builder( + reqwest::Client::builder() + .timeout(crate::http::DEFAULT_HTTP_TIMEOUT) + .redirect(reqwest::redirect::Policy::none()), + ca_bundle, + ) + .context("failed to initialize oauth HTTP client")?; let token_response = oauth_client .exchange_code(AuthorizationCode::new(code.to_string())) .set_pkce_verifier(code_verifier) - .request_async(async_http_client) + .request_async(move |request| oauth_async_http_client(request, http_client)) .await .with_context(|| { format!( @@ -1869,11 +1921,19 @@ async fn refresh_oauth_access_token( api_url: &str, refresh_token: &str, client_id: &str, + ca_bundle: Option<&Path>, ) -> Result { let oauth_client = build_oauth_client(api_url, client_id, None)?; + let http_client = build_http_client_from_builder( + reqwest::Client::builder() + .timeout(crate::http::DEFAULT_HTTP_TIMEOUT) + .redirect(reqwest::redirect::Policy::none()), + ca_bundle, + ) + .context("failed to initialize oauth HTTP client")?; let token_response = oauth_client .exchange_refresh_token(&RefreshToken::new(refresh_token.to_string())) - .request_async(async_http_client) + .request_async(move |request| oauth_async_http_client(request, http_client)) .await .with_context(|| { format!( @@ -1886,6 +1946,62 @@ async fn refresh_oauth_access_token( type OAuth2StdTokenResponse = StandardTokenResponse; +#[derive(Debug)] +struct OAuthHttpClientError(String); + +impl std::fmt::Display for OAuthHttpClientError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl std::error::Error for OAuthHttpClientError {} + +async fn oauth_async_http_client( + request: oauth2::HttpRequest, + client: reqwest::Client, +) -> std::result::Result { + let method = reqwest::Method::from_bytes(request.method.as_str().as_bytes()) + .map_err(|err| OAuthHttpClientError(format!("invalid oauth request method: {err}")))?; + let mut request_builder = client + .request(method, request.url.as_str()) + .body(request.body); + for (name, value) in &request.headers { + request_builder = request_builder.header(name.as_str(), value.as_bytes()); + } + let request = request_builder + .build() + .map_err(|err| OAuthHttpClientError(format!("failed to build oauth request: {err}")))?; + + let response = client + .execute(request) + .await + .map_err(|err| OAuthHttpClientError(format!("failed to send oauth request: {err}")))?; + let status_code = oauth2::http::StatusCode::from_u16(response.status().as_u16()) + .map_err(|err| OAuthHttpClientError(format!("invalid oauth response status: {err}")))?; + let mut headers = oauth2::http::HeaderMap::with_capacity(response.headers().len()); + for (name, value) in response.headers() { + let name = oauth2::http::header::HeaderName::from_bytes(name.as_str().as_bytes()).map_err( + |err| OAuthHttpClientError(format!("invalid oauth response header name: {err}")), + )?; + let value = oauth2::http::HeaderValue::from_bytes(value.as_bytes()).map_err(|err| { + OAuthHttpClientError(format!("invalid oauth response header value: {err}")) + })?; + headers.append(name, value); + } + let body = response + .bytes() + .await + .map_err(|err| OAuthHttpClientError(format!("failed to read oauth response body: {err}")))? + .to_vec(); + + Ok(oauth2::HttpResponse { + status_code, + headers, + body, + }) +} + fn build_oauth_client( api_url: &str, client_id: &str, @@ -2555,6 +2671,7 @@ mod tests { no_input: false, api_url: None, app_url: None, + ca_bundle: None, env_file: None, } } diff --git a/src/eval.rs b/src/eval.rs index 2373ec7..7d24b91 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -403,10 +403,11 @@ pub async fn run(base: BaseArgs, args: EvalArgs) -> Result<()> { allowed_org_name: args.dev_org_name.clone(), allowed_origins: collect_allowed_dev_origins(&args.dev_allowed_origin, &app_url), app_url, - http_client: Client::builder() - .timeout(crate::http::DEFAULT_HTTP_TIMEOUT) - .build() - .context("failed to create dev server HTTP client")?, + http_client: crate::http::build_http_client( + crate::http::DEFAULT_HTTP_TIMEOUT, + base.ca_bundle.as_deref(), + ) + .context("failed to create dev server HTTP client")?, }; return run_dev_server(state).await; } diff --git a/src/http.rs b/src/http.rs index 564ac4a..a629cbf 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,5 +1,7 @@ +use std::path::Path; + use anyhow::{Context, Result}; -use reqwest::Client; +use reqwest::{Client, ClientBuilder}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -8,6 +10,37 @@ use crate::auth::LoginContext; pub const DEFAULT_HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); +pub fn build_http_client(timeout: std::time::Duration, ca_bundle: Option<&Path>) -> Result { + build_http_client_from_builder(Client::builder().timeout(timeout), ca_bundle) +} + +pub fn build_http_client_from_builder( + mut builder: ClientBuilder, + ca_bundle: Option<&Path>, +) -> Result { + if let Some(ca_bundle) = ca_bundle { + let pem = std::fs::read(ca_bundle) + .with_context(|| format!("failed to read CA bundle {}", ca_bundle.display()))?; + let certs = reqwest::Certificate::from_pem_bundle(&pem).with_context(|| { + format!( + "failed to parse PEM certificates from {}", + ca_bundle.display() + ) + })?; + if certs.is_empty() { + anyhow::bail!( + "CA bundle {} did not contain any PEM certificates", + ca_bundle.display() + ); + } + for cert in certs { + builder = builder.add_root_certificate(cert); + } + } + + builder.build().context("failed to build HTTP client") +} + #[derive(Clone)] pub struct ApiClient { http: Client, @@ -37,16 +70,13 @@ pub struct BtqlResponse { impl ApiClient { pub fn new(ctx: &LoginContext) -> Result { - let http = Client::builder() - .timeout(DEFAULT_HTTP_TIMEOUT) - .build() - .context("failed to build HTTP client")?; + let http = build_http_client(DEFAULT_HTTP_TIMEOUT, ctx.ca_bundle.as_deref())?; Ok(Self { http, base_url: ctx.api_url.trim_end_matches('/').to_string(), - api_key: ctx.login.api_key.clone(), - org_name: ctx.login.org_name.clone(), + api_key: ctx.login.api_key().context("login state missing API key")?, + org_name: ctx.login.org_name().unwrap_or_default(), }) } diff --git a/src/main.rs b/src/main.rs index 0c2361a..4866a7f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -83,6 +83,7 @@ Flags --no-input Disable all interactive prompts --api-url Override API URL [env: BRAINTRUST_API_URL] --app-url Override app URL [env: BRAINTRUST_APP_URL] + --ca-bundle PEM-encoded CA bundle for HTTPS [env: BRAINTRUST_CA_BUNDLE] --env-file Path to a .env file to load -h, --help Print help -V, --version Print version @@ -224,7 +225,7 @@ async fn try_main() -> Result<()> { Commands::Experiments(cmd) => experiments::run(cmd.base, cmd.args).await?, Commands::Sync(cmd) => sync::run(cmd.base, cmd.args).await?, Commands::Util(cmd) => util_cmd::run(cmd.base, cmd.args).await?, - Commands::SelfCommand(cmd) => self_update::run(cmd.args).await?, + Commands::SelfCommand(cmd) => self_update::run(cmd.base, cmd.args).await?, Commands::Switch(cmd) => switch::run(cmd.base, cmd.args).await?, Commands::Status(cmd) => status::run(cmd.base, cmd.args).await?, } diff --git a/src/projects/mod.rs b/src/projects/mod.rs index 10fd85d..c89d2d6 100644 --- a/src/projects/mod.rs +++ b/src/projects/mod.rs @@ -73,14 +73,13 @@ struct DeleteArgs { pub async fn run(base: BaseArgs, args: ProjectsArgs) -> Result<()> { let ctx = login(&base).await?; let client = ApiClient::new(&ctx)?; + let org_name = ctx.login.org_name().unwrap_or_default(); match args.command { - None | Some(ProjectsCommands::List) => { - list::run(&client, &ctx.login.org_name, base.json).await - } + None | Some(ProjectsCommands::List) => list::run(&client, &org_name, base.json).await, Some(ProjectsCommands::Create(a)) => create::run(&client, a.name.as_deref()).await, Some(ProjectsCommands::View(a)) => { - view::run(&client, &ctx.app_url, &ctx.login.org_name, a.name()).await + view::run(&client, &ctx.app_url, &org_name, a.name()).await } Some(ProjectsCommands::Delete(a)) => delete::run(&client, a.name.as_deref(), a.force).await, } diff --git a/src/self_update.rs b/src/self_update.rs index 8de3170..2921337 100644 --- a/src/self_update.rs +++ b/src/self_update.rs @@ -7,6 +7,7 @@ use clap::{Args, Subcommand, ValueEnum}; use reqwest::Client; use serde::Deserialize; +use crate::args::BaseArgs; use crate::http::DEFAULT_HTTP_TIMEOUT; #[derive(Debug, Clone, Args)] @@ -82,25 +83,25 @@ struct GitHubRelease { tag_name: String, } -pub async fn run(args: SelfArgs) -> Result<()> { +pub async fn run(base: BaseArgs, args: SelfArgs) -> Result<()> { match args.command { - SelfSubcommand::Update(args) => run_update(args).await, + SelfSubcommand::Update(args) => run_update(&base, args).await, } } -async fn run_update(args: UpdateArgs) -> Result<()> { +async fn run_update(base: &BaseArgs, args: UpdateArgs) -> Result<()> { ensure_installer_managed_install()?; let channel = args .channel .unwrap_or_else(|| inferred_update_channel(BUILD_UPDATE_CHANNEL)); if args.check { - check_for_update(channel).await?; + check_for_update(base, channel).await?; return Ok(()); } if channel == UpdateChannel::Stable { - match fetch_release(channel).await { + match fetch_release(base, channel).await { Ok(release) => { let current = env!("CARGO_PKG_VERSION"); if stable_is_up_to_date(current, &release.tag_name) { @@ -135,8 +136,8 @@ fn ensure_installer_managed_install() -> Result<()> { ); } -async fn check_for_update(channel: UpdateChannel) -> Result<()> { - let release = fetch_release(channel).await?; +async fn check_for_update(base: &BaseArgs, channel: UpdateChannel) -> Result<()> { + let release = fetch_release(base, channel).await?; let current = env!("CARGO_PKG_VERSION"); match channel { @@ -151,12 +152,14 @@ async fn check_for_update(channel: UpdateChannel) -> Result<()> { Ok(()) } -async fn fetch_release(channel: UpdateChannel) -> Result { - let client = Client::builder() - .user_agent("bt-self-update") - .timeout(DEFAULT_HTTP_TIMEOUT) - .build() - .context("failed to initialize HTTP client")?; +async fn fetch_release(base: &BaseArgs, channel: UpdateChannel) -> Result { + let client = crate::http::build_http_client_from_builder( + Client::builder() + .user_agent("bt-self-update") + .timeout(DEFAULT_HTTP_TIMEOUT), + base.ca_bundle.as_deref(), + ) + .context("failed to initialize HTTP client")?; let mut request = client .get(channel.github_release_api_url()) diff --git a/src/setup/docs.rs b/src/setup/docs.rs index 41f6f80..4a329cf 100644 --- a/src/setup/docs.rs +++ b/src/setup/docs.rs @@ -5,7 +5,6 @@ use std::path::{Path, PathBuf}; use anyhow::{bail, Context, Result}; use clap::{Args, Subcommand}; use regex::Regex; -use reqwest::Client; use serde::Serialize; use crate::args::BaseArgs; @@ -117,7 +116,7 @@ async fn run_docs(base: BaseArgs, args: DocsArgs) -> Result<()> { async fn run_docs_fetch(base: BaseArgs, args: DocsFetchArgs) -> Result<()> { let selected_workflows = resolve_workflow_selection(&args.workflows); - let fetch_result = fetch_docs_pages(&args, &selected_workflows).await?; + let fetch_result = fetch_docs_pages(&base, &args, &selected_workflows).await?; if base.json { let report = DocsFetchJsonReport { @@ -176,6 +175,7 @@ async fn run_docs_fetch(base: BaseArgs, args: DocsFetchArgs) -> Result<()> { } pub(super) async fn fetch_docs_pages( + base: &BaseArgs, args: &DocsFetchArgs, selected_workflows: &[WorkflowArg], ) -> Result { @@ -194,10 +194,11 @@ pub(super) async fn fetch_docs_pages( Regex::new(r#"(?m)\b(https?://[^\s<>"')]+)"#).context("failed to build URL regex")?; let llms_base = reqwest::Url::parse(&args.llms_url) .with_context(|| format!("invalid llms URL: {}", args.llms_url))?; - let client = Client::builder() - .timeout(crate::http::DEFAULT_HTTP_TIMEOUT) - .build() - .context("failed to build HTTP client")?; + let client = crate::http::build_http_client( + crate::http::DEFAULT_HTTP_TIMEOUT, + base.ca_bundle.as_deref(), + ) + .context("failed to build HTTP client")?; let index_response = client .get(&args.llms_url) diff --git a/src/setup/mod.rs b/src/setup/mod.rs index d89aa83..16a9765 100644 --- a/src/setup/mod.rs +++ b/src/setup/mod.rs @@ -686,6 +686,7 @@ async fn execute_skills_setup( notes.push("Skipped workflow docs prefetch (no workflows selected).".to_string()); } else { prefetch_workflow_docs( + &base, show_progress, scope, local_root.as_deref(), @@ -768,6 +769,7 @@ async fn run_instrument_setup(base: BaseArgs, args: InstrumentSetupArgs) -> Resu }); notes.push("Skipped skills setup (already configured).".to_string()); prefetch_workflow_docs( + &base, show_progress, InstallScope::Local, Some(&root), @@ -941,6 +943,7 @@ fn skill_config_path( #[allow(clippy::too_many_arguments)] async fn prefetch_workflow_docs( + base: &BaseArgs, show_progress: bool, scope: InstallScope, local_root: Option<&Path>, @@ -972,11 +975,11 @@ async fn prefetch_workflow_docs( let fetch_result = if show_progress { with_spinner( "Prefetching workflow docs...", - docs::fetch_docs_pages(&docs_args, selected_workflows), + docs::fetch_docs_pages(base, &docs_args, selected_workflows), ) .await } else { - docs::fetch_docs_pages(&docs_args, selected_workflows).await + docs::fetch_docs_pages(base, &docs_args, selected_workflows).await }; match fetch_result { Ok(fetch_result) => { diff --git a/src/switch.rs b/src/switch.rs index cc9bf3a..e3734b1 100644 --- a/src/switch.rs +++ b/src/switch.rs @@ -255,6 +255,7 @@ mod tests { no_input: false, api_url: None, app_url: None, + ca_bundle: None, env_file: None, } } diff --git a/src/sync.rs b/src/sync.rs index e62a856..e6ca4ad 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -760,14 +760,15 @@ async fn run_pull( )?; if json_output { + let login_org_name = ctx.login.org_name().unwrap_or_default(); let warning = if state.items_done == 0 { Some(format!( "no rows found for {} in org '{}'; verify object id and active credentials", spec.object_ref, - if ctx.login.org_name.trim().is_empty() { + if login_org_name.trim().is_empty() { "(default)".to_string() } else { - ctx.login.org_name.clone() + login_org_name.clone() } )) } else { @@ -795,10 +796,11 @@ async fn run_pull( let spans_per_sec = spans_done as f64 / elapsed_secs as f64; let bytes_per_sec = state.bytes_written as f64 / elapsed_secs as f64; if state.items_done == 0 { - let org_label = if ctx.login.org_name.trim().is_empty() { + let login_org_name = ctx.login.org_name().unwrap_or_default(); + let org_label = if login_org_name.trim().is_empty() { "(default)".to_string() } else { - ctx.login.org_name.clone() + login_org_name }; println!( "Warning: no rows found for {} in org '{}'; verify object id and active credentials.", @@ -1696,10 +1698,15 @@ async fn run_push( } }); + let login_api_key = ctx + .login + .api_key() + .context("login state missing API key for upload")?; + let login_org_name = ctx.login.org_name().unwrap_or_default(); let uploader_template = Logs3BatchUploader::new( ctx.api_url.clone(), - ctx.login.api_key.clone(), - (!ctx.login.org_name.trim().is_empty()).then_some(ctx.login.org_name.clone()), + login_api_key, + (!login_org_name.trim().is_empty()).then_some(login_org_name.clone()), ) .context("failed to initialize logs3 uploader")?; @@ -2058,7 +2065,7 @@ async fn execute_btql_query( "fmt": "json", "query_source": "bt_sync_9f4b1e6d7c2a4a7b8d4f9a6c2b1e7f3d", }); - let org_name = ctx.login.org_name.clone(); + let org_name = ctx.login.org_name().unwrap_or_default(); let client = client.clone(); let attempt_counter = Arc::new(AtomicUsize::new(0)); diff --git a/src/traces.rs b/src/traces.rs index 201b7f9..5645dd9 100644 --- a/src/traces.rs +++ b/src/traces.rs @@ -6068,6 +6068,7 @@ mod tests { no_input: false, api_url: None, app_url: None, + ca_bundle: None, env_file: None, } }