From 8eb180f0042e052c0a831467bce81d95eb058de2 Mon Sep 17 00:00:00 2001 From: Nate Selvidge Date: Thu, 19 Mar 2026 18:56:47 +0000 Subject: [PATCH 1/3] add --sandbox flag --- scripts/eval-runner.py | 201 ++++++++++- scripts/eval-runner.ts | 252 ++++++++++++- src/eval.rs | 795 ++++++++++++++++++++++++++++++++++++++++- src/experiments/api.rs | 2 + src/functions/mod.rs | 3 +- src/functions/push.rs | 155 +++++++- src/sync.rs | 2 +- tests/functions.rs | 5 +- 8 files changed, 1383 insertions(+), 32 deletions(-) diff --git a/scripts/eval-runner.py b/scripts/eval-runner.py index e9d3ce0..d8ac067 100755 --- a/scripts/eval-runner.py +++ b/scripts/eval-runner.py @@ -9,9 +9,10 @@ import re import socket import sys +import time import traceback from dataclasses import dataclass -from typing import Any, Callable +from typing import Any, AsyncIterator, Callable try: from braintrust import init_dataset, invoke, login @@ -79,6 +80,41 @@ def close(self) -> None: self.sock.close() +@dataclass +class PullChannel: + sock: socket.socket + + def send(self, payload: Any) -> None: + self.sock.sendall((json.dumps(payload) + "\n").encode("utf-8")) + + async def lines(self) -> AsyncIterator[str]: + buffer = "" + while True: + chunk = await asyncio.to_thread(self.sock.recv, 4096) + if not chunk: + break + buffer += chunk.decode("utf-8") + while True: + newline = buffer.find("\n") + if newline == -1: + break + line = buffer[:newline].strip() + buffer = buffer[newline + 1 :] + if line: + yield line + + trailing = buffer.strip() + if trailing: + yield trailing + + def close(self) -> None: + try: + self.sock.shutdown(socket.SHUT_RDWR) + except OSError: + pass + self.sock.close() + + def serialize_sse_event(event: str, data: Any) -> str: if isinstance(data, (dict, list)): data_str = json.dumps(data) @@ -105,6 +141,16 @@ def create_sse_writer() -> SseWriter | None: return None +def create_pull_channel() -> PullChannel | None: + sock_path = os.getenv("BT_EVAL_PULL_SOCK") + if not sock_path: + return None + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(sock_path) + return PullChannel(sock) + + def env_flag(name: str) -> bool: value = os.getenv(name) if value is None: @@ -137,7 +183,7 @@ def parse_serialized_filters(serialized: str | None) -> list[EvalFilter]: def parse_dev_mode(value: str | None) -> str | None: if value is None or value == "": return None - if value in {"list", "eval"}: + if value in {"list", "eval", "rows"}: return value raise ValueError(f"Invalid BT_EVAL_DEV_MODE value: {value}") @@ -302,6 +348,26 @@ def parse_eval_request(raw: str | None) -> dict[str, Any]: return parsed +def parse_eval_pull_request(raw: str | None) -> dict[str, Any]: + if not raw: + raise ValueError("Missing BT_EVAL_DEV_REQUEST_JSON") + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid BT_EVAL_DEV_REQUEST_JSON: {exc}") from exc + + if not isinstance(parsed, dict): + raise ValueError("BT_EVAL_DEV_REQUEST_JSON must be a JSON object.") + if not isinstance(parsed.get("name"), str) or not parsed["name"]: + raise ValueError("Pull request must include a non-empty evaluator name.") + + parameters = parsed.get("parameters") + if parameters is not None and not isinstance(parameters, dict): + raise ValueError("Pull request parameters must be an object.") + + return parsed + + def resolve_eval_data(data: dict[str, Any]) -> Any: if "data" in data: return data["data"] @@ -324,6 +390,33 @@ def resolve_eval_data(data: dict[str, Any]) -> Any: raise ValueError("Invalid eval data payload.") +async def call_evaluator_data(data: Any) -> tuple[Any, str | None]: + data_result = data + if inspect.isclass(data_result): + data_result = data_result() + if inspect.isfunction(data_result) or inspect.isroutine(data_result): + data_result = data_result() + if inspect.isawaitable(data_result): + data_result = await data_result + + base_experiment_name = None + if isinstance(data_result, BaseExperiment): + base_experiment_name = data_result.name + + return data_result, base_experiment_name + + +def to_async_iterator(value: Any) -> AsyncIterator[Any]: + if inspect.isasyncgen(value): + return value + + async def to_async(it): + for item in it: + yield item + + return to_async(value) + + def make_eval_scorer( score: dict[str, Any], project_id: str | None, @@ -851,6 +944,108 @@ async def run_requested_eval( return True +async def run_dataset_pull( + evaluator_instances: list[EvaluatorInstance], + config: RunnerConfig, +) -> bool: + channel = create_pull_channel() + if channel is None: + raise ValueError("Missing BT_EVAL_PULL_SOCK") + + try: + request = parse_eval_pull_request(config.dev_request_json) + except Exception as exc: + channel.send({"type": "error", "message": str(exc)}) + channel.close() + return False + + target_name = request["name"] + evaluator_instance = next( + (candidate for candidate in evaluator_instances if candidate.evaluator.eval_name == target_name), + None, + ) + if evaluator_instance is None: + channel.send({"type": "error", "message": f"Evaluator '{target_name}' not found"}) + channel.close() + return False + + evaluator = evaluator_instance.evaluator + try: + raw_data, _base_experiment_name = await call_evaluator_data(evaluator.data) + data_iterator = to_async_iterator(raw_data) + iterator = data_iterator.__aiter__() + trial_count = getattr(evaluator, "trial_count", 1) + try: + trial_count = int(trial_count) + except Exception: + trial_count = 1 + if trial_count < 1: + trial_count = 1 + + max_concurrency = getattr(evaluator, "max_concurrency", None) + try: + max_concurrency = int(max_concurrency) if max_concurrency is not None else 10 + except Exception: + max_concurrency = 10 + if max_concurrency < 1: + max_concurrency = 1 + + experiment_name = getattr(evaluator, "experiment_name", None) + if not isinstance(experiment_name, str) or not experiment_name: + experiment_name = f"{evaluator.eval_name}-{int(time.time() * 1000)}" + + channel.send( + { + "type": "ready", + "evaluator_name": evaluator.eval_name, + "max_concurrency": max_concurrency, + "experiment_name": experiment_name, + } + ) + + current_datum = None + trial_index = 0 + async for line in channel.lines(): + parsed = json.loads(line) + command_type = parsed.get("type") if isinstance(parsed, dict) else None + if command_type == "close": + break + if command_type != "next": + channel.send( + { + "type": "error", + "message": f"Unsupported pull command '{command_type}'", + } + ) + break + + if current_datum is None: + try: + current_datum = await iterator.__anext__() + trial_index = 0 + except StopAsyncIteration: + channel.send({"type": "eof"}) + continue + + channel.send( + { + "type": "row", + "datum": current_datum, + "trial_index": trial_index, + } + ) + trial_index += 1 + if trial_index >= trial_count: + current_datum = None + except Exception as exc: + channel.send({"type": "error", "message": str(exc)}) + channel.close() + return False + + channel.close() + return True + + async def run_once( files: list[str], no_send_logs: bool, @@ -872,6 +1067,8 @@ async def run_once( return True if config.dev_mode == "eval": return await run_requested_eval(evaluators, reporters, no_send_logs, sse, config) + if config.dev_mode == "rows": + return await run_dataset_pull(evaluators, config) if config.list_only: for evaluator_instance in evaluators: diff --git a/scripts/eval-runner.ts b/scripts/eval-runner.ts index 2a19c10..fdd7152 100644 --- a/scripts/eval-runner.ts +++ b/scripts/eval-runner.ts @@ -2,11 +2,17 @@ import { createRequire } from "node:module"; import path from "node:path"; import { fileURLToPath, pathToFileURL } from "node:url"; +type EvaluatorDefinition = { + evalName: string; + projectName: string; + data?: unknown; + trialCount?: unknown; + maxConcurrency?: unknown; + experimentName?: unknown; +} & Record; + type EvaluatorEntry = { - evaluator: { - evalName: string; - projectName: string; - } & Record; + evaluator: EvaluatorDefinition; reporter?: unknown; }; @@ -111,12 +117,17 @@ type EvalRequest = { scores?: EvalScoreSpec[]; }; +type EvalPullRequest = { + name: string; + parameters?: Record; +}; + type RunnerConfig = { jsonl: boolean; list: boolean; terminateOnFailure: boolean; filters: EvalFilter[]; - devMode: "list" | "eval" | null; + devMode: "list" | "eval" | "rows" | null; devRequestJson: string | null; }; @@ -243,11 +254,13 @@ function parseSerializedFilters(serialized: string | undefined): EvalFilter[] { } } -function parseDevMode(value: string | undefined): "list" | "eval" | null { +function parseDevMode( + value: string | undefined, +): "list" | "eval" | "rows" | null { if (!value) { return null; } - if (value === "list" || value === "eval") { + if (value === "list" || value === "eval" || value === "rows") { return value; } throw new Error(`Invalid BT_EVAL_DEV_MODE value: ${value}`); @@ -283,6 +296,7 @@ type NetModule = { setNoDelay: (value?: boolean) => void; on: (event: string, listener: (...args: unknown[]) => void) => void; write: (data: string) => void; + [Symbol.asyncIterator]?: () => AsyncIterator; }; }; @@ -766,6 +780,69 @@ function createSseWriter(): SseWriter | null { return { send, close }; } +type PullChannel = { + send: (payload: unknown) => void; + close: () => void; + lines: () => AsyncGenerator; +}; + +function createPullChannel(): PullChannel | null { + const netModule = (() => { + try { + return runtimeRequire("node:net") as NetModule; + } catch { + return null; + } + })(); + const sock = process.env.BT_EVAL_PULL_SOCK; + if (!sock) { + return null; + } + if (!netModule) { + return null; + } + + const socket = netModule.createConnection({ path: sock }); + socket.setNoDelay(true); + + const send = (payload: unknown) => { + if (!socket.writable) { + return; + } + socket.write(`${JSON.stringify(payload)}\n`); + }; + + const close = () => { + socket.end(); + }; + + const lines = async function* () { + let buffer = ""; + for await (const chunk of socket as unknown as AsyncIterable< + Buffer | string + >) { + buffer += typeof chunk === "string" ? chunk : chunk.toString("utf8"); + while (true) { + const newline = buffer.indexOf("\n"); + if (newline === -1) { + break; + } + const line = buffer.slice(0, newline).trim(); + buffer = buffer.slice(newline + 1); + if (line.length > 0) { + yield line; + } + } + } + const trailing = buffer.trim(); + if (trailing.length > 0) { + yield trailing; + } + }; + + return { send, close, lines }; +} + function initRegistry() { globalThis._evals = { functions: [], @@ -1405,6 +1482,24 @@ async function buildEvaluatorDefinitions(evaluators: EvaluatorEntry[]) { return result; } +function parseEvalPullRequest(raw: string | null): EvalPullRequest { + if (!raw) { + throw new Error("Missing BT_EVAL_DEV_REQUEST_JSON"); + } + const parsed = JSON.parse(raw); + if (!isObject(parsed) || typeof parsed.name !== "string" || parsed.name.length === 0) { + throw new Error("Pull request must include a non-empty evaluator name."); + } + const request = parsed as EvalPullRequest; + if ( + request.parameters !== undefined && + (!isObject(request.parameters) || Array.isArray(request.parameters)) + ) { + throw new Error("Pull request parameters must be an object."); + } + return request; +} + function parseEvalRequest(raw: string | null): EvalRequest { if (!raw) { throw new Error("Missing BT_EVAL_DEV_REQUEST_JSON"); @@ -1487,6 +1582,48 @@ function resolveEvalData( throw new Error("Invalid eval data payload."); } +function callEvaluatorData( + data: unknown, +): { data: unknown; baseExperiment: string | undefined } { + const dataResult = typeof data === "function" ? (data as () => unknown)() : data; + let baseExperiment: string | undefined = undefined; + if ( + isObject(dataResult) && + Reflect.get(dataResult, "_type") === "BaseExperiment" && + typeof Reflect.get(dataResult, "name") === "string" + ) { + baseExperiment = Reflect.get(dataResult, "name") as string; + } + return { data: dataResult, baseExperiment }; +} + +function toAsyncIterable(value: unknown): AsyncIterable { + if ( + typeof value === "object" && + value !== null && + Symbol.asyncIterator in value && + typeof (value as AsyncIterable)[Symbol.asyncIterator] === "function" + ) { + return value as AsyncIterable; + } + if ( + typeof value === "object" && + value !== null && + Symbol.iterator in value && + typeof (value as Iterable)[Symbol.iterator] === "function" + ) { + const iterable = value as Iterable; + return (async function* () { + for (const item of iterable) { + yield item; + } + })(); + } + throw new Error( + "Evaluator data must be an array, iterable, or async iterable", + ); +} + function convertFunctionId( functionId: Record, ): Record { @@ -1661,6 +1798,102 @@ async function runRequestedEval(config: RunnerConfig, runner: EvalRunner) { } } +async function runDatasetPull(config: RunnerConfig, runner: EvalRunner) { + const channel = createPullChannel(); + if (!channel) { + throw new Error("Missing BT_EVAL_PULL_SOCK"); + } + + try { + const request = parseEvalPullRequest(config.devRequestJson); + const entry = getEvaluators().find( + (candidate) => candidate.evaluator.evalName === request.name, + ); + if (!entry) { + channel.send({ + type: "error", + message: `Evaluator '${request.name}' not found`, + }); + return; + } + + const state = runner.getState ? runner.getState() : undefined; + const evaluator = { + ...entry.evaluator, + ...(state !== undefined && state !== null ? { state } : {}), + }; + const { data: rawData } = callEvaluatorData(evaluator.data); + const dataIterable = toAsyncIterable(rawData); + const iterator = dataIterable[Symbol.asyncIterator](); + const trialCountRaw = Number(evaluator.trialCount ?? 1); + const trialCount = + Number.isFinite(trialCountRaw) && trialCountRaw > 0 + ? Math.floor(trialCountRaw) + : 1; + const maxConcurrencyRaw = Number(evaluator.maxConcurrency ?? 10); + const maxConcurrency = + Number.isFinite(maxConcurrencyRaw) && maxConcurrencyRaw > 0 + ? Math.floor(maxConcurrencyRaw) + : 10; + const experimentName = + typeof evaluator.experimentName === "string" && + evaluator.experimentName.length > 0 + ? evaluator.experimentName + : `${entry.evaluator.evalName}-${Date.now()}`; + + channel.send({ + type: "ready", + evaluator_name: entry.evaluator.evalName, + max_concurrency: maxConcurrency, + experiment_name: experimentName, + }); + + let currentDatum: unknown | undefined = undefined; + let trialIndex = 0; + for await (const line of channel.lines()) { + const parsed = JSON.parse(line) as { type?: string }; + if (parsed.type === "close") { + break; + } + if (parsed.type !== "next") { + channel.send({ + type: "error", + message: `Unsupported pull command '${String(parsed.type)}'`, + }); + break; + } + + if (currentDatum === undefined) { + const next = await iterator.next(); + if (next.done) { + channel.send({ type: "eof" }); + continue; + } + currentDatum = next.value; + trialIndex = 0; + } + + channel.send({ + type: "row", + datum: currentDatum, + trial_index: trialIndex, + }); + + trialIndex += 1; + if (trialIndex >= trialCount) { + currentDatum = undefined; + } + } + } catch (err) { + channel.send({ + type: "error", + message: err instanceof Error ? err.message : String(err), + }); + } finally { + channel.close(); + } +} + function extractBtEvalMain(mod: unknown): BtEvalMain | null { if (!mod || typeof mod !== "object") { return null; @@ -2061,6 +2294,11 @@ async function main() { return; } + if (config.devMode === "rows") { + await runDatasetPull(config, runner); + return; + } + if (config.list) { for (const entry of filteredEvaluators) { console.log(entry.evaluator.evalName); diff --git a/src/eval.rs b/src/eval.rs index bb67433..ac792e4 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -15,6 +15,7 @@ use actix_web::http::header::{ }; use actix_web::{guard, web, App, HttpRequest, HttpResponse, HttpServer}; use anyhow::{Context, Result}; +use chrono::{SecondsFormat, Utc}; use clap::{Args, ValueEnum}; use crossterm::queue; use crossterm::style::{ @@ -22,11 +23,13 @@ use crossterm::style::{ Stylize, }; use futures_util::stream; +use futures_util::StreamExt; use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use strip_ansi_escapes::strip; +use tokio::io::AsyncWriteExt; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::net::UnixListener; use tokio::process::Command; @@ -41,7 +44,12 @@ use ratatui::widgets::{Cell, Row, Table}; use ratatui::Terminal; use crate::args::BaseArgs; +use crate::auth::login; use crate::auth::resolved_auth_env; +use crate::experiments::api::create_experiment; +use crate::functions::publish_eval_sandbox_functions; +use crate::http::ApiClient; +use crate::source_language::SourceLanguage; use crate::ui::{animations_enabled, is_quiet}; const MAX_NAME_LENGTH: usize = 40; @@ -161,6 +169,61 @@ struct ResolvedDatasetEvalData { _internal_btql: Option, } +#[derive(Debug, Serialize, Deserialize)] +struct EvalPullRequest { + name: String, + #[serde(default)] + parameters: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum EvalPullClientMessage { + Next, + Close, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum EvalPullResponse { + Ready { + evaluator_name: String, + max_concurrency: usize, + experiment_name: String, + }, + Row { + datum: Value, + trial_index: usize, + }, + Eof, + Error { + message: String, + }, +} + +#[derive(Debug)] +struct EvalDataPuller { + child: tokio::process::Child, + writer: tokio::net::unix::OwnedWriteHalf, + reader: BufReader, + _socket_cleanup_guard: SocketCleanupGuard, +} + +#[derive(Debug, Clone)] +struct EvalSandboxPlan { + evaluator_name: String, + function_id: String, + project_id: String, +} + +#[derive(Debug, Clone, Deserialize)] +struct SandboxSummaryRow { + #[serde(default)] + scores: HashMap>, + #[serde(default)] + metrics: HashMap, +} + #[derive(Clone)] struct DevServerState { base: BaseArgs, @@ -194,6 +257,7 @@ const PY_RUNNER_FILE: &str = "eval-runner.py"; const JS_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.ts"); const PY_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.py"); +#[derive(Debug)] struct SocketCleanupGuard { path: PathBuf, } @@ -218,6 +282,12 @@ pub enum EvalLanguage { Python, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +pub enum EvalSandbox { + Local, + Lambda, +} + #[derive(Debug, Clone, Args)] #[command(after_help = "\ Examples: @@ -245,6 +315,10 @@ pub struct EvalArgs { )] pub language: Option, + /// Execute evals locally or in a remote sandbox. + #[arg(long, env = "BT_EVAL_SANDBOX", value_enum, default_value = "local")] + pub sandbox: EvalSandbox, + /// Run evals locally (do not send logs to Braintrust). #[arg( long, @@ -389,6 +463,31 @@ pub async fn run(base: BaseArgs, args: EvalArgs) -> Result<()> { extra_args: args.extra_args, }; + if args.sandbox != EvalSandbox::Local { + if args.dev { + anyhow::bail!("--sandbox is not supported with --dev."); + } + if args.watch { + anyhow::bail!("--sandbox is not supported with --watch."); + } + if args.list { + anyhow::bail!("--sandbox is not supported with --list."); + } + if files.len() != 1 { + anyhow::bail!("`bt eval --sandbox lambda` currently supports exactly one eval file."); + } + return run_eval_files_sandbox( + &base, + args.sandbox, + args.language, + args.runner.as_deref(), + &files, + args.no_send_logs, + &options, + ) + .await; + } + if args.dev { let language = detect_eval_language(&files, args.language)?; let app_url = resolve_app_url(&base); @@ -500,6 +599,667 @@ async fn run_eval_files_watch( } } +async fn run_eval_files_sandbox( + base: &BaseArgs, + sandbox: EvalSandbox, + language_override: Option, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, +) -> Result<()> { + if sandbox != EvalSandbox::Lambda { + anyhow::bail!("unsupported sandbox mode"); + } + if no_send_logs { + anyhow::bail!("--sandbox lambda is not supported with --no-send-logs."); + } + let language = detect_eval_language(files, language_override)?; + let source_language = match language { + EvalLanguage::JavaScript => SourceLanguage::JsLike, + EvalLanguage::Python => SourceLanguage::Python, + }; + + let source_file = PathBuf::from( + files + .first() + .ok_or_else(|| anyhow::anyhow!("missing sandbox source file"))?, + ); + let published = + publish_eval_sandbox_functions(base, &source_file, runner_override, source_language) + .await?; + let evaluator_names = list_sandbox_evaluator_names( + base, + language, + runner_override, + files, + no_send_logs, + options, + ) + .await?; + if evaluator_names.is_empty() { + anyhow::bail!("No evaluators found. Did you call Eval() in the file?"); + } + + let mut plans = Vec::new(); + for evaluator_name in evaluator_names { + let slug = sandbox_slug_from_source(&source_file, &evaluator_name); + let published_entry = published + .iter() + .find(|entry| entry.slug == slug) + .ok_or_else(|| { + anyhow::anyhow!( + "sandbox function '{}' for evaluator '{}' was not published", + slug, + evaluator_name + ) + })?; + plans.push(EvalSandboxPlan { + evaluator_name, + function_id: published_entry.function_id.clone(), + project_id: published_entry.project_id.clone(), + }); + } + + let login_ctx = login(base).await?; + let client = ApiClient::new(&login_ctx)?; + let started_at = Utc::now().to_rfc3339_opts(SecondsFormat::Millis, true); + + for plan in plans { + let mut puller = spawn_eval_data_puller( + base, + language, + runner_override, + files, + no_send_logs, + options, + &EvalPullRequest { + name: plan.evaluator_name.clone(), + parameters: Some(json!({})), + }, + ) + .await?; + let ready = puller.read_message().await?; + let (max_concurrency, experiment_name) = match ready { + EvalPullResponse::Ready { + evaluator_name, + max_concurrency, + experiment_name, + } => { + if evaluator_name != plan.evaluator_name { + anyhow::bail!( + "sandbox runner selected unexpected evaluator '{}', expected '{}'", + evaluator_name, + plan.evaluator_name + ); + } + (max_concurrency.max(1), experiment_name) + } + EvalPullResponse::Error { message } => anyhow::bail!("{message}"), + other => anyhow::bail!("unexpected initial sandbox pull response: {other:?}"), + }; + + let experiment = create_experiment(&client, &plan.project_id, &experiment_name, true) + .await + .with_context(|| { + format!( + "failed to create sandbox parent experiment '{}' for evaluator '{}'", + experiment_name, plan.evaluator_name + ) + })?; + let mut in_flight: tokio::task::JoinSet>> = + tokio::task::JoinSet::new(); + let mut saw_eof = false; + let mut experiment_url: Option = None; + + while !saw_eof || !in_flight.is_empty() { + while !saw_eof && in_flight.len() < max_concurrency { + puller.send_message(&EvalPullClientMessage::Next).await?; + match puller.read_message().await? { + EvalPullResponse::Row { + datum, + trial_index: _trial_index, + } => { + let function_id = plan.function_id.clone(); + let evaluator_name = plan.evaluator_name.clone(); + let project_id = plan.project_id.clone(); + let body = json!({ + "api_version": 1, + "function_id": { "function_id": function_id }, + "name": evaluator_name, + "project_id": project_id, + "scores": [], + "stream": true, + "experiment_name": experiment.name, + "parent": { + "object_type": "experiment", + "object_id": experiment.id, + }, + "data": { "data": [datum] }, + }); + let client_cloned = client.clone(); + let org_name = login_ctx.login.org_name.clone(); + let project_id = plan.project_id.clone(); + in_flight.spawn(async move { + invoke_sandbox_eval(&client_cloned, &org_name, &project_id, body).await + }); + } + EvalPullResponse::Eof => saw_eof = true, + EvalPullResponse::Error { message } => anyhow::bail!("{message}"), + other => anyhow::bail!("unexpected sandbox pull response: {other:?}"), + } + } + + if let Some(joined) = in_flight.join_next().await { + if let Some(start) = joined?? { + if experiment_url.is_none() { + experiment_url = start.experiment_url.clone(); + } + } + } + } + + puller.send_message(&EvalPullClientMessage::Close).await?; + puller.wait().await?; + + let summary = summarize_sandbox_experiment( + &client, + &plan.project_id, + &plan.project_id, + &experiment.name, + &experiment.id, + experiment_url, + &started_at, + ) + .await?; + let rendered = format_experiment_summary(&summary); + println!("{rendered}"); + } + + Ok(()) +} + +impl EvalDataPuller { + async fn send_message(&mut self, message: &EvalPullClientMessage) -> Result<()> { + let mut payload = + serde_json::to_string(message).context("failed to serialize pull request")?; + payload.push('\n'); + self.writer + .write_all(payload.as_bytes()) + .await + .context("failed to write pull request")?; + self.writer + .flush() + .await + .context("failed to flush pull request")?; + Ok(()) + } + + async fn read_message(&mut self) -> Result { + let mut line = String::new(); + let read = self + .reader + .read_line(&mut line) + .await + .context("failed to read sandbox pull response")?; + if read == 0 { + let status = self + .child + .wait() + .await + .context("sandbox pull runner exited unexpectedly")?; + anyhow::bail!("sandbox pull runner exited with status {status}"); + } + serde_json::from_str(line.trim()).context("failed to parse sandbox pull response JSON") + } + + async fn wait(mut self) -> Result<()> { + let status = self + .child + .wait() + .await + .context("sandbox pull runner failed")?; + if !status.success() { + anyhow::bail!("sandbox pull runner exited with status {status}"); + } + Ok(()) + } +} + +fn sandbox_slugify(input: &str) -> String { + let mut out = String::with_capacity(input.len()); + let mut previous_dash = false; + for ch in input.chars() { + let lower = ch.to_ascii_lowercase(); + if lower.is_ascii_alphanumeric() { + out.push(lower); + previous_dash = false; + } else if !previous_dash { + out.push('-'); + previous_dash = true; + } + } + out.trim_matches('-').to_string() +} + +fn sandbox_slug_from_source(source_file: &Path, eval_name: &str) -> String { + let stem = source_file + .file_stem() + .and_then(|value| value.to_str()) + .map(|value| value.strip_suffix(".eval").unwrap_or(value)) + .unwrap_or("eval"); + sandbox_slugify(&format!("{stem}-{eval_name}-sandbox")) +} + +async fn list_sandbox_evaluator_names( + base: &BaseArgs, + language: EvalLanguage, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, +) -> Result> { + let output = run_eval_runner_command_to_completion( + base, + language, + runner_override, + files, + no_send_logs, + options, + &[("BT_EVAL_DEV_MODE".to_string(), "list".to_string())], + JsMode::Auto, + ) + .await?; + + let parsed: Value = + serde_json::from_slice(&output.stdout).context("failed to parse sandbox evaluator list")?; + let object = parsed + .as_object() + .ok_or_else(|| anyhow::anyhow!("sandbox evaluator list was not a JSON object"))?; + Ok(object.keys().cloned().collect()) +} + +async fn spawn_eval_data_puller( + base: &BaseArgs, + language: EvalLanguage, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, + request: &EvalPullRequest, +) -> Result { + let (listener, socket_path, socket_cleanup_guard) = + bind_unix_listener("bt-eval-pull").context("failed to bind sandbox pull socket")?; + let request_json = + serde_json::to_string(request).context("failed to serialize sandbox pull request")?; + let extra_env = vec![ + ("BT_EVAL_DEV_MODE".to_string(), "rows".to_string()), + ("BT_EVAL_DEV_REQUEST_JSON".to_string(), request_json), + ( + "BT_EVAL_PULL_SOCK".to_string(), + socket_path.to_string_lossy().to_string(), + ), + ]; + let child = spawn_eval_support_process( + base, + language, + runner_override, + files, + no_send_logs, + options, + &extra_env, + JsMode::Auto, + ) + .await?; + + let (stream, _) = tokio::time::timeout(Duration::from_secs(30), listener.accept()) + .await + .context("timed out waiting for sandbox pull runner to connect")? + .context("sandbox pull runner failed to connect")?; + let (read_half, write_half) = stream.into_split(); + Ok(EvalDataPuller { + child, + writer: write_half, + reader: BufReader::new(read_half), + _socket_cleanup_guard: socket_cleanup_guard, + }) +} + +async fn spawn_eval_support_process( + base: &BaseArgs, + language: EvalLanguage, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, + extra_env: &[(String, String)], + js_mode: JsMode, +) -> Result { + let (js_runner, py_runner) = prepare_eval_runners()?; + let force_esm = matches!(js_mode, JsMode::ForceEsm); + let (mut cmd, runner_kind) = match language { + EvalLanguage::Python => ( + build_python_command(runner_override, &py_runner, files)?, + RunnerKind::Other, + ), + EvalLanguage::JavaScript => { + if force_esm { + ( + build_vite_node_fallback_command(&js_runner, files)?, + RunnerKind::ViteNode, + ) + } else { + let plan = build_js_plan(runner_override, &js_runner, files)?; + (plan.cmd, plan.kind) + } + } + }; + if language == EvalLanguage::JavaScript && should_set_node_heap_size(runner_kind) { + set_node_heap_size_env(&mut cmd); + } + cmd.envs(build_env(base).await?); + for (key, value) in extra_env { + cmd.env(key, value); + } + if no_send_logs { + cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); + cmd.env("BT_EVAL_LOCAL", "1"); + } + if options.jsonl { + cmd.env("BT_EVAL_JSONL", "1"); + } + if options.terminate_on_failure { + cmd.env("BT_EVAL_TERMINATE_ON_FAILURE", "1"); + } + if options.list { + cmd.env("BT_EVAL_LIST", "1"); + } + if let Some(num_workers) = options.num_workers { + cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } + if language == EvalLanguage::JavaScript && force_esm { + cmd.env("BT_EVAL_FORCE_ESM", "1"); + } + if !options.extra_args.is_empty() { + let serialized = + serde_json::to_string(&options.extra_args).context("failed to serialize extra args")?; + cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + } + cmd.stdout(Stdio::inherit()); + cmd.stderr(Stdio::inherit()); + cmd.spawn().context("failed to start eval support runner") +} + +async fn run_eval_runner_command_to_completion( + base: &BaseArgs, + language: EvalLanguage, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, + extra_env: &[(String, String)], + js_mode: JsMode, +) -> Result { + let (js_runner, py_runner) = prepare_eval_runners()?; + let force_esm = matches!(js_mode, JsMode::ForceEsm); + let (mut cmd, runner_kind) = match language { + EvalLanguage::Python => ( + build_python_command(runner_override, &py_runner, files)?, + RunnerKind::Other, + ), + EvalLanguage::JavaScript => { + if force_esm { + ( + build_vite_node_fallback_command(&js_runner, files)?, + RunnerKind::ViteNode, + ) + } else { + let plan = build_js_plan(runner_override, &js_runner, files)?; + (plan.cmd, plan.kind) + } + } + }; + if language == EvalLanguage::JavaScript && should_set_node_heap_size(runner_kind) { + set_node_heap_size_env(&mut cmd); + } + cmd.envs(build_env(base).await?); + for (key, value) in extra_env { + cmd.env(key, value); + } + if no_send_logs { + cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); + cmd.env("BT_EVAL_LOCAL", "1"); + } + if let Some(num_workers) = options.num_workers { + cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } + if !options.extra_args.is_empty() { + let serialized = + serde_json::to_string(&options.extra_args).context("failed to serialize extra args")?; + cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + } + let output = cmd + .output() + .await + .context("failed to run eval support runner")?; + if !output.status.success() { + anyhow::bail!( + "eval support runner exited with status {}: {}", + output.status, + String::from_utf8_lossy(&output.stderr).trim() + ); + } + Ok(output) +} + +async fn invoke_sandbox_eval( + client: &ApiClient, + org_name: &str, + project_id: &str, + body: Value, +) -> Result> { + let response = client + .post_with_headers_raw( + "/function/sandbox", + &body, + &[("x-bt-org-name", org_name), ("x-bt-project-id", project_id)], + ) + .await?; + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("sandbox invoke failed ({status}): {body}"); + } + + let mut bytes = response.bytes_stream(); + let mut buffer = String::new(); + let mut current_event: Option = None; + let mut data_lines: Vec = Vec::new(); + let mut start: Option = None; + let mut saw_done = false; + + while let Some(chunk) = bytes.next().await { + let chunk = chunk.context("failed to read sandbox SSE response")?; + buffer.push_str(&String::from_utf8_lossy(&chunk)); + while let Some(pos) = buffer.find('\n') { + let mut line: String = buffer.drain(..=pos).collect(); + if line.ends_with('\n') { + line.pop(); + } + if line.ends_with('\r') { + line.pop(); + } + if line.is_empty() { + if current_event.is_some() || !data_lines.is_empty() { + let event_name = current_event.take().unwrap_or_default(); + let data = data_lines.join("\n"); + data_lines.clear(); + match event_name.as_str() { + "start" => { + if let Ok(parsed) = serde_json::from_str::(&data) { + if start.is_none() { + start = Some(parsed); + } + } + } + "error" => { + if let Ok(payload) = serde_json::from_str::(&data) { + let message = payload + .get("message") + .or_else(|| payload.get("error")) + .and_then(Value::as_str) + .unwrap_or("sandbox eval failed"); + anyhow::bail!("{message}"); + } + anyhow::bail!("{data}"); + } + "done" => { + saw_done = true; + } + _ => {} + } + } + continue; + } + if let Some(value) = line.strip_prefix("event:") { + current_event = Some(value.trim().to_string()); + } else if let Some(value) = line.strip_prefix("data:") { + data_lines.push(value.trim_start().to_string()); + } + } + } + + if !saw_done { + anyhow::bail!("sandbox SSE stream ended before a done event"); + } + Ok(start) +} + +async fn summarize_sandbox_experiment( + client: &ApiClient, + project_name: &str, + _project_id: &str, + experiment_name: &str, + experiment_id: &str, + experiment_url: Option, + started_at: &str, +) -> Result { + let query = build_sandbox_summary_query(experiment_id, started_at); + let response = client.btql::(&query).await?; + Ok(aggregate_sandbox_summary( + project_name, + experiment_name, + experiment_id, + experiment_url, + &response.data, + )) +} + +fn build_sandbox_summary_query(experiment_id: &str, started_at: &str) -> String { + format!( + "select: scores, metrics | from: experiment('{}') summary | filter: created >= '{}' | limit: 1000", + experiment_id.replace('\'', "''"), + started_at.replace('\'', "''") + ) +} + +fn aggregate_sandbox_summary( + project_name: &str, + experiment_name: &str, + experiment_id: &str, + experiment_url: Option, + rows: &[SandboxSummaryRow], +) -> ExperimentSummary { + let mut scores: HashMap = HashMap::new(); + let mut metrics: HashMap = HashMap::new(); + for row in rows { + for (name, value) in &row.scores { + if let Some(value) = value { + let entry = scores.entry(name.clone()).or_insert((0.0, 0)); + entry.0 += value; + entry.1 += 1; + } + } + for (name, value) in &row.metrics { + let Some(number) = value.as_f64() else { + continue; + }; + let entry = metrics.entry(name.clone()).or_insert((0.0, 0)); + entry.0 += number; + entry.1 += 1; + } + } + + ExperimentSummary { + project_name: project_name.to_string(), + experiment_name: experiment_name.to_string(), + project_id: None, + experiment_id: Some(experiment_id.to_string()), + project_url: None, + experiment_url, + comparison_experiment_name: None, + scores: scores + .into_iter() + .map(|(name, (total, count))| { + let average = if count == 0 { + 0.0 + } else { + total / count as f64 + }; + ( + name.clone(), + ScoreSummary { + name, + score: average, + diff: None, + improvements: 0, + regressions: 0, + }, + ) + }) + .collect(), + metrics: if metrics.is_empty() { + None + } else { + Some( + metrics + .into_iter() + .map(|(name, (total, count))| { + let average = if count == 0 { + 0.0 + } else { + total / count as f64 + }; + ( + name.clone(), + MetricSummary { + name, + metric: average, + unit: String::new(), + diff: None, + improvements: 0, + regressions: 0, + }, + ) + }) + .collect(), + ) + }, + } +} + struct EvalPlan<'a> { language: EvalLanguage, files: &'a [String], @@ -2386,20 +3146,20 @@ fn find_binary_in_path(candidates: &[&str]) -> Option { None } -fn build_sse_socket_path() -> Result { +fn build_socket_path(prefix: &str) -> Result { let pid = std::process::id(); let serial = SSE_SOCKET_COUNTER.fetch_add(1, Ordering::Relaxed); let now = SystemTime::now() .duration_since(UNIX_EPOCH) .context("failed to read system time")? .as_nanos(); - Ok(std::env::temp_dir().join(format!("bt-eval-{pid}-{now}-{serial}.sock"))) + Ok(std::env::temp_dir().join(format!("{prefix}-{pid}-{now}-{serial}.sock"))) } -fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { +fn bind_unix_listener(prefix: &str) -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { let mut last_bind_err: Option = None; for _ in 0..SSE_SOCKET_BIND_MAX_ATTEMPTS { - let socket_path = build_sse_socket_path()?; + let socket_path = build_socket_path(prefix)?; let socket_cleanup_guard = SocketCleanupGuard::new(socket_path.clone()); let _ = std::fs::remove_file(&socket_path); match UnixListener::bind(&socket_path) { @@ -2425,10 +3185,14 @@ fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { ) }); Err(err).context(format!( - "failed to bind SSE unix socket after {SSE_SOCKET_BIND_MAX_ATTEMPTS} attempts" + "failed to bind unix socket after {SSE_SOCKET_BIND_MAX_ATTEMPTS} attempts" )) } +fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { + bind_unix_listener("bt-eval") +} + fn eval_runner_cache_dir() -> PathBuf { let root = std::env::var_os("XDG_CACHE_HOME") .map(PathBuf::from) @@ -4016,8 +4780,8 @@ mod tests { #[test] fn build_sse_socket_path_is_unique_for_consecutive_calls() { - let first = build_sse_socket_path().expect("first socket path"); - let second = build_sse_socket_path().expect("second socket path"); + let first = build_socket_path("bt-eval").expect("first socket path"); + let second = build_socket_path("bt-eval").expect("second socket path"); assert_ne!(first, second); } @@ -4193,6 +4957,7 @@ mod tests { "BT_EVAL_TERMINATE_ON_FAILURE", "BT_EVAL_NUM_WORKERS", "BT_EVAL_LIST", + "BT_EVAL_SANDBOX", "BT_EVAL_FILTER", "BT_EVAL_VERBOSE", "BT_EVAL_WATCH", @@ -4207,6 +4972,7 @@ mod tests { set_env_var("BT_EVAL_TERMINATE_ON_FAILURE", "1"); set_env_var("BT_EVAL_NUM_WORKERS", "4"); set_env_var("BT_EVAL_LIST", "yes"); + set_env_var("BT_EVAL_SANDBOX", "lambda"); set_env_var("BT_EVAL_FILTER", "metadata.case=smoke.*,metadata.kind=fast"); set_env_var("BT_EVAL_VERBOSE", "1"); set_env_var("BT_EVAL_WATCH", "on"); @@ -4221,6 +4987,7 @@ mod tests { assert!(parsed.eval.terminate_on_failure); assert_eq!(parsed.eval.num_workers, Some(4)); assert!(parsed.eval.list); + assert_eq!(parsed.eval.sandbox, EvalSandbox::Lambda); assert_eq!( parsed.eval.filter, vec![ @@ -4239,4 +5006,18 @@ mod tests { restore_env_var(key, value); } } + + #[test] + fn build_sandbox_summary_query_includes_timestamp_filter() { + let query = build_sandbox_summary_query("exp'123", "2026-03-19T12:00:00.000Z"); + assert!(query.contains("from: experiment('exp''123') summary")); + assert!(query.contains("filter: created >= '2026-03-19T12:00:00.000Z'")); + assert!(query.contains("select: scores, metrics")); + } + + #[test] + fn sandbox_slug_from_source_uses_source_stem_and_eval_name() { + let slug = sandbox_slug_from_source(Path::new("/tmp/My Eval.ts"), "Demo Eval"); + assert_eq!(slug, "my-eval-demo-eval-sandbox"); + } } diff --git a/src/experiments/api.rs b/src/experiments/api.rs index 3507a7e..bce8a00 100644 --- a/src/experiments/api.rs +++ b/src/experiments/api.rs @@ -70,11 +70,13 @@ pub async fn create_experiment( client: &ApiClient, project_id: &str, name: &str, + ensure_new: bool, ) -> Result { let body = serde_json::json!({ "name": name, "project_id": project_id, "org_name": client.org_name(), + "ensure_new": ensure_new, }); client.post("/v1/experiment", &body).await } diff --git a/src/functions/mod.rs b/src/functions/mod.rs index eadc100..934d0f7 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -17,11 +17,12 @@ mod delete; mod invoke; mod list; mod pull; -mod push; +pub(crate) mod push; pub(crate) mod report; mod view; use api::Function; +pub(crate) use push::publish_eval_sandbox_functions; #[derive(Debug, Clone, Copy, ValueEnum)] pub enum FunctionTypeFilter { diff --git a/src/functions/push.rs b/src/functions/push.rs index 0ca6109..89ecd87 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -640,6 +640,145 @@ struct FileSuccess { bundle_id: Option, } +#[derive(Debug, Clone)] +pub(crate) struct PublishedSandboxFunction { + pub slug: String, + pub project_id: String, + pub function_id: String, +} + +pub(crate) async fn publish_eval_sandbox_functions( + base: &BaseArgs, + source_file: &Path, + runner_override: Option<&str>, + language: SourceLanguage, +) -> Result> { + let available_orgs = list_available_orgs(base) + .await + .context("failed to list available orgs")?; + validate_explicit_org_selection(base, &available_orgs)?; + let auth_ctx = resolve_auth_context(base) + .await + .context("failed to resolve auth context")?; + + let args = PushArgs { + files: vec![source_file.to_path_buf()], + file_flag: Vec::new(), + if_exists: super::IfExistsMode::Replace, + terminate_on_failure: true, + create_missing_projects: true, + runner: runner_override.map(ToOwned::to_owned), + language: match language { + SourceLanguage::JsLike => PushLanguage::JavaScript, + SourceLanguage::Python => PushLanguage::Python, + }, + requirements: None, + tsconfig: None, + external_packages: Vec::new(), + yes: true, + }; + + let input_files = args.resolved_files(); + let classified = collect_classified_files(&input_files)?; + let files = classified.files_for_language(language); + if files.is_empty() { + bail!( + "no eligible {} files found for sandbox publish: {}", + language_label(language), + source_file.display() + ); + } + + let mut manifest = + run_functions_runner(&args, &files, language, auth_ctx.client.api_key()) + .map_err(|failure| anyhow!(failure.message))?; + + for file in &mut manifest.files { + file.entries.retain(|entry| match entry { + ManifestEntry::Code(code) => code.function_type.as_deref() == Some("sandbox"), + ManifestEntry::FunctionEvent(_) => false, + }); + } + manifest + .files + .retain(|file| !file.entries.is_empty() || file.python_bundle.is_some()); + + if manifest.files.is_empty() { + bail!("no sandbox evaluators found in {}", source_file.display()); + } + + validate_manifest_paths( + &manifest, + &files, + language, + &classified.allowed_roots, + ) + .map_err(|failure| anyhow!(failure.message))?; + + let preflight = collect_project_preflight(base, &manifest)?; + let mut project_name_cache = + resolve_named_projects(&auth_ctx, &preflight.named_projects, true).await?; + validate_direct_project_ids(&auth_ctx, &preflight.direct_project_ids).await?; + let default_project_id = resolve_default_project_id(&preflight, &project_name_cache)?; + let resolved_targets = resolve_manifest_targets( + &auth_ctx, + default_project_id.as_deref(), + &manifest, + &mut project_name_cache, + true, + ) + .await?; + validate_duplicate_slugs(&resolved_targets.entries)?; + + let mut published = Vec::new(); + for (manifest_file, resolved_file) in + manifest.files.iter().zip(resolved_targets.per_file.iter()) + { + let source_path = PathBuf::from(&manifest_file.source_file); + push_file( + &auth_ctx, + default_project_id.as_deref(), + &manifest.runtime_context, + &source_path, + manifest_file, + &resolved_file.entry_project_ids, + &args, + language, + None, + &classified.allowed_roots, + &mut project_name_cache, + ) + .await + .map_err(|failure| anyhow!(failure.message))?; + + for (entry_index, entry) in manifest_file.entries.iter().enumerate() { + let ManifestEntry::Code(code) = entry else { + continue; + }; + let project_id = resolved_file + .entry_project_ids + .get(entry_index) + .cloned() + .ok_or_else(|| anyhow!("missing resolved project id for sandbox entry"))?; + let function = api::get_function_by_slug(&auth_ctx.client, &project_id, &code.slug) + .await? + .ok_or_else(|| { + anyhow!( + "sandbox function '{}' was not found after publish", + code.slug + ) + })?; + published.push(PublishedSandboxFunction { + slug: code.slug.clone(), + project_id, + function_id: function.id, + }); + } + } + + Ok(published) +} + fn default_code_location(index: usize) -> Value { json!({ "type": "function", @@ -3294,12 +3433,8 @@ mod tests { "scores": [{ "name": "accuracy" }] } }); - let value = build_code_function_data( - &runtime, - sandbox_location.clone(), - "bundle-sandbox-1", - None, - ); + let value = + build_code_function_data(&runtime, sandbox_location.clone(), "bundle-sandbox-1", None); assert_eq!(value["type"], "code"); assert_eq!(value["data"]["type"], "bundle"); @@ -3319,12 +3454,8 @@ mod tests { "eval_name": "my-eval", "position": { "type": "task" } }); - let value = build_code_function_data( - &runtime, - experiment_location.clone(), - "bundle-task-1", - None, - ); + let value = + build_code_function_data(&runtime, experiment_location.clone(), "bundle-task-1", None); assert_eq!(value["type"], "code"); assert_eq!(value["data"]["location"], experiment_location); diff --git a/src/sync.rs b/src/sync.rs index e62a856..76e60ce 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -3116,7 +3116,7 @@ async fn resolve_push_experiment_target( ); } - let created = create_experiment(client, &project.id, experiment_selector) + let created = create_experiment(client, &project.id, experiment_selector, false) .await .with_context(|| { format!("experiment '{experiment_selector}' not found, and creating it failed") diff --git a/tests/functions.rs b/tests/functions.rs index fa3de7c..419a351 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -2067,7 +2067,9 @@ exit 24 "expected 1 inserted function (sandbox only)" ); - let sandbox_obj = inserted[0].as_object().expect("sandbox should be an object"); + let sandbox_obj = inserted[0] + .as_object() + .expect("sandbox should be an object"); assert_eq!( sandbox_obj.get("slug").and_then(Value::as_str), Some("my-eval-my-eval-sandbox") @@ -2118,7 +2120,6 @@ exit 24 .and_then(Value::as_str), Some("my-eval") ); - } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] From 1df96cc0e0e68a4152b9e18d10a17a6310c022c4 Mon Sep 17 00:00:00 2001 From: Nate Selvidge Date: Thu, 19 Mar 2026 21:00:48 +0000 Subject: [PATCH 2/3] Add tests --- src/eval.rs | 145 ++++++++++++++++++ src/functions/push.rs | 14 +- tests/eval_fixtures.rs | 329 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 478 insertions(+), 10 deletions(-) diff --git a/src/eval.rs b/src/eval.rs index ac792e4..04037f9 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -4160,6 +4160,46 @@ mod tests { eval: EvalArgs, } + fn base_args() -> BaseArgs { + BaseArgs { + json: false, + quiet: false, + no_color: false, + profile: None, + org_name: None, + project: None, + api_key: None, + prefer_profile: false, + no_input: false, + api_url: None, + app_url: None, + env_file: None, + } + } + + fn make_eval_args(files: Vec) -> EvalArgs { + EvalArgs { + files, + runner: None, + language: None, + sandbox: EvalSandbox::Local, + no_send_logs: false, + jsonl: false, + terminate_on_failure: false, + num_workers: None, + list: false, + filter: Vec::new(), + verbose: false, + watch: false, + extra_args: Vec::new(), + dev: false, + dev_host: "localhost".to_string(), + dev_port: 8300, + dev_org_name: None, + dev_allowed_origin: Vec::new(), + } + } + fn env_test_lock() -> &'static Mutex<()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) @@ -4215,6 +4255,12 @@ mod tests { path } + fn write_eval_file(dir: &Path, name: &str) -> String { + let path = dir.join(name); + fs::write(&path, "export {};").expect("eval file should be written"); + path.to_string_lossy().to_string() + } + #[test] fn join_app_url_normalizes_slashes() { let joined = @@ -5007,6 +5053,105 @@ mod tests { } } + #[test] + fn eval_args_parse_sandbox_flag() { + let parsed = + EvalArgsHarness::try_parse_from(["bt", "--sandbox", "lambda", "sample.eval.ts"]) + .expect("sandbox flag should parse"); + assert_eq!(parsed.eval.sandbox, EvalSandbox::Lambda); + assert_eq!(parsed.eval.files, vec!["sample.eval.ts".to_string()]); + } + + #[tokio::test] + async fn sandbox_eval_rejects_dev_mode() { + let dir = make_temp_dir("sandbox-dev"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.dev = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+dev should fail"); + assert!(err + .to_string() + .contains("--sandbox is not supported with --dev.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_watch_mode() { + let dir = make_temp_dir("sandbox-watch"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.watch = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+watch should fail"); + assert!(err + .to_string() + .contains("--sandbox is not supported with --watch.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_list_mode() { + let dir = make_temp_dir("sandbox-list"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.list = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+list should fail"); + assert!(err + .to_string() + .contains("--sandbox is not supported with --list.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_no_send_logs() { + let dir = make_temp_dir("sandbox-local"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.no_send_logs = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+no-send-logs should fail"); + assert!(err + .to_string() + .contains("--sandbox lambda is not supported with --no-send-logs.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_multiple_files() { + let dir = make_temp_dir("sandbox-multi"); + let first = write_eval_file(&dir, "first.eval.ts"); + let second = write_eval_file(&dir, "second.eval.ts"); + let mut args = make_eval_args(vec![first, second]); + args.sandbox = EvalSandbox::Lambda; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+multiple files should fail"); + assert!(err + .to_string() + .contains("`bt eval --sandbox lambda` currently supports exactly one eval file.")); + + let _ = fs::remove_dir_all(&dir); + } + #[test] fn build_sandbox_summary_query_includes_timestamp_filter() { let query = build_sandbox_summary_query("exp'123", "2026-03-19T12:00:00.000Z"); diff --git a/src/functions/push.rs b/src/functions/push.rs index 89ecd87..e3979f7 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -689,9 +689,8 @@ pub(crate) async fn publish_eval_sandbox_functions( ); } - let mut manifest = - run_functions_runner(&args, &files, language, auth_ctx.client.api_key()) - .map_err(|failure| anyhow!(failure.message))?; + let mut manifest = run_functions_runner(&args, &files, language, auth_ctx.client.api_key()) + .map_err(|failure| anyhow!(failure.message))?; for file in &mut manifest.files { file.entries.retain(|entry| match entry { @@ -707,13 +706,8 @@ pub(crate) async fn publish_eval_sandbox_functions( bail!("no sandbox evaluators found in {}", source_file.display()); } - validate_manifest_paths( - &manifest, - &files, - language, - &classified.allowed_roots, - ) - .map_err(|failure| anyhow!(failure.message))?; + validate_manifest_paths(&manifest, &files, language, &classified.allowed_roots) + .map_err(|failure| anyhow!(failure.message))?; let preflight = collect_project_preflight(base, &manifest)?; let mut project_name_cache = diff --git a/tests/eval_fixtures.rs b/tests/eval_fixtures.rs index 84c63a6..1e1d5ba 100644 --- a/tests/eval_fixtures.rs +++ b/tests/eval_fixtures.rs @@ -1,6 +1,10 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fs; +#[cfg(unix)] +use std::io::Write; use std::io::{BufRead, BufReader, Read}; +#[cfg(unix)] +use std::os::unix::net::{UnixListener, UnixStream}; use std::path::{Path, PathBuf}; use std::process::{Child, Command, Stdio}; use std::sync::{Arc, Mutex, MutexGuard, OnceLock}; @@ -8,6 +12,8 @@ use std::thread; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use serde::Deserialize; +#[cfg(unix)] +use serde_json::json; use serde_json::Value; #[derive(Debug, Deserialize, Clone)] @@ -435,11 +441,334 @@ fn eval_runner_list_mode_serializes_parameter_defaults() { ); } +#[cfg(unix)] +#[test] +fn eval_runner_rows_mode_streams_js_rows_and_trials() { + let _guard = test_lock(); + if !command_exists("node") { + if required_runtimes().contains("node") { + panic!("node runtime is required but unavailable for rows-mode test"); + } + eprintln!( + "Skipping eval_runner_rows_mode_streams_js_rows_and_trials (node not installed)." + ); + return; + } + + let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let fixture_dir = root + .join("tests") + .join("evals") + .join("js") + .join("eval-ts-cjs"); + ensure_dependencies(&fixture_dir); + let runner = local_tsx_path(&fixture_dir).expect("resolve tsx runner"); + let runner_script = root.join("scripts").join("eval-runner.ts"); + let fixture_name = format!("sandbox_rows_{}.eval.ts", unique_test_suffix()); + let fixture_path = fixture_dir.join(&fixture_name); + let fixture_source = r#"import { Eval } from "braintrust"; + +async function* rows() { + yield { input: { case_id: "row-1" }, expected: "alpha" }; + yield { input: { case_id: "row-2" }, expected: "bravo" }; +} + +Eval("sandbox-rows-js", { + data: rows, + task: async (input: { case_id: string }) => + input.case_id === "row-1" ? "alpha" : "bravo", + scores: [ + ({ output, expected }: { output: string; expected?: string }) => ({ + name: "match", + score: output === expected ? 1 : 0, + }), + ], + maxConcurrency: 3, + trialCount: 2, +}); +"#; + fs::write(&fixture_path, fixture_source).expect("write js rows fixture"); + + let socket_path = unique_socket_path("bt-eval-js-rows"); + let listener = UnixListener::bind(&socket_path).expect("bind unix listener"); + listener + .set_nonblocking(true) + .expect("set listener nonblocking"); + + let mut child = Command::new(&runner) + .arg(&runner_script) + .arg(&fixture_name) + .current_dir(&fixture_dir) + .env("BT_EVAL_DEV_MODE", "rows") + .env( + "BT_EVAL_DEV_REQUEST_JSON", + json!({ + "name": "sandbox-rows-js", + "parameters": {}, + }) + .to_string(), + ) + .env("BT_EVAL_PULL_SOCK", &socket_path) + .env("BT_EVAL_LOCAL", "1") + .env("BT_EVAL_NO_SEND_LOGS", "1") + .env("BRAINTRUST_API_KEY", "local") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("spawn js eval runner"); + + let stream = accept_pull_stream(&listener, &mut child, Duration::from_secs(10)); + let mut reader = BufReader::new(stream.try_clone().expect("clone pull stream")); + let mut writer = stream; + + let ready = read_pull_message(&mut reader); + assert_eq!(ready["type"], "ready"); + assert_eq!(ready["evaluator_name"], "sandbox-rows-js"); + assert_eq!(ready["max_concurrency"], 3); + assert!(ready["experiment_name"] + .as_str() + .is_some_and(|value| value.starts_with("sandbox-rows-js-"))); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let first = read_pull_message(&mut reader); + assert_eq!(first["type"], "row"); + assert_eq!(first["trial_index"], 0); + assert_eq!(first["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let second = read_pull_message(&mut reader); + assert_eq!(second["type"], "row"); + assert_eq!(second["trial_index"], 1); + assert_eq!(second["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let third = read_pull_message(&mut reader); + assert_eq!(third["type"], "row"); + assert_eq!(third["trial_index"], 0); + assert_eq!(third["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let fourth = read_pull_message(&mut reader); + assert_eq!(fourth["type"], "row"); + assert_eq!(fourth["trial_index"], 1); + assert_eq!(fourth["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let eof = read_pull_message(&mut reader); + assert_eq!(eof["type"], "eof"); + + write_pull_message(&mut writer, &json!({ "type": "close" })); + drop(writer); + drop(reader); + + let output = child.wait_with_output().expect("wait for js rows runner"); + if !output.status.success() { + panic!( + "js rows runner failed with status {}\nstdout:\n{}\nstderr:\n{}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + + let _ = fs::remove_file(&socket_path); + let _ = fs::remove_file(&fixture_path); +} + +#[cfg(unix)] +#[test] +fn eval_runner_rows_mode_streams_python_rows_and_trials() { + let _guard = test_lock(); + let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let fixtures_root = root.join("tests").join("evals"); + let fixture_dir = fixtures_root.join("py").join("local_import"); + let python = match ensure_python_env(&fixtures_root.join("py")) { + Some(python) => python, + None => { + if required_runtimes().contains("python") { + panic!("python runtime unavailable for rows-mode test"); + } + eprintln!( + "Skipping eval_runner_rows_mode_streams_python_rows_and_trials (python runtime unavailable)." + ); + return; + } + }; + + let runner_script = root.join("scripts").join("eval-runner.py"); + let fixture_name = format!("sandbox_rows_{}.py", unique_test_suffix()); + let fixture_path = fixture_dir.join(&fixture_name); + let fixture_source = r#"from braintrust import Eval + +def rows(): + yield {"input": {"case_id": "row-1"}, "expected": "alpha"} + yield {"input": {"case_id": "row-2"}, "expected": "bravo"} + +def task(input): + return "alpha" if input["case_id"] == "row-1" else "bravo" + +def match(output, expected): + return {"name": "match", "score": 1 if output == expected else 0} + +Eval( + "sandbox-rows-py", + data=rows, + task=task, + scores=[match], + max_concurrency=4, + trial_count=2, +) +"#; + fs::write(&fixture_path, fixture_source).expect("write python rows fixture"); + + let socket_path = unique_socket_path("bt-eval-py-rows"); + let listener = UnixListener::bind(&socket_path).expect("bind unix listener"); + listener + .set_nonblocking(true) + .expect("set listener nonblocking"); + + let mut child = Command::new(&python) + .arg(&runner_script) + .arg(&fixture_name) + .current_dir(&fixture_dir) + .env("BT_EVAL_DEV_MODE", "rows") + .env( + "BT_EVAL_DEV_REQUEST_JSON", + json!({ + "name": "sandbox-rows-py", + "parameters": {}, + }) + .to_string(), + ) + .env("BT_EVAL_PULL_SOCK", &socket_path) + .env("BT_EVAL_LOCAL", "1") + .env("BT_EVAL_NO_SEND_LOGS", "1") + .env("BRAINTRUST_API_KEY", "local") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("spawn python eval runner"); + + let stream = accept_pull_stream(&listener, &mut child, Duration::from_secs(10)); + let mut reader = BufReader::new(stream.try_clone().expect("clone pull stream")); + let mut writer = stream; + + let ready = read_pull_message(&mut reader); + assert_eq!(ready["type"], "ready"); + assert_eq!(ready["evaluator_name"], "sandbox-rows-py"); + assert_eq!(ready["max_concurrency"], 4); + assert!(ready["experiment_name"] + .as_str() + .is_some_and(|value| value.starts_with("sandbox-rows-py-"))); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let first = read_pull_message(&mut reader); + assert_eq!(first["type"], "row"); + assert_eq!(first["trial_index"], 0); + assert_eq!(first["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let second = read_pull_message(&mut reader); + assert_eq!(second["type"], "row"); + assert_eq!(second["trial_index"], 1); + assert_eq!(second["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let third = read_pull_message(&mut reader); + assert_eq!(third["type"], "row"); + assert_eq!(third["trial_index"], 0); + assert_eq!(third["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let fourth = read_pull_message(&mut reader); + assert_eq!(fourth["type"], "row"); + assert_eq!(fourth["trial_index"], 1); + assert_eq!(fourth["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let eof = read_pull_message(&mut reader); + assert_eq!(eof["type"], "eof"); + + write_pull_message(&mut writer, &json!({ "type": "close" })); + drop(writer); + drop(reader); + + let output = child + .wait_with_output() + .expect("wait for python rows runner"); + if !output.status.success() { + panic!( + "python rows runner failed with status {}\nstdout:\n{}\nstderr:\n{}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + + let _ = fs::remove_file(&socket_path); + let _ = fs::remove_file(&fixture_path); +} + fn read_fixture_config(path: &Path) -> FixtureConfig { let raw = fs::read_to_string(path).expect("read fixture.json"); serde_json::from_str(&raw).expect("parse fixture.json") } +#[cfg(unix)] +fn unique_test_suffix() -> String { + format!( + "{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system clock before epoch") + .as_nanos() + ) +} + +#[cfg(unix)] +fn unique_socket_path(prefix: &str) -> PathBuf { + std::env::temp_dir().join(format!("{prefix}-{}.sock", unique_test_suffix())) +} + +#[cfg(unix)] +fn accept_pull_stream(listener: &UnixListener, child: &mut Child, timeout: Duration) -> UnixStream { + let started = Instant::now(); + loop { + match listener.accept() { + Ok((stream, _)) => return stream, + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {} + Err(err) => panic!("accept pull stream: {err}"), + } + + if let Some(status) = child.try_wait().expect("try_wait runner") { + panic!("runner exited early with status {status}"); + } + + if started.elapsed() > timeout { + panic!("timed out waiting for runner pull socket connection"); + } + + thread::sleep(Duration::from_millis(25)); + } +} + +#[cfg(unix)] +fn read_pull_message(reader: &mut BufReader) -> Value { + let mut line = String::new(); + let read = reader.read_line(&mut line).expect("read pull message"); + assert!(read > 0, "pull channel closed unexpectedly"); + serde_json::from_str(line.trim()).expect("parse pull message json") +} + +#[cfg(unix)] +fn write_pull_message(writer: &mut UnixStream, payload: &Value) { + writer + .write_all(format!("{payload}\n").as_bytes()) + .expect("write pull message"); + writer.flush().expect("flush pull message"); +} + fn collect_deno_eval_diagnostics(dir: &Path, files: &[String]) -> Option { if !command_exists("deno") { return None; From a68e4cad758f4c03db590cc53712804c480ec93a Mon Sep 17 00:00:00 2001 From: Nate Selvidge Date: Thu, 19 Mar 2026 21:33:56 +0000 Subject: [PATCH 3/3] use unix socket for communication and split data generation to its own runner --- scripts/data-runner.py | 201 ++++++++++++ scripts/data-runner.ts | 201 ++++++++++++ scripts/eval-runner.py | 386 +---------------------- scripts/eval-runner.ts | 665 +++------------------------------------ scripts/runner-common.ts | 447 ++++++++++++++++++++++++++ scripts/runner_common.py | 225 +++++++++++++ src/eval.rs | 312 ++++++++++++------ tests/eval_fixtures.rs | 76 +++-- 8 files changed, 1374 insertions(+), 1139 deletions(-) create mode 100644 scripts/data-runner.py create mode 100644 scripts/data-runner.ts create mode 100644 scripts/runner_common.py diff --git a/scripts/data-runner.py b/scripts/data-runner.py new file mode 100644 index 0000000..35de3d5 --- /dev/null +++ b/scripts/data-runner.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import socket +import sys +import time +from dataclasses import dataclass +from typing import Any + +try: + from braintrust.util import eprint + from runner_common import call_evaluator_data, load_evaluators, to_async_iterator +except Exception as exc: # pragma: no cover - runtime guard + print( + "Unable to import the braintrust package. Please install it in your Python environment.", + file=sys.stderr, + ) + print(str(exc), file=sys.stderr) + sys.exit(1) + + +@dataclass +class PullChannel: + sock: socket.socket + + def send(self, payload: Any) -> None: + self.sock.sendall((json.dumps(payload) + "\n").encode("utf-8")) + + async def lines(self): + buffer = "" + while True: + chunk = await asyncio.to_thread(self.sock.recv, 4096) + if not chunk: + break + buffer += chunk.decode("utf-8") + while True: + newline = buffer.find("\n") + if newline == -1: + break + line = buffer[:newline].strip() + buffer = buffer[newline + 1 :] + if line: + yield line + + trailing = buffer.strip() + if trailing: + yield trailing + + def close(self) -> None: + try: + self.sock.shutdown(socket.SHUT_RDWR) + except OSError: + pass + self.sock.close() + + +def create_pull_channel() -> PullChannel: + sock_path = os.getenv("BT_EVAL_PULL_SOCK") + if not sock_path: + raise ValueError("Missing BT_EVAL_PULL_SOCK") + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(sock_path) + return PullChannel(sock) + + +def parse_start_request(raw: str) -> str: + parsed = json.loads(raw) + if not isinstance(parsed, dict): + raise ValueError("Start request must be a JSON object.") + if parsed.get("type") != "start": + raise ValueError("Expected initial start command.") + name = parsed.get("name") + if not isinstance(name, str) or not name: + raise ValueError("Start request must include a non-empty evaluator name.") + return name + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Stream eval rows over a unix socket for bt.") + parser.add_argument("files", nargs="*", help="Eval files or directories to load.") + return parser + + +async def run(files: list[str]) -> int: + evaluators, _reporters = load_evaluators(files) + channel = create_pull_channel() + + try: + line_iter = channel.lines() + try: + start_line = await anext(line_iter) + except StopAsyncIteration: + return 0 + + try: + target_name = parse_start_request(start_line) + except Exception as exc: + channel.send({"type": "error", "message": str(exc)}) + return 1 + + evaluator_instance = next( + (candidate for candidate in evaluators if candidate.evaluator.eval_name == target_name), + None, + ) + if evaluator_instance is None: + channel.send({"type": "error", "message": f"Evaluator '{target_name}' not found"}) + return 1 + + evaluator = evaluator_instance.evaluator + raw_data, _base_experiment_name = await call_evaluator_data(evaluator.data) + data_iterator = to_async_iterator(raw_data) + iterator = data_iterator.__aiter__() + + trial_count = getattr(evaluator, "trial_count", 1) + try: + trial_count = int(trial_count) + except Exception: + trial_count = 1 + if trial_count < 1: + trial_count = 1 + + max_concurrency = getattr(evaluator, "max_concurrency", None) + try: + max_concurrency = int(max_concurrency) if max_concurrency is not None else 10 + except Exception: + max_concurrency = 10 + if max_concurrency < 1: + max_concurrency = 1 + + experiment_name = getattr(evaluator, "experiment_name", None) + if not isinstance(experiment_name, str) or not experiment_name: + experiment_name = f"{evaluator.eval_name}-{int(time.time() * 1000)}" + + channel.send( + { + "type": "ready", + "evaluator_name": evaluator.eval_name, + "max_concurrency": max_concurrency, + "experiment_name": experiment_name, + } + ) + + current_datum = None + trial_index = 0 + async for line in line_iter: + parsed = json.loads(line) + command_type = parsed.get("type") if isinstance(parsed, dict) else None + if command_type == "close": + break + if command_type != "next": + channel.send( + { + "type": "error", + "message": f"Unsupported pull command '{command_type}'", + } + ) + return 1 + + if current_datum is None: + try: + current_datum = await iterator.__anext__() + trial_index = 0 + except StopAsyncIteration: + channel.send({"type": "eof"}) + continue + + channel.send( + { + "type": "row", + "datum": current_datum, + "trial_index": trial_index, + } + ) + trial_index += 1 + if trial_index >= trial_count: + current_datum = None + + return 0 + finally: + channel.close() + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + files = args.files or ["."] + + try: + return asyncio.run(run(files)) + except Exception as exc: + eprint(str(exc)) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/data-runner.ts b/scripts/data-runner.ts new file mode 100644 index 0000000..53cb81e --- /dev/null +++ b/scripts/data-runner.ts @@ -0,0 +1,201 @@ +import net from "node:net"; +import readline from "node:readline"; + +import { + callEvaluatorData, + formatError, + getBraintrustStateGetter, + getEvaluators, + initRegistry, + loadBraintrust, + loadFiles, + normalizeFiles, + propagateInheritedBraintrustState, + toAsyncIterable, +} from "./runner-common"; + +type StartMessage = { + type: "start"; + name: string; +}; + +type ClientMessage = + | StartMessage + | { type: "next" } + | { type: "close" }; + +type ServerMessage = + | { + type: "ready"; + evaluator_name: string; + max_concurrency: number; + experiment_name: string; + } + | { type: "row"; datum: unknown; trial_index: number } + | { type: "eof" } + | { type: "error"; message: string }; + +function writeMessage(socket: net.Socket, message: ServerMessage) { + socket.write(`${JSON.stringify(message)}\n`); +} + +function parseMessage(line: string): ClientMessage { + const parsed = JSON.parse(line) as { type?: unknown; name?: unknown }; + if (parsed.type === "start") { + if (typeof parsed.name !== "string" || parsed.name.length === 0) { + throw new Error("Start request must include a non-empty evaluator name."); + } + return { type: "start", name: parsed.name }; + } + if (parsed.type === "next" || parsed.type === "close") { + return { type: parsed.type }; + } + throw new Error(`Unsupported pull command '${String(parsed.type)}'`); +} + +async function readMessage( + lines: AsyncIterator, +): Promise { + const next = await lines.next(); + if (next.done) { + return null; + } + return parseMessage(next.value); +} + +function applyExtraArgsFromEnv() { + const extraArgs: string[] = process.env.BT_EVAL_EXTRA_ARGS_JSON + ? (JSON.parse(process.env.BT_EVAL_EXTRA_ARGS_JSON) as string[]) + : []; + process.argv = [...process.argv.slice(0, 2), ...extraArgs]; +} + +function toPositiveInteger(value: unknown, fallback: number): number { + const parsed = Number(value); + if (Number.isFinite(parsed) && parsed > 0) { + return Math.floor(parsed); + } + return fallback; +} + +async function main() { + const files = process.argv.slice(2); + if (files.length === 0) { + throw new Error("No eval files provided."); + } + const socketPath = process.env.BT_EVAL_PULL_SOCK; + if (!socketPath) { + throw new Error("Missing BT_EVAL_PULL_SOCK"); + } + + const normalized = normalizeFiles(files); + const braintrust = await loadBraintrust(normalized); + propagateInheritedBraintrustState(braintrust); + initRegistry(); + applyExtraArgsFromEnv(); + await loadFiles(normalized); + + const socket = net.createConnection({ path: socketPath }); + const socketReady = new Promise((resolve, reject) => { + socket.once("connect", resolve); + socket.once("error", reject); + }); + await socketReady; + + const reader = readline.createInterface({ + input: socket, + crlfDelay: Infinity, + }); + const lines = reader[Symbol.asyncIterator](); + + try { + const start = await readMessage(lines); + if (!start) { + return; + } + if (start.type !== "start") { + throw new Error("Expected initial start command."); + } + + const entry = getEvaluators().find( + (candidate) => candidate.evaluator.evalName === start.name, + ); + if (!entry) { + writeMessage(socket, { + type: "error", + message: `Evaluator '${start.name}' not found`, + }); + return; + } + + const getState = getBraintrustStateGetter(braintrust); + const state = getState ? getState() : undefined; + const evaluator = { + ...entry.evaluator, + ...(state !== undefined && state !== null ? { state } : {}), + }; + const { data: rawData } = callEvaluatorData(evaluator.data); + const dataIterable = toAsyncIterable(rawData); + const iterator = dataIterable[Symbol.asyncIterator](); + const trialCount = toPositiveInteger(evaluator.trialCount, 1); + const maxConcurrency = toPositiveInteger(evaluator.maxConcurrency, 10); + const experimentName = + typeof evaluator.experimentName === "string" && + evaluator.experimentName.length > 0 + ? evaluator.experimentName + : `${entry.evaluator.evalName}-${Date.now()}`; + + writeMessage(socket, { + type: "ready", + evaluator_name: entry.evaluator.evalName, + max_concurrency: maxConcurrency, + experiment_name: experimentName, + }); + + let currentDatum: unknown | undefined; + let trialIndex = 0; + while (true) { + const message = await readMessage(lines); + if (!message || message.type === "close") { + return; + } + if (message.type !== "next") { + throw new Error(`Unsupported pull command '${message.type}'`); + } + + if (currentDatum === undefined) { + const next = await iterator.next(); + if (next.done) { + writeMessage(socket, { type: "eof" }); + continue; + } + currentDatum = next.value; + trialIndex = 0; + } + + writeMessage(socket, { + type: "row", + datum: currentDatum, + trial_index: trialIndex, + }); + + trialIndex += 1; + if (trialIndex >= trialCount) { + currentDatum = undefined; + } + } + } catch (err) { + writeMessage(socket, { + type: "error", + message: formatError(err), + }); + } finally { + reader.close(); + socket.end(); + } +} + +main().catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/scripts/eval-runner.py b/scripts/eval-runner.py index d8ac067..776aa1c 100755 --- a/scripts/eval-runner.py +++ b/scripts/eval-runner.py @@ -1,26 +1,19 @@ #!/usr/bin/env python3 import argparse import asyncio -import fnmatch -import importlib.util import inspect import json import os -import re import socket import sys -import time import traceback from dataclasses import dataclass -from typing import Any, AsyncIterator, Callable +from typing import Any, Callable try: from braintrust import init_dataset, invoke, login from braintrust.framework import ( BaseExperiment, - EvaluatorInstance, - _evals, - _set_lazy_load, run_evaluator, set_thread_pool_max_workers, ) @@ -28,6 +21,15 @@ from braintrust.parameters import parameters_to_json_schema, validate_parameters from braintrust.util import eprint from braintrust.span_identifier_v4 import parse_parent + from runner_common import ( + EvalFilter, + EvaluatorInstance, + call_evaluator_data, + env_flag, + filter_evaluators, + load_evaluators, + parse_serialized_filters, + ) except Exception as exc: # pragma: no cover - runtime guard print( "Unable to import the braintrust package. Please install it in your Python environment.", @@ -45,14 +47,6 @@ "/venv/", ) _DATASET_TOTAL_CACHE: dict[str, int] = {} - - -@dataclass(frozen=True) -class EvalFilter: - path: list[str] - pattern: re.Pattern[str] - - @dataclass(frozen=True) class RunnerConfig: jsonl: bool @@ -80,41 +74,6 @@ def close(self) -> None: self.sock.close() -@dataclass -class PullChannel: - sock: socket.socket - - def send(self, payload: Any) -> None: - self.sock.sendall((json.dumps(payload) + "\n").encode("utf-8")) - - async def lines(self) -> AsyncIterator[str]: - buffer = "" - while True: - chunk = await asyncio.to_thread(self.sock.recv, 4096) - if not chunk: - break - buffer += chunk.decode("utf-8") - while True: - newline = buffer.find("\n") - if newline == -1: - break - line = buffer[:newline].strip() - buffer = buffer[newline + 1 :] - if line: - yield line - - trailing = buffer.strip() - if trailing: - yield trailing - - def close(self) -> None: - try: - self.sock.shutdown(socket.SHUT_RDWR) - except OSError: - pass - self.sock.close() - - def serialize_sse_event(event: str, data: Any) -> str: if isinstance(data, (dict, list)): data_str = json.dumps(data) @@ -139,51 +98,10 @@ def create_sse_writer() -> SseWriter | None: return SseWriter(sock) return None - - -def create_pull_channel() -> PullChannel | None: - sock_path = os.getenv("BT_EVAL_PULL_SOCK") - if not sock_path: - return None - - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(sock_path) - return PullChannel(sock) - - -def env_flag(name: str) -> bool: - value = os.getenv(name) - if value is None: - return False - return value.lower() not in {"0", "false", "no", "off", ""} - - -def parse_serialized_filters(serialized: str | None) -> list[EvalFilter]: - if not serialized: - return [] - - parsed = json.loads(serialized) - if not isinstance(parsed, list): - raise ValueError("BT_EVAL_FILTER_PARSED must be a JSON array") - - filters: list[EvalFilter] = [] - for i, entry in enumerate(parsed): - if not isinstance(entry, dict): - raise ValueError("BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}") - key_path = entry.get("path") - pattern = entry.get("pattern") - if not isinstance(key_path, list) or not all(isinstance(part, str) for part in key_path): - raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} path must be an array of strings") - if not isinstance(pattern, str): - raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} pattern must be a string") - filters.append(EvalFilter(path=key_path, pattern=re.compile(pattern))) - return filters - - def parse_dev_mode(value: str | None) -> str | None: if value is None or value == "": return None - if value in {"list", "eval", "rows"}: + if value in {"list", "eval"}: return value raise ValueError(f"Invalid BT_EVAL_DEV_MODE value: {value}") @@ -201,46 +119,6 @@ def read_runner_config() -> RunnerConfig: dev_request_json=os.getenv("BT_EVAL_DEV_REQUEST_JSON"), ) - -def _to_mapping(value: Any) -> Any: - if isinstance(value, dict): - return {k: _to_mapping(v) for k, v in value.items()} - if isinstance(value, list): - return [_to_mapping(v) for v in value] - if hasattr(value, "__dict__"): - return { - key: _to_mapping(val) - for key, val in vars(value).items() - if not key.startswith("_") - } - return value - - -def serialize_json_with_plain_string(value: Any) -> str: - if isinstance(value, str): - return value - return json.dumps(value) - - -def evaluate_filter(value: Any, filt: EvalFilter) -> bool: - current = _to_mapping(value) - for part in filt.path: - if not isinstance(current, dict) or part not in current: - return False - current = current[part] - return bool(filt.pattern.search(serialize_json_with_plain_string(current))) - - -def filter_evaluators(evaluators: list[EvaluatorInstance], filters: list[EvalFilter]) -> list[EvaluatorInstance]: - if not filters: - return evaluators - return [ - evaluator - for evaluator in evaluators - if all(evaluate_filter(evaluator.evaluator, filt) for filt in filters) - ] - - def snake_to_camel(value: str) -> str: parts = value.split("_") if not parts: @@ -348,26 +226,6 @@ def parse_eval_request(raw: str | None) -> dict[str, Any]: return parsed -def parse_eval_pull_request(raw: str | None) -> dict[str, Any]: - if not raw: - raise ValueError("Missing BT_EVAL_DEV_REQUEST_JSON") - try: - parsed = json.loads(raw) - except json.JSONDecodeError as exc: - raise ValueError(f"Invalid BT_EVAL_DEV_REQUEST_JSON: {exc}") from exc - - if not isinstance(parsed, dict): - raise ValueError("BT_EVAL_DEV_REQUEST_JSON must be a JSON object.") - if not isinstance(parsed.get("name"), str) or not parsed["name"]: - raise ValueError("Pull request must include a non-empty evaluator name.") - - parameters = parsed.get("parameters") - if parameters is not None and not isinstance(parameters, dict): - raise ValueError("Pull request parameters must be an object.") - - return parsed - - def resolve_eval_data(data: dict[str, Any]) -> Any: if "data" in data: return data["data"] @@ -390,33 +248,6 @@ def resolve_eval_data(data: dict[str, Any]) -> Any: raise ValueError("Invalid eval data payload.") -async def call_evaluator_data(data: Any) -> tuple[Any, str | None]: - data_result = data - if inspect.isclass(data_result): - data_result = data_result() - if inspect.isfunction(data_result) or inspect.isroutine(data_result): - data_result = data_result() - if inspect.isawaitable(data_result): - data_result = await data_result - - base_experiment_name = None - if isinstance(data_result, BaseExperiment): - base_experiment_name = data_result.name - - return data_result, base_experiment_name - - -def to_async_iterator(value: Any) -> AsyncIterator[Any]: - if inspect.isasyncgen(value): - return value - - async def to_async(it): - for item in it: - yield item - - return to_async(value) - - def make_eval_scorer( score: dict[str, Any], project_id: str | None, @@ -457,16 +288,6 @@ def build_eval_definitions(evaluator_instances: list[EvaluatorInstance]) -> dict return definitions -def collect_files(input_path: str) -> list[str]: - if os.path.isdir(input_path): - matches: list[str] = [] - for root, _, files in os.walk(input_path): - for filename in files: - matches.append(os.path.join(root, filename)) - return matches - return [input_path] - - def is_watchable_dependency(path_input: str, cwd: str) -> bool: path = os.path.abspath(path_input) normalized = path.replace("\\", "/") @@ -502,86 +323,6 @@ def collect_dependency_files(cwd: str, input_files: list[str]) -> list[str]: return sorted(dependencies) -def resolve_module_info(in_file: str) -> tuple[str, list[str]]: - in_file = os.path.abspath(in_file) - module_dir = os.path.dirname(in_file) - module_name = os.path.splitext(os.path.basename(in_file))[0] - - package_parts: list[str] = [] - current = module_dir - while os.path.isfile(os.path.join(current, "__init__.py")): - package_parts.insert(0, os.path.basename(current)) - current = os.path.dirname(current) - - extra_paths = [module_dir] - if package_parts: - module_name = ".".join(package_parts + [module_name]) - if current not in extra_paths: - extra_paths.append(current) - - return module_name, extra_paths - - -def load_evaluators(files: list[str]) -> tuple[list[EvaluatorInstance], dict[str, Any]]: - evaluator_instances: list[EvaluatorInstance] = [] - reporters: dict[str, Any] = {} - cwd = os.getcwd() - if cwd not in sys.path: - sys.path.insert(0, cwd) - - # Add the project root inferred from input files to sys.path so that - # sibling-package imports work when files live outside CWD (e.g. - # sandbox bundles extracted to a temp directory). Walk up from each - # file's directory looking for a register.py (bundle marker) or the - # filesystem root, whichever comes first. - for f in files: - d = os.path.dirname(os.path.abspath(f)) - while d and d != os.path.dirname(d): - if os.path.isfile(os.path.join(d, "register.py")): - if d not in sys.path: - sys.path.insert(0, d) - break - d = os.path.dirname(d) - - unique_files: set[str] = set() - for file_path in files: - for candidate in collect_files(file_path): - unique_files.add(os.path.abspath(candidate)) - - for file_path in sorted(unique_files): - module_name, extra_paths = resolve_module_info(file_path) - with _set_lazy_load(True): - _evals.clear() - try: - for extra_path in reversed(extra_paths): - if extra_path not in sys.path: - sys.path.insert(0, extra_path) - - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None or spec.loader is None: - raise ImportError(f"Unable to load module spec for {file_path}") - - sys.modules.pop(module_name, None) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - - evaluator_instances.extend( - [ - instance - for instance in _evals.evaluators.values() - if isinstance(instance, EvaluatorInstance) - ] - ) - for reporter_name, reporter in _evals.reporters.items(): - if reporter_name not in reporters: - reporters[reporter_name] = reporter - finally: - _evals.clear() - - return evaluator_instances, reporters - - def resolve_reporter( reporter: Any, reporters: dict[str, Any], @@ -944,108 +685,6 @@ async def run_requested_eval( return True -async def run_dataset_pull( - evaluator_instances: list[EvaluatorInstance], - config: RunnerConfig, -) -> bool: - channel = create_pull_channel() - if channel is None: - raise ValueError("Missing BT_EVAL_PULL_SOCK") - - try: - request = parse_eval_pull_request(config.dev_request_json) - except Exception as exc: - channel.send({"type": "error", "message": str(exc)}) - channel.close() - return False - - target_name = request["name"] - evaluator_instance = next( - (candidate for candidate in evaluator_instances if candidate.evaluator.eval_name == target_name), - None, - ) - if evaluator_instance is None: - channel.send({"type": "error", "message": f"Evaluator '{target_name}' not found"}) - channel.close() - return False - - evaluator = evaluator_instance.evaluator - try: - raw_data, _base_experiment_name = await call_evaluator_data(evaluator.data) - data_iterator = to_async_iterator(raw_data) - iterator = data_iterator.__aiter__() - trial_count = getattr(evaluator, "trial_count", 1) - try: - trial_count = int(trial_count) - except Exception: - trial_count = 1 - if trial_count < 1: - trial_count = 1 - - max_concurrency = getattr(evaluator, "max_concurrency", None) - try: - max_concurrency = int(max_concurrency) if max_concurrency is not None else 10 - except Exception: - max_concurrency = 10 - if max_concurrency < 1: - max_concurrency = 1 - - experiment_name = getattr(evaluator, "experiment_name", None) - if not isinstance(experiment_name, str) or not experiment_name: - experiment_name = f"{evaluator.eval_name}-{int(time.time() * 1000)}" - - channel.send( - { - "type": "ready", - "evaluator_name": evaluator.eval_name, - "max_concurrency": max_concurrency, - "experiment_name": experiment_name, - } - ) - - current_datum = None - trial_index = 0 - async for line in channel.lines(): - parsed = json.loads(line) - command_type = parsed.get("type") if isinstance(parsed, dict) else None - if command_type == "close": - break - if command_type != "next": - channel.send( - { - "type": "error", - "message": f"Unsupported pull command '{command_type}'", - } - ) - break - - if current_datum is None: - try: - current_datum = await iterator.__anext__() - trial_index = 0 - except StopAsyncIteration: - channel.send({"type": "eof"}) - continue - - channel.send( - { - "type": "row", - "datum": current_datum, - "trial_index": trial_index, - } - ) - trial_index += 1 - if trial_index >= trial_count: - current_datum = None - except Exception as exc: - channel.send({"type": "error", "message": str(exc)}) - channel.close() - return False - - channel.close() - return True - - async def run_once( files: list[str], no_send_logs: bool, @@ -1067,9 +706,6 @@ async def run_once( return True if config.dev_mode == "eval": return await run_requested_eval(evaluators, reporters, no_send_logs, sse, config) - if config.dev_mode == "rows": - return await run_dataset_pull(evaluators, config) - if config.list_only: for evaluator_instance in evaluators: print(evaluator_instance.evaluator.eval_name) diff --git a/scripts/eval-runner.ts b/scripts/eval-runner.ts index fdd7152..40f4b86 100644 --- a/scripts/eval-runner.ts +++ b/scripts/eval-runner.ts @@ -2,19 +2,27 @@ import { createRequire } from "node:module"; import path from "node:path"; import { fileURLToPath, pathToFileURL } from "node:url"; -type EvaluatorDefinition = { - evalName: string; - projectName: string; - data?: unknown; - trialCount?: unknown; - maxConcurrency?: unknown; - experimentName?: unknown; -} & Record; - -type EvaluatorEntry = { - evaluator: EvaluatorDefinition; - reporter?: unknown; -}; +import { + envFlag, + filterEvaluators, + formatError, + getBraintrustStateGetter, + getEvaluators, + getReporters, + initRegistry, + isObject, + loadBraintrust, + loadBraintrustUtilParseParent, + loadFiles, + normalizeFiles, + parseSerializedFilters, + propagateInheritedBraintrustState, + resolveBraintrustPath, + type EvalFilter, + type EvaluatorEntry, + type GlobalEvals, + type ParseParentFunction, +} from "./runner-common"; type EvalResult = { results: Array<{ error?: unknown }>; @@ -49,23 +57,6 @@ type InitDatasetFunction = ( ) => unknown; type InvokeFunction = (options: Record) => Promise; -type BraintrustModule = { - Eval?: EvalFunction; - login?: LoginFunction; - initDataset?: InitDatasetFunction; - invoke?: InvokeFunction; - _internalGetGlobalState?: () => unknown; - default?: BraintrustModule; -}; - -type GlobalEvals = { - functions: unknown[]; - prompts: unknown[]; - parameters: unknown[]; - evaluators: Record; - reporters: Record; -}; - type BtEvalMain = (context: BtEvalContext) => void | Promise; type BtEvalContext = { @@ -89,16 +80,6 @@ type SseWriter = { close: () => void; }; -type EvalFilter = { - path: string[]; - pattern: RegExp; -}; - -type SerializedEvalFilter = { - path: string[]; - pattern: string; -}; - type EvalScoreSpec = { name: string; function_id: Record; @@ -117,17 +98,12 @@ type EvalRequest = { scores?: EvalScoreSpec[]; }; -type EvalPullRequest = { - name: string; - parameters?: Record; -}; - type RunnerConfig = { jsonl: boolean; list: boolean; terminateOnFailure: boolean; filters: EvalFilter[]; - devMode: "list" | "eval" | "rows" | null; + devMode: "list" | "eval" | null; devRequestJson: string | null; }; @@ -156,8 +132,6 @@ type EvalRunner = { type ParameterContainerSerializer = (parameters: unknown) => unknown; type PromptDefinitionSerializer = (prompt: unknown) => unknown; type ZodSchemaSerializer = (schema: unknown) => Record; -type ParseParentFunction = (parent: unknown) => string | undefined; - type ParameterSerializationHelpers = { sdkSerializeParameters: ParameterContainerSerializer | null; promptDefinitionToPromptData: PromptDefinitionSerializer | null; @@ -173,94 +147,13 @@ declare global { var __inherited_braintrust_state: unknown; } -function isObject(value: unknown): value is Record { - return typeof value === "object" && value !== null; -} - -function isBraintrustModule(value: unknown): value is BraintrustModule { - return isObject(value) && ("Eval" in value || "login" in value); -} - -function normalizeBraintrustModule(value: unknown): BraintrustModule { - if (isBraintrustModule(value)) { - return value; - } - if (isObject(value) && isBraintrustModule(value.default)) { - return value.default; - } - throw new Error("Unable to load braintrust module."); -} - -function normalizeFiles(files: string[]): string[] { - return files.map((file) => path.resolve(process.cwd(), file)); -} - -function envFlag(name: string): boolean { - const value = process.env[name]; - if (!value) { - return false; - } - const normalized = value.toLowerCase(); - return !["0", "false", "no", "off", ""].includes(normalized); -} - -function serializeJSONWithPlainString(value: unknown): string { - if (typeof value === "string") { - return value; - } - return JSON.stringify(value); -} - -function parseSerializedFilters(serialized: string | undefined): EvalFilter[] { - if (!serialized) { - return []; - } - - try { - const parsed = JSON.parse(serialized); - if (!Array.isArray(parsed)) { - throw new Error("BT_EVAL_FILTER_PARSED must be a JSON array."); - } - return parsed.map((value) => { - if (!isObject(value)) { - throw new Error( - "BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}.", - ); - } - const { path: rawPath, pattern: rawPattern } = - value as SerializedEvalFilter; - if ( - !Array.isArray(rawPath) || - !rawPath.every((part) => typeof part === "string") - ) { - throw new Error( - "BT_EVAL_FILTER_PARSED entry path must be an array of strings.", - ); - } - if (typeof rawPattern !== "string") { - throw new Error( - "BT_EVAL_FILTER_PARSED entry pattern must be a string.", - ); - } - return { - path: rawPath, - pattern: new RegExp(rawPattern), - }; - }); - } catch (err) { - throw new Error( - `Invalid BT_EVAL_FILTER_PARSED value: ${err instanceof Error ? err.message : String(err)}`, - ); - } -} - function parseDevMode( value: string | undefined, -): "list" | "eval" | "rows" | null { +): "list" | "eval" | null { if (!value) { return null; } - if (value === "list" || value === "eval" || value === "rows") { + if (value === "list" || value === "eval") { return value; } throw new Error(`Invalid BT_EVAL_DEV_MODE value: ${value}`); @@ -780,128 +673,6 @@ function createSseWriter(): SseWriter | null { return { send, close }; } -type PullChannel = { - send: (payload: unknown) => void; - close: () => void; - lines: () => AsyncGenerator; -}; - -function createPullChannel(): PullChannel | null { - const netModule = (() => { - try { - return runtimeRequire("node:net") as NetModule; - } catch { - return null; - } - })(); - const sock = process.env.BT_EVAL_PULL_SOCK; - if (!sock) { - return null; - } - if (!netModule) { - return null; - } - - const socket = netModule.createConnection({ path: sock }); - socket.setNoDelay(true); - - const send = (payload: unknown) => { - if (!socket.writable) { - return; - } - socket.write(`${JSON.stringify(payload)}\n`); - }; - - const close = () => { - socket.end(); - }; - - const lines = async function* () { - let buffer = ""; - for await (const chunk of socket as unknown as AsyncIterable< - Buffer | string - >) { - buffer += typeof chunk === "string" ? chunk : chunk.toString("utf8"); - while (true) { - const newline = buffer.indexOf("\n"); - if (newline === -1) { - break; - } - const line = buffer.slice(0, newline).trim(); - buffer = buffer.slice(newline + 1); - if (line.length > 0) { - yield line; - } - } - } - const trailing = buffer.trim(); - if (trailing.length > 0) { - yield trailing; - } - }; - - return { send, close, lines }; -} - -function initRegistry() { - globalThis._evals = { - functions: [], - prompts: [], - parameters: [], - evaluators: {}, - reporters: {}, - }; - globalThis._lazy_load = true; -} - -function ensureBraintrustAvailable() { - resolveBraintrustPath(); -} - -function resolveBraintrustPath(): string { - const files = normalizeFiles(process.argv.slice(2)); - for (const file of files) { - try { - const require = createRequire(pathToFileURL(file).href); - return require.resolve("braintrust"); - } catch { - continue; - } - } - - try { - const require = createRequire(process.cwd() + "/"); - return require.resolve("braintrust"); - } catch { - const message = - "Unable to resolve the `braintrust` package. " + - "Please install it in your project (e.g. `pnpm add braintrust` or `npm install braintrust`)."; - throw new Error(message); - } -} - -async function loadBraintrust() { - const cjsPath = resolveBraintrustPath(); - const cjsUrl = pathToFileURL(cjsPath).href; - - try { - const mod: unknown = await import(cjsUrl); - return normalizeBraintrustModule(mod); - } catch {} - - const esmPath = cjsPath.replace(/\.js$/, ".mjs"); - if (esmPath !== cjsPath && fsMutable.existsSync(esmPath)) { - try { - const mod: unknown = await import(pathToFileURL(esmPath).href); - return normalizeBraintrustModule(mod); - } catch {} - } - - const require = createRequire(cjsUrl); - const mod: unknown = require(cjsPath); - return normalizeBraintrustModule(mod); -} - function extractParameterSerializer( mod: unknown, ): ParameterContainerSerializer | null { @@ -1041,7 +812,7 @@ function loadZodSchemaSerializer( } async function loadParameterSerializationHelpers(): Promise { - const braintrustPath = resolveBraintrustPath(); + const braintrustPath = resolveBraintrustPath(process.argv.slice(2)); const zodToJsonSchema = loadZodSchemaSerializer(braintrustPath); try { const mod: unknown = await import(pathToFileURL(braintrustPath).href); @@ -1059,169 +830,6 @@ async function loadParameterSerializationHelpers(): Promise unknown) | null { - if (!isObject(mod)) { - return null; - } - const candidate = Reflect.get(mod, "_internalGetGlobalState"); - if (typeof candidate === "function") { - return candidate as () => unknown; - } - const defaultExport = Reflect.get(mod, "default"); - if (isObject(defaultExport)) { - const fromDefault = Reflect.get(defaultExport, "_internalGetGlobalState"); - if (typeof fromDefault === "function") { - return fromDefault as () => unknown; - } - } - return null; -} - -function loadBraintrustUtilParseParent(): ParseParentFunction | null { - const braintrustPath = resolveBraintrustPath(); - const requireFromBraintrust = createRequire( - pathToFileURL(braintrustPath).href, - ); - try { - const utilMod: unknown = requireFromBraintrust("braintrust/util"); - return extractParseParent(utilMod); - } catch { - return null; - } -} - -function propagateInheritedBraintrustState(braintrust: BraintrustModule) { - const getter = (braintrust as Record) - ._internalGetGlobalState; - if (typeof getter !== "function") { - return; - } - const state = getter(); - if (state !== undefined && state !== null) { - globalThis.__inherited_braintrust_state = state; - } -} - -async function loadFiles(files: string[]): Promise { - const modules: unknown[] = []; - // Internal CLI-controlled flag for ESM retry; not user-facing config. - const forceEsm = envFlag("BT_EVAL_FORCE_ESM"); - // vite-node installs transform hooks that handle TypeScript (including - // extension-less imports) and CJS named-export interop natively. A failed - // require() corrupts Node's module cache and causes the subsequent import() - // to hit the "module imported again after being required" bug, so we skip - // require() for .ts/.tsx files entirely when running under vite-node. - const isViteNode = process.env.BT_EVAL_RUNNER_KIND === "vite-node"; - for (const file of files) { - const fileUrl = pathToFileURL(file).href; - const isTypeScript = file.endsWith(".ts") || file.endsWith(".tsx"); - const preferRequire = - !forceEsm && - !(isViteNode && isTypeScript) && - (isTypeScript || file.endsWith(".cjs")); - - if (preferRequire) { - try { - const require = createRequire(fileUrl); - const mod = require(file); - modules.push(mod); - continue; - } catch (requireErr) { - try { - const mod = await import(fileUrl); - modules.push(mod); - continue; - } catch (esmErr) { - throw new Error( - `Failed to load ${file} as CJS (${formatError(requireErr)}) or ESM (${formatError(esmErr)}).`, - ); - } - } - } - - try { - const mod = await import(fileUrl); - modules.push(mod); - continue; - } catch (err) { - if (!shouldTryRequire(file, err)) { - throw err; - } - try { - const require = createRequire(fileUrl); - const mod = require(file); - modules.push(mod); - continue; - } catch (requireErr) { - throw new Error( - `Failed to load ${file} as ESM (${formatError(err)}) or CJS (${formatError(requireErr)}).`, - ); - } - } - } - return modules; -} - -function shouldTryRequire(file: string, err: unknown): boolean { - if (envFlag("BT_EVAL_FORCE_ESM")) { - return false; - } - if (process.env.BT_EVAL_RUNNER_KIND === "vite-node") { - return false; - } - if (process.env.BT_EVAL_CJS === "1" || file.endsWith(".cjs")) { - return true; - } - if ( - (file.endsWith(".ts") || file.endsWith(".tsx")) && - isNodeErrorCode(err, "ERR_UNKNOWN_FILE_EXTENSION") - ) { - return true; - } - if (!(err instanceof Error)) { - return false; - } - const message = err.message || ""; - return ( - message.includes("require is not defined") || - message.includes("exports is not defined") || - message.includes("module is not defined") || - message.includes("Cannot use import statement outside a module") - ); -} - -function isNodeErrorCode(err: unknown, code: string): boolean { - if (!isObject(err) || !("code" in err)) { - return false; - } - return typeof err.code === "string" && err.code === code; -} - -function formatError(err: unknown): string { - if (err instanceof Error) { - return err.message; - } - return String(err); -} - function createEvalProgressReporter( sse: SseWriter | null, evaluatorName: string, @@ -1293,22 +901,6 @@ function sendConsole( sse.send("console", { stream, message }); } -function getEvaluators(): EvaluatorEntry[] { - const evals = globalThis._evals; - if (!evals || !evals.evaluators) { - return []; - } - return Object.values(evals.evaluators) as EvaluatorEntry[]; -} - -function getReporters(): Record { - const evals = globalThis._evals; - if (!evals || !evals.reporters) { - return {}; - } - return evals.reporters as Record; -} - function resolveReporter( reporter: unknown, reporters: Record, @@ -1336,34 +928,6 @@ function resolveReporter( ); } -function evaluateFilter( - object: Record, - filter: EvalFilter, -): boolean { - const key = filter.path.reduce((acc, part) => { - if (!isObject(acc)) { - return undefined; - } - return acc[part]; - }, object); - if (key === undefined) { - return false; - } - return filter.pattern.test(serializeJSONWithPlainString(key)); -} - -function filterEvaluators( - evaluators: EvaluatorEntry[], - filters: EvalFilter[], -): EvaluatorEntry[] { - if (filters.length === 0) { - return evaluators; - } - return evaluators.filter((entry) => - filters.every((filter) => evaluateFilter(entry.evaluator, filter)), - ); -} - function extractScoreName(score: unknown, idx: number): string { if (typeof score === "function" && typeof score.name === "string") { return score.name || `scorer_${idx}`; @@ -1482,24 +1046,6 @@ async function buildEvaluatorDefinitions(evaluators: EvaluatorEntry[]) { return result; } -function parseEvalPullRequest(raw: string | null): EvalPullRequest { - if (!raw) { - throw new Error("Missing BT_EVAL_DEV_REQUEST_JSON"); - } - const parsed = JSON.parse(raw); - if (!isObject(parsed) || typeof parsed.name !== "string" || parsed.name.length === 0) { - throw new Error("Pull request must include a non-empty evaluator name."); - } - const request = parsed as EvalPullRequest; - if ( - request.parameters !== undefined && - (!isObject(request.parameters) || Array.isArray(request.parameters)) - ) { - throw new Error("Pull request parameters must be an object."); - } - return request; -} - function parseEvalRequest(raw: string | null): EvalRequest { if (!raw) { throw new Error("Missing BT_EVAL_DEV_REQUEST_JSON"); @@ -1582,48 +1128,6 @@ function resolveEvalData( throw new Error("Invalid eval data payload."); } -function callEvaluatorData( - data: unknown, -): { data: unknown; baseExperiment: string | undefined } { - const dataResult = typeof data === "function" ? (data as () => unknown)() : data; - let baseExperiment: string | undefined = undefined; - if ( - isObject(dataResult) && - Reflect.get(dataResult, "_type") === "BaseExperiment" && - typeof Reflect.get(dataResult, "name") === "string" - ) { - baseExperiment = Reflect.get(dataResult, "name") as string; - } - return { data: dataResult, baseExperiment }; -} - -function toAsyncIterable(value: unknown): AsyncIterable { - if ( - typeof value === "object" && - value !== null && - Symbol.asyncIterator in value && - typeof (value as AsyncIterable)[Symbol.asyncIterator] === "function" - ) { - return value as AsyncIterable; - } - if ( - typeof value === "object" && - value !== null && - Symbol.iterator in value && - typeof (value as Iterable)[Symbol.iterator] === "function" - ) { - const iterable = value as Iterable; - return (async function* () { - for (const item of iterable) { - yield item; - } - })(); - } - throw new Error( - "Evaluator data must be an array, iterable, or async iterable", - ); -} - function convertFunctionId( functionId: Record, ): Record { @@ -1798,102 +1302,6 @@ async function runRequestedEval(config: RunnerConfig, runner: EvalRunner) { } } -async function runDatasetPull(config: RunnerConfig, runner: EvalRunner) { - const channel = createPullChannel(); - if (!channel) { - throw new Error("Missing BT_EVAL_PULL_SOCK"); - } - - try { - const request = parseEvalPullRequest(config.devRequestJson); - const entry = getEvaluators().find( - (candidate) => candidate.evaluator.evalName === request.name, - ); - if (!entry) { - channel.send({ - type: "error", - message: `Evaluator '${request.name}' not found`, - }); - return; - } - - const state = runner.getState ? runner.getState() : undefined; - const evaluator = { - ...entry.evaluator, - ...(state !== undefined && state !== null ? { state } : {}), - }; - const { data: rawData } = callEvaluatorData(evaluator.data); - const dataIterable = toAsyncIterable(rawData); - const iterator = dataIterable[Symbol.asyncIterator](); - const trialCountRaw = Number(evaluator.trialCount ?? 1); - const trialCount = - Number.isFinite(trialCountRaw) && trialCountRaw > 0 - ? Math.floor(trialCountRaw) - : 1; - const maxConcurrencyRaw = Number(evaluator.maxConcurrency ?? 10); - const maxConcurrency = - Number.isFinite(maxConcurrencyRaw) && maxConcurrencyRaw > 0 - ? Math.floor(maxConcurrencyRaw) - : 10; - const experimentName = - typeof evaluator.experimentName === "string" && - evaluator.experimentName.length > 0 - ? evaluator.experimentName - : `${entry.evaluator.evalName}-${Date.now()}`; - - channel.send({ - type: "ready", - evaluator_name: entry.evaluator.evalName, - max_concurrency: maxConcurrency, - experiment_name: experimentName, - }); - - let currentDatum: unknown | undefined = undefined; - let trialIndex = 0; - for await (const line of channel.lines()) { - const parsed = JSON.parse(line) as { type?: string }; - if (parsed.type === "close") { - break; - } - if (parsed.type !== "next") { - channel.send({ - type: "error", - message: `Unsupported pull command '${String(parsed.type)}'`, - }); - break; - } - - if (currentDatum === undefined) { - const next = await iterator.next(); - if (next.done) { - channel.send({ type: "eof" }); - continue; - } - currentDatum = next.value; - trialIndex = 0; - } - - channel.send({ - type: "row", - datum: currentDatum, - trial_index: trialIndex, - }); - - trialIndex += 1; - if (trialIndex >= trialCount) { - currentDatum = undefined; - } - } - } catch (err) { - channel.send({ - type: "error", - message: err instanceof Error ? err.message : String(err), - }); - } finally { - channel.close(); - } -} - function extractBtEvalMain(mod: unknown): BtEvalMain | null { if (!mod || typeof mod !== "object") { return null; @@ -2064,9 +1472,12 @@ function mergeProgress( }; } -async function createEvalRunner(config: RunnerConfig): Promise { - const braintrust = await loadBraintrust(); - const Eval = braintrust.Eval; +async function createEvalRunner( + config: RunnerConfig, + files: string[], +): Promise { + const braintrust = await loadBraintrust(files); + const Eval = braintrust.Eval as EvalFunction | undefined; if (typeof Eval !== "function") { throw new Error("Unable to load Eval() from braintrust package."); } @@ -2076,8 +1487,8 @@ async function createEvalRunner(config: RunnerConfig): Promise { const sse = createSseWriter(); const noSendLogs = shouldDisableSendLogs(); - const parseParent = loadBraintrustUtilParseParent(); - const getState = extractGlobalStateGetter(braintrust); + const parseParent = loadBraintrustUtilParseParent(files); + const getState = getBraintrustStateGetter(braintrust); const makeEvalOptions = ( evaluatorName: string, @@ -2229,8 +1640,7 @@ async function main() { maybeRecordDependency(file); } collectStaticLocalDependencies(normalized); - ensureBraintrustAvailable(); - const braintrust = await loadBraintrust(); + const braintrust = await loadBraintrust(normalized); propagateInheritedBraintrustState(braintrust); initRegistry(); // Replace process.argv with [runtime, script, ...extraArgs] so that user @@ -2243,7 +1653,7 @@ async function main() { const modules = await loadFiles(normalized); const btEvalMains = collectBtEvalMains(modules); - const runner = await createEvalRunner(config); + const runner = await createEvalRunner(config, normalized); if (!runner.noSendLogs && typeof runner.login === "function") { try { await runner.login({}); @@ -2294,11 +1704,6 @@ async function main() { return; } - if (config.devMode === "rows") { - await runDatasetPull(config, runner); - return; - } - if (config.list) { for (const entry of filteredEvaluators) { console.log(entry.evaluator.evalName); diff --git a/scripts/runner-common.ts b/scripts/runner-common.ts index 8a0dd63..a98c612 100644 --- a/scripts/runner-common.ts +++ b/scripts/runner-common.ts @@ -1,3 +1,8 @@ +import { createRequire } from "node:module"; +import fs from "node:fs"; +import path from "node:path"; +import { pathToFileURL } from "node:url"; + export type JsonPrimitive = string | number | boolean | null; export type JsonArray = JsonValue[]; export type JsonObject = { [key: string]: JsonValue }; @@ -13,6 +18,56 @@ export type ProjectRef = { name?: string; }; +export type EvaluatorDefinition = { + evalName: string; + projectName: string; + data?: unknown; + trialCount?: unknown; + maxConcurrency?: unknown; + experimentName?: unknown; +} & Record; + +export type EvaluatorEntry = { + evaluator: EvaluatorDefinition; + reporter?: unknown; +}; + +export type BraintrustModule = { + Eval?: (...args: unknown[]) => unknown; + login?: (...args: unknown[]) => Promise; + initDataset?: (...args: unknown[]) => unknown; + invoke?: (...args: unknown[]) => Promise; + _internalGetGlobalState?: () => unknown; + default?: BraintrustModule; +}; + +export type GlobalEvals = { + functions: unknown[]; + prompts: unknown[]; + parameters: unknown[]; + evaluators: Record; + reporters: Record; +}; + +export type EvalFilter = { + path: string[]; + pattern: RegExp; +}; + +export type SerializedEvalFilter = { + path: string[]; + pattern: string; +}; + +declare global { + // eslint-disable-next-line no-var + var _evals: GlobalEvals | undefined; + // eslint-disable-next-line no-var + var _lazy_load: boolean | undefined; + // eslint-disable-next-line no-var + var __inherited_braintrust_state: unknown; +} + export function asProjectSelector( project: ProjectRef | undefined, ): ProjectSelector { @@ -81,3 +136,395 @@ export function toJsonValue(input: JsonValue): JsonValue { return input; } + +export function isObject(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +export function normalizeFiles(files: string[]): string[] { + return files.map((file) => path.resolve(process.cwd(), file)); +} + +export function envFlag(name: string): boolean { + const value = process.env[name]; + if (!value) { + return false; + } + const normalized = value.toLowerCase(); + return !["0", "false", "no", "off", ""].includes(normalized); +} + +export function serializeJSONWithPlainString(value: unknown): string { + if (typeof value === "string") { + return value; + } + return JSON.stringify(value); +} + +export function parseSerializedFilters( + serialized: string | undefined, +): EvalFilter[] { + if (!serialized) { + return []; + } + + try { + const parsed = JSON.parse(serialized); + if (!Array.isArray(parsed)) { + throw new Error("BT_EVAL_FILTER_PARSED must be a JSON array."); + } + return parsed.map((value) => { + if (!isObject(value)) { + throw new Error( + "BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}.", + ); + } + const { path: rawPath, pattern: rawPattern } = + value as SerializedEvalFilter; + if ( + !Array.isArray(rawPath) || + !rawPath.every((part) => typeof part === "string") + ) { + throw new Error( + "BT_EVAL_FILTER_PARSED entry path must be an array of strings.", + ); + } + if (typeof rawPattern !== "string") { + throw new Error( + "BT_EVAL_FILTER_PARSED entry pattern must be a string.", + ); + } + return { + path: rawPath, + pattern: new RegExp(rawPattern), + }; + }); + } catch (err) { + throw new Error( + `Invalid BT_EVAL_FILTER_PARSED value: ${err instanceof Error ? err.message : String(err)}`, + ); + } +} + +export function formatError(err: unknown): string { + if (err instanceof Error) { + return err.message; + } + return String(err); +} + +export function initRegistry() { + globalThis._evals = { + functions: [], + prompts: [], + parameters: [], + evaluators: {}, + reporters: {}, + }; + globalThis._lazy_load = true; +} + +function isBraintrustModule(value: unknown): value is BraintrustModule { + return isObject(value) && ("Eval" in value || "login" in value); +} + +function normalizeBraintrustModule(value: unknown): BraintrustModule { + if (isBraintrustModule(value)) { + return value; + } + if (isObject(value) && isBraintrustModule(value.default)) { + return value.default; + } + throw new Error("Unable to load braintrust module."); +} + +export function resolveBraintrustPath(files: string[]): string { + const normalizedFiles = normalizeFiles(files); + for (const file of normalizedFiles) { + try { + const require = createRequire(pathToFileURL(file).href); + return require.resolve("braintrust"); + } catch { + continue; + } + } + + try { + const require = createRequire(process.cwd() + "/"); + return require.resolve("braintrust"); + } catch { + const message = + "Unable to resolve the `braintrust` package. " + + "Please install it in your project (e.g. `pnpm add braintrust` or `npm install braintrust`)."; + throw new Error(message); + } +} + +export async function loadBraintrust( + files: string[], +): Promise { + const cjsPath = resolveBraintrustPath(files); + const cjsUrl = pathToFileURL(cjsPath).href; + + try { + const mod: unknown = await import(cjsUrl); + return normalizeBraintrustModule(mod); + } catch {} + + const esmPath = cjsPath.replace(/\.js$/, ".mjs"); + if (esmPath !== cjsPath && fs.existsSync(esmPath)) { + try { + const mod: unknown = await import(pathToFileURL(esmPath).href); + return normalizeBraintrustModule(mod); + } catch {} + } + + const require = createRequire(cjsUrl); + const mod: unknown = require(cjsPath); + return normalizeBraintrustModule(mod); +} + +export type ParseParentFunction = (parent: unknown) => string | undefined; + +function extractParseParent(mod: unknown): ParseParentFunction | null { + if (!isObject(mod)) { + return null; + } + const candidate = Reflect.get(mod, "parseParent"); + if (typeof candidate === "function") { + return candidate as ParseParentFunction; + } + const defaultExport = Reflect.get(mod, "default"); + if (isObject(defaultExport)) { + const fromDefault = Reflect.get(defaultExport, "parseParent"); + if (typeof fromDefault === "function") { + return fromDefault as ParseParentFunction; + } + } + return null; +} + +export function loadBraintrustUtilParseParent( + files: string[], +): ParseParentFunction | null { + const braintrustPath = resolveBraintrustPath(files); + const requireFromBraintrust = createRequire( + pathToFileURL(braintrustPath).href, + ); + try { + const utilMod: unknown = requireFromBraintrust("braintrust/util"); + return extractParseParent(utilMod); + } catch { + return null; + } +} + +function extractGlobalStateGetter(mod: unknown): (() => unknown) | null { + if (!isObject(mod)) { + return null; + } + const candidate = Reflect.get(mod, "_internalGetGlobalState"); + if (typeof candidate === "function") { + return candidate as () => unknown; + } + const defaultExport = Reflect.get(mod, "default"); + if (isObject(defaultExport)) { + const fromDefault = Reflect.get(defaultExport, "_internalGetGlobalState"); + if (typeof fromDefault === "function") { + return fromDefault as () => unknown; + } + } + return null; +} + +export function getBraintrustStateGetter( + braintrust: BraintrustModule, +): (() => unknown) | null { + return extractGlobalStateGetter(braintrust); +} + +export function propagateInheritedBraintrustState(braintrust: BraintrustModule) { + const getter = getBraintrustStateGetter(braintrust); + if (!getter) { + return; + } + const state = getter(); + if (state !== undefined && state !== null) { + globalThis.__inherited_braintrust_state = state; + } +} + +export async function loadFiles(files: string[]): Promise { + const modules: unknown[] = []; + const forceEsm = envFlag("BT_EVAL_FORCE_ESM"); + const isViteNode = process.env.BT_EVAL_RUNNER_KIND === "vite-node"; + for (const file of files) { + const fileUrl = pathToFileURL(file).href; + const isTypeScript = file.endsWith(".ts") || file.endsWith(".tsx"); + const preferRequire = + !forceEsm && + !(isViteNode && isTypeScript) && + (isTypeScript || file.endsWith(".cjs")); + + if (preferRequire) { + try { + const require = createRequire(fileUrl); + const mod = require(file); + modules.push(mod); + continue; + } catch (requireErr) { + try { + const mod = await import(fileUrl); + modules.push(mod); + continue; + } catch (esmErr) { + throw new Error( + `Failed to load ${file} as CJS (${formatError(requireErr)}) or ESM (${formatError(esmErr)}).`, + ); + } + } + } + + try { + const mod = await import(fileUrl); + modules.push(mod); + continue; + } catch (err) { + if (!shouldTryRequire(file, err)) { + throw err; + } + try { + const require = createRequire(fileUrl); + const mod = require(file); + modules.push(mod); + continue; + } catch (requireErr) { + throw new Error( + `Failed to load ${file} as ESM (${formatError(err)}) or CJS (${formatError(requireErr)}).`, + ); + } + } + } + return modules; +} + +function shouldTryRequire(file: string, err: unknown): boolean { + if (envFlag("BT_EVAL_FORCE_ESM")) { + return false; + } + if (process.env.BT_EVAL_RUNNER_KIND === "vite-node") { + return false; + } + if (process.env.BT_EVAL_CJS === "1" || file.endsWith(".cjs")) { + return true; + } + if ( + (file.endsWith(".ts") || file.endsWith(".tsx")) && + isNodeErrorCode(err, "ERR_UNKNOWN_FILE_EXTENSION") + ) { + return true; + } + if (!(err instanceof Error)) { + return false; + } + const message = err.message || ""; + return ( + message.includes("require is not defined") || + message.includes("exports is not defined") || + message.includes("module is not defined") || + message.includes("Cannot use import statement outside a module") + ); +} + +function isNodeErrorCode(err: unknown, code: string): boolean { + if (!isObject(err) || !("code" in err)) { + return false; + } + return typeof err.code === "string" && err.code === code; +} + +export function getEvaluators(): EvaluatorEntry[] { + const evals = globalThis._evals; + if (!evals || !evals.evaluators) { + return []; + } + return Object.values(evals.evaluators) as EvaluatorEntry[]; +} + +export function getReporters(): Record { + const evals = globalThis._evals; + if (!evals || !evals.reporters) { + return {}; + } + return evals.reporters as Record; +} + +export function evaluateFilter( + object: Record, + filter: EvalFilter, +): boolean { + const key = filter.path.reduce((acc, part) => { + if (!isObject(acc)) { + return undefined; + } + return acc[part]; + }, object); + if (key === undefined) { + return false; + } + return filter.pattern.test(serializeJSONWithPlainString(key)); +} + +export function filterEvaluators( + evaluators: EvaluatorEntry[], + filters: EvalFilter[], +): EvaluatorEntry[] { + if (filters.length === 0) { + return evaluators; + } + return evaluators.filter((entry) => + filters.every((filter) => evaluateFilter(entry.evaluator, filter)), + ); +} + +export function callEvaluatorData( + data: unknown, +): { data: unknown; baseExperiment: string | undefined } { + const dataResult = typeof data === "function" ? (data as () => unknown)() : data; + let baseExperiment: string | undefined = undefined; + if ( + isObject(dataResult) && + Reflect.get(dataResult, "_type") === "BaseExperiment" && + typeof Reflect.get(dataResult, "name") === "string" + ) { + baseExperiment = Reflect.get(dataResult, "name") as string; + } + return { data: dataResult, baseExperiment }; +} + +export function toAsyncIterable(value: unknown): AsyncIterable { + if ( + typeof value === "object" && + value !== null && + Symbol.asyncIterator in value && + typeof (value as AsyncIterable)[Symbol.asyncIterator] === "function" + ) { + return value as AsyncIterable; + } + if ( + typeof value === "object" && + value !== null && + Symbol.iterator in value && + typeof (value as Iterable)[Symbol.iterator] === "function" + ) { + const iterable = value as Iterable; + return (async function* () { + for (const item of iterable) { + yield item; + } + })(); + } + throw new Error( + "Evaluator data must be an array, iterable, or async iterable", + ); +} diff --git a/scripts/runner_common.py b/scripts/runner_common.py new file mode 100644 index 0000000..f079a8a --- /dev/null +++ b/scripts/runner_common.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import asyncio +import importlib.util +import inspect +import json +import os +import re +import sys +from dataclasses import dataclass +from typing import Any, AsyncIterator + +try: + from braintrust.framework import ( + BaseExperiment, + EvaluatorInstance, + _evals, + _set_lazy_load, + ) + from braintrust.logger import Dataset +except Exception: + raise + + +@dataclass(frozen=True) +class EvalFilter: + path: list[str] + pattern: re.Pattern[str] + + +def env_flag(name: str) -> bool: + value = os.getenv(name) + if value is None: + return False + return value.lower() not in {"0", "false", "no", "off", ""} + + +def parse_serialized_filters(serialized: str | None) -> list[EvalFilter]: + if not serialized: + return [] + + parsed = json.loads(serialized) + if not isinstance(parsed, list): + raise ValueError("BT_EVAL_FILTER_PARSED must be a JSON array") + + filters: list[EvalFilter] = [] + for i, entry in enumerate(parsed): + if not isinstance(entry, dict): + raise ValueError("BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}") + key_path = entry.get("path") + pattern = entry.get("pattern") + if not isinstance(key_path, list) or not all(isinstance(part, str) for part in key_path): + raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} path must be an array of strings") + if not isinstance(pattern, str): + raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} pattern must be a string") + filters.append(EvalFilter(path=key_path, pattern=re.compile(pattern))) + return filters + + +def _to_mapping(value: Any) -> Any: + if isinstance(value, dict): + return {k: _to_mapping(v) for k, v in value.items()} + if isinstance(value, list): + return [_to_mapping(v) for v in value] + if hasattr(value, "__dict__"): + return { + key: _to_mapping(val) + for key, val in vars(value).items() + if not key.startswith("_") + } + return value + + +def serialize_json_with_plain_string(value: Any) -> str: + if isinstance(value, str): + return value + return json.dumps(value) + + +def evaluate_filter(value: Any, filt: EvalFilter) -> bool: + current = _to_mapping(value) + for part in filt.path: + if not isinstance(current, dict) or part not in current: + return False + current = current[part] + return bool(filt.pattern.search(serialize_json_with_plain_string(current))) + + +def filter_evaluators( + evaluators: list[EvaluatorInstance], filters: list[EvalFilter] +) -> list[EvaluatorInstance]: + if not filters: + return evaluators + return [ + evaluator + for evaluator in evaluators + if all(evaluate_filter(evaluator.evaluator, filt) for filt in filters) + ] + + +async def call_evaluator_data(data: Any) -> tuple[Any, str | None]: + data_result = data + if inspect.isclass(data_result): + data_result = data_result() + if inspect.isfunction(data_result) or inspect.isroutine(data_result): + data_result = data_result() + if inspect.isawaitable(data_result): + data_result = await data_result + + base_experiment_name = None + if isinstance(data_result, BaseExperiment): + base_experiment_name = data_result.name + + return data_result, base_experiment_name + + +def to_async_iterator(value: Any) -> AsyncIterator[Any]: + if inspect.isasyncgen(value): + return value + + async def to_async(it): + for item in it: + yield item + + return to_async(value) + + +def collect_files(input_path: str) -> list[str]: + if os.path.isdir(input_path): + matches: list[str] = [] + for root, _, files in os.walk(input_path): + for filename in files: + matches.append(os.path.join(root, filename)) + return matches + return [input_path] + + +def resolve_module_info(in_file: str) -> tuple[str, list[str]]: + in_file = os.path.abspath(in_file) + module_dir = os.path.dirname(in_file) + module_name = os.path.splitext(os.path.basename(in_file))[0] + + package_parts: list[str] = [] + current = module_dir + while os.path.isfile(os.path.join(current, "__init__.py")): + package_parts.insert(0, os.path.basename(current)) + current = os.path.dirname(current) + + extra_paths = [module_dir] + if package_parts: + module_name = ".".join(package_parts + [module_name]) + if current not in extra_paths: + extra_paths.append(current) + + return module_name, extra_paths + + +def load_evaluators(files: list[str]) -> tuple[list[EvaluatorInstance], dict[str, Any]]: + evaluator_instances: list[EvaluatorInstance] = [] + reporters: dict[str, Any] = {} + cwd = os.getcwd() + if cwd not in sys.path: + sys.path.insert(0, cwd) + + for f in files: + d = os.path.dirname(os.path.abspath(f)) + while d and d != os.path.dirname(d): + if os.path.isfile(os.path.join(d, "register.py")): + if d not in sys.path: + sys.path.insert(0, d) + break + d = os.path.dirname(d) + + unique_files: set[str] = set() + for file_path in files: + for candidate in collect_files(file_path): + unique_files.add(os.path.abspath(candidate)) + + for file_path in sorted(unique_files): + module_name, extra_paths = resolve_module_info(file_path) + with _set_lazy_load(True): + _evals.clear() + try: + for extra_path in reversed(extra_paths): + if extra_path not in sys.path: + sys.path.insert(0, extra_path) + + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load module spec for {file_path}") + + sys.modules.pop(module_name, None) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + evaluator_instances.extend( + [ + instance + for instance in _evals.evaluators.values() + if isinstance(instance, EvaluatorInstance) + ] + ) + for reporter_name, reporter in _evals.reporters.items(): + if reporter_name not in reporters: + reporters[reporter_name] = reporter + finally: + _evals.clear() + + return evaluator_instances, reporters + + +__all__ = [ + "BaseExperiment", + "Dataset", + "EvalFilter", + "EvaluatorInstance", + "call_evaluator_data", + "env_flag", + "filter_evaluators", + "load_evaluators", + "parse_serialized_filters", + "serialize_json_with_plain_string", + "to_async_iterator", +] diff --git a/src/eval.rs b/src/eval.rs index 04037f9..8c88d03 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -179,6 +179,7 @@ struct EvalPullRequest { #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] enum EvalPullClientMessage { + Start { name: String }, Next, Close, } @@ -253,9 +254,17 @@ struct RunnerFilter { } const JS_RUNNER_FILE: &str = "eval-runner.ts"; +const JS_DATA_RUNNER_FILE: &str = "data-runner.ts"; +const JS_RUNNER_COMMON_FILE: &str = "runner-common.ts"; const PY_RUNNER_FILE: &str = "eval-runner.py"; +const PY_DATA_RUNNER_FILE: &str = "data-runner.py"; +const PY_RUNNER_COMMON_FILE: &str = "runner_common.py"; const JS_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.ts"); +const JS_DATA_RUNNER_SOURCE: &str = include_str!("../scripts/data-runner.ts"); +const JS_RUNNER_COMMON_SOURCE: &str = include_str!("../scripts/runner-common.ts"); const PY_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.py"); +const PY_DATA_RUNNER_SOURCE: &str = include_str!("../scripts/data-runner.py"); +const PY_RUNNER_COMMON_SOURCE: &str = include_str!("../scripts/runner_common.py"); #[derive(Debug)] struct SocketCleanupGuard { @@ -890,110 +899,131 @@ async fn spawn_eval_data_puller( ) -> Result { let (listener, socket_path, socket_cleanup_guard) = bind_unix_listener("bt-eval-pull").context("failed to bind sandbox pull socket")?; - let request_json = - serde_json::to_string(request).context("failed to serialize sandbox pull request")?; - let extra_env = vec![ - ("BT_EVAL_DEV_MODE".to_string(), "rows".to_string()), - ("BT_EVAL_DEV_REQUEST_JSON".to_string(), request_json), - ( - "BT_EVAL_PULL_SOCK".to_string(), - socket_path.to_string_lossy().to_string(), - ), - ]; - let child = spawn_eval_support_process( - base, - language, - runner_override, - files, - no_send_logs, - options, - &extra_env, - JsMode::Auto, - ) - .await?; + let child = match language { + EvalLanguage::JavaScript => { + let js_runner = prepare_js_data_runner()?; + let mut extra_env = vec![( + "BT_EVAL_PULL_SOCK".to_string(), + socket_path.to_string_lossy().to_string(), + )]; + let mut plan = build_js_plan_with_entrypoint( + runner_override, + &js_runner, + files, + JS_DATA_RUNNER_FILE, + JS_DATA_RUNNER_SOURCE, + )?; + if should_set_node_heap_size(plan.kind) { + set_node_heap_size_env(&mut plan.cmd); + } + plan.cmd.envs(build_env(base).await?); + for (key, value) in extra_env.drain(..) { + plan.cmd.env(key, value); + } + if no_send_logs { + plan.cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); + plan.cmd.env("BT_EVAL_LOCAL", "1"); + } + if options.jsonl { + plan.cmd.env("BT_EVAL_JSONL", "1"); + } + if options.terminate_on_failure { + plan.cmd.env("BT_EVAL_TERMINATE_ON_FAILURE", "1"); + } + if options.list { + plan.cmd.env("BT_EVAL_LIST", "1"); + } + if let Some(num_workers) = options.num_workers { + plan.cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + plan.cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } + if !options.extra_args.is_empty() { + let serialized = serde_json::to_string(&options.extra_args) + .context("failed to serialize eval extra args")?; + plan.cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + } + let runner_name = match plan.kind { + RunnerKind::Tsx => "tsx", + RunnerKind::ViteNode => "vite-node", + RunnerKind::Deno => "deno", + RunnerKind::Bun => "bun", + RunnerKind::Other => "other", + }; + plan.cmd.env("BT_EVAL_RUNNER_KIND", runner_name); + plan.cmd + .stdin(Stdio::null()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()); + plan.cmd + .spawn() + .context("failed to spawn sandbox pull runner")? + } + EvalLanguage::Python => { + let py_runner = prepare_py_data_runner()?; + let mut cmd = build_python_command(runner_override, &py_runner, files)?; + cmd.envs(build_env(base).await?); + cmd.env( + "BT_EVAL_PULL_SOCK", + socket_path.to_string_lossy().to_string(), + ); + if no_send_logs { + cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); + cmd.env("BT_EVAL_LOCAL", "1"); + } + if options.jsonl { + cmd.env("BT_EVAL_JSONL", "1"); + } + if options.terminate_on_failure { + cmd.env("BT_EVAL_TERMINATE_ON_FAILURE", "1"); + } + if options.list { + cmd.env("BT_EVAL_LIST", "1"); + } + if let Some(num_workers) = options.num_workers { + cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } + if !options.extra_args.is_empty() { + let serialized = serde_json::to_string(&options.extra_args) + .context("failed to serialize eval extra args")?; + cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + } + cmd.stdin(Stdio::null()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()); + cmd.spawn().context("failed to spawn sandbox pull runner")? + } + }; let (stream, _) = tokio::time::timeout(Duration::from_secs(30), listener.accept()) .await .context("timed out waiting for sandbox pull runner to connect")? .context("sandbox pull runner failed to connect")?; let (read_half, write_half) = stream.into_split(); - Ok(EvalDataPuller { + let mut puller = EvalDataPuller { child, writer: write_half, reader: BufReader::new(read_half), _socket_cleanup_guard: socket_cleanup_guard, - }) -} - -async fn spawn_eval_support_process( - base: &BaseArgs, - language: EvalLanguage, - runner_override: Option<&str>, - files: &[String], - no_send_logs: bool, - options: &EvalRunOptions, - extra_env: &[(String, String)], - js_mode: JsMode, -) -> Result { - let (js_runner, py_runner) = prepare_eval_runners()?; - let force_esm = matches!(js_mode, JsMode::ForceEsm); - let (mut cmd, runner_kind) = match language { - EvalLanguage::Python => ( - build_python_command(runner_override, &py_runner, files)?, - RunnerKind::Other, - ), - EvalLanguage::JavaScript => { - if force_esm { - ( - build_vite_node_fallback_command(&js_runner, files)?, - RunnerKind::ViteNode, - ) - } else { - let plan = build_js_plan(runner_override, &js_runner, files)?; - (plan.cmd, plan.kind) - } - } }; - if language == EvalLanguage::JavaScript && should_set_node_heap_size(runner_kind) { - set_node_heap_size_env(&mut cmd); - } - cmd.envs(build_env(base).await?); - for (key, value) in extra_env { - cmd.env(key, value); - } - if no_send_logs { - cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); - cmd.env("BT_EVAL_LOCAL", "1"); - } - if options.jsonl { - cmd.env("BT_EVAL_JSONL", "1"); - } - if options.terminate_on_failure { - cmd.env("BT_EVAL_TERMINATE_ON_FAILURE", "1"); - } - if options.list { - cmd.env("BT_EVAL_LIST", "1"); - } - if let Some(num_workers) = options.num_workers { - cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); - } - if !options.filter.is_empty() { - let parsed = parse_eval_filter_expressions(&options.filter)?; - let serialized = - serde_json::to_string(&parsed).context("failed to serialize eval filters")?; - cmd.env("BT_EVAL_FILTER_PARSED", serialized); - } - if language == EvalLanguage::JavaScript && force_esm { - cmd.env("BT_EVAL_FORCE_ESM", "1"); - } - if !options.extra_args.is_empty() { - let serialized = - serde_json::to_string(&options.extra_args).context("failed to serialize extra args")?; - cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + if matches!(language, EvalLanguage::JavaScript | EvalLanguage::Python) { + puller + .send_message(&EvalPullClientMessage::Start { + name: request.name.clone(), + }) + .await?; } - cmd.stdout(Stdio::inherit()); - cmd.stderr(Stdio::inherit()); - cmd.spawn().context("failed to start eval support runner") + Ok(puller) } async fn run_eval_runner_command_to_completion( @@ -2858,18 +2888,40 @@ fn build_js_plan( runner_override: Option<&str>, runner: &Path, files: &[String], +) -> Result { + build_js_plan_with_entrypoint( + runner_override, + runner, + files, + JS_RUNNER_FILE, + JS_RUNNER_SOURCE, + ) +} + +fn build_js_plan_with_entrypoint( + runner_override: Option<&str>, + runner: &Path, + files: &[String], + embedded_file_name: &str, + embedded_source: &str, ) -> Result { if let Some(explicit) = runner_override { let resolved_runner = resolve_js_runner_command(explicit, files); if is_deno_runner(explicit) || is_deno_runner_path(resolved_runner.as_ref()) { - let runner_script = prepare_js_runner_in_cwd()?; + let runner_script = + prepare_js_embedded_runner_in_cwd(embedded_file_name, embedded_source)?; return Ok(JsRunnerPlan { cmd: build_deno_js_command(resolved_runner.as_os_str(), &runner_script, files), kind: RunnerKind::Deno, }); } let kind = runner_kind_for_bin(resolved_runner.as_ref()); - let runner_script = select_js_runner_entrypoint(runner, resolved_runner.as_ref())?; + let runner_script = select_js_runner_entrypoint_with_source( + runner, + resolved_runner.as_ref(), + embedded_file_name, + embedded_source, + )?; let mut command = Command::new(resolved_runner); command.arg(runner_script).args(files); return Ok(JsRunnerPlan { cmd: command, kind }); @@ -2877,14 +2929,20 @@ fn build_js_plan( if let Some(auto_runner) = find_js_runner_binary(files) { if is_deno_runner_path(&auto_runner) { - let runner_script = prepare_js_runner_in_cwd()?; + let runner_script = + prepare_js_embedded_runner_in_cwd(embedded_file_name, embedded_source)?; return Ok(JsRunnerPlan { cmd: build_deno_js_command(auto_runner.as_os_str(), &runner_script, files), kind: RunnerKind::Deno, }); } let kind = runner_kind_for_bin(auto_runner.as_ref()); - let runner_script = select_js_runner_entrypoint(runner, auto_runner.as_ref())?; + let runner_script = select_js_runner_entrypoint_with_source( + runner, + auto_runner.as_ref(), + embedded_file_name, + embedded_source, + )?; let mut command = Command::new(auto_runner); command.arg(runner_script).args(files); return Ok(JsRunnerPlan { cmd: command, kind }); @@ -3039,14 +3097,19 @@ fn is_deno_runner_path(runner: &Path) -> bool { .unwrap_or(false) } -fn select_js_runner_entrypoint(default_runner: &Path, runner_command: &Path) -> Result { +fn select_js_runner_entrypoint_with_source( + default_runner: &Path, + runner_command: &Path, + embedded_file_name: &str, + embedded_source: &str, +) -> Result { if is_ts_node_runner(runner_command) { - return prepare_js_runner_in_cwd(); + return prepare_js_embedded_runner_in_cwd(embedded_file_name, embedded_source); } Ok(default_runner.to_path_buf()) } -fn prepare_js_runner_in_cwd() -> Result { +fn prepare_js_embedded_runner_in_cwd(file_name: &str, source: &str) -> Result { let cwd = std::env::current_dir().context("failed to resolve current working directory")?; let cache_dir = cwd .join(".bt") @@ -3058,7 +3121,8 @@ fn prepare_js_runner_in_cwd() -> Result { cache_dir.display() ) })?; - materialize_runner_script(&cache_dir, JS_RUNNER_FILE, JS_RUNNER_SOURCE) + materialize_runner_script(&cache_dir, JS_RUNNER_COMMON_FILE, JS_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(&cache_dir, file_name, source) } fn runner_bin_name(runner_command: &Path) -> Option { @@ -3208,6 +3272,14 @@ fn prepare_eval_runners() -> Result<(PathBuf, PathBuf)> { prepare_eval_runners_in_dir(&eval_runner_cache_dir()) } +fn prepare_js_data_runner() -> Result { + prepare_js_data_runner_in_dir(&eval_runner_cache_dir()) +} + +fn prepare_py_data_runner() -> Result { + prepare_py_data_runner_in_dir(&eval_runner_cache_dir()) +} + fn prepare_eval_runners_in_dir(cache_dir: &Path) -> Result<(PathBuf, PathBuf)> { std::fs::create_dir_all(cache_dir).with_context(|| { format!( @@ -3216,11 +3288,35 @@ fn prepare_eval_runners_in_dir(cache_dir: &Path) -> Result<(PathBuf, PathBuf)> { ) })?; + materialize_runner_script(cache_dir, JS_RUNNER_COMMON_FILE, JS_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(cache_dir, PY_RUNNER_COMMON_FILE, PY_RUNNER_COMMON_SOURCE)?; let js_runner = materialize_runner_script(cache_dir, JS_RUNNER_FILE, JS_RUNNER_SOURCE)?; let py_runner = materialize_runner_script(cache_dir, PY_RUNNER_FILE, PY_RUNNER_SOURCE)?; Ok((js_runner, py_runner)) } +fn prepare_js_data_runner_in_dir(cache_dir: &Path) -> Result { + std::fs::create_dir_all(cache_dir).with_context(|| { + format!( + "failed to create eval runner cache dir {}", + cache_dir.display() + ) + })?; + materialize_runner_script(cache_dir, JS_RUNNER_COMMON_FILE, JS_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(cache_dir, JS_DATA_RUNNER_FILE, JS_DATA_RUNNER_SOURCE) +} + +fn prepare_py_data_runner_in_dir(cache_dir: &Path) -> Result { + std::fs::create_dir_all(cache_dir).with_context(|| { + format!( + "failed to create eval runner cache dir {}", + cache_dir.display() + ) + })?; + materialize_runner_script(cache_dir, PY_RUNNER_COMMON_FILE, PY_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(cache_dir, PY_DATA_RUNNER_FILE, PY_DATA_RUNNER_SOURCE) +} + fn materialize_runner_script(cache_dir: &Path, file_name: &str, source: &str) -> Result { let path = cache_dir.join(file_name); let current = std::fs::read_to_string(&path).ok(); @@ -4324,11 +4420,27 @@ mod tests { let dir = make_temp_dir("embedded"); let (js_runner, py_runner) = prepare_eval_runners_in_dir(&dir).expect("embedded runners should be materialized"); + let js_data_runner = prepare_js_data_runner_in_dir(&dir) + .expect("embedded data runner should be materialized"); + let py_data_runner = prepare_py_data_runner_in_dir(&dir) + .expect("embedded py data runner should be materialized"); let js = fs::read_to_string(js_runner).expect("js runner should be readable"); + let js_data = + fs::read_to_string(js_data_runner).expect("js data runner should be readable"); + let js_common = fs::read_to_string(dir.join(JS_RUNNER_COMMON_FILE)) + .expect("js common should be readable"); let py = fs::read_to_string(py_runner).expect("python runner should be readable"); + let py_data = + fs::read_to_string(py_data_runner).expect("python data runner should be readable"); + let py_common = fs::read_to_string(dir.join(PY_RUNNER_COMMON_FILE)) + .expect("python common should be readable"); assert_eq!(js, JS_RUNNER_SOURCE); + assert_eq!(js_data, JS_DATA_RUNNER_SOURCE); + assert_eq!(js_common, JS_RUNNER_COMMON_SOURCE); assert_eq!(py, PY_RUNNER_SOURCE); + assert_eq!(py_data, PY_DATA_RUNNER_SOURCE); + assert_eq!(py_common, PY_RUNNER_COMMON_SOURCE); let _ = fs::remove_dir_all(&dir); } diff --git a/tests/eval_fixtures.rs b/tests/eval_fixtures.rs index 1e1d5ba..b64485f 100644 --- a/tests/eval_fixtures.rs +++ b/tests/eval_fixtures.rs @@ -4,6 +4,8 @@ use std::fs; use std::io::Write; use std::io::{BufRead, BufReader, Read}; #[cfg(unix)] +use std::os::unix::fs::symlink; +#[cfg(unix)] use std::os::unix::net::{UnixListener, UnixStream}; use std::path::{Path, PathBuf}; use std::process::{Child, Command, Stdio}; @@ -15,6 +17,8 @@ use serde::Deserialize; #[cfg(unix)] use serde_json::json; use serde_json::Value; +#[cfg(unix)] +use tempfile::tempdir; #[derive(Debug, Deserialize, Clone)] struct FixtureConfig { @@ -463,9 +467,20 @@ fn eval_runner_rows_mode_streams_js_rows_and_trials() { .join("eval-ts-cjs"); ensure_dependencies(&fixture_dir); let runner = local_tsx_path(&fixture_dir).expect("resolve tsx runner"); - let runner_script = root.join("scripts").join("eval-runner.ts"); - let fixture_name = format!("sandbox_rows_{}.eval.ts", unique_test_suffix()); - let fixture_path = fixture_dir.join(&fixture_name); + let runner_script = root.join("scripts").join("data-runner.ts"); + let temp_fixture_dir = tempdir().expect("create js rows tempdir"); + fs::copy( + fixture_dir.join("package.json"), + temp_fixture_dir.path().join("package.json"), + ) + .expect("copy js fixture package.json"); + symlink( + fixture_dir.join("node_modules"), + temp_fixture_dir.path().join("node_modules"), + ) + .expect("symlink js fixture node_modules"); + let fixture_name = "sandbox_rows.eval.ts"; + let fixture_path = temp_fixture_dir.path().join(fixture_name); let fixture_source = r#"import { Eval } from "braintrust"; async function* rows() { @@ -475,10 +490,9 @@ async function* rows() { Eval("sandbox-rows-js", { data: rows, - task: async (input: { case_id: string }) => - input.case_id === "row-1" ? "alpha" : "bravo", + task: async (input) => (input.case_id === "row-1" ? "alpha" : "bravo"), scores: [ - ({ output, expected }: { output: string; expected?: string }) => ({ + ({ output, expected }) => ({ name: "match", score: output === expected ? 1 : 0, }), @@ -497,17 +511,8 @@ Eval("sandbox-rows-js", { let mut child = Command::new(&runner) .arg(&runner_script) - .arg(&fixture_name) - .current_dir(&fixture_dir) - .env("BT_EVAL_DEV_MODE", "rows") - .env( - "BT_EVAL_DEV_REQUEST_JSON", - json!({ - "name": "sandbox-rows-js", - "parameters": {}, - }) - .to_string(), - ) + .arg(fixture_name) + .current_dir(temp_fixture_dir.path()) .env("BT_EVAL_PULL_SOCK", &socket_path) .env("BT_EVAL_LOCAL", "1") .env("BT_EVAL_NO_SEND_LOGS", "1") @@ -521,6 +526,13 @@ Eval("sandbox-rows-js", { let mut reader = BufReader::new(stream.try_clone().expect("clone pull stream")); let mut writer = stream; + write_pull_message( + &mut writer, + &json!({ + "type": "start", + "name": "sandbox-rows-js", + }), + ); let ready = read_pull_message(&mut reader); assert_eq!(ready["type"], "ready"); assert_eq!(ready["evaluator_name"], "sandbox-rows-js"); @@ -572,7 +584,6 @@ Eval("sandbox-rows-js", { } let _ = fs::remove_file(&socket_path); - let _ = fs::remove_file(&fixture_path); } #[cfg(unix)] @@ -581,7 +592,6 @@ fn eval_runner_rows_mode_streams_python_rows_and_trials() { let _guard = test_lock(); let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); let fixtures_root = root.join("tests").join("evals"); - let fixture_dir = fixtures_root.join("py").join("local_import"); let python = match ensure_python_env(&fixtures_root.join("py")) { Some(python) => python, None => { @@ -595,9 +605,10 @@ fn eval_runner_rows_mode_streams_python_rows_and_trials() { } }; - let runner_script = root.join("scripts").join("eval-runner.py"); - let fixture_name = format!("sandbox_rows_{}.py", unique_test_suffix()); - let fixture_path = fixture_dir.join(&fixture_name); + let runner_script = root.join("scripts").join("data-runner.py"); + let temp_fixture_dir = tempdir().expect("create python rows tempdir"); + let fixture_name = "sandbox_rows.py"; + let fixture_path = temp_fixture_dir.path().join(fixture_name); let fixture_source = r#"from braintrust import Eval def rows(): @@ -629,17 +640,8 @@ Eval( let mut child = Command::new(&python) .arg(&runner_script) - .arg(&fixture_name) - .current_dir(&fixture_dir) - .env("BT_EVAL_DEV_MODE", "rows") - .env( - "BT_EVAL_DEV_REQUEST_JSON", - json!({ - "name": "sandbox-rows-py", - "parameters": {}, - }) - .to_string(), - ) + .arg(fixture_name) + .current_dir(temp_fixture_dir.path()) .env("BT_EVAL_PULL_SOCK", &socket_path) .env("BT_EVAL_LOCAL", "1") .env("BT_EVAL_NO_SEND_LOGS", "1") @@ -653,6 +655,13 @@ Eval( let mut reader = BufReader::new(stream.try_clone().expect("clone pull stream")); let mut writer = stream; + write_pull_message( + &mut writer, + &json!({ + "type": "start", + "name": "sandbox-rows-py", + }), + ); let ready = read_pull_message(&mut reader); assert_eq!(ready["type"], "ready"); assert_eq!(ready["evaluator_name"], "sandbox-rows-py"); @@ -706,7 +715,6 @@ Eval( } let _ = fs::remove_file(&socket_path); - let _ = fs::remove_file(&fixture_path); } fn read_fixture_config(path: &Path) -> FixtureConfig {