diff --git a/crates/braintrust-llm-router/src/catalog/resolver.rs b/crates/braintrust-llm-router/src/catalog/resolver.rs index c506bf4d..2f083808 100644 --- a/crates/braintrust-llm-router/src/catalog/resolver.rs +++ b/crates/braintrust-llm-router/src/catalog/resolver.rs @@ -28,7 +28,7 @@ impl ModelResolver { Arc::clone(&self.catalog) } - pub fn resolve(&self, model: &str) -> Result<(Arc, ProviderFormat, String)> { + pub fn resolve(&self, model: &str) -> Result<(Arc, ProviderFormat, Vec)> { let spec = self .catalog .get(model) @@ -44,14 +44,20 @@ impl ModelResolver { } else { spec.format }; - let provider_alias = self.aliases.get(model).cloned().unwrap_or_else(|| { - if lingua::is_vertex_model(model) { - "vertex".to_string() - } else { - format_identifier(format) - } - }); - Ok((spec, format, provider_alias)) + let provider_aliases = self + .aliases + .get(model) + .map(|alias| vec![alias.clone()]) + .unwrap_or_else(|| { + if !spec.available_providers.is_empty() { + spec.available_providers.clone() + } else if lingua::is_vertex_model(model) { + vec!["vertex".to_string()] + } else { + vec![format_identifier(format)] + } + }); + Ok((spec, format, provider_aliases)) } } @@ -95,9 +101,20 @@ mod tests { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), } } + fn spec_with_available_providers( + model: &str, + format: ProviderFormat, + providers: Vec, + ) -> ModelSpec { + let mut s = spec(model, format); + s.available_providers = providers; + s + } + #[test] fn resolve_returns_default_alias() { let mut catalog = ModelCatalog::empty(); @@ -107,9 +124,9 @@ mod tests { ); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve("model").expect("resolves"); + let (_, format, aliases) = resolver.resolve("model").expect("resolves"); assert_eq!(format, ProviderFormat::ChatCompletions); - assert_eq!(alias, "openai"); + assert_eq!(aliases, vec!["openai".to_string()]); } #[test] @@ -119,9 +136,9 @@ mod tests { let resolver = ModelResolver::new(Arc::new(catalog)) .with_aliases(HashMap::from([("model".into(), "custom".into())])); - let (_, format, alias) = resolver.resolve("model").expect("resolves"); + let (_, format, aliases) = resolver.resolve("model").expect("resolves"); assert_eq!(format, ProviderFormat::Anthropic); - assert_eq!(alias, "custom"); + assert_eq!(aliases, vec!["custom".to_string()]); } #[test] @@ -138,9 +155,9 @@ mod tests { catalog.insert(model.into(), spec(model, ProviderFormat::Anthropic)); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve(model).expect("resolves"); + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); assert_eq!(format, ProviderFormat::BedrockAnthropic); - assert_eq!(alias, "bedrock"); + assert_eq!(aliases, vec!["bedrock".to_string()]); } #[test] @@ -150,9 +167,9 @@ mod tests { catalog.insert(model.into(), spec(model, ProviderFormat::Anthropic)); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve(model).expect("resolves"); + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); assert_eq!(format, ProviderFormat::Anthropic); - assert_eq!(alias, "anthropic"); + assert_eq!(aliases, vec!["anthropic".to_string()]); } #[test] @@ -162,9 +179,9 @@ mod tests { catalog.insert(model.into(), spec(model, ProviderFormat::Google)); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve(model).expect("resolves"); + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); assert_eq!(format, ProviderFormat::Google); - assert_eq!(alias, "vertex"); + assert_eq!(aliases, vec!["vertex".to_string()]); } #[test] @@ -174,9 +191,9 @@ mod tests { catalog.insert(model.into(), spec(model, ProviderFormat::Google)); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve(model).expect("resolves"); + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); assert_eq!(format, ProviderFormat::Google); - assert_eq!(alias, "google"); + assert_eq!(aliases, vec!["google".to_string()]); } #[test] @@ -186,8 +203,47 @@ mod tests { catalog.insert(model.into(), spec(model, ProviderFormat::Anthropic)); let resolver = ModelResolver::new(Arc::new(catalog)); - let (_, format, alias) = resolver.resolve(model).expect("resolves"); + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); assert_eq!(format, ProviderFormat::VertexAnthropic); - assert_eq!(alias, "vertex"); + assert_eq!(aliases, vec!["vertex".to_string()]); + } + + #[test] + fn resolve_returns_available_providers_when_no_alias() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + catalog.insert( + model.into(), + spec_with_available_providers( + model, + ProviderFormat::ChatCompletions, + vec!["openai".to_string(), "azure".to_string()], + ), + ); + let resolver = ModelResolver::new(Arc::new(catalog)); + + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); + assert_eq!(format, ProviderFormat::ChatCompletions); + assert_eq!(aliases, vec!["openai".to_string(), "azure".to_string()]); + } + + #[test] + fn resolve_custom_alias_overrides_available_providers() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + catalog.insert( + model.into(), + spec_with_available_providers( + model, + ProviderFormat::ChatCompletions, + vec!["openai".to_string(), "azure".to_string()], + ), + ); + let resolver = ModelResolver::new(Arc::new(catalog)) + .with_aliases(HashMap::from([(model.into(), "custom".into())])); + + let (_, format, aliases) = resolver.resolve(model).expect("resolves"); + assert_eq!(format, ProviderFormat::ChatCompletions); + assert_eq!(aliases, vec!["custom".to_string()]); } } diff --git a/crates/braintrust-llm-router/src/catalog/spec.rs b/crates/braintrust-llm-router/src/catalog/spec.rs index c4c2ae33..4aeb2018 100644 --- a/crates/braintrust-llm-router/src/catalog/spec.rs +++ b/crates/braintrust-llm-router/src/catalog/spec.rs @@ -44,6 +44,8 @@ pub struct ModelSpec { pub supports_streaming: bool, #[serde(default)] pub extra: serde_json::Map, + #[serde(default)] + pub available_providers: Vec, } fn default_true() -> bool { @@ -115,6 +117,7 @@ mod tests { max_output_tokens: None, supports_streaming: true, extra: serde_json::Map::new(), + available_providers: vec![], }; assert!(spec.requires_responses_api()); } diff --git a/crates/braintrust-llm-router/src/client.rs b/crates/braintrust-llm-router/src/client.rs index 51a666c9..2facb4ff 100644 --- a/crates/braintrust-llm-router/src/client.rs +++ b/crates/braintrust-llm-router/src/client.rs @@ -57,3 +57,15 @@ pub fn set_override_client(client: ClientWithMiddleware) { pub fn clear_override_client() { *OVERRIDE_CLIENT.write() = None; } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_middleware_client_with_no_override() { + clear_override_client(); + let client = build_middleware_client(&ClientSettings::default()); + assert!(client.is_ok()); + } +} diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 9692f5ea..7467a444 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -99,6 +99,7 @@ pub fn create_provider( /// Resolved route information from model resolution. type ResolvedRoute<'a> = ( + String, Arc, &'a AuthConfig, Arc, @@ -150,9 +151,13 @@ impl Router { output_format: ProviderFormat, client_headers: &ClientHeaders, ) -> Result { - let (provider, auth, spec, format, strategy) = - self.resolve_provider(model, output_format)?; - let payload = match lingua::transform_request(body.clone(), format, Some(&spec.model)) { + let routes = self.resolve_providers(model, output_format)?; + // Choose the first provider + let route = routes + .first() + .ok_or_else(|| Error::NoProvider(output_format))?; + let (_, provider, auth, spec, format, strategy) = route; + let payload = match lingua::transform_request(body.clone(), *format, Some(&spec.model)) { Ok(TransformResult::PassThrough(bytes)) => bytes, Ok(TransformResult::Transformed { bytes, .. }) => bytes, Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(), @@ -163,10 +168,10 @@ impl Router { .execute_with_retry( provider.clone(), auth, - spec, - format, + spec.clone(), + *format, payload, - strategy, + strategy.clone(), client_headers, ) .await?; @@ -207,8 +212,12 @@ impl Router { output_format: ProviderFormat, client_headers: &ClientHeaders, ) -> Result { - let (provider, auth, spec, format, _) = self.resolve_provider(model, output_format)?; - let payload = match lingua::transform_request(body.clone(), format, Some(&spec.model)) { + let routes = self.resolve_providers(model, output_format)?; + let route = routes + .first() + .ok_or_else(|| Error::NoProvider(output_format))?; + let (_, provider, auth, spec, format, _) = route; + let payload = match lingua::transform_request(body.clone(), *format, Some(&spec.model)) { Ok(TransformResult::PassThrough(bytes)) => bytes, Ok(TransformResult::Transformed { bytes, .. }) => bytes, Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(), @@ -216,47 +225,138 @@ impl Router { }; let raw_stream = provider - .complete_stream(payload, auth, &spec, format, client_headers) + .clone() + .complete_stream(payload, auth, spec.as_ref(), *format, client_headers) .await?; Ok(transform_stream(raw_stream, output_format)) } - pub fn provider_alias(&self, model: &str) -> Result { - let (_, format, alias) = self.resolver.resolve(model)?; - let alias = if self.providers.contains_key(&alias) { - alias - } else { - self.formats.get(&format).cloned().unwrap_or(alias) - }; - Ok(alias) + /// Get the aliases of the providers that can handle the given model and output format. + /// + /// # Arguments + /// + /// * `model` - The model name for routing (e.g., "gpt-4", "claude-3-opus") + /// * `output_format` - The output format, or None to auto-detect from body + /// + /// # Returns + /// A vector of provider aliases that can handle the given model and output + /// format. The aliases are in priority order. Follows the same order as the + /// complete and complete_stream methods. + pub fn provider_aliases( + &self, + model: &str, + output_format: ProviderFormat, + ) -> Result> { + self.resolve_providers(model, output_format) + .map(|routes| routes.into_iter().map(|(alias, ..)| alias).collect()) } - fn resolve_provider( + /// Resolve all providers for a given model and output format. + /// + /// # Arguments + /// + /// * `model` - The model name for routing (e.g., "gpt-4", "claude-3-opus") + /// * `output_format` - The output format, or None to auto-detect from body + /// + /// # Returns + /// A vector of resolved routes, one for each provider. Returns routes in + /// priority order. + fn resolve_providers( &self, model: &str, output_format: ProviderFormat, + ) -> Result>> { + let (spec, catalog_format, aliases) = self.resolver.resolve(model)?; + let routes: Vec>> = aliases + .iter() + .map(|alias| { + self.resolve_provider( + output_format, + spec.clone(), + catalog_format, + alias.to_string(), + ) + }) + .collect(); + let mut first_error = None; + let successes: Vec> = routes + .into_iter() + .filter_map(|r| match r { + Ok(s) => Some(s), + Err(e) => { + #[cfg(feature = "tracing")] + tracing::warn!( + model = %model, + output_format = ?output_format, + all_aliases = ?aliases, + spec = ?spec, + catalog_format = ?catalog_format, + error = %e, + "error resolving provider, falling back to next alias", + ); + if first_error.is_none() { + first_error = Some(e); + } + None + } + }) + .collect(); + if successes.is_empty() { + if let Some(fallback_alias) = self.formats.get(&catalog_format).cloned() { + match self.resolve_provider( + output_format, + spec, + catalog_format, + fallback_alias.clone(), + ) { + Ok(route) => return Ok(vec![route]), + Err(fallback_error) => { + #[cfg(feature = "tracing")] + tracing::warn!( + model, + aliases = ?aliases, + fallback_alias = %fallback_alias, + error = %fallback_error, + "format fallback failed", + ); + return Err(fallback_error); + } + } + } + #[cfg(feature = "tracing")] + tracing::warn!( + model, + aliases = ?aliases, + "no providers found for model", + ); + return Err(first_error.unwrap_or_else(|| Error::NoProvider(catalog_format))); + } + Ok(successes) + } + + fn resolve_provider( + &self, + output_format: ProviderFormat, + spec: Arc, + catalog_format: ProviderFormat, + alias: String, ) -> Result> { - let (spec, catalog_format, alias) = self.resolver.resolve(model)?; #[cfg(feature = "tracing")] let registered: Vec<&str> = self.providers.keys().map(String::as_str).collect(); - let alias = if self.providers.contains_key(&alias) { - alias - } else { + if !self.providers.contains_key(alias.as_str()) { #[cfg(feature = "tracing")] tracing::debug!( - model, resolver_alias = %alias, format = ?catalog_format, registered = ?registered, - "resolver alias not found in providers, falling back to format slot" + "resolver alias not found in providers" ); - self.formats.get(&catalog_format).cloned().unwrap_or(alias) - }; + return Err(Error::NoProvider(catalog_format)); + } let provider = self.providers.get(&alias).cloned().ok_or_else(|| { #[cfg(feature = "tracing")] tracing::warn!( - model, alias = %alias, format = ?catalog_format, registered = ?registered, @@ -280,7 +380,7 @@ impl Router { .get(&alias) .ok_or_else(|| Error::NoAuth(alias.clone()))?; let strategy = self.retry_policy.strategy(); - Ok((provider, auth, spec, format, strategy)) + Ok((alias, provider, auth, spec, format, strategy)) } #[allow(clippy::too_many_arguments)] @@ -556,6 +656,7 @@ mod tests { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), } } @@ -575,9 +676,16 @@ mod tests { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), } } + fn openai_spec_with_available_providers(model: &str, flavor: ModelFlavor) -> ModelSpec { + let mut spec = openai_spec(model, flavor); + spec.available_providers = vec!["openai".into(), "azure".into(), "cerebras".into()]; + spec + } + fn dummy_auth() -> AuthConfig { AuthConfig::ApiKey { key: "test".into(), @@ -618,8 +726,18 @@ mod tests { .build() .expect("router builds"); - assert_eq!(router.provider_alias(vertex_model).unwrap(), "vertex"); - assert_eq!(router.provider_alias(google_model).unwrap(), "google"); + assert_eq!( + router + .provider_aliases(vertex_model, ProviderFormat::Google) + .unwrap(), + vec!["vertex".to_string()] + ); + assert_eq!( + router + .provider_aliases(google_model, ProviderFormat::Google) + .unwrap(), + vec!["google".to_string()] + ); } #[test] @@ -643,7 +761,12 @@ mod tests { .build() .expect("router builds"); - assert_eq!(router.provider_alias(vertex_model).unwrap(), "google"); + assert_eq!( + router + .provider_aliases(vertex_model, ProviderFormat::Google) + .unwrap(), + vec!["google".to_string()] + ); } #[test] @@ -665,9 +788,11 @@ mod tests { .build() .expect("router builds"); - let (_, _, _, format, _) = router - .resolve_provider(model, ProviderFormat::ChatCompletions) + let routes = router + .resolve_providers(model, ProviderFormat::ChatCompletions) .expect("resolves"); + assert_eq!(routes.len(), 1); + let (_, _, _, _, format, _) = routes[0]; assert_eq!(format, ProviderFormat::Responses); } @@ -690,9 +815,11 @@ mod tests { .build() .expect("router builds"); - let (_, _, _, format, _) = router - .resolve_provider(model, ProviderFormat::ChatCompletions) + let routes = router + .resolve_providers(model, ProviderFormat::ChatCompletions) .expect("resolves"); + assert_eq!(routes.len(), 1); + let (_, _, _, _, format, _) = routes[0]; assert_eq!(format, ProviderFormat::Responses); } @@ -715,9 +842,11 @@ mod tests { .build() .expect("router builds"); - let (_, _, _, format, _) = router - .resolve_provider(model, ProviderFormat::ChatCompletions) + let routes = router + .resolve_providers(model, ProviderFormat::ChatCompletions) .expect("resolves"); + assert_eq!(routes.len(), 1); + let (_, _, _, _, format, _) = routes[0]; assert_eq!(format, ProviderFormat::ChatCompletions); } @@ -740,9 +869,11 @@ mod tests { .build() .expect("router builds"); - let (_, _, _, format, _) = router - .resolve_provider(model, ProviderFormat::ChatCompletions) + let routes = router + .resolve_providers(model, ProviderFormat::ChatCompletions) .expect("resolves"); + assert_eq!(routes.len(), 1); + let (_, _, _, _, format, _) = routes[0]; assert_eq!(format, ProviderFormat::ChatCompletions); } @@ -841,4 +972,79 @@ mod tests { .expect("router builds"); assert!(router.catalog().get("any").is_none()); } + + #[test] + fn provider_aliases_returns_only_registered_available_providers() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + catalog.insert( + model.into(), + openai_spec_with_available_providers(model, ModelFlavor::Chat), + ); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "openai", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![ProviderFormat::ChatCompletions], + ) + .add_provider( + "azure", + FakeProvider { + name: "azure", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let aliases = router + .provider_aliases(model, ProviderFormat::ChatCompletions) + .expect("provider_aliases"); + assert_eq!(aliases, vec!["openai".to_string(), "azure".to_string()]); + } + + #[test] + fn resolve_providers_falls_back_to_format_slot_when_alias_not_registered() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + catalog.insert( + model.into(), + openai_spec_with_available_providers(model, ModelFlavor::Chat), + ); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "other_gpt", + FakeProvider { + name: "other_gpt", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![ProviderFormat::ChatCompletions], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_providers(model, ProviderFormat::ChatCompletions) + .expect("resolves"); + assert!( + !routes.is_empty(), + "at least one route (unregistered openai falls back to format slot azure)" + ); + assert_eq!( + router + .provider_aliases(model, ProviderFormat::ChatCompletions) + .unwrap(), + vec!["other_gpt".to_string()], + "provider_aliases returns only registered providers from available_providers" + ); + } } diff --git a/crates/braintrust-llm-router/tests/router.rs b/crates/braintrust-llm-router/tests/router.rs index cb340329..86c283f9 100644 --- a/crates/braintrust-llm-router/tests/router.rs +++ b/crates/braintrust-llm-router/tests/router.rs @@ -104,6 +104,7 @@ async fn router_routes_to_stub_provider() { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), }, ); let catalog = Arc::new(catalog); @@ -172,6 +173,7 @@ async fn router_requires_auth_for_provider() { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), }, ); let catalog = Arc::new(catalog); @@ -234,6 +236,7 @@ async fn router_reports_missing_provider() { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), }, ); @@ -361,6 +364,7 @@ async fn router_retries_and_propagates_terminal_error() { max_output_tokens: None, supports_streaming: true, extra: Default::default(), + available_providers: Default::default(), }, ); let catalog = Arc::new(catalog);