From 40eaf35a806aad3ef46e3d532882ac508695c0ce Mon Sep 17 00:00:00 2001 From: Aswin Karumbunathan Date: Tue, 3 Mar 2026 17:30:28 -0800 Subject: [PATCH] Add support for available_providers in model spec For now, the router behavior is the same as it was before, but this opens things up for callers to choose providers when multiple are available for a single model. Related to https://github.com/braintrustdata/braintrust-proxy/pull/407 --- .../src/catalog/resolver.rs | 102 +++++-- .../braintrust-llm-router/src/catalog/spec.rs | 3 + crates/braintrust-llm-router/src/client.rs | 12 + crates/braintrust-llm-router/src/router.rs | 284 +++++++++++++++--- crates/braintrust-llm-router/tests/router.rs | 4 + 5 files changed, 343 insertions(+), 62 deletions(-) diff --git a/crates/braintrust-llm-router/src/catalog/resolver.rs b/crates/braintrust-llm-router/src/catalog/resolver.rs index c506bf4d..2f083808 100644 --- a/crates/braintrust-llm-router/src/catalog/resolver.rs +++ b/crates/braintrust-llm-router/src/catalog/resolver.rs @@ -28,7 +28,7 @@ impl ModelResolver { Arc::clone(&self.catalog) } - pub fn resolve(&self, model: &str) -> Result<(Arc, ProviderFormat, String)> { + pub fn resolve(&self, model: &str) -> Result<(Arc, ProviderFormat, Vec)> { let spec = self .catalog .get(model) @@ -44,14 +44,20 @@ impl ModelResolver { } else { spec.format }; - let provider_alias = self.aliases.get(model).cloned().unwrap_or_else(|| { - if lingua::is_vertex_model(model) { - "vertex".to_string() - } else { - format_identifier(format) - } - }); - Ok((spec, format, provider_alias)) + let provider_aliases = self + .aliases + .get(model) + .map(|alias| vec![alias.clone()]) + .unwrap_or_else(|| { + if !spec.available_providers.is_empty() { + spec.available_providers.clone() + } else if lingua::is_vertex_model(model) { + vec!["vertex".to_string()] + } else { + vec![format_identifier(format)] + } + }); + Ok((spec, format, provider_aliases)) } } @@ -95,9 +101,20 @@ mod tests { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), } } + fn spec_with_available_providers( + model: &str, + format: ProviderFormat, + providers: Vec, + ) -> ModelSpec { + let mut s = spec(model, format); + s.available_providers = providers; + s + } + #[test] fn resolve_returns_default_alias() { let mut catalog = ModelCatalog::empty(); @@ -107,9 +124,9 @@ mod tests { ); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve("model").expect("resolves"); + let (_, format, aliases) = resolver.resolve("model").expect("resolves"); assert_eq!(format, ProviderFormat::ChatCompletions); - assert_eq!(alias, "openai"); + assert_eq!(aliases, vec!["openai".to_string()]); } #[test] @@ -119,9 +136,9 @@ mod tests { let resolver = ModelResolver::new(Arc::new(catalog)) .with_aliases(HashMap::from([("model".into(), "custom".into())])); - let (_, format, alias) = resolver.resolve("model").expect("resolves"); + let (_, format, aliases) = resolver.resolve("model").expect("resolves"); assert_eq!(format, ProviderFormat::Anthropic); - assert_eq!(alias, "custom"); + assert_eq!(aliases, vec!["custom".to_string()]); } #[test] @@ -138,9 +155,9 @@ mod tests { catalog.insert(model.into(), spec(model, ProviderFormat::Anthropic)); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve(model).expect("resolves"); + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); assert_eq!(format, ProviderFormat::BedrockAnthropic); - assert_eq!(alias, "bedrock"); + assert_eq!(aliases, vec!["bedrock".to_string()]); } #[test] @@ -150,9 +167,9 @@ mod tests { catalog.insert(model.into(), spec(model, ProviderFormat::Anthropic)); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve(model).expect("resolves"); + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); assert_eq!(format, ProviderFormat::Anthropic); - assert_eq!(alias, "anthropic"); + assert_eq!(aliases, vec!["anthropic".to_string()]); } #[test] @@ -162,9 +179,9 @@ mod tests { catalog.insert(model.into(), spec(model, ProviderFormat::Google)); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve(model).expect("resolves"); + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); assert_eq!(format, ProviderFormat::Google); - assert_eq!(alias, "vertex"); + assert_eq!(aliases, vec!["vertex".to_string()]); } #[test] @@ -174,9 +191,9 @@ mod tests { catalog.insert(model.into(), spec(model, ProviderFormat::Google)); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve(model).expect("resolves"); + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); assert_eq!(format, ProviderFormat::Google); - assert_eq!(alias, "google"); + assert_eq!(aliases, vec!["google".to_string()]); } #[test] @@ -186,8 +203,47 @@ mod tests { catalog.insert(model.into(), spec(model, ProviderFormat::Anthropic)); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve(model).expect("resolves"); + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); assert_eq!(format, ProviderFormat::VertexAnthropic); - assert_eq!(alias, "vertex"); + assert_eq!(aliases, vec!["vertex".to_string()]); + } + + #[test] + fn resolve_returns_available_providers_when_no_alias() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + catalog.insert( + model.into(), + spec_with_available_providers( + model, + ProviderFormat::ChatCompletions, + vec!["openai".to_string(), "azure".to_string()], + ), + ); + let resolver = ModelResolver::new(Arc::new(catalog)); + + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); + assert_eq!(format, ProviderFormat::ChatCompletions); + assert_eq!(aliases, vec!["openai".to_string(), "azure".to_string()]); + } + + #[test] + fn resolve_custom_alias_overrides_available_providers() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + catalog.insert( + model.into(), + spec_with_available_providers( + model, + ProviderFormat::ChatCompletions, + vec!["openai".to_string(), "azure".to_string()], + ), + ); + let resolver = ModelResolver::new(Arc::new(catalog)) + .with_aliases(HashMap::from([(model.into(), "custom".into())])); + + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); + assert_eq!(format, ProviderFormat::ChatCompletions); + assert_eq!(aliases, vec!["custom".to_string()]); } } diff --git a/crates/braintrust-llm-router/src/catalog/spec.rs b/crates/braintrust-llm-router/src/catalog/spec.rs index c4c2ae33..4aeb2018 100644 --- a/crates/braintrust-llm-router/src/catalog/spec.rs +++ b/crates/braintrust-llm-router/src/catalog/spec.rs @@ -44,6 +44,8 @@ pub struct ModelSpec { pub supports_streaming: bool, #[serde(default)] pub extra: serde_json::Map, + #[serde(default)] + pub available_providers: Vec, } fn default_true() -> bool { @@ -115,6 +117,7 @@ mod tests { max_output_tokens: None, supports_streaming: true, extra: serde_json::Map::new(), + available_providers: vec![], }; assert!(spec.requires_responses_api()); } diff --git a/crates/braintrust-llm-router/src/client.rs b/crates/braintrust-llm-router/src/client.rs index 51a666c9..2facb4ff 100644 --- a/crates/braintrust-llm-router/src/client.rs +++ b/crates/braintrust-llm-router/src/client.rs @@ -57,3 +57,15 @@ pub fn set_override_client(client: ClientWithMiddleware) { pub fn clear_override_client() { *OVERRIDE_CLIENT.write() = None; } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_middleware_client_with_no_override() { + clear_override_client(); + let client = build_middleware_client(&ClientSettings::default()); + assert!(client.is_ok()); + } +} diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 9692f5ea..7467a444 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -99,6 +99,7 @@ pub fn create_provider( /// Resolved route information from model resolution. type ResolvedRoute<'a> = ( + String, Arc, &'a AuthConfig, Arc, @@ -150,9 +151,13 @@ impl Router { output_format: ProviderFormat, client_headers: &ClientHeaders, ) -> Result { - let (provider, auth, spec, format, strategy) = - self.resolve_provider(model, output_format)?; - let payload = match lingua::transform_request(body.clone(), format, Some(&spec.model)) { + let routes = self.resolve_providers(model, output_format)?; + // Choose the first provider + let route = routes + .first() + .ok_or_else(|| Error::NoProvider(output_format))?; + let (_, provider, auth, spec, format, strategy) = route; + let payload = match lingua::transform_request(body.clone(), *format, Some(&spec.model)) { Ok(TransformResult::PassThrough(bytes)) => bytes, Ok(TransformResult::Transformed { bytes, .. }) => bytes, Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(), @@ -163,10 +168,10 @@ impl Router { .execute_with_retry( provider.clone(), auth, - spec, - format, + spec.clone(), + *format, payload, - strategy, + strategy.clone(), client_headers, ) .await?; @@ -207,8 +212,12 @@ impl Router { output_format: ProviderFormat, client_headers: &ClientHeaders, ) -> Result { - let (provider, auth, spec, format, _) = self.resolve_provider(model, output_format)?; - let payload = match lingua::transform_request(body.clone(), format, Some(&spec.model)) { + let routes = self.resolve_providers(model, output_format)?; + let route = routes + .first() + .ok_or_else(|| Error::NoProvider(output_format))?; + let (_, provider, auth, spec, format, _) = route; + let payload = match lingua::transform_request(body.clone(), *format, Some(&spec.model)) { Ok(TransformResult::PassThrough(bytes)) => bytes, Ok(TransformResult::Transformed { bytes, .. }) => bytes, Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(), @@ -216,47 +225,138 @@ impl Router { }; let raw_stream = provider - .complete_stream(payload, auth, &spec, format, client_headers) + .clone() + .complete_stream(payload, auth, spec.as_ref(), *format, client_headers) .await?; Ok(transform_stream(raw_stream, output_format)) } - pub fn provider_alias(&self, model: &str) -> Result { - let (_, format, alias) = self.resolver.resolve(model)?; - let alias = if self.providers.contains_key(&alias) { - alias - } else { - self.formats.get(&format).cloned().unwrap_or(alias) - }; - Ok(alias) + /// Get the aliases of the providers that can handle the given model and output format. + /// + /// # Arguments + /// + /// * `model` - The model name for routing (e.g., "gpt-4", "claude-3-opus") + /// * `output_format` - The output format, or None to auto-detect from body + /// + /// # Returns + /// A vector of provider aliases that can handle the given model and output + /// format. The aliases are in priority order. Follows the same order as the + /// complete and complete_stream methods. + pub fn provider_aliases( + &self, + model: &str, + output_format: ProviderFormat, + ) -> Result> { + self.resolve_providers(model, output_format) + .map(|routes| routes.into_iter().map(|(alias, ..)| alias).collect()) } - fn resolve_provider( + /// Resolve all providers for a given model and output format. + /// + /// # Arguments + /// + /// * `model` - The model name for routing (e.g., "gpt-4", "claude-3-opus") + /// * `output_format` - The output format, or None to auto-detect from body + /// + /// # Returns + /// A vector of resolved routes, one for each provider. Returns routes in + /// priority order. + fn resolve_providers( &self, model: &str, output_format: ProviderFormat, + ) -> Result>> { + let (spec, catalog_format, aliases) = self.resolver.resolve(model)?; + let routes: Vec>> = aliases + .iter() + .map(|alias| { + self.resolve_provider( + output_format, + spec.clone(), + catalog_format, + alias.to_string(), + ) + }) + .collect(); + let mut first_error = None; + let successes: Vec> = routes + .into_iter() + .filter_map(|r| match r { + Ok(s) => Some(s), + Err(e) => { + #[cfg(feature = "tracing")] + tracing::warn!( + model = %model, + output_format = ?output_format, + all_aliases = ?aliases, + spec = ?spec, + catalog_format = ?catalog_format, + error = %e, + "error resolving provider, falling back to next alias", + ); + if first_error.is_none() { + first_error = Some(e); + } + None + } + }) + .collect(); + if successes.is_empty() { + if let Some(fallback_alias) = self.formats.get(&catalog_format).cloned() { + match self.resolve_provider( + output_format, + spec, + catalog_format, + fallback_alias.clone(), + ) { + Ok(route) => return Ok(vec![route]), + Err(fallback_error) => { + #[cfg(feature = "tracing")] + tracing::warn!( + model, + aliases = ?aliases, + fallback_alias = %fallback_alias, + error = %fallback_error, + "format fallback failed", + ); + return Err(fallback_error); + } + } + } + #[cfg(feature = "tracing")] + tracing::warn!( + model, + aliases = ?aliases, + "no providers found for model", + ); + return Err(first_error.unwrap_or_else(|| Error::NoProvider(catalog_format))); + } + Ok(successes) + } + + fn resolve_provider( + &self, + output_format: ProviderFormat, + spec: Arc, + catalog_format: ProviderFormat, + alias: String, ) -> Result> { - let (spec, catalog_format, alias) = self.resolver.resolve(model)?; #[cfg(feature = "tracing")] let registered: Vec<&str> = self.providers.keys().map(String::as_str).collect(); - let alias = if self.providers.contains_key(&alias) { - alias - } else { + if !self.providers.contains_key(alias.as_str()) { #[cfg(feature = "tracing")] tracing::debug!( - model, resolver_alias = %alias, format = ?catalog_format, registered = ?registered, - "resolver alias not found in providers, falling back to format slot" + "resolver alias not found in providers" ); - self.formats.get(&catalog_format).cloned().unwrap_or(alias) - }; + return Err(Error::NoProvider(catalog_format)); + } let provider = self.providers.get(&alias).cloned().ok_or_else(|| { #[cfg(feature = "tracing")] tracing::warn!( - model, alias = %alias, format = ?catalog_format, registered = ?registered, @@ -280,7 +380,7 @@ impl Router { .get(&alias) .ok_or_else(|| Error::NoAuth(alias.clone()))?; let strategy = self.retry_policy.strategy(); - Ok((provider, auth, spec, format, strategy)) + Ok((alias, provider, auth, spec, format, strategy)) } #[allow(clippy::too_many_arguments)] @@ -556,6 +656,7 @@ mod tests { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), } } @@ -575,9 +676,16 @@ mod tests { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), } } + fn openai_spec_with_available_providers(model: &str, flavor: ModelFlavor) -> ModelSpec { + let mut spec = openai_spec(model, flavor); + spec.available_providers = vec!["openai".into(), "azure".into(), "cerebras".into()]; + spec + } + fn dummy_auth() -> AuthConfig { AuthConfig::ApiKey { key: "test".into(), @@ -618,8 +726,18 @@ mod tests { .build() .expect("router builds"); - assert_eq!(router.provider_alias(vertex_model).unwrap(), "vertex"); - assert_eq!(router.provider_alias(google_model).unwrap(), "google"); + assert_eq!( + router + .provider_aliases(vertex_model, ProviderFormat::Google) + .unwrap(), + vec!["vertex".to_string()] + ); + assert_eq!( + router + .provider_aliases(google_model, ProviderFormat::Google) + .unwrap(), + vec!["google".to_string()] + ); } #[test] @@ -643,7 +761,12 @@ mod tests { .build() .expect("router builds"); - assert_eq!(router.provider_alias(vertex_model).unwrap(), "google"); + assert_eq!( + router + .provider_aliases(vertex_model, ProviderFormat::Google) + .unwrap(), + vec!["google".to_string()] + ); } #[test] @@ -665,9 +788,11 @@ mod tests { .build() .expect("router builds"); - let (_, _, _, format, _) = router - .resolve_provider(model, ProviderFormat::ChatCompletions) + let routes = router + .resolve_providers(model, ProviderFormat::ChatCompletions) .expect("resolves"); + assert_eq!(routes.len(), 1); + let (_, _, _, _, format, _) = routes[0]; assert_eq!(format, ProviderFormat::Responses); } @@ -690,9 +815,11 @@ mod tests { .build() .expect("router builds"); - let (_, _, _, format, _) = router - .resolve_provider(model, ProviderFormat::ChatCompletions) + let routes = router + .resolve_providers(model, ProviderFormat::ChatCompletions) .expect("resolves"); + assert_eq!(routes.len(), 1); + let (_, _, _, _, format, _) = routes[0]; assert_eq!(format, ProviderFormat::Responses); } @@ -715,9 +842,11 @@ mod tests { .build() .expect("router builds"); - let (_, _, _, format, _) = router - .resolve_provider(model, ProviderFormat::ChatCompletions) + let routes = router + .resolve_providers(model, ProviderFormat::ChatCompletions) .expect("resolves"); + assert_eq!(routes.len(), 1); + let (_, _, _, _, format, _) = routes[0]; assert_eq!(format, ProviderFormat::ChatCompletions); } @@ -740,9 +869,11 @@ mod tests { .build() .expect("router builds"); - let (_, _, _, format, _) = router - .resolve_provider(model, ProviderFormat::ChatCompletions) + let routes = router + .resolve_providers(model, ProviderFormat::ChatCompletions) .expect("resolves"); + assert_eq!(routes.len(), 1); + let (_, _, _, _, format, _) = routes[0]; assert_eq!(format, ProviderFormat::ChatCompletions); } @@ -841,4 +972,79 @@ mod tests { .expect("router builds"); assert!(router.catalog().get("any").is_none()); } + + #[test] + fn provider_aliases_returns_only_registered_available_providers() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + catalog.insert( + model.into(), + openai_spec_with_available_providers(model, ModelFlavor::Chat), + ); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "openai", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![ProviderFormat::ChatCompletions], + ) + .add_provider( + "azure", + FakeProvider { + name: "azure", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let aliases = router + .provider_aliases(model, ProviderFormat::ChatCompletions) + .expect("provider_aliases"); + assert_eq!(aliases, vec!["openai".to_string(), "azure".to_string()]); + } + + #[test] + fn resolve_providers_falls_back_to_format_slot_when_alias_not_registered() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + catalog.insert( + model.into(), + openai_spec_with_available_providers(model, ModelFlavor::Chat), + ); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "other_gpt", + FakeProvider { + name: "other_gpt", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![ProviderFormat::ChatCompletions], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_providers(model, ProviderFormat::ChatCompletions) + .expect("resolves"); + assert!( + !routes.is_empty(), + "at least one route (unregistered openai falls back to format slot azure)" + ); + assert_eq!( + router + .provider_aliases(model, ProviderFormat::ChatCompletions) + .unwrap(), + vec!["other_gpt".to_string()], + "provider_aliases returns only registered providers from available_providers" + ); + } } diff --git a/crates/braintrust-llm-router/tests/router.rs b/crates/braintrust-llm-router/tests/router.rs index cb340329..86c283f9 100644 --- a/crates/braintrust-llm-router/tests/router.rs +++ b/crates/braintrust-llm-router/tests/router.rs @@ -104,6 +104,7 @@ async fn router_routes_to_stub_provider() { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), }, ); let catalog = Arc::new(catalog); @@ -172,6 +173,7 @@ async fn router_requires_auth_for_provider() { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), }, ); let catalog = Arc::new(catalog); @@ -234,6 +236,7 @@ async fn router_reports_missing_provider() { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), }, ); @@ -361,6 +364,7 @@ async fn router_retries_and_propagates_terminal_error() { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), }, ); let catalog = Arc::new(catalog);