diff --git a/src/eval.rs b/src/eval.rs index aba1c34..9a62601 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,61 +1,38 @@ use std::collections::{BTreeSet, HashMap, VecDeque}; use std::ffi::{OsStr, OsString}; -use std::io::IsTerminal; use std::path::{Path, PathBuf}; use std::process::{ExitStatus, Stdio}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use actix_web::dev::Service; -use actix_web::http::header::{ - HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, - ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, - ACCESS_CONTROL_MAX_AGE, AUTHORIZATION, CACHE_CONTROL, CONNECTION, CONTENT_TYPE, ORIGIN, VARY, -}; -use actix_web::{guard, web, App, HttpRequest, HttpResponse, HttpServer}; use anyhow::{Context, Result}; use clap::{Args, ValueEnum}; -use crossterm::queue; -use crossterm::style::{ - Attribute, Color as CtColor, ResetColor, SetAttribute, SetBackgroundColor, SetForegroundColor, - Stylize, -}; -use futures_util::stream; -use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; use reqwest::Client; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; -use strip_ansi_escapes::strip; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::net::UnixListener; use tokio::process::Command; use tokio::sync::mpsc; -use unicode_width::UnicodeWidthStr; - -use ratatui::backend::TestBackend; -use ratatui::layout::{Alignment, Constraint}; -use ratatui::style::{Color, Modifier, Style}; -use ratatui::text::{Line, Span}; -use ratatui::widgets::{Cell, Row, Table}; -use ratatui::Terminal; use crate::args::BaseArgs; use crate::auth::resolved_auth_env; -use crate::ui::{animations_enabled, is_quiet}; + +mod dev_server; +mod events; +mod ui; + +use self::dev_server::{ + collect_allowed_dev_origins, resolve_app_url, run_dev_server, DevServerState, +}; +use self::events::{ + EvalErrorPayload, EvalEvent, ExperimentStart, ExperimentSummary, ProcessingEventData, + SseConsoleEventData, SseDependenciesEventData, SseProgressEventData, +}; +use self::ui::EvalUi; const MAX_NAME_LENGTH: usize = 40; const WATCH_POLL_INTERVAL: Duration = Duration::from_millis(500); -const MAIN_ORIGIN: &str = "https://www.braintrust.dev"; -const BRAINTRUSTDATA_ORIGIN: &str = "https://www.braintrustdata.com"; -const CORS_METHODS: &str = "GET, PATCH, POST, PUT, DELETE, OPTIONS"; -const CORS_ALLOWED_HEADERS: &str = "Content-Type, X-Amz-Date, Authorization, X-Api-Key, X-Amz-Security-Token, x-bt-auth-token, x-bt-parent, x-bt-org-name, x-bt-project-id, x-bt-stream-fmt, x-bt-use-cache, x-bt-use-gateway, x-stainless-os, x-stainless-lang, x-stainless-package-version, x-stainless-runtime, x-stainless-runtime-version, x-stainless-arch"; -const CORS_EXPOSED_HEADERS: &str = - "x-bt-cursor, x-bt-found-existing-experiment, x-bt-span-id, x-bt-span-export"; -const HEADER_BT_AUTH_TOKEN: &str = "x-bt-auth-token"; -const HEADER_BT_ORG_NAME: &str = "x-bt-org-name"; -const HEADER_CORS_REQ_PRIVATE_NETWORK: &str = "access-control-request-private-network"; -const HEADER_CORS_ALLOW_PRIVATE_NETWORK: &str = "access-control-allow-private-network"; const SSE_SOCKET_BIND_MAX_ATTEMPTS: u8 = 16; const EVAL_NODE_MAX_OLD_SPACE_SIZE_MB: usize = 8192; const MAX_DEFERRED_EVAL_ERRORS: usize = 8; @@ -108,81 +85,6 @@ enum RunnerKind { Other, } -#[derive(Debug, Clone, Serialize, Deserialize)] -struct EvalRequest { - name: String, - #[serde(default)] - parameters: Option, - data: Value, - #[serde(default)] - scores: Option>, - #[serde(default)] - experiment_name: Option, - #[serde(default)] - project_id: Option, - #[serde(default)] - parent: Option, - #[serde(default)] - stream: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct EvalScore { - name: String, - function_id: Value, -} - -#[derive(Debug, Deserialize)] -struct DatasetLookupRow { - project_id: String, - name: String, -} - -#[derive(Debug, Clone, Deserialize)] -#[serde(untagged)] -enum DatasetIdField { - String(String), - Other(Value), -} - -#[derive(Debug, Clone, Deserialize)] -struct DatasetEvalDataInput { - #[serde(default)] - dataset_id: Option, - #[serde(default)] - _internal_btql: Option, -} - -#[derive(Debug, Clone, Serialize)] -struct ResolvedDatasetEvalData { - project_id: String, - dataset_name: String, - #[serde(skip_serializing_if = "Option::is_none")] - _internal_btql: Option, -} - -#[derive(Clone)] -struct DevServerState { - base: BaseArgs, - language_override: Option, - runner_override: Option, - files: Vec, - no_send_logs: bool, - options: EvalRunOptions, - host: String, - port: u16, - allowed_org_name: Option, - allowed_origins: Vec, - app_url: String, - http_client: Client, -} - -#[derive(Debug)] -struct DevAuthContext { - token: String, - org_name: String, -} - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] struct RunnerFilter { path: Vec, @@ -922,739 +824,6 @@ fn is_esm_interop_error(message: &str) -> bool { PATTERNS.iter().any(|pattern| message.contains(pattern)) } -fn resolve_app_url(base: &BaseArgs) -> String { - if let Some(app_url) = base.app_url.as_ref() { - return app_url.clone(); - } - "https://www.braintrust.dev".to_string() -} - -fn app_origin_from_url(url: &str) -> Option { - reqwest::Url::parse(url).ok().and_then(|parsed| { - let origin = parsed.origin(); - if origin.is_tuple() { - Some(origin.ascii_serialization()) - } else { - None - } - }) -} - -fn collect_allowed_dev_origins(explicit: &[String], app_url: &str) -> Vec { - let mut deduped = BTreeSet::new(); - for origin in explicit { - let trimmed = origin.trim(); - if !trimmed.is_empty() { - deduped.insert(trimmed.to_string()); - } - } - if let Some(origin) = app_origin_from_url(app_url) { - deduped.insert(origin); - } - deduped.into_iter().collect() -} - -fn join_app_url(app_url: &str, path: &str) -> Result { - let base = format!("{}/", app_url.trim_end_matches('/')); - let base_url = reqwest::Url::parse(&base).context("invalid app URL")?; - let joined = base_url - .join(path.trim_start_matches('/')) - .context("failed to join app URL path")?; - Ok(joined.to_string()) -} - -fn json_error_response(status: actix_web::http::StatusCode, message: &str) -> HttpResponse { - HttpResponse::build(status).json(json!({ "error": message })) -} - -fn parse_auth_token(req: &HttpRequest) -> Option { - if let Some(token) = req.headers().get(HEADER_BT_AUTH_TOKEN) { - if let Ok(value) = token.to_str() { - if !value.trim().is_empty() { - return Some(value.trim().to_string()); - } - } - } - - let auth = req.headers().get(AUTHORIZATION)?; - let auth = auth.to_str().ok()?.trim(); - if auth.is_empty() { - return None; - } - if let Some(token) = auth.strip_prefix("Bearer ") { - let token = token.trim(); - if token.is_empty() { - None - } else { - Some(token.to_string()) - } - } else { - Some(auth.to_string()) - } -} - -async fn authenticate_dev_request( - req: &HttpRequest, - state: &DevServerState, -) -> std::result::Result { - let token = match parse_auth_token(req) { - Some(token) if !token.eq_ignore_ascii_case("null") => token, - _ => { - return Err(json_error_response( - actix_web::http::StatusCode::UNAUTHORIZED, - "Unauthorized", - )); - } - }; - - let org_name = match req - .headers() - .get(HEADER_BT_ORG_NAME) - .and_then(|value| value.to_str().ok()) - { - Some(value) if !value.trim().is_empty() => value.trim().to_string(), - _ => { - return Err(json_error_response( - actix_web::http::StatusCode::BAD_REQUEST, - &format!("Missing {HEADER_BT_ORG_NAME} header"), - )); - } - }; - - if let Some(allowed_org_name) = state.allowed_org_name.as_ref() { - if allowed_org_name != &org_name { - let message = format!( - "Org '{org_name}' is not allowed. Only org '{allowed_org_name}' is allowed." - ); - return Err(json_error_response( - actix_web::http::StatusCode::FORBIDDEN, - &message, - )); - } - } - - let login_url = match join_app_url(&state.app_url, "api/apikey/login") { - Ok(url) => url, - Err(err) => { - return Err(json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - &format!("{err:#}"), - )); - } - }; - let response = state - .http_client - .post(login_url) - .bearer_auth(&token) - .send() - .await - .map_err(|_| { - json_error_response(actix_web::http::StatusCode::UNAUTHORIZED, "Unauthorized") - })?; - if !response.status().is_success() { - return Err(json_error_response( - actix_web::http::StatusCode::UNAUTHORIZED, - "Unauthorized", - )); - } - - let payload = response.json::().await.unwrap_or(Value::Null); - if let Some(orgs) = payload.get("org_info").and_then(|value| value.as_array()) { - let matched = orgs.iter().any(|org| { - org.get("name") - .and_then(|name| name.as_str()) - .map(|name| name == org_name) - .unwrap_or(false) - }); - if !matched { - return Err(json_error_response( - actix_web::http::StatusCode::UNAUTHORIZED, - "Unauthorized", - )); - } - } else { - return Err(json_error_response( - actix_web::http::StatusCode::UNAUTHORIZED, - "Unauthorized", - )); - } - - Ok(DevAuthContext { token, org_name }) -} - -async fn resolve_dataset_ref_for_eval_request( - state: &DevServerState, - auth: &DevAuthContext, - eval_request: &mut EvalRequest, -) -> std::result::Result<(), HttpResponse> { - let input = match serde_json::from_value::(eval_request.data.clone()) { - Ok(value) => value, - Err(_) => return Ok(()), - }; - - let dataset_id = match input.dataset_id { - Some(DatasetIdField::String(dataset_id)) => dataset_id, - Some(DatasetIdField::Other(value)) => { - let received_type = match value { - Value::Null => "null", - Value::Bool(_) => "boolean", - Value::Number(_) => "number", - Value::String(_) => "string", - Value::Array(_) => "array", - Value::Object(_) => "object", - }; - return Err(json_error_response( - actix_web::http::StatusCode::BAD_REQUEST, - &format!("Invalid dataset_id: expected a string, got {received_type}."), - )); - } - None => { - return Ok(()); - } - }; - if dataset_id.trim().is_empty() { - return Err(json_error_response( - actix_web::http::StatusCode::BAD_REQUEST, - "Invalid dataset_id: expected a non-empty string.", - )); - } - - let lookup_url = match join_app_url(&state.app_url, "api/dataset/get") { - Ok(url) => url, - Err(err) => { - return Err(json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - &format!("{err:#}"), - )); - } - }; - let response = state - .http_client - .post(lookup_url) - .bearer_auth(&auth.token) - .header(HEADER_BT_ORG_NAME, auth.org_name.clone()) - .json(&json!({ "id": dataset_id })) - .send() - .await - .map_err(|err| { - json_error_response( - actix_web::http::StatusCode::BAD_REQUEST, - &format!("Failed to load dataset '{dataset_id}': {err}"), - ) - })?; - if !response.status().is_success() { - return Err(json_error_response( - actix_web::http::StatusCode::BAD_REQUEST, - &format!( - "Failed to load dataset '{dataset_id}' (status {}).", - response.status() - ), - )); - } - - let datasets = response - .json::>() - .await - .map_err(|err| { - json_error_response( - actix_web::http::StatusCode::BAD_REQUEST, - &format!("Failed to parse dataset response for '{dataset_id}': {err}"), - ) - })?; - let Some(dataset) = datasets.first() else { - return Err(json_error_response( - actix_web::http::StatusCode::BAD_REQUEST, - &format!("Dataset '{dataset_id}' not found."), - )); - }; - - let resolved = ResolvedDatasetEvalData { - project_id: dataset.project_id.clone(), - dataset_name: dataset.name.clone(), - _internal_btql: input._internal_btql, - }; - eval_request.data = serde_json::to_value(resolved).map_err(|err| { - json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - &format!("Failed to serialize resolved dataset reference: {err}"), - ) - })?; - Ok(()) -} - -fn make_dev_mode_env( - auth: &DevAuthContext, - state: &DevServerState, - request: Option<&EvalRequest>, - dev_mode: &str, -) -> Result> { - let mut env = vec![ - ("BRAINTRUST_API_KEY".to_string(), auth.token.clone()), - ("BRAINTRUST_ORG_NAME".to_string(), auth.org_name.clone()), - ("BRAINTRUST_APP_URL".to_string(), state.app_url.clone()), - ("BT_EVAL_DEV_MODE".to_string(), dev_mode.to_string()), - ]; - if let Some(request) = request { - let serialized = - serde_json::to_string(request).context("failed to serialize eval request payload")?; - env.push(("BT_EVAL_DEV_REQUEST_JSON".to_string(), serialized)); - } - Ok(env) -} - -fn serialize_sse_event(event: &str, data: &str) -> String { - format!("event: {event}\ndata: {data}\n\n") -} - -fn is_eval_progress_payload(progress: &SseProgressEventData) -> bool { - serde_json::from_str::(&progress.data) - .map(|payload| payload.kind_type == "eval_progress") - .unwrap_or(false) -} - -fn encode_eval_event_for_http(event: &EvalEvent) -> Option { - match event { - EvalEvent::Processing(payload) => serde_json::to_string(payload) - .ok() - .map(|data| serialize_sse_event("processing", &data)), - EvalEvent::Start(start) => serde_json::to_string(start) - .ok() - .map(|data| serialize_sse_event("start", &data)), - EvalEvent::Summary(summary) => serde_json::to_string(summary) - .ok() - .map(|data| serialize_sse_event("summary", &data)), - EvalEvent::Progress(progress) => { - if is_eval_progress_payload(progress) { - None - } else { - serde_json::to_string(progress) - .ok() - .map(|data| serialize_sse_event("progress", &data)) - } - } - EvalEvent::Dependencies { .. } => None, - EvalEvent::Done => Some(serialize_sse_event("done", "")), - EvalEvent::Error { - message, - stack, - status, - } => serde_json::to_string(&json!({ - "message": message, - "stack": stack, - "status": status, - })) - .ok() - .map(|data| serialize_sse_event("error", &data)), - EvalEvent::Console { .. } => None, - } -} - -async fn dev_server_index() -> HttpResponse { - HttpResponse::Ok().body("Hello, world!") -} - -async fn dev_server_options() -> HttpResponse { - HttpResponse::Ok().finish() -} - -fn is_allowed_preview_origin(origin: &str) -> bool { - origin.starts_with("https://") && origin.ends_with(".preview.braintrust.dev") -} - -fn is_allowed_origin(origin: &str, allowed_origins: &[String]) -> bool { - if origin == MAIN_ORIGIN || origin == BRAINTRUSTDATA_ORIGIN || is_allowed_preview_origin(origin) - { - return true; - } - allowed_origins.iter().any(|value| value == origin) -} - -fn apply_cors_headers( - headers: &mut actix_web::http::header::HeaderMap, - request_origin: Option<&str>, - allow_private_network: bool, - allowed_origins: &[String], -) { - if let Some(origin) = request_origin { - if is_allowed_origin(origin, allowed_origins) { - if let Ok(origin_value) = HeaderValue::from_str(origin) { - headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin_value); - headers.insert( - ACCESS_CONTROL_ALLOW_METHODS, - HeaderValue::from_static(CORS_METHODS), - ); - headers.insert( - ACCESS_CONTROL_ALLOW_HEADERS, - HeaderValue::from_static(CORS_ALLOWED_HEADERS), - ); - headers.insert( - ACCESS_CONTROL_EXPOSE_HEADERS, - HeaderValue::from_static(CORS_EXPOSED_HEADERS), - ); - headers.insert( - ACCESS_CONTROL_ALLOW_CREDENTIALS, - HeaderValue::from_static("true"), - ); - headers.insert(ACCESS_CONTROL_MAX_AGE, HeaderValue::from_static("86400")); - headers.insert(VARY, HeaderValue::from_static("Origin")); - } - } - } - - if allow_private_network { - headers.insert( - HeaderName::from_static(HEADER_CORS_ALLOW_PRIVATE_NETWORK), - HeaderValue::from_static("true"), - ); - } -} - -async fn dev_server_list(state: web::Data, req: HttpRequest) -> HttpResponse { - let auth = match authenticate_dev_request(&req, &state).await { - Ok(auth) => auth, - Err(response) => return response, - }; - let extra_env = match make_dev_mode_env(&auth, &state, None, "list") { - Ok(extra_env) => extra_env, - Err(err) => { - return json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - &format!("{err:#}"), - ); - } - }; - - let language = match detect_eval_language(&state.files, state.language_override) { - Ok(language) => language, - Err(err) => { - return json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - &format!("{err:#}"), - ); - } - }; - let spawned = match spawn_eval_runner( - &state.base, - language, - state.runner_override.as_deref(), - &state.files, - state.no_send_logs, - &state.options, - &extra_env, - JsMode::Auto, - ) - .await - { - Ok(value) => value, - Err(err) => { - return json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - &format!("{err:#}"), - ); - } - }; - - let mut stdout_lines = Vec::new(); - let mut errors: Vec<(String, Option)> = Vec::new(); - let output = - match drive_eval_runner( - spawned.process, - ConsolePolicy::Forward, - |event| match event { - EvalEvent::Console { stream, message } if stream == "stdout" => { - stdout_lines.push(message); - } - EvalEvent::Error { - message, - stack: _, - status, - } => errors.push((message, status)), - _ => {} - }, - ) - .await - { - Ok(output) => output, - Err(err) => { - return json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - &format!("{err:#}"), - ); - } - }; - - if let Some((message, status)) = errors.first() { - let status = status - .and_then(|status| actix_web::http::StatusCode::from_u16(status).ok()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - return json_error_response(status, message); - } - if !output.status.success() { - return json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - "Eval runner exited with an error.", - ); - } - - let mut parsed_manifest: Option = None; - for line in stdout_lines.iter().rev() { - if let Ok(value) = serde_json::from_str::(line) { - parsed_manifest = Some(value); - break; - } - } - if parsed_manifest.is_none() { - let joined = stdout_lines.join("\n"); - if let Ok(value) = serde_json::from_str::(&joined) { - parsed_manifest = Some(value); - } - } - - match parsed_manifest { - Some(manifest) => HttpResponse::Ok().json(manifest), - None => json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - "Failed to parse evaluator manifest from runner output.", - ), - } -} - -async fn dev_server_eval( - state: web::Data, - req: HttpRequest, - body: web::Bytes, -) -> HttpResponse { - let auth = match authenticate_dev_request(&req, &state).await { - Ok(auth) => auth, - Err(response) => return response, - }; - - let mut eval_request: EvalRequest = match serde_json::from_slice(&body) { - Ok(eval_request) => eval_request, - Err(err) => { - return json_error_response(actix_web::http::StatusCode::BAD_REQUEST, &err.to_string()); - } - }; - if let Err(response) = - resolve_dataset_ref_for_eval_request(&state, &auth, &mut eval_request).await - { - return response; - } - let stream_requested = eval_request.stream.unwrap_or(false); - let extra_env = match make_dev_mode_env(&auth, &state, Some(&eval_request), "eval") { - Ok(extra_env) => extra_env, - Err(err) => { - return json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - &format!("{err:#}"), - ); - } - }; - - let language = match detect_eval_language(&state.files, state.language_override) { - Ok(language) => language, - Err(err) => { - return json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - &format!("{err:#}"), - ); - } - }; - let spawned = match spawn_eval_runner( - &state.base, - language, - state.runner_override.as_deref(), - &state.files, - state.no_send_logs, - &state.options, - &extra_env, - JsMode::Auto, - ) - .await - { - Ok(value) => value, - Err(err) => { - return json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - &format!("{err:#}"), - ); - } - }; - - if stream_requested { - let (tx, rx) = mpsc::unbounded_channel::(); - tokio::spawn(async move { - let mut saw_error = false; - let mut stderr_lines: Vec = Vec::new(); - let output = drive_eval_runner(spawned.process, ConsolePolicy::Forward, |event| { - if matches!(event, EvalEvent::Error { .. }) { - saw_error = true; - } - if matches!(event, EvalEvent::Done) { - return; - } - if let EvalEvent::Console { - ref stream, - ref message, - } = event - { - for line in message.lines() { - let _ = tx.send(format!(": [{stream}] {line}\n")); - } - if stream == "stderr" { - stderr_lines.push(message.clone()); - } - return; - } - if let Some(encoded) = encode_eval_event_for_http(&event) { - let _ = tx.send(encoded); - } - }) - .await; - - match output { - Ok(output) => { - if !output.status.success() && !saw_error { - let mut detail = format!("Eval runner exited with {}.", output.status); - for line in stderr_lines.iter() { - detail.push('\n'); - detail.push_str(line); - } - let error = - serialize_sse_event("error", &json!({ "message": detail }).to_string()); - let _ = tx.send(error); - } - } - Err(err) => { - let error = serialize_sse_event( - "error", - &json!({ "message": format!("{err:#}") }).to_string(), - ); - let _ = tx.send(error); - } - } - - let _ = tx.send(serialize_sse_event("done", "")); - }); - - let response_stream = stream::unfold(rx, |mut rx| async { - rx.recv() - .await - .map(|chunk| (Ok::<_, actix_web::Error>(web::Bytes::from(chunk)), rx)) - }); - return HttpResponse::Ok() - .append_header((CONTENT_TYPE, "text/event-stream")) - .append_header((CACHE_CONTROL, "no-cache")) - .append_header((CONNECTION, "keep-alive")) - .streaming(response_stream); - } - - let mut summary: Option = None; - let mut errors: Vec<(String, Option)> = Vec::new(); - let output = - match drive_eval_runner( - spawned.process, - ConsolePolicy::Forward, - |event| match event { - EvalEvent::Summary(current) => summary = Some(current), - EvalEvent::Error { - message, - stack: _, - status, - } => errors.push((message, status)), - _ => {} - }, - ) - .await - { - Ok(output) => output, - Err(err) => { - return json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - &format!("{err:#}"), - ); - } - }; - - if let Some((message, status)) = errors.first() { - let status = status - .and_then(|status| actix_web::http::StatusCode::from_u16(status).ok()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - return json_error_response(status, message); - } - if let Some(summary) = summary { - return HttpResponse::Ok().json(summary); - } - if !output.status.success() { - return json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - "Eval runner exited with an error.", - ); - } - json_error_response( - actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - "Eval runner did not return a summary.", - ) -} - -async fn run_dev_server(state: DevServerState) -> Result<()> { - println!( - "Starting eval dev server on http://{}:{}", - state.host, state.port - ); - let host = state.host.clone(); - let port = state.port; - HttpServer::new(move || { - let allowed_origins = state.allowed_origins.clone(); - App::new() - .wrap_fn({ - let allowed_origins = allowed_origins.clone(); - move |req, srv| { - let allowed_origins = allowed_origins.clone(); - let request_origin = req - .headers() - .get(ORIGIN) - .and_then(|value| value.to_str().ok()) - .map(str::to_owned); - let allow_private_network = - req.headers().contains_key(HEADER_CORS_REQ_PRIVATE_NETWORK); - let fut = srv.call(req); - async move { - let mut res = fut.await?; - apply_cors_headers( - res.headers_mut(), - request_origin.as_deref(), - allow_private_network, - &allowed_origins, - ); - Ok::<_, actix_web::Error>(res) - } - } - }) - .app_data(web::Data::new(state.clone())) - .route("/", web::get().to(dev_server_index)) - .route( - "/", - web::route().guard(guard::Options()).to(dev_server_options), - ) - .route("/list", web::get().to(dev_server_list)) - .route( - "/list", - web::route().guard(guard::Options()).to(dev_server_options), - ) - .route("/eval", web::post().to(dev_server_eval)) - .route( - "/eval", - web::route().guard(guard::Options()).to(dev_server_options), - ) - }) - .bind((host.as_str(), port)) - .with_context(|| format!("failed to bind eval dev server on {host}:{port}"))? - .run() - .await - .context("eval dev server exited unexpectedly") -} - #[derive(Debug, Clone, Eq, PartialEq)] struct WatchEntry { modified: Option, @@ -2466,128 +1635,6 @@ fn materialize_runner_script(cache_dir: &Path, file_name: &str, source: &str) -> Ok(path) } -#[derive(Debug)] -enum EvalEvent { - Processing(ProcessingEventData), - Start(ExperimentStart), - Summary(ExperimentSummary), - Progress(SseProgressEventData), - Dependencies { - files: Vec, - }, - Done, - Error { - message: String, - stack: Option, - status: Option, - }, - Console { - stream: String, - message: String, - }, -} - -#[derive(Debug, Deserialize, Serialize)] -struct ProcessingEventData { - #[serde(default)] - evaluators: usize, -} - -#[derive(Debug, Deserialize, Serialize, Default)] -#[serde(rename_all = "camelCase")] -struct ExperimentStart { - #[serde(default, alias = "project_name")] - project_name: Option, - #[serde(default, alias = "experiment_name")] - experiment_name: Option, - #[serde(default, alias = "project_id")] - project_id: Option, - #[serde(default, alias = "experiment_id")] - experiment_id: Option, - #[serde(default, alias = "project_url")] - project_url: Option, - #[serde(default, alias = "experiment_url")] - experiment_url: Option, -} - -#[allow(dead_code)] -#[derive(Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -struct ExperimentSummary { - project_name: String, - experiment_name: String, - project_id: Option, - experiment_id: Option, - project_url: Option, - experiment_url: Option, - comparison_experiment_name: Option, - scores: HashMap, - metrics: Option>, -} - -#[derive(Debug, Deserialize, Serialize)] -struct ScoreSummary { - name: String, - score: f64, - diff: Option, - #[serde(default)] - improvements: i64, - #[serde(default)] - regressions: i64, -} - -#[derive(Debug, Deserialize)] -struct EvalErrorPayload { - message: String, - stack: Option, - status: Option, -} - -#[derive(Debug, Deserialize, Serialize)] -struct MetricSummary { - name: String, - metric: f64, - #[serde(default)] - unit: String, - diff: Option, - #[serde(default)] - improvements: i64, - #[serde(default)] - regressions: i64, -} - -#[allow(dead_code)] -#[derive(Debug, Deserialize, Serialize)] -struct SseProgressEventData { - id: String, - object_type: String, - origin: Option, - format: String, - output_type: String, - name: String, - event: String, - data: String, -} - -#[derive(Debug, Deserialize)] -struct EvalProgressData { - #[serde(rename = "type")] - kind_type: String, - kind: String, - total: Option, -} - -#[derive(Debug, Deserialize)] -struct SseConsoleEventData { - stream: String, - message: String, -} - -#[derive(Debug, Deserialize)] -struct SseDependenciesEventData { - files: Vec, -} - async fn forward_stream( stream: T, name: &'static str, @@ -2699,687 +1746,6 @@ fn handle_sse_event(event: Option, data: String, tx: &mpsc::UnboundedSen } } -struct EvalUi { - progress: MultiProgress, - bars: HashMap, - bar_style: ProgressStyle, - spinner_style: ProgressStyle, - jsonl: bool, - list: bool, - verbose: bool, - deferred_errors: Vec, - suppressed_stderr_lines: usize, - finished: bool, -} - -impl EvalUi { - fn new(jsonl: bool, list: bool, verbose: bool) -> Self { - let draw_target = if std::io::stderr().is_terminal() && animations_enabled() && !is_quiet() - { - ProgressDrawTarget::stderr_with_hz(10) - } else { - ProgressDrawTarget::stderr() - }; - let progress = MultiProgress::with_draw_target(draw_target); - let bar_style = - ProgressStyle::with_template("{bar:10.blue} {msg} {percent}% {pos}/{len} {eta}") - .unwrap(); - let spinner_style = ProgressStyle::with_template("{spinner} {msg}").unwrap(); - Self { - progress, - bars: HashMap::new(), - bar_style, - spinner_style, - jsonl, - list, - verbose, - deferred_errors: Vec::new(), - suppressed_stderr_lines: 0, - finished: false, - } - } - - fn finish(&mut self) { - if self.finished { - return; - } - for (_, bar) in self.bars.drain() { - bar.finish_and_clear(); - } - let _ = self.progress.clear(); - self.progress.set_draw_target(ProgressDrawTarget::hidden()); - self.print_deferred_error_footnote(); - self.finished = true; - } - - fn handle(&mut self, event: EvalEvent) { - match event { - EvalEvent::Processing(payload) => { - self.print_persistent_line(format_processing_line(payload.evaluators)); - } - EvalEvent::Start(start) => { - if let Some(line) = format_start_line(&start) { - self.print_persistent_line(line); - } - } - EvalEvent::Summary(summary) => { - if self.jsonl { - if let Ok(line) = serde_json::to_string(&summary) { - println!("{line}"); - } - } else { - let rendered = format_experiment_summary(&summary); - self.print_persistent_multiline(rendered); - } - } - EvalEvent::Progress(progress) => { - self.handle_progress(progress); - } - EvalEvent::Dependencies { .. } => {} - EvalEvent::Console { stream, message } => { - if stream == "stdout" && (self.list || self.jsonl) { - println!("{message}"); - } else if stream == "stderr" && !self.verbose { - self.suppressed_stderr_lines += 1; - } else { - let _ = self.progress.println(message); - } - } - EvalEvent::Error { message, stack, .. } => { - let show_hint = message.contains("Please specify an api key"); - if self.verbose { - let line = message.as_str().red().to_string(); - let _ = self.progress.println(line); - if let Some(stack) = stack { - for line in stack.lines() { - let _ = self.progress.println(line.dark_grey().to_string()); - } - } - } else { - self.record_deferred_error(message); - } - if show_hint { - let hint = "Hint: pass --api-key, set BRAINTRUST_API_KEY, run `bt auth login`/`bt auth login --oauth`, or use --no-send-logs for local evals."; - if self.verbose { - let _ = self.progress.println(hint.dark_grey().to_string()); - } else { - self.record_deferred_error(hint.to_string()); - } - } - } - EvalEvent::Done => { - self.finish(); - } - } - } - - fn handle_progress(&mut self, progress: SseProgressEventData) { - let payload = match serde_json::from_str::(&progress.data) { - Ok(payload) if payload.kind_type == "eval_progress" => payload, - _ => return, - }; - - match payload.kind.as_str() { - "start" => { - let bar = if let Some(total) = payload.total { - if total > 0 { - let bar = self.progress.add(ProgressBar::new(total)); - bar.set_style(self.bar_style.clone()); - bar - } else { - let bar = self.progress.add(ProgressBar::new_spinner()); - bar.set_style(self.spinner_style.clone()); - bar - } - } else { - let bar = self.progress.add(ProgressBar::new_spinner()); - bar.set_style(self.spinner_style.clone()); - bar - }; - bar.set_message(fit_name_to_spaces(&progress.name, MAX_NAME_LENGTH)); - self.bars.insert(progress.name.clone(), bar); - } - "increment" => { - if let Some(bar) = self.bars.get(&progress.name) { - bar.inc(1); - bar.set_message(fit_name_to_spaces(&progress.name, MAX_NAME_LENGTH)); - } - } - "set_total" => { - if let Some(bar) = self.bars.get(&progress.name) { - if let Some(total) = payload.total { - bar.set_length(total); - bar.set_style(self.bar_style.clone()); - } - } - } - "stop" => { - if let Some(bar) = self.bars.remove(&progress.name) { - bar.finish_and_clear(); - } - } - _ => {} - } - } - - fn print_persistent_line(&self, line: String) { - self.progress.suspend(|| { - eprintln!("{line}"); - }); - } - - fn print_persistent_multiline(&self, text: String) { - self.progress.suspend(|| { - for line in text.lines() { - eprintln!("{line}"); - } - }); - } - - fn record_deferred_error(&mut self, message: String) { - let trimmed = message.trim(); - if trimmed.is_empty() { - return; - } - if self - .deferred_errors - .iter() - .any(|existing| existing == trimmed) - { - return; - } - if self.deferred_errors.len() < MAX_DEFERRED_EVAL_ERRORS { - self.deferred_errors.push(trimmed.to_string()); - } - } - - fn print_deferred_error_footnote(&self) { - if self.verbose { - return; - } - if self.deferred_errors.is_empty() && self.suppressed_stderr_lines == 0 { - return; - } - - eprintln!(); - if !self.deferred_errors.is_empty() { - let noun = if self.deferred_errors.len() == 1 { - "error" - } else { - "errors" - }; - eprintln!( - "Encountered {} evaluator {noun}:", - self.deferred_errors.len() - ); - for message in &self.deferred_errors { - eprintln!(" - {message}"); - } - } - if self.suppressed_stderr_lines > 0 { - eprintln!( - "Suppressed {} stderr line(s). Re-run with `bt eval --verbose ...` to inspect details.", - self.suppressed_stderr_lines - ); - } - } -} - -impl Drop for EvalUi { - fn drop(&mut self) { - self.finish(); - } -} - -fn fit_name_to_spaces(name: &str, length: usize) -> String { - let char_count = name.chars().count(); - if char_count < length { - let mut padded = name.to_string(); - padded.push_str(&" ".repeat(length - char_count)); - return padded; - } - if char_count == length { - return name.to_string(); - } - if length <= 3 { - return name.chars().take(length).collect(); - } - if length <= 5 { - let truncated: String = name.chars().take(length - 3).collect(); - return format!("{truncated}..."); - } - - // Keep both prefix and suffix so similarly named evaluators remain distinguishable. - let keep_total = length - 3; - let head_len = keep_total / 2; - let tail_len = keep_total - head_len; - let head: String = name.chars().take(head_len).collect(); - let tail: String = name - .chars() - .rev() - .take(tail_len) - .collect::() - .chars() - .rev() - .collect(); - format!("{head}...{tail}") -} - -fn format_processing_line(evaluators: usize) -> String { - let noun = if evaluators == 1 { - "evaluator" - } else { - "evaluators" - }; - format!("Processing {evaluators} {noun}...") -} - -fn format_start_line(start: &ExperimentStart) -> Option { - let experiment_name = start - .experiment_name - .as_deref() - .map(str::trim) - .filter(|value| !value.is_empty()); - let experiment_url = start - .experiment_url - .as_deref() - .map(str::trim) - .filter(|value| !value.is_empty()); - let arrow = "▶".cyan(); - - match (experiment_name, experiment_url) { - (Some(name), Some(url)) => Some(format!( - "{arrow} Experiment {} is running at {url}", - name.bold() - )), - (Some(name), None) => Some(format!( - "{arrow} Experiment {} is running at locally", - name.bold() - )), - (None, Some(url)) => Some(format!("{arrow} Experiment is running at {url}")), - (None, None) => None, - } -} - -fn format_experiment_summary(summary: &ExperimentSummary) -> String { - let mut parts: Vec = Vec::new(); - - if let Some(comparison) = summary.comparison_experiment_name.as_deref() { - let line = format!( - "{baseline} {baseline_tag} ← {comparison_name} {comparison_tag}", - baseline = comparison, - baseline_tag = "(baseline)".dark_grey(), - comparison_name = summary.experiment_name, - comparison_tag = "(comparison)".dark_grey(), - ); - parts.push(line); - } - - let has_scores = !summary.scores.is_empty(); - let has_metrics = summary - .metrics - .as_ref() - .map(|metrics| !metrics.is_empty()) - .unwrap_or(false); - - if has_scores || has_metrics { - let has_comparison = summary.comparison_experiment_name.is_some(); - let mut rows: Vec> = Vec::new(); - - let header = if has_comparison { - Some(vec![ - header_line("Name"), - header_line("Value"), - header_line("Change"), - header_line("Improvements"), - header_line("Regressions"), - ]) - } else { - None - }; - - let mut score_values: Vec<_> = summary.scores.values().collect(); - score_values.sort_by(|a, b| a.name.cmp(&b.name)); - for score in score_values { - let score_percent = - Line::from(format!("{:.2}%", score.score * 100.0)).alignment(Alignment::Right); - let diff = format_diff_line(score.diff); - let improvements = format_improvements_line(score.improvements); - let regressions = format_regressions_line(score.regressions); - let name = truncate_plain(&score.name, MAX_NAME_LENGTH); - let name = Line::from(vec![ - Span::styled("◯", Style::default().fg(Color::Blue)), - Span::raw(" "), - Span::raw(name), - ]); - if has_comparison { - rows.push(vec![name, score_percent, diff, improvements, regressions]); - } else { - rows.push(vec![name, score_percent]); - } - } - - if let Some(metrics) = &summary.metrics { - let mut metric_values: Vec<_> = metrics.values().collect(); - metric_values.sort_by(|a, b| a.name.cmp(&b.name)); - for metric in metric_values { - let formatted_value = Line::from(format_metric_value(metric.metric, &metric.unit)) - .alignment(Alignment::Right); - let diff = format_diff_line(metric.diff); - let improvements = format_improvements_line(metric.improvements); - let regressions = format_regressions_line(metric.regressions); - let name = truncate_plain(&metric.name, MAX_NAME_LENGTH); - let name = Line::from(vec![ - Span::styled("◯", Style::default().fg(Color::Magenta)), - Span::raw(" "), - Span::raw(name), - ]); - if has_comparison { - rows.push(vec![name, formatted_value, diff, improvements, regressions]); - } else { - rows.push(vec![name, formatted_value]); - } - } - } - - parts.push(render_table_ratatui(header, rows, has_comparison)); - } - - if let Some(url) = &summary.experiment_url { - parts.push(format!("See results at {url}")); - } - - let content = parts.join("\n\n"); - box_with_title("Experiment summary", &content) -} - -fn format_diff_line(diff: Option) -> Line<'static> { - match diff { - Some(value) => { - let sign = if value > 0.0 { "+" } else { "" }; - let percent = format!("{sign}{:.2}%", value * 100.0); - let style = if value > 0.0 { - Style::default().fg(Color::Green) - } else { - Style::default().fg(Color::Red) - }; - Line::from(Span::styled(percent, style)).alignment(Alignment::Right) - } - None => Line::from(Span::styled("-", Style::default().fg(Color::DarkGray))) - .alignment(Alignment::Right), - } -} - -fn format_improvements_line(value: i64) -> Line<'static> { - if value > 0 { - Line::from(Span::styled( - value.to_string(), - Style::default() - .fg(Color::Green) - .add_modifier(Modifier::DIM), - )) - .alignment(Alignment::Right) - } else { - Line::from(Span::styled("-", Style::default().fg(Color::DarkGray))) - .alignment(Alignment::Right) - } -} - -fn format_regressions_line(value: i64) -> Line<'static> { - if value > 0 { - Line::from(Span::styled( - value.to_string(), - Style::default().fg(Color::Red).add_modifier(Modifier::DIM), - )) - .alignment(Alignment::Right) - } else { - Line::from(Span::styled("-", Style::default().fg(Color::DarkGray))) - .alignment(Alignment::Right) - } -} - -fn format_metric_value(metric: f64, unit: &str) -> String { - let formatted = if metric.fract() == 0.0 { - format!("{metric:.0}") - } else { - format!("{metric:.2}") - }; - if unit == "$" { - format!("{unit}{formatted}") - } else { - format!("{formatted}{unit}") - } -} - -fn render_table_ratatui( - header: Option>>, - rows: Vec>>, - has_comparison: bool, -) -> String { - if rows.is_empty() { - return String::new(); - } - - let columns = if has_comparison { 5 } else { 2 }; - let mut widths = vec![0usize; columns]; - - if let Some(header_row) = &header { - for (idx, line) in header_row.iter().enumerate().take(columns) { - widths[idx] = widths[idx].max(line.width()); - } - } - - for row in &rows { - for (idx, line) in row.iter().enumerate().take(columns) { - widths[idx] = widths[idx].max(line.width()); - } - } - - let column_spacing = 2; - let total_width = widths.iter().sum::() + column_spacing * (columns - 1); - let mut height = rows.len(); - if header.is_some() { - height += 1; - } - let backend = TestBackend::new(total_width as u16, height as u16); - let mut terminal = Terminal::new(backend).expect("failed to create table backend"); - - let table_rows = rows.into_iter().map(|row| { - let cells = row.into_iter().map(Cell::new).collect::>(); - Row::new(cells) - }); - - let mut table = Table::new( - table_rows, - widths.iter().map(|w| Constraint::Length(*w as u16)), - ) - .column_spacing(column_spacing as u16); - - if let Some(header_row) = header { - let header_cells = header_row.into_iter().map(Cell::new).collect::>(); - table = table.header(Row::new(header_cells)); - } - - terminal - .draw(|frame| { - let area = frame.area(); - frame.render_widget(table, area); - }) - .expect("failed to render table"); - - let buffer = terminal.backend().buffer(); - buffer_to_ansi_lines(buffer).join("\n") -} - -fn header_line(text: &str) -> Line<'static> { - Line::from(Span::styled( - text.to_string(), - Style::default() - .fg(Color::DarkGray) - .add_modifier(Modifier::BOLD), - )) -} - -fn truncate_plain(text: &str, max_len: usize) -> String { - if text.chars().count() <= max_len { - return text.to_string(); - } - if max_len <= 3 { - return text.chars().take(max_len).collect(); - } - let truncated: String = text.chars().take(max_len - 3).collect(); - format!("{truncated}...") -} - -fn box_with_title(title: &str, content: &str) -> String { - let lines: Vec<&str> = content.lines().collect(); - let content_width = lines - .iter() - .map(|line| visible_width(line)) - .max() - .unwrap_or(0); - let padding = 1; - let inner_width = content_width + padding * 2; - - let title_plain = format!(" {title} "); - let title_width = visible_width(&title_plain); - let mut top = String::from("╭"); - top.push_str(&title_plain.dark_grey().to_string()); - if inner_width > title_width { - top.push_str(&"─".repeat(inner_width - title_width)); - } - top.push('╮'); - - let mut boxed = vec![top]; - for line in lines { - let line_width = visible_width(line); - // Defensive: if width accounting ever drifts (e.g. escape-sequence parsing), - // avoid underflow and render without extra trailing padding. - let right_padding = inner_width.saturating_sub(line_width + padding); - let mut row = String::from("│"); - row.push_str(&" ".repeat(padding)); - row.push_str(line); - row.push_str(&" ".repeat(right_padding)); - row.push('│'); - boxed.push(row); - } - - let bottom = format!("╰{}╯", "─".repeat(inner_width)); - boxed.push(bottom); - - format!("\n{}", boxed.join("\n")) -} - -fn visible_width(text: &str) -> usize { - let stripped = strip(text.as_bytes()); - let stripped = String::from_utf8_lossy(&stripped); - UnicodeWidthStr::width(stripped.as_ref()) -} - -fn buffer_to_ansi_lines(buffer: &ratatui::buffer::Buffer) -> Vec { - let width = buffer.area.width as usize; - let height = buffer.area.height as usize; - let mut lines = Vec::with_capacity(height); - let mut current_style = Style::reset(); - - for y in 0..height { - let mut line = String::new(); - let mut skip = 0usize; - for x in 0..width { - let cell = &buffer[(x as u16, y as u16)]; - let symbol = cell.symbol(); - let symbol_width = UnicodeWidthStr::width(symbol); - if skip > 0 { - skip -= 1; - continue; - } - - let style = Style { - fg: Some(cell.fg), - bg: Some(cell.bg), - add_modifier: cell.modifier, - ..Style::default() - }; - - if style != current_style { - line.push_str(&style_to_ansi(style)); - current_style = style; - } - - line.push_str(symbol); - skip = symbol_width.saturating_sub(1); - } - line.push_str(&style_to_ansi(Style::reset())); - lines.push(line.trim_end().to_string()); - } - - lines -} - -fn style_to_ansi(style: Style) -> String { - let mut buf = Vec::new(); - let _ = queue!(buf, SetAttribute(Attribute::Reset), ResetColor); - - if let Some(fg) = style.fg { - let _ = queue!(buf, SetForegroundColor(convert_color(fg))); - } - if let Some(bg) = style.bg { - let _ = queue!(buf, SetBackgroundColor(convert_color(bg))); - } - - let mods = style.add_modifier; - if mods.contains(Modifier::BOLD) { - let _ = queue!(buf, SetAttribute(Attribute::Bold)); - } - if mods.contains(Modifier::DIM) { - let _ = queue!(buf, SetAttribute(Attribute::Dim)); - } - if mods.contains(Modifier::ITALIC) { - let _ = queue!(buf, SetAttribute(Attribute::Italic)); - } - if mods.contains(Modifier::UNDERLINED) { - let _ = queue!(buf, SetAttribute(Attribute::Underlined)); - } - if mods.contains(Modifier::REVERSED) { - let _ = queue!(buf, SetAttribute(Attribute::Reverse)); - } - if mods.contains(Modifier::CROSSED_OUT) { - let _ = queue!(buf, SetAttribute(Attribute::CrossedOut)); - } - if mods.contains(Modifier::SLOW_BLINK) { - let _ = queue!(buf, SetAttribute(Attribute::SlowBlink)); - } - if mods.contains(Modifier::RAPID_BLINK) { - let _ = queue!(buf, SetAttribute(Attribute::RapidBlink)); - } - - String::from_utf8_lossy(&buf).to_string() -} - -fn convert_color(color: Color) -> CtColor { - match color { - Color::Reset => CtColor::Reset, - Color::Black => CtColor::Black, - Color::Red => CtColor::Red, - Color::Green => CtColor::Green, - Color::Yellow => CtColor::Yellow, - Color::Blue => CtColor::Blue, - Color::Magenta => CtColor::Magenta, - Color::Cyan => CtColor::Cyan, - Color::Gray => CtColor::Grey, - Color::DarkGray => CtColor::DarkGrey, - Color::LightRed => CtColor::Red, - Color::LightGreen => CtColor::Green, - Color::LightYellow => CtColor::Yellow, - Color::LightBlue => CtColor::Blue, - Color::LightMagenta => CtColor::Magenta, - Color::LightCyan => CtColor::Cyan, - Color::White => CtColor::White, - Color::Indexed(value) => CtColor::AnsiValue(value), - Color::Rgb(r, g, b) => CtColor::Rgb { r, g, b }, - } -} - #[cfg(test)] mod tests { use super::*; @@ -3450,38 +1816,6 @@ mod tests { path } - #[test] - fn join_app_url_normalizes_slashes() { - let joined = - join_app_url("https://www.braintrust.dev/", "/api/dataset/get").expect("join app url"); - assert_eq!(joined, "https://www.braintrust.dev/api/dataset/get"); - } - - #[test] - fn collect_allowed_dev_origins_includes_app_origin_and_dedupes() { - let origins = collect_allowed_dev_origins( - &[ - "https://example.com".to_string(), - "https://example.com".to_string(), - ], - "https://app.example.dev/some/path", - ); - assert_eq!( - origins, - vec![ - "https://app.example.dev".to_string(), - "https://example.com".to_string() - ] - ); - } - - #[test] - fn is_allowed_origin_accepts_configured_origin() { - let allowed = vec!["https://example.com".to_string()]; - assert!(is_allowed_origin("https://example.com", &allowed)); - assert!(!is_allowed_origin("https://evil.example", &allowed)); - } - #[test] fn materialize_runner_script_writes_file() { let dir = make_temp_dir("write"); @@ -3838,15 +2172,6 @@ mod tests { let _ = fs::remove_dir_all(&dir); } - #[test] - fn box_with_title_handles_ansi_content_without_panicking() { - let content = "plain line\n\x1b[38;5;196mred text\x1b[0m"; - let boxed = box_with_title("Summary", content); - assert!(boxed.contains("Summary")); - assert!(boxed.contains("plain line")); - assert!(boxed.contains("red text")); - } - #[test] fn detect_watch_changes_detects_file_create() { let dir = make_temp_dir("create"); @@ -4020,75 +2345,6 @@ mod tests { assert_ne!(first, second); } - #[test] - fn encode_eval_event_for_http_filters_internal_eval_progress() { - let event = EvalEvent::Progress(SseProgressEventData { - id: "id-1".to_string(), - object_type: "task".to_string(), - origin: None, - format: "global".to_string(), - output_type: "any".to_string(), - name: "My evaluation".to_string(), - event: "progress".to_string(), - data: r#"{"type":"eval_progress","kind":"start","total":1}"#.to_string(), - }); - - assert!(encode_eval_event_for_http(&event).is_none()); - } - - #[test] - fn encode_eval_event_for_http_keeps_external_progress_events() { - let event = EvalEvent::Progress(SseProgressEventData { - id: "id-2".to_string(), - object_type: "task".to_string(), - origin: None, - format: "code".to_string(), - output_type: "completion".to_string(), - name: "My evaluation".to_string(), - event: "json_delta".to_string(), - data: "\"China\"".to_string(), - }); - - let encoded = encode_eval_event_for_http(&event).expect("progress should be forwarded"); - assert!(encoded.contains("event: progress")); - assert!(encoded.contains("json_delta")); - } - - #[test] - fn format_processing_line_handles_pluralization() { - assert_eq!(format_processing_line(1), "Processing 1 evaluator..."); - assert_eq!(format_processing_line(2), "Processing 2 evaluators..."); - } - - #[test] - fn format_start_line_handles_partial_payload() { - let start = ExperimentStart { - experiment_name: Some("my-exp".to_string()), - experiment_url: Some("https://example.dev/exp".to_string()), - ..Default::default() - }; - let line = format_start_line(&start).expect("line should be rendered"); - assert!(line.contains("my-exp")); - assert!(line.contains("https://example.dev/exp")); - - assert!(format_start_line(&ExperimentStart::default()).is_none()); - } - - #[test] - fn fit_name_to_spaces_preserves_suffix_when_truncating() { - let rendered = - fit_name_to_spaces("Topics [experimentName=facets-real-world-30b-f5a78312]", 40); - assert_eq!(rendered.chars().count(), 40); - assert!(rendered.contains("...")); - assert!(rendered.contains("f5a78312]")); - } - - #[test] - fn fit_name_to_spaces_pads_short_names() { - let rendered = fit_name_to_spaces("short", 10); - assert_eq!(rendered, "short "); - } - #[test] fn handle_sse_event_parses_processing_and_start_payloads() { let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); diff --git a/src/eval/dev_server.rs b/src/eval/dev_server.rs new file mode 100644 index 0000000..c1e6293 --- /dev/null +++ b/src/eval/dev_server.rs @@ -0,0 +1,911 @@ +use actix_web::dev::Service; +use actix_web::http::header::{ + HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, + ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, + ACCESS_CONTROL_MAX_AGE, AUTHORIZATION, CACHE_CONTROL, CONNECTION, CONTENT_TYPE, ORIGIN, VARY, +}; +use actix_web::{guard, web, App, HttpRequest, HttpResponse, HttpServer}; +use anyhow::{Context, Result}; +use futures_util::stream; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use tokio::sync::mpsc; + +use crate::args::BaseArgs; + +use super::events::{EvalEvent, EvalProgressData, ExperimentSummary, SseProgressEventData}; +use super::{ + detect_eval_language, drive_eval_runner, spawn_eval_runner, ConsolePolicy, EvalLanguage, + EvalRunOptions, JsMode, +}; + +const MAIN_ORIGIN: &str = "https://www.braintrust.dev"; +const BRAINTRUSTDATA_ORIGIN: &str = "https://www.braintrustdata.com"; +const CORS_METHODS: &str = "GET, PATCH, POST, PUT, DELETE, OPTIONS"; +const CORS_ALLOWED_HEADERS: &str = "Content-Type, X-Amz-Date, Authorization, X-Api-Key, X-Amz-Security-Token, x-bt-auth-token, x-bt-parent, x-bt-org-name, x-bt-project-id, x-bt-stream-fmt, x-bt-use-cache, x-bt-use-gateway, x-stainless-os, x-stainless-lang, x-stainless-package-version, x-stainless-runtime, x-stainless-runtime-version, x-stainless-arch"; +const CORS_EXPOSED_HEADERS: &str = + "x-bt-cursor, x-bt-found-existing-experiment, x-bt-span-id, x-bt-span-export"; +const HEADER_BT_AUTH_TOKEN: &str = "x-bt-auth-token"; +const HEADER_BT_ORG_NAME: &str = "x-bt-org-name"; +const HEADER_CORS_REQ_PRIVATE_NETWORK: &str = "access-control-request-private-network"; +const HEADER_CORS_ALLOW_PRIVATE_NETWORK: &str = "access-control-allow-private-network"; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct EvalRequest { + name: String, + #[serde(default)] + parameters: Option, + data: Value, + #[serde(default)] + scores: Option>, + #[serde(default)] + experiment_name: Option, + #[serde(default)] + project_id: Option, + #[serde(default)] + parent: Option, + #[serde(default)] + stream: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct EvalScore { + name: String, + function_id: Value, +} + +#[derive(Debug, Deserialize)] +struct DatasetLookupRow { + project_id: String, + name: String, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +enum DatasetIdField { + String(String), + Other(Value), +} + +#[derive(Debug, Clone, Deserialize)] +struct DatasetEvalDataInput { + #[serde(default)] + dataset_id: Option, + #[serde(default)] + _internal_btql: Option, +} + +#[derive(Debug, Clone, Serialize)] +struct ResolvedDatasetEvalData { + project_id: String, + dataset_name: String, + #[serde(skip_serializing_if = "Option::is_none")] + _internal_btql: Option, +} + +#[derive(Clone)] +pub(super) struct DevServerState { + pub(super) base: BaseArgs, + pub(super) language_override: Option, + pub(super) runner_override: Option, + pub(super) files: Vec, + pub(super) no_send_logs: bool, + pub(super) options: EvalRunOptions, + pub(super) host: String, + pub(super) port: u16, + pub(super) allowed_org_name: Option, + pub(super) allowed_origins: Vec, + pub(super) app_url: String, + pub(super) http_client: Client, +} + +#[derive(Debug)] +struct DevAuthContext { + token: String, + org_name: String, +} + +pub(super) fn resolve_app_url(base: &BaseArgs) -> String { + if let Some(app_url) = base.app_url.as_ref() { + return app_url.clone(); + } + "https://www.braintrust.dev".to_string() +} + +fn app_origin_from_url(url: &str) -> Option { + reqwest::Url::parse(url).ok().and_then(|parsed| { + let origin = parsed.origin(); + if origin.is_tuple() { + Some(origin.ascii_serialization()) + } else { + None + } + }) +} + +pub(super) fn collect_allowed_dev_origins(explicit: &[String], app_url: &str) -> Vec { + let mut deduped = std::collections::BTreeSet::new(); + for origin in explicit { + let trimmed = origin.trim(); + if !trimmed.is_empty() { + deduped.insert(trimmed.to_string()); + } + } + if let Some(origin) = app_origin_from_url(app_url) { + deduped.insert(origin); + } + deduped.into_iter().collect() +} + +fn join_app_url(app_url: &str, path: &str) -> Result { + let base = format!("{}/", app_url.trim_end_matches('/')); + let base_url = reqwest::Url::parse(&base).context("invalid app URL")?; + let joined = base_url + .join(path.trim_start_matches('/')) + .context("failed to join app URL path")?; + Ok(joined.to_string()) +} + +fn json_error_response(status: actix_web::http::StatusCode, message: &str) -> HttpResponse { + HttpResponse::build(status).json(json!({ "error": message })) +} + +fn parse_auth_token(req: &HttpRequest) -> Option { + if let Some(token) = req.headers().get(HEADER_BT_AUTH_TOKEN) { + if let Ok(value) = token.to_str() { + if !value.trim().is_empty() { + return Some(value.trim().to_string()); + } + } + } + + let auth = req.headers().get(AUTHORIZATION)?; + let auth = auth.to_str().ok()?.trim(); + if auth.is_empty() { + return None; + } + if let Some(token) = auth.strip_prefix("Bearer ") { + let token = token.trim(); + if token.is_empty() { + None + } else { + Some(token.to_string()) + } + } else { + Some(auth.to_string()) + } +} + +async fn authenticate_dev_request( + req: &HttpRequest, + state: &DevServerState, +) -> std::result::Result { + let token = match parse_auth_token(req) { + Some(token) if !token.eq_ignore_ascii_case("null") => token, + _ => { + return Err(json_error_response( + actix_web::http::StatusCode::UNAUTHORIZED, + "Unauthorized", + )); + } + }; + + let org_name = match req + .headers() + .get(HEADER_BT_ORG_NAME) + .and_then(|value| value.to_str().ok()) + { + Some(value) if !value.trim().is_empty() => value.trim().to_string(), + _ => { + return Err(json_error_response( + actix_web::http::StatusCode::BAD_REQUEST, + &format!("Missing {HEADER_BT_ORG_NAME} header"), + )); + } + }; + + if let Some(allowed_org_name) = state.allowed_org_name.as_ref() { + if allowed_org_name != &org_name { + let message = format!( + "Org '{org_name}' is not allowed. Only org '{allowed_org_name}' is allowed." + ); + return Err(json_error_response( + actix_web::http::StatusCode::FORBIDDEN, + &message, + )); + } + } + + let login_url = match join_app_url(&state.app_url, "api/apikey/login") { + Ok(url) => url, + Err(err) => { + return Err(json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + &format!("{err:#}"), + )); + } + }; + let response = state + .http_client + .post(login_url) + .bearer_auth(&token) + .send() + .await + .map_err(|_| { + json_error_response(actix_web::http::StatusCode::UNAUTHORIZED, "Unauthorized") + })?; + if !response.status().is_success() { + return Err(json_error_response( + actix_web::http::StatusCode::UNAUTHORIZED, + "Unauthorized", + )); + } + + let payload = response.json::().await.unwrap_or(Value::Null); + if let Some(orgs) = payload.get("org_info").and_then(|value| value.as_array()) { + let matched = orgs.iter().any(|org| { + org.get("name") + .and_then(|name| name.as_str()) + .map(|name| name == org_name) + .unwrap_or(false) + }); + if !matched { + return Err(json_error_response( + actix_web::http::StatusCode::UNAUTHORIZED, + "Unauthorized", + )); + } + } else { + return Err(json_error_response( + actix_web::http::StatusCode::UNAUTHORIZED, + "Unauthorized", + )); + } + + Ok(DevAuthContext { token, org_name }) +} + +async fn resolve_dataset_ref_for_eval_request( + state: &DevServerState, + auth: &DevAuthContext, + eval_request: &mut EvalRequest, +) -> std::result::Result<(), HttpResponse> { + let input = match serde_json::from_value::(eval_request.data.clone()) { + Ok(value) => value, + Err(_) => return Ok(()), + }; + + let dataset_id = match input.dataset_id { + Some(DatasetIdField::String(dataset_id)) => dataset_id, + Some(DatasetIdField::Other(value)) => { + let received_type = match value { + Value::Null => "null", + Value::Bool(_) => "boolean", + Value::Number(_) => "number", + Value::String(_) => "string", + Value::Array(_) => "array", + Value::Object(_) => "object", + }; + return Err(json_error_response( + actix_web::http::StatusCode::BAD_REQUEST, + &format!("Invalid dataset_id: expected a string, got {received_type}."), + )); + } + None => { + return Ok(()); + } + }; + if dataset_id.trim().is_empty() { + return Err(json_error_response( + actix_web::http::StatusCode::BAD_REQUEST, + "Invalid dataset_id: expected a non-empty string.", + )); + } + + let lookup_url = match join_app_url(&state.app_url, "api/dataset/get") { + Ok(url) => url, + Err(err) => { + return Err(json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + &format!("{err:#}"), + )); + } + }; + let response = state + .http_client + .post(lookup_url) + .bearer_auth(&auth.token) + .header(HEADER_BT_ORG_NAME, auth.org_name.clone()) + .json(&json!({ "id": dataset_id })) + .send() + .await + .map_err(|err| { + json_error_response( + actix_web::http::StatusCode::BAD_REQUEST, + &format!("Failed to load dataset '{dataset_id}': {err}"), + ) + })?; + if !response.status().is_success() { + return Err(json_error_response( + actix_web::http::StatusCode::BAD_REQUEST, + &format!( + "Failed to load dataset '{dataset_id}' (status {}).", + response.status() + ), + )); + } + + let datasets = response + .json::>() + .await + .map_err(|err| { + json_error_response( + actix_web::http::StatusCode::BAD_REQUEST, + &format!("Failed to parse dataset response for '{dataset_id}': {err}"), + ) + })?; + let Some(dataset) = datasets.first() else { + return Err(json_error_response( + actix_web::http::StatusCode::BAD_REQUEST, + &format!("Dataset '{dataset_id}' not found."), + )); + }; + + let resolved = ResolvedDatasetEvalData { + project_id: dataset.project_id.clone(), + dataset_name: dataset.name.clone(), + _internal_btql: input._internal_btql, + }; + eval_request.data = serde_json::to_value(resolved).map_err(|err| { + json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + &format!("Failed to serialize resolved dataset reference: {err}"), + ) + })?; + Ok(()) +} + +fn make_dev_mode_env( + auth: &DevAuthContext, + state: &DevServerState, + request: Option<&EvalRequest>, + dev_mode: &str, +) -> Result> { + let mut env = vec![ + ("BRAINTRUST_API_KEY".to_string(), auth.token.clone()), + ("BRAINTRUST_ORG_NAME".to_string(), auth.org_name.clone()), + ("BRAINTRUST_APP_URL".to_string(), state.app_url.clone()), + ("BT_EVAL_DEV_MODE".to_string(), dev_mode.to_string()), + ]; + if let Some(request) = request { + let serialized = + serde_json::to_string(request).context("failed to serialize eval request payload")?; + env.push(("BT_EVAL_DEV_REQUEST_JSON".to_string(), serialized)); + } + Ok(env) +} + +fn serialize_sse_event(event: &str, data: &str) -> String { + format!("event: {event}\ndata: {data}\n\n") +} + +fn is_eval_progress_payload(progress: &SseProgressEventData) -> bool { + serde_json::from_str::(&progress.data) + .map(|payload| payload.kind_type == "eval_progress") + .unwrap_or(false) +} + +fn encode_eval_event_for_http(event: &EvalEvent) -> Option { + match event { + EvalEvent::Processing(payload) => serde_json::to_string(payload) + .ok() + .map(|data| serialize_sse_event("processing", &data)), + EvalEvent::Start(start) => serde_json::to_string(start) + .ok() + .map(|data| serialize_sse_event("start", &data)), + EvalEvent::Summary(summary) => serde_json::to_string(summary) + .ok() + .map(|data| serialize_sse_event("summary", &data)), + EvalEvent::Progress(progress) => { + if is_eval_progress_payload(progress) { + None + } else { + serde_json::to_string(progress) + .ok() + .map(|data| serialize_sse_event("progress", &data)) + } + } + EvalEvent::Dependencies { .. } => None, + EvalEvent::Done => Some(serialize_sse_event("done", "")), + EvalEvent::Error { + message, + stack, + status, + } => serde_json::to_string(&json!({ + "message": message, + "stack": stack, + "status": status, + })) + .ok() + .map(|data| serialize_sse_event("error", &data)), + EvalEvent::Console { .. } => None, + } +} + +async fn dev_server_index() -> HttpResponse { + HttpResponse::Ok().body("Hello, world!") +} + +async fn dev_server_options() -> HttpResponse { + HttpResponse::Ok().finish() +} + +fn is_allowed_preview_origin(origin: &str) -> bool { + origin.starts_with("https://") && origin.ends_with(".preview.braintrust.dev") +} + +fn is_allowed_origin(origin: &str, allowed_origins: &[String]) -> bool { + if origin == MAIN_ORIGIN || origin == BRAINTRUSTDATA_ORIGIN || is_allowed_preview_origin(origin) + { + return true; + } + allowed_origins.iter().any(|value| value == origin) +} + +fn apply_cors_headers( + headers: &mut actix_web::http::header::HeaderMap, + request_origin: Option<&str>, + allow_private_network: bool, + allowed_origins: &[String], +) { + if let Some(origin) = request_origin { + if is_allowed_origin(origin, allowed_origins) { + if let Ok(origin_value) = HeaderValue::from_str(origin) { + headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin_value); + headers.insert( + ACCESS_CONTROL_ALLOW_METHODS, + HeaderValue::from_static(CORS_METHODS), + ); + headers.insert( + ACCESS_CONTROL_ALLOW_HEADERS, + HeaderValue::from_static(CORS_ALLOWED_HEADERS), + ); + headers.insert( + ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::from_static(CORS_EXPOSED_HEADERS), + ); + headers.insert( + ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + headers.insert(ACCESS_CONTROL_MAX_AGE, HeaderValue::from_static("86400")); + headers.insert(VARY, HeaderValue::from_static("Origin")); + } + } + } + + if allow_private_network { + headers.insert( + HeaderName::from_static(HEADER_CORS_ALLOW_PRIVATE_NETWORK), + HeaderValue::from_static("true"), + ); + } +} + +async fn dev_server_list(state: web::Data, req: HttpRequest) -> HttpResponse { + let auth = match authenticate_dev_request(&req, &state).await { + Ok(auth) => auth, + Err(response) => return response, + }; + let extra_env = match make_dev_mode_env(&auth, &state, None, "list") { + Ok(extra_env) => extra_env, + Err(err) => { + return json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + &format!("{err:#}"), + ); + } + }; + + let language = match detect_eval_language(&state.files, state.language_override) { + Ok(language) => language, + Err(err) => { + return json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + &format!("{err:#}"), + ); + } + }; + let spawned = match spawn_eval_runner( + &state.base, + language, + state.runner_override.as_deref(), + &state.files, + state.no_send_logs, + &state.options, + &extra_env, + JsMode::Auto, + ) + .await + { + Ok(value) => value, + Err(err) => { + return json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + &format!("{err:#}"), + ); + } + }; + + let mut stdout_lines = Vec::new(); + let mut errors: Vec<(String, Option)> = Vec::new(); + let output = + match drive_eval_runner( + spawned.process, + ConsolePolicy::Forward, + |event| match event { + EvalEvent::Console { stream, message } if stream == "stdout" => { + stdout_lines.push(message); + } + EvalEvent::Error { + message, + stack: _, + status, + } => errors.push((message, status)), + _ => {} + }, + ) + .await + { + Ok(output) => output, + Err(err) => { + return json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + &format!("{err:#}"), + ); + } + }; + + if let Some((message, status)) = errors.first() { + let status = status + .and_then(|status| actix_web::http::StatusCode::from_u16(status).ok()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + return json_error_response(status, message); + } + if !output.status.success() { + return json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + "Eval runner exited with an error.", + ); + } + + let mut parsed_manifest: Option = None; + for line in stdout_lines.iter().rev() { + if let Ok(value) = serde_json::from_str::(line) { + parsed_manifest = Some(value); + break; + } + } + if parsed_manifest.is_none() { + let joined = stdout_lines.join("\n"); + if let Ok(value) = serde_json::from_str::(&joined) { + parsed_manifest = Some(value); + } + } + + match parsed_manifest { + Some(manifest) => HttpResponse::Ok().json(manifest), + None => json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + "Failed to parse evaluator manifest from runner output.", + ), + } +} + +async fn dev_server_eval( + state: web::Data, + req: HttpRequest, + body: web::Bytes, +) -> HttpResponse { + let auth = match authenticate_dev_request(&req, &state).await { + Ok(auth) => auth, + Err(response) => return response, + }; + + let mut eval_request: EvalRequest = match serde_json::from_slice(&body) { + Ok(eval_request) => eval_request, + Err(err) => { + return json_error_response(actix_web::http::StatusCode::BAD_REQUEST, &err.to_string()); + } + }; + if let Err(response) = + resolve_dataset_ref_for_eval_request(&state, &auth, &mut eval_request).await + { + return response; + } + let stream_requested = eval_request.stream.unwrap_or(false); + let extra_env = match make_dev_mode_env(&auth, &state, Some(&eval_request), "eval") { + Ok(extra_env) => extra_env, + Err(err) => { + return json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + &format!("{err:#}"), + ); + } + }; + + let language = match detect_eval_language(&state.files, state.language_override) { + Ok(language) => language, + Err(err) => { + return json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + &format!("{err:#}"), + ); + } + }; + let spawned = match spawn_eval_runner( + &state.base, + language, + state.runner_override.as_deref(), + &state.files, + state.no_send_logs, + &state.options, + &extra_env, + JsMode::Auto, + ) + .await + { + Ok(value) => value, + Err(err) => { + return json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + &format!("{err:#}"), + ); + } + }; + + if stream_requested { + let (tx, rx) = mpsc::unbounded_channel::(); + tokio::spawn(async move { + let mut saw_error = false; + let mut stderr_lines: Vec = Vec::new(); + let output = drive_eval_runner(spawned.process, ConsolePolicy::Forward, |event| { + if matches!(event, EvalEvent::Error { .. }) { + saw_error = true; + } + if matches!(event, EvalEvent::Done) { + return; + } + if let EvalEvent::Console { + ref stream, + ref message, + } = event + { + for line in message.lines() { + let _ = tx.send(format!(": [{stream}] {line}\n")); + } + if stream == "stderr" { + stderr_lines.push(message.clone()); + } + return; + } + if let Some(encoded) = encode_eval_event_for_http(&event) { + let _ = tx.send(encoded); + } + }) + .await; + + match output { + Ok(output) => { + if !output.status.success() && !saw_error { + let mut detail = format!("Eval runner exited with {}.", output.status); + for line in stderr_lines.iter() { + detail.push('\n'); + detail.push_str(line); + } + let error = + serialize_sse_event("error", &json!({ "message": detail }).to_string()); + let _ = tx.send(error); + } + } + Err(err) => { + let error = serialize_sse_event( + "error", + &json!({ "message": format!("{err:#}") }).to_string(), + ); + let _ = tx.send(error); + } + } + + let _ = tx.send(serialize_sse_event("done", "")); + }); + + let response_stream = stream::unfold(rx, |mut rx| async { + rx.recv() + .await + .map(|chunk| (Ok::<_, actix_web::Error>(web::Bytes::from(chunk)), rx)) + }); + return HttpResponse::Ok() + .append_header((CONTENT_TYPE, "text/event-stream")) + .append_header((CACHE_CONTROL, "no-cache")) + .append_header((CONNECTION, "keep-alive")) + .streaming(response_stream); + } + + let mut summary: Option = None; + let mut errors: Vec<(String, Option)> = Vec::new(); + let output = + match drive_eval_runner( + spawned.process, + ConsolePolicy::Forward, + |event| match event { + EvalEvent::Summary(current) => summary = Some(current), + EvalEvent::Error { + message, + stack: _, + status, + } => errors.push((message, status)), + _ => {} + }, + ) + .await + { + Ok(output) => output, + Err(err) => { + return json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + &format!("{err:#}"), + ); + } + }; + + if let Some((message, status)) = errors.first() { + let status = status + .and_then(|status| actix_web::http::StatusCode::from_u16(status).ok()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + return json_error_response(status, message); + } + if let Some(summary) = summary { + return HttpResponse::Ok().json(summary); + } + if !output.status.success() { + return json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + "Eval runner exited with an error.", + ); + } + json_error_response( + actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, + "Eval runner did not return a summary.", + ) +} + +pub(super) async fn run_dev_server(state: DevServerState) -> Result<()> { + println!( + "Starting eval dev server on http://{}:{}", + state.host, state.port + ); + let host = state.host.clone(); + let port = state.port; + HttpServer::new(move || { + let allowed_origins = state.allowed_origins.clone(); + App::new() + .wrap_fn({ + let allowed_origins = allowed_origins.clone(); + move |req, srv| { + let allowed_origins = allowed_origins.clone(); + let request_origin = req + .headers() + .get(ORIGIN) + .and_then(|value| value.to_str().ok()) + .map(str::to_owned); + let allow_private_network = + req.headers().contains_key(HEADER_CORS_REQ_PRIVATE_NETWORK); + let fut = srv.call(req); + async move { + let mut res = fut.await?; + apply_cors_headers( + res.headers_mut(), + request_origin.as_deref(), + allow_private_network, + &allowed_origins, + ); + Ok::<_, actix_web::Error>(res) + } + } + }) + .app_data(web::Data::new(state.clone())) + .route("/", web::get().to(dev_server_index)) + .route( + "/", + web::route().guard(guard::Options()).to(dev_server_options), + ) + .route("/list", web::get().to(dev_server_list)) + .route( + "/list", + web::route().guard(guard::Options()).to(dev_server_options), + ) + .route("/eval", web::post().to(dev_server_eval)) + .route( + "/eval", + web::route().guard(guard::Options()).to(dev_server_options), + ) + }) + .bind((host.as_str(), port)) + .with_context(|| format!("failed to bind eval dev server on {host}:{port}"))? + .run() + .await + .context("eval dev server exited unexpectedly") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn join_app_url_normalizes_slashes() { + let joined = + join_app_url("https://www.braintrust.dev/", "/api/dataset/get").expect("join app url"); + assert_eq!(joined, "https://www.braintrust.dev/api/dataset/get"); + } + + #[test] + fn collect_allowed_dev_origins_includes_app_origin_and_dedupes() { + let origins = collect_allowed_dev_origins( + &[ + "https://example.com".to_string(), + "https://example.com".to_string(), + ], + "https://app.example.dev/some/path", + ); + assert_eq!( + origins, + vec![ + "https://app.example.dev".to_string(), + "https://example.com".to_string() + ] + ); + } + + #[test] + fn is_allowed_origin_accepts_configured_origin() { + let allowed = vec!["https://example.com".to_string()]; + assert!(is_allowed_origin("https://example.com", &allowed)); + assert!(!is_allowed_origin("https://evil.example", &allowed)); + } + + #[test] + fn encode_eval_event_for_http_filters_internal_eval_progress() { + let event = EvalEvent::Progress(SseProgressEventData { + id: "id-1".to_string(), + object_type: "task".to_string(), + origin: None, + format: "global".to_string(), + output_type: "any".to_string(), + name: "My evaluation".to_string(), + event: "progress".to_string(), + data: r#"{"type":"eval_progress","kind":"start","total":1}"#.to_string(), + }); + + assert!(encode_eval_event_for_http(&event).is_none()); + } + + #[test] + fn encode_eval_event_for_http_keeps_external_progress_events() { + let event = EvalEvent::Progress(SseProgressEventData { + id: "id-2".to_string(), + object_type: "task".to_string(), + origin: None, + format: "code".to_string(), + output_type: "completion".to_string(), + name: "My evaluation".to_string(), + event: "json_delta".to_string(), + data: "\"China\"".to_string(), + }); + + let encoded = encode_eval_event_for_http(&event).expect("progress should be forwarded"); + assert!(encoded.contains("event: progress")); + assert!(encoded.contains("json_delta")); + } +} diff --git a/src/eval/events.rs b/src/eval/events.rs new file mode 100644 index 0000000..b95f737 --- /dev/null +++ b/src/eval/events.rs @@ -0,0 +1,126 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug)] +pub(super) enum EvalEvent { + Processing(ProcessingEventData), + Start(ExperimentStart), + Summary(ExperimentSummary), + Progress(SseProgressEventData), + Dependencies { + files: Vec, + }, + Done, + Error { + message: String, + stack: Option, + status: Option, + }, + Console { + stream: String, + message: String, + }, +} + +#[derive(Debug, Deserialize, Serialize)] +pub(super) struct ProcessingEventData { + #[serde(default)] + pub(super) evaluators: usize, +} + +#[derive(Debug, Deserialize, Serialize, Default)] +#[serde(rename_all = "camelCase")] +pub(super) struct ExperimentStart { + #[serde(default, alias = "project_name")] + pub(super) project_name: Option, + #[serde(default, alias = "experiment_name")] + pub(super) experiment_name: Option, + #[serde(default, alias = "project_id")] + pub(super) project_id: Option, + #[serde(default, alias = "experiment_id")] + pub(super) experiment_id: Option, + #[serde(default, alias = "project_url")] + pub(super) project_url: Option, + #[serde(default, alias = "experiment_url")] + pub(super) experiment_url: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub(super) struct ExperimentSummary { + pub(super) project_name: String, + pub(super) experiment_name: String, + pub(super) project_id: Option, + pub(super) experiment_id: Option, + pub(super) project_url: Option, + pub(super) experiment_url: Option, + pub(super) comparison_experiment_name: Option, + pub(super) scores: HashMap, + pub(super) metrics: Option>, +} + +#[derive(Debug, Deserialize, Serialize)] +pub(super) struct ScoreSummary { + pub(super) name: String, + pub(super) score: f64, + pub(super) diff: Option, + #[serde(default)] + pub(super) improvements: i64, + #[serde(default)] + pub(super) regressions: i64, +} + +#[derive(Debug, Deserialize)] +pub(super) struct EvalErrorPayload { + pub(super) message: String, + pub(super) stack: Option, + pub(super) status: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub(super) struct MetricSummary { + pub(super) name: String, + pub(super) metric: f64, + #[serde(default)] + pub(super) unit: String, + pub(super) diff: Option, + #[serde(default)] + pub(super) improvements: i64, + #[serde(default)] + pub(super) regressions: i64, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize, Serialize)] +pub(super) struct SseProgressEventData { + pub(super) id: String, + pub(super) object_type: String, + pub(super) origin: Option, + pub(super) format: String, + pub(super) output_type: String, + pub(super) name: String, + pub(super) event: String, + pub(super) data: String, +} + +#[derive(Debug, Deserialize)] +pub(super) struct EvalProgressData { + #[serde(rename = "type")] + pub(super) kind_type: String, + pub(super) kind: String, + pub(super) total: Option, +} + +#[derive(Debug, Deserialize)] +pub(super) struct SseConsoleEventData { + pub(super) stream: String, + pub(super) message: String, +} + +#[derive(Debug, Deserialize)] +pub(super) struct SseDependenciesEventData { + pub(super) files: Vec, +} diff --git a/src/eval/ui.rs b/src/eval/ui.rs new file mode 100644 index 0000000..cc37b6f --- /dev/null +++ b/src/eval/ui.rs @@ -0,0 +1,751 @@ +use std::collections::HashMap; +use std::io::IsTerminal; + +use crossterm::queue; +use crossterm::style::{ + Attribute, Color as CtColor, ResetColor, SetAttribute, SetBackgroundColor, SetForegroundColor, + Stylize, +}; +use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; +use ratatui::backend::TestBackend; +use ratatui::layout::{Alignment, Constraint}; +use ratatui::style::{Color, Modifier, Style}; +use ratatui::text::{Line, Span}; +use ratatui::widgets::{Cell, Row, Table}; +use ratatui::Terminal; +use strip_ansi_escapes::strip; +use unicode_width::UnicodeWidthStr; + +use crate::ui::{animations_enabled, is_quiet}; + +use super::events::{ + EvalEvent, EvalProgressData, ExperimentStart, ExperimentSummary, SseProgressEventData, +}; +use super::{MAX_DEFERRED_EVAL_ERRORS, MAX_NAME_LENGTH}; + +pub(super) struct EvalUi { + progress: MultiProgress, + bars: HashMap, + bar_style: ProgressStyle, + spinner_style: ProgressStyle, + jsonl: bool, + list: bool, + verbose: bool, + deferred_errors: Vec, + suppressed_stderr_lines: usize, + finished: bool, +} + +impl EvalUi { + pub(super) fn new(jsonl: bool, list: bool, verbose: bool) -> Self { + let draw_target = if std::io::stderr().is_terminal() && animations_enabled() && !is_quiet() + { + ProgressDrawTarget::stderr_with_hz(10) + } else { + ProgressDrawTarget::stderr() + }; + let progress = MultiProgress::with_draw_target(draw_target); + let bar_style = + ProgressStyle::with_template("{bar:10.blue} {msg} {percent}% {pos}/{len} {eta}") + .unwrap(); + let spinner_style = ProgressStyle::with_template("{spinner} {msg}").unwrap(); + Self { + progress, + bars: HashMap::new(), + bar_style, + spinner_style, + jsonl, + list, + verbose, + deferred_errors: Vec::new(), + suppressed_stderr_lines: 0, + finished: false, + } + } + + pub(super) fn finish(&mut self) { + if self.finished { + return; + } + for (_, bar) in self.bars.drain() { + bar.finish_and_clear(); + } + let _ = self.progress.clear(); + self.progress.set_draw_target(ProgressDrawTarget::hidden()); + self.print_deferred_error_footnote(); + self.finished = true; + } + + pub(super) fn handle(&mut self, event: EvalEvent) { + match event { + EvalEvent::Processing(payload) => { + self.print_persistent_line(format_processing_line(payload.evaluators)); + } + EvalEvent::Start(start) => { + if let Some(line) = format_start_line(&start) { + self.print_persistent_line(line); + } + } + EvalEvent::Summary(summary) => { + if self.jsonl { + if let Ok(line) = serde_json::to_string(&summary) { + println!("{line}"); + } + } else { + let rendered = format_experiment_summary(&summary); + self.print_persistent_multiline(rendered); + } + } + EvalEvent::Progress(progress) => { + self.handle_progress(progress); + } + EvalEvent::Dependencies { .. } => {} + EvalEvent::Console { stream, message } => { + if stream == "stdout" && (self.list || self.jsonl) { + println!("{message}"); + } else if stream == "stderr" && !self.verbose { + self.suppressed_stderr_lines += 1; + } else { + let _ = self.progress.println(message); + } + } + EvalEvent::Error { message, stack, .. } => { + let show_hint = message.contains("Please specify an api key"); + if self.verbose { + let line = message.as_str().red().to_string(); + let _ = self.progress.println(line); + if let Some(stack) = stack { + for line in stack.lines() { + let _ = self.progress.println(line.dark_grey().to_string()); + } + } + } else { + self.record_deferred_error(message); + } + if show_hint { + let hint = "Hint: pass --api-key, set BRAINTRUST_API_KEY, run `bt auth login`/`bt auth login --oauth`, or use --no-send-logs for local evals."; + if self.verbose { + let _ = self.progress.println(hint.dark_grey().to_string()); + } else { + self.record_deferred_error(hint.to_string()); + } + } + } + EvalEvent::Done => { + self.finish(); + } + } + } + + fn handle_progress(&mut self, progress: SseProgressEventData) { + let payload = match serde_json::from_str::(&progress.data) { + Ok(payload) if payload.kind_type == "eval_progress" => payload, + _ => return, + }; + + match payload.kind.as_str() { + "start" => { + let bar = if let Some(total) = payload.total { + if total > 0 { + let bar = self.progress.add(ProgressBar::new(total)); + bar.set_style(self.bar_style.clone()); + bar + } else { + let bar = self.progress.add(ProgressBar::new_spinner()); + bar.set_style(self.spinner_style.clone()); + bar + } + } else { + let bar = self.progress.add(ProgressBar::new_spinner()); + bar.set_style(self.spinner_style.clone()); + bar + }; + bar.set_message(fit_name_to_spaces(&progress.name, MAX_NAME_LENGTH)); + self.bars.insert(progress.name.clone(), bar); + } + "increment" => { + if let Some(bar) = self.bars.get(&progress.name) { + bar.inc(1); + bar.set_message(fit_name_to_spaces(&progress.name, MAX_NAME_LENGTH)); + } + } + "set_total" => { + if let Some(bar) = self.bars.get(&progress.name) { + if let Some(total) = payload.total { + bar.set_length(total); + bar.set_style(self.bar_style.clone()); + } + } + } + "stop" => { + if let Some(bar) = self.bars.remove(&progress.name) { + bar.finish_and_clear(); + } + } + _ => {} + } + } + + fn print_persistent_line(&self, line: String) { + self.progress.suspend(|| { + eprintln!("{line}"); + }); + } + + fn print_persistent_multiline(&self, text: String) { + self.progress.suspend(|| { + for line in text.lines() { + eprintln!("{line}"); + } + }); + } + + fn record_deferred_error(&mut self, message: String) { + let trimmed = message.trim(); + if trimmed.is_empty() { + return; + } + if self + .deferred_errors + .iter() + .any(|existing| existing == trimmed) + { + return; + } + if self.deferred_errors.len() < MAX_DEFERRED_EVAL_ERRORS { + self.deferred_errors.push(trimmed.to_string()); + } + } + + fn print_deferred_error_footnote(&self) { + if self.verbose { + return; + } + if self.deferred_errors.is_empty() && self.suppressed_stderr_lines == 0 { + return; + } + + eprintln!(); + if !self.deferred_errors.is_empty() { + let noun = if self.deferred_errors.len() == 1 { + "error" + } else { + "errors" + }; + eprintln!( + "Encountered {} evaluator {noun}:", + self.deferred_errors.len() + ); + for message in &self.deferred_errors { + eprintln!(" - {message}"); + } + } + if self.suppressed_stderr_lines > 0 { + eprintln!( + "Suppressed {} stderr line(s). Re-run with `bt eval --verbose ...` to inspect details.", + self.suppressed_stderr_lines + ); + } + } +} + +impl Drop for EvalUi { + fn drop(&mut self) { + self.finish(); + } +} + +fn fit_name_to_spaces(name: &str, length: usize) -> String { + let char_count = name.chars().count(); + if char_count < length { + let mut padded = name.to_string(); + padded.push_str(&" ".repeat(length - char_count)); + return padded; + } + if char_count == length { + return name.to_string(); + } + if length <= 3 { + return name.chars().take(length).collect(); + } + if length <= 5 { + let truncated: String = name.chars().take(length - 3).collect(); + return format!("{truncated}..."); + } + + let keep_total = length - 3; + let head_len = keep_total / 2; + let tail_len = keep_total - head_len; + let head: String = name.chars().take(head_len).collect(); + let tail: String = name + .chars() + .rev() + .take(tail_len) + .collect::() + .chars() + .rev() + .collect(); + format!("{head}...{tail}") +} + +fn format_processing_line(evaluators: usize) -> String { + let noun = if evaluators == 1 { + "evaluator" + } else { + "evaluators" + }; + format!("Processing {evaluators} {noun}...") +} + +fn format_start_line(start: &ExperimentStart) -> Option { + let experiment_name = start + .experiment_name + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()); + let experiment_url = start + .experiment_url + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()); + let arrow = "▶".cyan(); + + match (experiment_name, experiment_url) { + (Some(name), Some(url)) => Some(format!( + "{arrow} Experiment {} is running at {url}", + name.bold() + )), + (Some(name), None) => Some(format!( + "{arrow} Experiment {} is running at locally", + name.bold() + )), + (None, Some(url)) => Some(format!("{arrow} Experiment is running at {url}")), + (None, None) => None, + } +} + +fn format_experiment_summary(summary: &ExperimentSummary) -> String { + let mut parts: Vec = Vec::new(); + + if let Some(comparison) = summary.comparison_experiment_name.as_deref() { + let line = format!( + "{baseline} {baseline_tag} ← {comparison_name} {comparison_tag}", + baseline = comparison, + baseline_tag = "(baseline)".dark_grey(), + comparison_name = summary.experiment_name, + comparison_tag = "(comparison)".dark_grey(), + ); + parts.push(line); + } + + let has_scores = !summary.scores.is_empty(); + let has_metrics = summary + .metrics + .as_ref() + .map(|metrics| !metrics.is_empty()) + .unwrap_or(false); + + if has_scores || has_metrics { + let has_comparison = summary.comparison_experiment_name.is_some(); + let mut rows: Vec> = Vec::new(); + + let header = if has_comparison { + Some(vec![ + header_line("Name"), + header_line("Value"), + header_line("Change"), + header_line("Improvements"), + header_line("Regressions"), + ]) + } else { + None + }; + + let mut score_values: Vec<_> = summary.scores.values().collect(); + score_values.sort_by(|a, b| a.name.cmp(&b.name)); + for score in score_values { + let score_percent = + Line::from(format!("{:.2}%", score.score * 100.0)).alignment(Alignment::Right); + let diff = format_diff_line(score.diff); + let improvements = format_improvements_line(score.improvements); + let regressions = format_regressions_line(score.regressions); + let name = truncate_plain(&score.name, MAX_NAME_LENGTH); + let name = Line::from(vec![ + Span::styled("◯", Style::default().fg(Color::Blue)), + Span::raw(" "), + Span::raw(name), + ]); + if has_comparison { + rows.push(vec![name, score_percent, diff, improvements, regressions]); + } else { + rows.push(vec![name, score_percent]); + } + } + + if let Some(metrics) = &summary.metrics { + let mut metric_values: Vec<_> = metrics.values().collect(); + metric_values.sort_by(|a, b| a.name.cmp(&b.name)); + for metric in metric_values { + let formatted_value = Line::from(format_metric_value(metric.metric, &metric.unit)) + .alignment(Alignment::Right); + let diff = format_diff_line(metric.diff); + let improvements = format_improvements_line(metric.improvements); + let regressions = format_regressions_line(metric.regressions); + let name = truncate_plain(&metric.name, MAX_NAME_LENGTH); + let name = Line::from(vec![ + Span::styled("◯", Style::default().fg(Color::Magenta)), + Span::raw(" "), + Span::raw(name), + ]); + if has_comparison { + rows.push(vec![name, formatted_value, diff, improvements, regressions]); + } else { + rows.push(vec![name, formatted_value]); + } + } + } + + parts.push(render_table_ratatui(header, rows, has_comparison)); + } + + if let Some(url) = &summary.experiment_url { + parts.push(format!("See results at {url}")); + } + + let content = parts.join("\n\n"); + box_with_title("Experiment summary", &content) +} + +fn format_diff_line(diff: Option) -> Line<'static> { + match diff { + Some(value) => { + let sign = if value > 0.0 { "+" } else { "" }; + let percent = format!("{sign}{:.2}%", value * 100.0); + let style = if value > 0.0 { + Style::default().fg(Color::Green) + } else { + Style::default().fg(Color::Red) + }; + Line::from(Span::styled(percent, style)).alignment(Alignment::Right) + } + None => Line::from(Span::styled("-", Style::default().fg(Color::DarkGray))) + .alignment(Alignment::Right), + } +} + +fn format_improvements_line(value: i64) -> Line<'static> { + if value > 0 { + Line::from(Span::styled( + value.to_string(), + Style::default() + .fg(Color::Green) + .add_modifier(Modifier::DIM), + )) + .alignment(Alignment::Right) + } else { + Line::from(Span::styled("-", Style::default().fg(Color::DarkGray))) + .alignment(Alignment::Right) + } +} + +fn format_regressions_line(value: i64) -> Line<'static> { + if value > 0 { + Line::from(Span::styled( + value.to_string(), + Style::default().fg(Color::Red).add_modifier(Modifier::DIM), + )) + .alignment(Alignment::Right) + } else { + Line::from(Span::styled("-", Style::default().fg(Color::DarkGray))) + .alignment(Alignment::Right) + } +} + +fn format_metric_value(metric: f64, unit: &str) -> String { + let formatted = if metric.fract() == 0.0 { + format!("{metric:.0}") + } else { + format!("{metric:.2}") + }; + if unit == "$" { + format!("{unit}{formatted}") + } else { + format!("{formatted}{unit}") + } +} + +fn render_table_ratatui( + header: Option>>, + rows: Vec>>, + has_comparison: bool, +) -> String { + if rows.is_empty() { + return String::new(); + } + + let columns = if has_comparison { 5 } else { 2 }; + let mut widths = vec![0usize; columns]; + + if let Some(header_row) = &header { + for (idx, line) in header_row.iter().enumerate().take(columns) { + widths[idx] = widths[idx].max(line.width()); + } + } + + for row in &rows { + for (idx, line) in row.iter().enumerate().take(columns) { + widths[idx] = widths[idx].max(line.width()); + } + } + + let column_spacing = 2; + let total_width = widths.iter().sum::() + column_spacing * (columns - 1); + let mut height = rows.len(); + if header.is_some() { + height += 1; + } + let backend = TestBackend::new(total_width as u16, height as u16); + let mut terminal = Terminal::new(backend).expect("failed to create table backend"); + + let table_rows = rows.into_iter().map(|row| { + let cells = row.into_iter().map(Cell::new).collect::>(); + Row::new(cells) + }); + + let mut table = Table::new( + table_rows, + widths.iter().map(|w| Constraint::Length(*w as u16)), + ) + .column_spacing(column_spacing as u16); + + if let Some(header_row) = header { + let header_cells = header_row.into_iter().map(Cell::new).collect::>(); + table = table.header(Row::new(header_cells)); + } + + terminal + .draw(|frame| { + let area = frame.area(); + frame.render_widget(table, area); + }) + .expect("failed to render table"); + + let buffer = terminal.backend().buffer(); + buffer_to_ansi_lines(buffer).join("\n") +} + +fn header_line(text: &str) -> Line<'static> { + Line::from(Span::styled( + text.to_string(), + Style::default() + .fg(Color::DarkGray) + .add_modifier(Modifier::BOLD), + )) +} + +fn truncate_plain(text: &str, max_len: usize) -> String { + if text.chars().count() <= max_len { + return text.to_string(); + } + if max_len <= 3 { + return text.chars().take(max_len).collect(); + } + let truncated: String = text.chars().take(max_len - 3).collect(); + format!("{truncated}...") +} + +fn box_with_title(title: &str, content: &str) -> String { + let lines: Vec<&str> = content.lines().collect(); + let content_width = lines + .iter() + .map(|line| visible_width(line)) + .max() + .unwrap_or(0); + let padding = 1; + let inner_width = content_width + padding * 2; + + let title_plain = format!(" {title} "); + let title_width = visible_width(&title_plain); + let mut top = String::from("╭"); + top.push_str(&title_plain.dark_grey().to_string()); + if inner_width > title_width { + top.push_str(&"─".repeat(inner_width - title_width)); + } + top.push('╮'); + + let mut boxed = vec![top]; + for line in lines { + let line_width = visible_width(line); + let right_padding = inner_width.saturating_sub(line_width + padding); + let mut row = String::from("│"); + row.push_str(&" ".repeat(padding)); + row.push_str(line); + row.push_str(&" ".repeat(right_padding)); + row.push('│'); + boxed.push(row); + } + + let bottom = format!("╰{}╯", "─".repeat(inner_width)); + boxed.push(bottom); + + format!("\n{}", boxed.join("\n")) +} + +fn visible_width(text: &str) -> usize { + let stripped = strip(text.as_bytes()); + let stripped = String::from_utf8_lossy(&stripped); + UnicodeWidthStr::width(stripped.as_ref()) +} + +fn buffer_to_ansi_lines(buffer: &ratatui::buffer::Buffer) -> Vec { + let width = buffer.area.width as usize; + let height = buffer.area.height as usize; + let mut lines = Vec::with_capacity(height); + let mut current_style = Style::reset(); + + for y in 0..height { + let mut line = String::new(); + let mut skip = 0usize; + for x in 0..width { + let cell = &buffer[(x as u16, y as u16)]; + let symbol = cell.symbol(); + let symbol_width = UnicodeWidthStr::width(symbol); + if skip > 0 { + skip -= 1; + continue; + } + + let style = Style { + fg: Some(cell.fg), + bg: Some(cell.bg), + add_modifier: cell.modifier, + ..Style::default() + }; + + if style != current_style { + line.push_str(&style_to_ansi(style)); + current_style = style; + } + + line.push_str(symbol); + skip = symbol_width.saturating_sub(1); + } + line.push_str(&style_to_ansi(Style::reset())); + lines.push(line.trim_end().to_string()); + } + + lines +} + +fn style_to_ansi(style: Style) -> String { + let mut buf = Vec::new(); + let _ = queue!(buf, SetAttribute(Attribute::Reset), ResetColor); + + if let Some(fg) = style.fg { + let _ = queue!(buf, SetForegroundColor(convert_color(fg))); + } + if let Some(bg) = style.bg { + let _ = queue!(buf, SetBackgroundColor(convert_color(bg))); + } + + let mods = style.add_modifier; + if mods.contains(Modifier::BOLD) { + let _ = queue!(buf, SetAttribute(Attribute::Bold)); + } + if mods.contains(Modifier::DIM) { + let _ = queue!(buf, SetAttribute(Attribute::Dim)); + } + if mods.contains(Modifier::ITALIC) { + let _ = queue!(buf, SetAttribute(Attribute::Italic)); + } + if mods.contains(Modifier::UNDERLINED) { + let _ = queue!(buf, SetAttribute(Attribute::Underlined)); + } + if mods.contains(Modifier::REVERSED) { + let _ = queue!(buf, SetAttribute(Attribute::Reverse)); + } + if mods.contains(Modifier::CROSSED_OUT) { + let _ = queue!(buf, SetAttribute(Attribute::CrossedOut)); + } + if mods.contains(Modifier::SLOW_BLINK) { + let _ = queue!(buf, SetAttribute(Attribute::SlowBlink)); + } + if mods.contains(Modifier::RAPID_BLINK) { + let _ = queue!(buf, SetAttribute(Attribute::RapidBlink)); + } + + String::from_utf8_lossy(&buf).to_string() +} + +fn convert_color(color: Color) -> CtColor { + match color { + Color::Reset => CtColor::Reset, + Color::Black => CtColor::Black, + Color::Red => CtColor::Red, + Color::Green => CtColor::Green, + Color::Yellow => CtColor::Yellow, + Color::Blue => CtColor::Blue, + Color::Magenta => CtColor::Magenta, + Color::Cyan => CtColor::Cyan, + Color::Gray => CtColor::Grey, + Color::DarkGray => CtColor::DarkGrey, + Color::LightRed => CtColor::Red, + Color::LightGreen => CtColor::Green, + Color::LightYellow => CtColor::Yellow, + Color::LightBlue => CtColor::Blue, + Color::LightMagenta => CtColor::Magenta, + Color::LightCyan => CtColor::Cyan, + Color::White => CtColor::White, + Color::Indexed(value) => CtColor::AnsiValue(value), + Color::Rgb(r, g, b) => CtColor::Rgb { r, g, b }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn box_with_title_handles_ansi_content_without_panicking() { + let content = "plain line\n\x1b[38;5;196mred text\x1b[0m"; + let boxed = box_with_title("Summary", content); + assert!(boxed.contains("Summary")); + assert!(boxed.contains("plain line")); + assert!(boxed.contains("red text")); + } + + #[test] + fn format_processing_line_handles_pluralization() { + assert_eq!(format_processing_line(1), "Processing 1 evaluator..."); + assert_eq!(format_processing_line(2), "Processing 2 evaluators..."); + } + + #[test] + fn format_start_line_handles_partial_payload() { + let start = ExperimentStart { + experiment_name: Some("my-exp".to_string()), + experiment_url: Some("https://example.dev/exp".to_string()), + ..Default::default() + }; + let line = format_start_line(&start).expect("line should be rendered"); + assert!(line.contains("my-exp")); + assert!(line.contains("https://example.dev/exp")); + + assert!(format_start_line(&ExperimentStart::default()).is_none()); + } + + #[test] + fn fit_name_to_spaces_preserves_suffix_when_truncating() { + let rendered = + fit_name_to_spaces("Topics [experimentName=facets-real-world-30b-f5a78312]", 40); + assert_eq!(rendered.chars().count(), 40); + assert!(rendered.contains("...")); + assert!(rendered.contains("f5a78312]")); + } + + #[test] + fn fit_name_to_spaces_pads_short_names() { + let rendered = fit_name_to_spaces("short", 10); + assert_eq!(rendered, "short "); + } +}