Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions cli/planoai/config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,23 @@ def validate_and_render_schema():
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
)

# Validate at most one model listener
model_listeners = [l for l in listeners if l.get("type") == "model"]
if len(model_listeners) > 1:
raise Exception(
f"Only one model listener is allowed, found {len(model_listeners)}"
)

# Validate filter_chain IDs on listeners reference valid agent/filter IDs
for listener in listeners:
listener_filter_chain = listener.get("filter_chain", [])
for fc_id in listener_filter_chain:
if fc_id not in agent_id_keys:
raise Exception(
f"Listener '{listener.get('name', 'unknown')}' references filter_chain id '{fc_id}' "
f"which is not defined in agents or filters. Available ids: {', '.join(sorted(agent_id_keys))}"
)

# Validate model aliases if present
if "model_aliases" in config_yaml:
model_aliases = config_yaml["model_aliases"]
Expand Down
2 changes: 1 addition & 1 deletion cli/planoai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _resolve_cli_agent_endpoint(plano_config_yaml: dict) -> tuple[str, int]:

if isinstance(listeners, list):
for listener in listeners:
if listener.get("type") in ["model", "model_listener"]:
if listener.get("type") == "model":
host = listener.get("host") or listener.get("address") or "0.0.0.0"
port = listener.get("port", 12000)
return host, port
Expand Down
4 changes: 2 additions & 2 deletions cli/planoai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def convert_legacy_listeners(
) -> tuple[list, dict | None, dict | None]:
llm_gateway_listener = {
"name": "egress_traffic",
"type": "model_listener",
"type": "model",
"port": 12000,
"address": "0.0.0.0",
"timeout": "30s",
Expand All @@ -98,7 +98,7 @@ def convert_legacy_listeners(

prompt_gateway_listener = {
"name": "ingress_traffic",
"type": "prompt_listener",
"type": "prompt",
"port": 10000,
"address": "0.0.0.0",
"timeout": "30s",
Expand Down
12 changes: 6 additions & 6 deletions cli/test/test_config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,15 @@ def test_convert_legacy_llm_providers():
assert updated_providers == [
{
"name": "egress_traffic",
"type": "model_listener",
"type": "model",
"port": 12000,
"address": "0.0.0.0",
"timeout": "30s",
"model_providers": [{"model": "openai/gpt-4o", "access_key": "test_key"}],
},
{
"name": "ingress_traffic",
"type": "prompt_listener",
"type": "prompt",
"port": 10000,
"address": "0.0.0.0",
"timeout": "30s",
Expand All @@ -400,7 +400,7 @@ def test_convert_legacy_llm_providers():
},
],
"name": "egress_traffic",
"type": "model_listener",
"type": "model",
"port": 12000,
"timeout": "30s",
}
Expand All @@ -410,7 +410,7 @@ def test_convert_legacy_llm_providers():
"name": "ingress_traffic",
"port": 10000,
"timeout": "30s",
"type": "prompt_listener",
"type": "prompt",
}


Expand Down Expand Up @@ -449,7 +449,7 @@ def test_convert_legacy_llm_providers_no_prompt_gateway():
"name": "egress_traffic",
"port": 12000,
"timeout": "30s",
"type": "model_listener",
"type": "model",
}
]
assert llm_gateway == {
Expand All @@ -461,7 +461,7 @@ def test_convert_legacy_llm_providers_no_prompt_gateway():
},
],
"name": "egress_traffic",
"type": "model_listener",
"type": "model",
"port": 12000,
"timeout": "30s",
}
4 changes: 4 additions & 0 deletions config/plano_config_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ properties:
required:
- id
- description
filter_chain:
type: array
items:
type: string
additionalProperties: false
required:
- type
Expand Down
4 changes: 3 additions & 1 deletion crates/brightstaff/src/handlers/agent_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl AgentSelector {
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{AgentFilterChain, Listener};
use common::configuration::{AgentFilterChain, Listener, ListenerType};

fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
Expand All @@ -192,8 +192,10 @@ mod tests {

fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
Listener {
listener_type: ListenerType::Agent,
name: name.to_string(),
agents: Some(agents),
filter_chain: None,
port: 8080,
router: None,
}
Expand Down
4 changes: 3 additions & 1 deletion crates/brightstaff/src/handlers/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use hyper::StatusCode;
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};
use common::configuration::{Agent, AgentFilterChain, Listener, ListenerType};

fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
Expand Down Expand Up @@ -72,8 +72,10 @@ mod tests {
};

let listener = Listener {
listener_type: ListenerType::Agent,
name: "test-listener".to_string(),
agents: Some(vec![agent_pipeline.clone()]),
filter_chain: None,
port: 8080,
router: None,
};
Expand Down
82 changes: 79 additions & 3 deletions crates/brightstaff/src/handlers/llm.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use bytes::Bytes;
use common::configuration::{ModelAlias, SpanAttributes};
use common::configuration::{AgentFilterChain, ModelAlias, ModelFilterChain, SpanAttributes};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
Expand All @@ -8,9 +8,9 @@ use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::{BodyExt, Full};
use hyper::header::{self};
use hyper::{Request, Response};
use hyper::{Request, Response, StatusCode};
use opentelemetry::global;
use opentelemetry::trace::get_active_span;
use opentelemetry_http::HeaderInjector;
Expand All @@ -19,6 +19,8 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};

use super::pipeline_processor::PipelineProcessor;

use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
Expand All @@ -34,6 +36,7 @@ use crate::tracing::{

use common::errors::BrightStaffError;

#[allow(clippy::too_many_arguments)]
pub async fn llm_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
Expand All @@ -42,6 +45,7 @@ pub async fn llm_chat(
llm_providers: Arc<RwLock<LlmProviders>>,
span_attributes: Arc<Option<SpanAttributes>>,
state_storage: Option<Arc<dyn StateStorage>>,
model_filter_chain: Arc<Option<ModelFilterChain>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
Expand Down Expand Up @@ -81,6 +85,7 @@ pub async fn llm_chat(
request_id,
request_path,
request_headers,
model_filter_chain,
)
.instrument(request_span)
.await
Expand All @@ -98,6 +103,7 @@ async fn llm_chat_inner(
request_id: String,
request_path: String,
mut request_headers: hyper::HeaderMap,
model_filter_chain: Arc<Option<ModelFilterChain>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
Expand Down Expand Up @@ -250,6 +256,70 @@ async fn llm_chat_inner(
if client_request.remove_metadata_key("plano_preference_config") {
debug!("removed plano_preference_config from metadata");
}

// === Filter chain processing for model listener ===
if let Some(ref mfc) = *model_filter_chain {
{
debug!(filter_ids = ?mfc.filter_ids, "processing model listener filter chain");

let temp_filter_chain = AgentFilterChain {
id: "model_listener".to_string(),
default: None,
description: None,
filter_chain: Some(mfc.filter_ids.clone()),
};

let mut pipeline_processor = PipelineProcessor::default();
let messages = client_request.get_messages();
match pipeline_processor
.process_filter_chain(&messages, &temp_filter_chain, &mfc.agents, &request_headers)
.await
{
Ok(filtered_messages) => {
client_request.set_messages(&filtered_messages);
info!(
original_count = messages.len(),
filtered_count = filtered_messages.len(),
"filter chain processed successfully"
);
}
Err(super::pipeline_processor::PipelineError::ClientError {
agent,
status,
body,
}) => {
warn!(
agent = %agent,
status = %status,
body = %body,
"client error from filter chain"
);
let error_json = serde_json::json!({
"error": "FilterChainError",
"agent": agent,
"status": status,
"agent_response": body
});
let mut error_response = Response::new(full(error_json.to_string()));
*error_response.status_mut() =
StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST);
error_response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
return Ok(error_response);
}
Err(err) => {
warn!(error = %err, "filter chain processing failed");
let err_msg = format!("Filter chain processing failed: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
}
}
}

if let Some(ref client_api_kind) = client_api {
let upstream_api =
provider_id.compatible_api_for_client(client_api_kind, is_streaming_request);
Expand Down Expand Up @@ -570,3 +640,9 @@ async fn get_provider_info(
(hermesllm::ProviderId::OpenAI, None)
}
}

fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
33 changes: 32 additions & 1 deletion crates/brightstaff/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::{Agent, Configuration};
use common::configuration::{Agent, Configuration, ListenerType, ModelFilterChain};
use common::consts::{
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
};
Expand All @@ -24,6 +24,7 @@ use hyper_util::rt::TokioIo;
use opentelemetry::trace::FutureExt;
use opentelemetry::{global, Context};
use opentelemetry_http::HeaderExtractor;
use std::collections::HashMap;
use std::sync::Arc;
use std::{env, fs};
use tokio::net::TcpListener;
Expand Down Expand Up @@ -80,11 +81,38 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.cloned()
.collect();

// Build global agent map for resolving filter chain references
let global_agent_map: HashMap<String, Agent> = all_agents
.iter()
.map(|a| (a.id.clone(), a.clone()))
.collect();

// Create expanded provider list for /v1/models endpoint
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
.expect("Failed to create LlmProviders");
let llm_providers = Arc::new(RwLock::new(llm_providers));
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));

// Resolve model listener filter chain and agents at startup
let model_listener = plano_config
.listeners
.iter()
.find(|l| l.listener_type == ListenerType::Model);
let model_filter_chain: Arc<Option<ModelFilterChain>> = Arc::new(
model_listener
.and_then(|l| l.filter_chain.clone())
.filter(|fc| !fc.is_empty())
.map(|fc| {
let agents = fc
.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect();
ModelFilterChain {
filter_ids: fc,
agents,
}
}),
);
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
Expand Down Expand Up @@ -179,6 +207,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {

let llm_providers = llm_providers.clone();
let agents_list = combined_agents_filters_list.clone();
let model_filter_chain = model_filter_chain.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
Expand All @@ -190,6 +219,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let model_filter_chain = model_filter_chain.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
Expand Down Expand Up @@ -248,6 +278,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
llm_providers,
span_attributes,
state_storage,
model_filter_chain,
)
.with_context(parent_cx)
.await
Expand Down
Loading
Loading