diff --git a/scripts/data-runner.py b/scripts/data-runner.py new file mode 100644 index 0000000..35de3d5 --- /dev/null +++ b/scripts/data-runner.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import socket +import sys +import time +from dataclasses import dataclass +from typing import Any + +try: + from braintrust.util import eprint + from runner_common import call_evaluator_data, load_evaluators, to_async_iterator +except Exception as exc: # pragma: no cover - runtime guard + print( + "Unable to import the braintrust package. Please install it in your Python environment.", + file=sys.stderr, + ) + print(str(exc), file=sys.stderr) + sys.exit(1) + + +@dataclass +class PullChannel: + sock: socket.socket + + def send(self, payload: Any) -> None: + self.sock.sendall((json.dumps(payload) + "\n").encode("utf-8")) + + async def lines(self): + buffer = "" + while True: + chunk = await asyncio.to_thread(self.sock.recv, 4096) + if not chunk: + break + buffer += chunk.decode("utf-8") + while True: + newline = buffer.find("\n") + if newline == -1: + break + line = buffer[:newline].strip() + buffer = buffer[newline + 1 :] + if line: + yield line + + trailing = buffer.strip() + if trailing: + yield trailing + + def close(self) -> None: + try: + self.sock.shutdown(socket.SHUT_RDWR) + except OSError: + pass + self.sock.close() + + +def create_pull_channel() -> PullChannel: + sock_path = os.getenv("BT_EVAL_PULL_SOCK") + if not sock_path: + raise ValueError("Missing BT_EVAL_PULL_SOCK") + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(sock_path) + return PullChannel(sock) + + +def parse_start_request(raw: str) -> str: + parsed = json.loads(raw) + if not isinstance(parsed, dict): + raise ValueError("Start request must be a JSON object.") + if parsed.get("type") != "start": + raise ValueError("Expected initial start command.") + name = parsed.get("name") + if not isinstance(name, str) or not name: + raise ValueError("Start request must include a non-empty evaluator name.") + return name + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Stream eval rows over a unix socket for bt.") + parser.add_argument("files", nargs="*", help="Eval files or directories to load.") + return parser + + +async def run(files: list[str]) -> int: + evaluators, _reporters = load_evaluators(files) + channel = create_pull_channel() + + try: + line_iter = channel.lines() + try: + start_line = await anext(line_iter) + except StopAsyncIteration: + return 0 + + try: + target_name = parse_start_request(start_line) + except Exception as exc: + channel.send({"type": "error", "message": str(exc)}) + return 1 + + evaluator_instance = next( + (candidate for candidate in evaluators if candidate.evaluator.eval_name == target_name), + None, + ) + if evaluator_instance is None: + channel.send({"type": "error", "message": f"Evaluator '{target_name}' not found"}) + return 1 + + evaluator = evaluator_instance.evaluator + raw_data, _base_experiment_name = await call_evaluator_data(evaluator.data) + data_iterator = to_async_iterator(raw_data) + iterator = data_iterator.__aiter__() + + trial_count = getattr(evaluator, "trial_count", 1) + try: + trial_count = int(trial_count) + except Exception: + trial_count = 1 + if trial_count < 1: + trial_count = 1 + + max_concurrency = getattr(evaluator, "max_concurrency", None) + try: + max_concurrency = int(max_concurrency) if max_concurrency is not None else 10 + except Exception: + max_concurrency = 10 + if max_concurrency < 1: + max_concurrency = 1 + + experiment_name = getattr(evaluator, "experiment_name", None) + if not isinstance(experiment_name, str) or not experiment_name: + experiment_name = f"{evaluator.eval_name}-{int(time.time() * 1000)}" + + channel.send( + { + "type": "ready", + "evaluator_name": evaluator.eval_name, + "max_concurrency": max_concurrency, + "experiment_name": experiment_name, + } + ) + + current_datum = None + trial_index = 0 + async for line in line_iter: + parsed = json.loads(line) + command_type = parsed.get("type") if isinstance(parsed, dict) else None + if command_type == "close": + break + if command_type != "next": + channel.send( + { + "type": "error", + "message": f"Unsupported pull command '{command_type}'", + } + ) + return 1 + + if current_datum is None: + try: + current_datum = await iterator.__anext__() + trial_index = 0 + except StopAsyncIteration: + channel.send({"type": "eof"}) + continue + + channel.send( + { + "type": "row", + "datum": current_datum, + "trial_index": trial_index, + } + ) + trial_index += 1 + if trial_index >= trial_count: + current_datum = None + + return 0 + finally: + channel.close() + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + files = args.files or ["."] + + try: + return asyncio.run(run(files)) + except Exception as exc: + eprint(str(exc)) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/data-runner.ts b/scripts/data-runner.ts new file mode 100644 index 0000000..53cb81e --- /dev/null +++ b/scripts/data-runner.ts @@ -0,0 +1,201 @@ +import net from "node:net"; +import readline from "node:readline"; + +import { + callEvaluatorData, + formatError, + getBraintrustStateGetter, + getEvaluators, + initRegistry, + loadBraintrust, + loadFiles, + normalizeFiles, + propagateInheritedBraintrustState, + toAsyncIterable, +} from "./runner-common"; + +type StartMessage = { + type: "start"; + name: string; +}; + +type ClientMessage = + | StartMessage + | { type: "next" } + | { type: "close" }; + +type ServerMessage = + | { + type: "ready"; + evaluator_name: string; + max_concurrency: number; + experiment_name: string; + } + | { type: "row"; datum: unknown; trial_index: number } + | { type: "eof" } + | { type: "error"; message: string }; + +function writeMessage(socket: net.Socket, message: ServerMessage) { + socket.write(`${JSON.stringify(message)}\n`); +} + +function parseMessage(line: string): ClientMessage { + const parsed = JSON.parse(line) as { type?: unknown; name?: unknown }; + if (parsed.type === "start") { + if (typeof parsed.name !== "string" || parsed.name.length === 0) { + throw new Error("Start request must include a non-empty evaluator name."); + } + return { type: "start", name: parsed.name }; + } + if (parsed.type === "next" || parsed.type === "close") { + return { type: parsed.type }; + } + throw new Error(`Unsupported pull command '${String(parsed.type)}'`); +} + +async function readMessage( + lines: AsyncIterator, +): Promise { + const next = await lines.next(); + if (next.done) { + return null; + } + return parseMessage(next.value); +} + +function applyExtraArgsFromEnv() { + const extraArgs: string[] = process.env.BT_EVAL_EXTRA_ARGS_JSON + ? (JSON.parse(process.env.BT_EVAL_EXTRA_ARGS_JSON) as string[]) + : []; + process.argv = [...process.argv.slice(0, 2), ...extraArgs]; +} + +function toPositiveInteger(value: unknown, fallback: number): number { + const parsed = Number(value); + if (Number.isFinite(parsed) && parsed > 0) { + return Math.floor(parsed); + } + return fallback; +} + +async function main() { + const files = process.argv.slice(2); + if (files.length === 0) { + throw new Error("No eval files provided."); + } + const socketPath = process.env.BT_EVAL_PULL_SOCK; + if (!socketPath) { + throw new Error("Missing BT_EVAL_PULL_SOCK"); + } + + const normalized = normalizeFiles(files); + const braintrust = await loadBraintrust(normalized); + propagateInheritedBraintrustState(braintrust); + initRegistry(); + applyExtraArgsFromEnv(); + await loadFiles(normalized); + + const socket = net.createConnection({ path: socketPath }); + const socketReady = new Promise((resolve, reject) => { + socket.once("connect", resolve); + socket.once("error", reject); + }); + await socketReady; + + const reader = readline.createInterface({ + input: socket, + crlfDelay: Infinity, + }); + const lines = reader[Symbol.asyncIterator](); + + try { + const start = await readMessage(lines); + if (!start) { + return; + } + if (start.type !== "start") { + throw new Error("Expected initial start command."); + } + + const entry = getEvaluators().find( + (candidate) => candidate.evaluator.evalName === start.name, + ); + if (!entry) { + writeMessage(socket, { + type: "error", + message: `Evaluator '${start.name}' not found`, + }); + return; + } + + const getState = getBraintrustStateGetter(braintrust); + const state = getState ? getState() : undefined; + const evaluator = { + ...entry.evaluator, + ...(state !== undefined && state !== null ? { state } : {}), + }; + const { data: rawData } = callEvaluatorData(evaluator.data); + const dataIterable = toAsyncIterable(rawData); + const iterator = dataIterable[Symbol.asyncIterator](); + const trialCount = toPositiveInteger(evaluator.trialCount, 1); + const maxConcurrency = toPositiveInteger(evaluator.maxConcurrency, 10); + const experimentName = + typeof evaluator.experimentName === "string" && + evaluator.experimentName.length > 0 + ? evaluator.experimentName + : `${entry.evaluator.evalName}-${Date.now()}`; + + writeMessage(socket, { + type: "ready", + evaluator_name: entry.evaluator.evalName, + max_concurrency: maxConcurrency, + experiment_name: experimentName, + }); + + let currentDatum: unknown | undefined; + let trialIndex = 0; + while (true) { + const message = await readMessage(lines); + if (!message || message.type === "close") { + return; + } + if (message.type !== "next") { + throw new Error(`Unsupported pull command '${message.type}'`); + } + + if (currentDatum === undefined) { + const next = await iterator.next(); + if (next.done) { + writeMessage(socket, { type: "eof" }); + continue; + } + currentDatum = next.value; + trialIndex = 0; + } + + writeMessage(socket, { + type: "row", + datum: currentDatum, + trial_index: trialIndex, + }); + + trialIndex += 1; + if (trialIndex >= trialCount) { + currentDatum = undefined; + } + } + } catch (err) { + writeMessage(socket, { + type: "error", + message: formatError(err), + }); + } finally { + reader.close(); + socket.end(); + } +} + +main().catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/scripts/eval-runner.py b/scripts/eval-runner.py index e9d3ce0..776aa1c 100755 --- a/scripts/eval-runner.py +++ b/scripts/eval-runner.py @@ -1,12 +1,9 @@ #!/usr/bin/env python3 import argparse import asyncio -import fnmatch -import importlib.util import inspect import json import os -import re import socket import sys import traceback @@ -17,9 +14,6 @@ from braintrust import init_dataset, invoke, login from braintrust.framework import ( BaseExperiment, - EvaluatorInstance, - _evals, - _set_lazy_load, run_evaluator, set_thread_pool_max_workers, ) @@ -27,6 +21,15 @@ from braintrust.parameters import parameters_to_json_schema, validate_parameters from braintrust.util import eprint from braintrust.span_identifier_v4 import parse_parent + from runner_common import ( + EvalFilter, + EvaluatorInstance, + call_evaluator_data, + env_flag, + filter_evaluators, + load_evaluators, + parse_serialized_filters, + ) except Exception as exc: # pragma: no cover - runtime guard print( "Unable to import the braintrust package. Please install it in your Python environment.", @@ -44,14 +47,6 @@ "/venv/", ) _DATASET_TOTAL_CACHE: dict[str, int] = {} - - -@dataclass(frozen=True) -class EvalFilter: - path: list[str] - pattern: re.Pattern[str] - - @dataclass(frozen=True) class RunnerConfig: jsonl: bool @@ -103,37 +98,6 @@ def create_sse_writer() -> SseWriter | None: return SseWriter(sock) return None - - -def env_flag(name: str) -> bool: - value = os.getenv(name) - if value is None: - return False - return value.lower() not in {"0", "false", "no", "off", ""} - - -def parse_serialized_filters(serialized: str | None) -> list[EvalFilter]: - if not serialized: - return [] - - parsed = json.loads(serialized) - if not isinstance(parsed, list): - raise ValueError("BT_EVAL_FILTER_PARSED must be a JSON array") - - filters: list[EvalFilter] = [] - for i, entry in enumerate(parsed): - if not isinstance(entry, dict): - raise ValueError("BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}") - key_path = entry.get("path") - pattern = entry.get("pattern") - if not isinstance(key_path, list) or not all(isinstance(part, str) for part in key_path): - raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} path must be an array of strings") - if not isinstance(pattern, str): - raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} pattern must be a string") - filters.append(EvalFilter(path=key_path, pattern=re.compile(pattern))) - return filters - - def parse_dev_mode(value: str | None) -> str | None: if value is None or value == "": return None @@ -155,46 +119,6 @@ def read_runner_config() -> RunnerConfig: dev_request_json=os.getenv("BT_EVAL_DEV_REQUEST_JSON"), ) - -def _to_mapping(value: Any) -> Any: - if isinstance(value, dict): - return {k: _to_mapping(v) for k, v in value.items()} - if isinstance(value, list): - return [_to_mapping(v) for v in value] - if hasattr(value, "__dict__"): - return { - key: _to_mapping(val) - for key, val in vars(value).items() - if not key.startswith("_") - } - return value - - -def serialize_json_with_plain_string(value: Any) -> str: - if isinstance(value, str): - return value - return json.dumps(value) - - -def evaluate_filter(value: Any, filt: EvalFilter) -> bool: - current = _to_mapping(value) - for part in filt.path: - if not isinstance(current, dict) or part not in current: - return False - current = current[part] - return bool(filt.pattern.search(serialize_json_with_plain_string(current))) - - -def filter_evaluators(evaluators: list[EvaluatorInstance], filters: list[EvalFilter]) -> list[EvaluatorInstance]: - if not filters: - return evaluators - return [ - evaluator - for evaluator in evaluators - if all(evaluate_filter(evaluator.evaluator, filt) for filt in filters) - ] - - def snake_to_camel(value: str) -> str: parts = value.split("_") if not parts: @@ -364,16 +288,6 @@ def build_eval_definitions(evaluator_instances: list[EvaluatorInstance]) -> dict return definitions -def collect_files(input_path: str) -> list[str]: - if os.path.isdir(input_path): - matches: list[str] = [] - for root, _, files in os.walk(input_path): - for filename in files: - matches.append(os.path.join(root, filename)) - return matches - return [input_path] - - def is_watchable_dependency(path_input: str, cwd: str) -> bool: path = os.path.abspath(path_input) normalized = path.replace("\\", "/") @@ -409,86 +323,6 @@ def collect_dependency_files(cwd: str, input_files: list[str]) -> list[str]: return sorted(dependencies) -def resolve_module_info(in_file: str) -> tuple[str, list[str]]: - in_file = os.path.abspath(in_file) - module_dir = os.path.dirname(in_file) - module_name = os.path.splitext(os.path.basename(in_file))[0] - - package_parts: list[str] = [] - current = module_dir - while os.path.isfile(os.path.join(current, "__init__.py")): - package_parts.insert(0, os.path.basename(current)) - current = os.path.dirname(current) - - extra_paths = [module_dir] - if package_parts: - module_name = ".".join(package_parts + [module_name]) - if current not in extra_paths: - extra_paths.append(current) - - return module_name, extra_paths - - -def load_evaluators(files: list[str]) -> tuple[list[EvaluatorInstance], dict[str, Any]]: - evaluator_instances: list[EvaluatorInstance] = [] - reporters: dict[str, Any] = {} - cwd = os.getcwd() - if cwd not in sys.path: - sys.path.insert(0, cwd) - - # Add the project root inferred from input files to sys.path so that - # sibling-package imports work when files live outside CWD (e.g. - # sandbox bundles extracted to a temp directory). Walk up from each - # file's directory looking for a register.py (bundle marker) or the - # filesystem root, whichever comes first. - for f in files: - d = os.path.dirname(os.path.abspath(f)) - while d and d != os.path.dirname(d): - if os.path.isfile(os.path.join(d, "register.py")): - if d not in sys.path: - sys.path.insert(0, d) - break - d = os.path.dirname(d) - - unique_files: set[str] = set() - for file_path in files: - for candidate in collect_files(file_path): - unique_files.add(os.path.abspath(candidate)) - - for file_path in sorted(unique_files): - module_name, extra_paths = resolve_module_info(file_path) - with _set_lazy_load(True): - _evals.clear() - try: - for extra_path in reversed(extra_paths): - if extra_path not in sys.path: - sys.path.insert(0, extra_path) - - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None or spec.loader is None: - raise ImportError(f"Unable to load module spec for {file_path}") - - sys.modules.pop(module_name, None) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - - evaluator_instances.extend( - [ - instance - for instance in _evals.evaluators.values() - if isinstance(instance, EvaluatorInstance) - ] - ) - for reporter_name, reporter in _evals.reporters.items(): - if reporter_name not in reporters: - reporters[reporter_name] = reporter - finally: - _evals.clear() - - return evaluator_instances, reporters - - def resolve_reporter( reporter: Any, reporters: dict[str, Any], @@ -872,7 +706,6 @@ async def run_once( return True if config.dev_mode == "eval": return await run_requested_eval(evaluators, reporters, no_send_logs, sse, config) - if config.list_only: for evaluator_instance in evaluators: print(evaluator_instance.evaluator.eval_name) diff --git a/scripts/eval-runner.ts b/scripts/eval-runner.ts index 2a19c10..40f4b86 100644 --- a/scripts/eval-runner.ts +++ b/scripts/eval-runner.ts @@ -2,13 +2,27 @@ import { createRequire } from "node:module"; import path from "node:path"; import { fileURLToPath, pathToFileURL } from "node:url"; -type EvaluatorEntry = { - evaluator: { - evalName: string; - projectName: string; - } & Record; - reporter?: unknown; -}; +import { + envFlag, + filterEvaluators, + formatError, + getBraintrustStateGetter, + getEvaluators, + getReporters, + initRegistry, + isObject, + loadBraintrust, + loadBraintrustUtilParseParent, + loadFiles, + normalizeFiles, + parseSerializedFilters, + propagateInheritedBraintrustState, + resolveBraintrustPath, + type EvalFilter, + type EvaluatorEntry, + type GlobalEvals, + type ParseParentFunction, +} from "./runner-common"; type EvalResult = { results: Array<{ error?: unknown }>; @@ -43,23 +57,6 @@ type InitDatasetFunction = ( ) => unknown; type InvokeFunction = (options: Record) => Promise; -type BraintrustModule = { - Eval?: EvalFunction; - login?: LoginFunction; - initDataset?: InitDatasetFunction; - invoke?: InvokeFunction; - _internalGetGlobalState?: () => unknown; - default?: BraintrustModule; -}; - -type GlobalEvals = { - functions: unknown[]; - prompts: unknown[]; - parameters: unknown[]; - evaluators: Record; - reporters: Record; -}; - type BtEvalMain = (context: BtEvalContext) => void | Promise; type BtEvalContext = { @@ -83,16 +80,6 @@ type SseWriter = { close: () => void; }; -type EvalFilter = { - path: string[]; - pattern: RegExp; -}; - -type SerializedEvalFilter = { - path: string[]; - pattern: string; -}; - type EvalScoreSpec = { name: string; function_id: Record; @@ -145,8 +132,6 @@ type EvalRunner = { type ParameterContainerSerializer = (parameters: unknown) => unknown; type PromptDefinitionSerializer = (prompt: unknown) => unknown; type ZodSchemaSerializer = (schema: unknown) => Record; -type ParseParentFunction = (parent: unknown) => string | undefined; - type ParameterSerializationHelpers = { sdkSerializeParameters: ParameterContainerSerializer | null; promptDefinitionToPromptData: PromptDefinitionSerializer | null; @@ -162,88 +147,9 @@ declare global { var __inherited_braintrust_state: unknown; } -function isObject(value: unknown): value is Record { - return typeof value === "object" && value !== null; -} - -function isBraintrustModule(value: unknown): value is BraintrustModule { - return isObject(value) && ("Eval" in value || "login" in value); -} - -function normalizeBraintrustModule(value: unknown): BraintrustModule { - if (isBraintrustModule(value)) { - return value; - } - if (isObject(value) && isBraintrustModule(value.default)) { - return value.default; - } - throw new Error("Unable to load braintrust module."); -} - -function normalizeFiles(files: string[]): string[] { - return files.map((file) => path.resolve(process.cwd(), file)); -} - -function envFlag(name: string): boolean { - const value = process.env[name]; - if (!value) { - return false; - } - const normalized = value.toLowerCase(); - return !["0", "false", "no", "off", ""].includes(normalized); -} - -function serializeJSONWithPlainString(value: unknown): string { - if (typeof value === "string") { - return value; - } - return JSON.stringify(value); -} - -function parseSerializedFilters(serialized: string | undefined): EvalFilter[] { - if (!serialized) { - return []; - } - - try { - const parsed = JSON.parse(serialized); - if (!Array.isArray(parsed)) { - throw new Error("BT_EVAL_FILTER_PARSED must be a JSON array."); - } - return parsed.map((value) => { - if (!isObject(value)) { - throw new Error( - "BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}.", - ); - } - const { path: rawPath, pattern: rawPattern } = - value as SerializedEvalFilter; - if ( - !Array.isArray(rawPath) || - !rawPath.every((part) => typeof part === "string") - ) { - throw new Error( - "BT_EVAL_FILTER_PARSED entry path must be an array of strings.", - ); - } - if (typeof rawPattern !== "string") { - throw new Error( - "BT_EVAL_FILTER_PARSED entry pattern must be a string.", - ); - } - return { - path: rawPath, - pattern: new RegExp(rawPattern), - }; - }); - } catch (err) { - throw new Error( - `Invalid BT_EVAL_FILTER_PARSED value: ${err instanceof Error ? err.message : String(err)}`, - ); - } -} - -function parseDevMode(value: string | undefined): "list" | "eval" | null { +function parseDevMode( + value: string | undefined, +): "list" | "eval" | null { if (!value) { return null; } @@ -283,6 +189,7 @@ type NetModule = { setNoDelay: (value?: boolean) => void; on: (event: string, listener: (...args: unknown[]) => void) => void; write: (data: string) => void; + [Symbol.asyncIterator]?: () => AsyncIterator; }; }; @@ -766,65 +673,6 @@ function createSseWriter(): SseWriter | null { return { send, close }; } -function initRegistry() { - globalThis._evals = { - functions: [], - prompts: [], - parameters: [], - evaluators: {}, - reporters: {}, - }; - globalThis._lazy_load = true; -} - -function ensureBraintrustAvailable() { - resolveBraintrustPath(); -} - -function resolveBraintrustPath(): string { - const files = normalizeFiles(process.argv.slice(2)); - for (const file of files) { - try { - const require = createRequire(pathToFileURL(file).href); - return require.resolve("braintrust"); - } catch { - continue; - } - } - - try { - const require = createRequire(process.cwd() + "/"); - return require.resolve("braintrust"); - } catch { - const message = - "Unable to resolve the `braintrust` package. " + - "Please install it in your project (e.g. `pnpm add braintrust` or `npm install braintrust`)."; - throw new Error(message); - } -} - -async function loadBraintrust() { - const cjsPath = resolveBraintrustPath(); - const cjsUrl = pathToFileURL(cjsPath).href; - - try { - const mod: unknown = await import(cjsUrl); - return normalizeBraintrustModule(mod); - } catch {} - - const esmPath = cjsPath.replace(/\.js$/, ".mjs"); - if (esmPath !== cjsPath && fsMutable.existsSync(esmPath)) { - try { - const mod: unknown = await import(pathToFileURL(esmPath).href); - return normalizeBraintrustModule(mod); - } catch {} - } - - const require = createRequire(cjsUrl); - const mod: unknown = require(cjsPath); - return normalizeBraintrustModule(mod); -} - function extractParameterSerializer( mod: unknown, ): ParameterContainerSerializer | null { @@ -964,7 +812,7 @@ function loadZodSchemaSerializer( } async function loadParameterSerializationHelpers(): Promise { - const braintrustPath = resolveBraintrustPath(); + const braintrustPath = resolveBraintrustPath(process.argv.slice(2)); const zodToJsonSchema = loadZodSchemaSerializer(braintrustPath); try { const mod: unknown = await import(pathToFileURL(braintrustPath).href); @@ -982,169 +830,6 @@ async function loadParameterSerializationHelpers(): Promise unknown) | null { - if (!isObject(mod)) { - return null; - } - const candidate = Reflect.get(mod, "_internalGetGlobalState"); - if (typeof candidate === "function") { - return candidate as () => unknown; - } - const defaultExport = Reflect.get(mod, "default"); - if (isObject(defaultExport)) { - const fromDefault = Reflect.get(defaultExport, "_internalGetGlobalState"); - if (typeof fromDefault === "function") { - return fromDefault as () => unknown; - } - } - return null; -} - -function loadBraintrustUtilParseParent(): ParseParentFunction | null { - const braintrustPath = resolveBraintrustPath(); - const requireFromBraintrust = createRequire( - pathToFileURL(braintrustPath).href, - ); - try { - const utilMod: unknown = requireFromBraintrust("braintrust/util"); - return extractParseParent(utilMod); - } catch { - return null; - } -} - -function propagateInheritedBraintrustState(braintrust: BraintrustModule) { - const getter = (braintrust as Record) - ._internalGetGlobalState; - if (typeof getter !== "function") { - return; - } - const state = getter(); - if (state !== undefined && state !== null) { - globalThis.__inherited_braintrust_state = state; - } -} - -async function loadFiles(files: string[]): Promise { - const modules: unknown[] = []; - // Internal CLI-controlled flag for ESM retry; not user-facing config. - const forceEsm = envFlag("BT_EVAL_FORCE_ESM"); - // vite-node installs transform hooks that handle TypeScript (including - // extension-less imports) and CJS named-export interop natively. A failed - // require() corrupts Node's module cache and causes the subsequent import() - // to hit the "module imported again after being required" bug, so we skip - // require() for .ts/.tsx files entirely when running under vite-node. - const isViteNode = process.env.BT_EVAL_RUNNER_KIND === "vite-node"; - for (const file of files) { - const fileUrl = pathToFileURL(file).href; - const isTypeScript = file.endsWith(".ts") || file.endsWith(".tsx"); - const preferRequire = - !forceEsm && - !(isViteNode && isTypeScript) && - (isTypeScript || file.endsWith(".cjs")); - - if (preferRequire) { - try { - const require = createRequire(fileUrl); - const mod = require(file); - modules.push(mod); - continue; - } catch (requireErr) { - try { - const mod = await import(fileUrl); - modules.push(mod); - continue; - } catch (esmErr) { - throw new Error( - `Failed to load ${file} as CJS (${formatError(requireErr)}) or ESM (${formatError(esmErr)}).`, - ); - } - } - } - - try { - const mod = await import(fileUrl); - modules.push(mod); - continue; - } catch (err) { - if (!shouldTryRequire(file, err)) { - throw err; - } - try { - const require = createRequire(fileUrl); - const mod = require(file); - modules.push(mod); - continue; - } catch (requireErr) { - throw new Error( - `Failed to load ${file} as ESM (${formatError(err)}) or CJS (${formatError(requireErr)}).`, - ); - } - } - } - return modules; -} - -function shouldTryRequire(file: string, err: unknown): boolean { - if (envFlag("BT_EVAL_FORCE_ESM")) { - return false; - } - if (process.env.BT_EVAL_RUNNER_KIND === "vite-node") { - return false; - } - if (process.env.BT_EVAL_CJS === "1" || file.endsWith(".cjs")) { - return true; - } - if ( - (file.endsWith(".ts") || file.endsWith(".tsx")) && - isNodeErrorCode(err, "ERR_UNKNOWN_FILE_EXTENSION") - ) { - return true; - } - if (!(err instanceof Error)) { - return false; - } - const message = err.message || ""; - return ( - message.includes("require is not defined") || - message.includes("exports is not defined") || - message.includes("module is not defined") || - message.includes("Cannot use import statement outside a module") - ); -} - -function isNodeErrorCode(err: unknown, code: string): boolean { - if (!isObject(err) || !("code" in err)) { - return false; - } - return typeof err.code === "string" && err.code === code; -} - -function formatError(err: unknown): string { - if (err instanceof Error) { - return err.message; - } - return String(err); -} - function createEvalProgressReporter( sse: SseWriter | null, evaluatorName: string, @@ -1216,22 +901,6 @@ function sendConsole( sse.send("console", { stream, message }); } -function getEvaluators(): EvaluatorEntry[] { - const evals = globalThis._evals; - if (!evals || !evals.evaluators) { - return []; - } - return Object.values(evals.evaluators) as EvaluatorEntry[]; -} - -function getReporters(): Record { - const evals = globalThis._evals; - if (!evals || !evals.reporters) { - return {}; - } - return evals.reporters as Record; -} - function resolveReporter( reporter: unknown, reporters: Record, @@ -1259,34 +928,6 @@ function resolveReporter( ); } -function evaluateFilter( - object: Record, - filter: EvalFilter, -): boolean { - const key = filter.path.reduce((acc, part) => { - if (!isObject(acc)) { - return undefined; - } - return acc[part]; - }, object); - if (key === undefined) { - return false; - } - return filter.pattern.test(serializeJSONWithPlainString(key)); -} - -function filterEvaluators( - evaluators: EvaluatorEntry[], - filters: EvalFilter[], -): EvaluatorEntry[] { - if (filters.length === 0) { - return evaluators; - } - return evaluators.filter((entry) => - filters.every((filter) => evaluateFilter(entry.evaluator, filter)), - ); -} - function extractScoreName(score: unknown, idx: number): string { if (typeof score === "function" && typeof score.name === "string") { return score.name || `scorer_${idx}`; @@ -1831,9 +1472,12 @@ function mergeProgress( }; } -async function createEvalRunner(config: RunnerConfig): Promise { - const braintrust = await loadBraintrust(); - const Eval = braintrust.Eval; +async function createEvalRunner( + config: RunnerConfig, + files: string[], +): Promise { + const braintrust = await loadBraintrust(files); + const Eval = braintrust.Eval as EvalFunction | undefined; if (typeof Eval !== "function") { throw new Error("Unable to load Eval() from braintrust package."); } @@ -1843,8 +1487,8 @@ async function createEvalRunner(config: RunnerConfig): Promise { const sse = createSseWriter(); const noSendLogs = shouldDisableSendLogs(); - const parseParent = loadBraintrustUtilParseParent(); - const getState = extractGlobalStateGetter(braintrust); + const parseParent = loadBraintrustUtilParseParent(files); + const getState = getBraintrustStateGetter(braintrust); const makeEvalOptions = ( evaluatorName: string, @@ -1996,8 +1640,7 @@ async function main() { maybeRecordDependency(file); } collectStaticLocalDependencies(normalized); - ensureBraintrustAvailable(); - const braintrust = await loadBraintrust(); + const braintrust = await loadBraintrust(normalized); propagateInheritedBraintrustState(braintrust); initRegistry(); // Replace process.argv with [runtime, script, ...extraArgs] so that user @@ -2010,7 +1653,7 @@ async function main() { const modules = await loadFiles(normalized); const btEvalMains = collectBtEvalMains(modules); - const runner = await createEvalRunner(config); + const runner = await createEvalRunner(config, normalized); if (!runner.noSendLogs && typeof runner.login === "function") { try { await runner.login({}); diff --git a/scripts/runner-common.ts b/scripts/runner-common.ts index 8a0dd63..a98c612 100644 --- a/scripts/runner-common.ts +++ b/scripts/runner-common.ts @@ -1,3 +1,8 @@ +import { createRequire } from "node:module"; +import fs from "node:fs"; +import path from "node:path"; +import { pathToFileURL } from "node:url"; + export type JsonPrimitive = string | number | boolean | null; export type JsonArray = JsonValue[]; export type JsonObject = { [key: string]: JsonValue }; @@ -13,6 +18,56 @@ export type ProjectRef = { name?: string; }; +export type EvaluatorDefinition = { + evalName: string; + projectName: string; + data?: unknown; + trialCount?: unknown; + maxConcurrency?: unknown; + experimentName?: unknown; +} & Record; + +export type EvaluatorEntry = { + evaluator: EvaluatorDefinition; + reporter?: unknown; +}; + +export type BraintrustModule = { + Eval?: (...args: unknown[]) => unknown; + login?: (...args: unknown[]) => Promise; + initDataset?: (...args: unknown[]) => unknown; + invoke?: (...args: unknown[]) => Promise; + _internalGetGlobalState?: () => unknown; + default?: BraintrustModule; +}; + +export type GlobalEvals = { + functions: unknown[]; + prompts: unknown[]; + parameters: unknown[]; + evaluators: Record; + reporters: Record; +}; + +export type EvalFilter = { + path: string[]; + pattern: RegExp; +}; + +export type SerializedEvalFilter = { + path: string[]; + pattern: string; +}; + +declare global { + // eslint-disable-next-line no-var + var _evals: GlobalEvals | undefined; + // eslint-disable-next-line no-var + var _lazy_load: boolean | undefined; + // eslint-disable-next-line no-var + var __inherited_braintrust_state: unknown; +} + export function asProjectSelector( project: ProjectRef | undefined, ): ProjectSelector { @@ -81,3 +136,395 @@ export function toJsonValue(input: JsonValue): JsonValue { return input; } + +export function isObject(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +export function normalizeFiles(files: string[]): string[] { + return files.map((file) => path.resolve(process.cwd(), file)); +} + +export function envFlag(name: string): boolean { + const value = process.env[name]; + if (!value) { + return false; + } + const normalized = value.toLowerCase(); + return !["0", "false", "no", "off", ""].includes(normalized); +} + +export function serializeJSONWithPlainString(value: unknown): string { + if (typeof value === "string") { + return value; + } + return JSON.stringify(value); +} + +export function parseSerializedFilters( + serialized: string | undefined, +): EvalFilter[] { + if (!serialized) { + return []; + } + + try { + const parsed = JSON.parse(serialized); + if (!Array.isArray(parsed)) { + throw new Error("BT_EVAL_FILTER_PARSED must be a JSON array."); + } + return parsed.map((value) => { + if (!isObject(value)) { + throw new Error( + "BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}.", + ); + } + const { path: rawPath, pattern: rawPattern } = + value as SerializedEvalFilter; + if ( + !Array.isArray(rawPath) || + !rawPath.every((part) => typeof part === "string") + ) { + throw new Error( + "BT_EVAL_FILTER_PARSED entry path must be an array of strings.", + ); + } + if (typeof rawPattern !== "string") { + throw new Error( + "BT_EVAL_FILTER_PARSED entry pattern must be a string.", + ); + } + return { + path: rawPath, + pattern: new RegExp(rawPattern), + }; + }); + } catch (err) { + throw new Error( + `Invalid BT_EVAL_FILTER_PARSED value: ${err instanceof Error ? err.message : String(err)}`, + ); + } +} + +export function formatError(err: unknown): string { + if (err instanceof Error) { + return err.message; + } + return String(err); +} + +export function initRegistry() { + globalThis._evals = { + functions: [], + prompts: [], + parameters: [], + evaluators: {}, + reporters: {}, + }; + globalThis._lazy_load = true; +} + +function isBraintrustModule(value: unknown): value is BraintrustModule { + return isObject(value) && ("Eval" in value || "login" in value); +} + +function normalizeBraintrustModule(value: unknown): BraintrustModule { + if (isBraintrustModule(value)) { + return value; + } + if (isObject(value) && isBraintrustModule(value.default)) { + return value.default; + } + throw new Error("Unable to load braintrust module."); +} + +export function resolveBraintrustPath(files: string[]): string { + const normalizedFiles = normalizeFiles(files); + for (const file of normalizedFiles) { + try { + const require = createRequire(pathToFileURL(file).href); + return require.resolve("braintrust"); + } catch { + continue; + } + } + + try { + const require = createRequire(process.cwd() + "/"); + return require.resolve("braintrust"); + } catch { + const message = + "Unable to resolve the `braintrust` package. " + + "Please install it in your project (e.g. `pnpm add braintrust` or `npm install braintrust`)."; + throw new Error(message); + } +} + +export async function loadBraintrust( + files: string[], +): Promise { + const cjsPath = resolveBraintrustPath(files); + const cjsUrl = pathToFileURL(cjsPath).href; + + try { + const mod: unknown = await import(cjsUrl); + return normalizeBraintrustModule(mod); + } catch {} + + const esmPath = cjsPath.replace(/\.js$/, ".mjs"); + if (esmPath !== cjsPath && fs.existsSync(esmPath)) { + try { + const mod: unknown = await import(pathToFileURL(esmPath).href); + return normalizeBraintrustModule(mod); + } catch {} + } + + const require = createRequire(cjsUrl); + const mod: unknown = require(cjsPath); + return normalizeBraintrustModule(mod); +} + +export type ParseParentFunction = (parent: unknown) => string | undefined; + +function extractParseParent(mod: unknown): ParseParentFunction | null { + if (!isObject(mod)) { + return null; + } + const candidate = Reflect.get(mod, "parseParent"); + if (typeof candidate === "function") { + return candidate as ParseParentFunction; + } + const defaultExport = Reflect.get(mod, "default"); + if (isObject(defaultExport)) { + const fromDefault = Reflect.get(defaultExport, "parseParent"); + if (typeof fromDefault === "function") { + return fromDefault as ParseParentFunction; + } + } + return null; +} + +export function loadBraintrustUtilParseParent( + files: string[], +): ParseParentFunction | null { + const braintrustPath = resolveBraintrustPath(files); + const requireFromBraintrust = createRequire( + pathToFileURL(braintrustPath).href, + ); + try { + const utilMod: unknown = requireFromBraintrust("braintrust/util"); + return extractParseParent(utilMod); + } catch { + return null; + } +} + +function extractGlobalStateGetter(mod: unknown): (() => unknown) | null { + if (!isObject(mod)) { + return null; + } + const candidate = Reflect.get(mod, "_internalGetGlobalState"); + if (typeof candidate === "function") { + return candidate as () => unknown; + } + const defaultExport = Reflect.get(mod, "default"); + if (isObject(defaultExport)) { + const fromDefault = Reflect.get(defaultExport, "_internalGetGlobalState"); + if (typeof fromDefault === "function") { + return fromDefault as () => unknown; + } + } + return null; +} + +export function getBraintrustStateGetter( + braintrust: BraintrustModule, +): (() => unknown) | null { + return extractGlobalStateGetter(braintrust); +} + +export function propagateInheritedBraintrustState(braintrust: BraintrustModule) { + const getter = getBraintrustStateGetter(braintrust); + if (!getter) { + return; + } + const state = getter(); + if (state !== undefined && state !== null) { + globalThis.__inherited_braintrust_state = state; + } +} + +export async function loadFiles(files: string[]): Promise { + const modules: unknown[] = []; + const forceEsm = envFlag("BT_EVAL_FORCE_ESM"); + const isViteNode = process.env.BT_EVAL_RUNNER_KIND === "vite-node"; + for (const file of files) { + const fileUrl = pathToFileURL(file).href; + const isTypeScript = file.endsWith(".ts") || file.endsWith(".tsx"); + const preferRequire = + !forceEsm && + !(isViteNode && isTypeScript) && + (isTypeScript || file.endsWith(".cjs")); + + if (preferRequire) { + try { + const require = createRequire(fileUrl); + const mod = require(file); + modules.push(mod); + continue; + } catch (requireErr) { + try { + const mod = await import(fileUrl); + modules.push(mod); + continue; + } catch (esmErr) { + throw new Error( + `Failed to load ${file} as CJS (${formatError(requireErr)}) or ESM (${formatError(esmErr)}).`, + ); + } + } + } + + try { + const mod = await import(fileUrl); + modules.push(mod); + continue; + } catch (err) { + if (!shouldTryRequire(file, err)) { + throw err; + } + try { + const require = createRequire(fileUrl); + const mod = require(file); + modules.push(mod); + continue; + } catch (requireErr) { + throw new Error( + `Failed to load ${file} as ESM (${formatError(err)}) or CJS (${formatError(requireErr)}).`, + ); + } + } + } + return modules; +} + +function shouldTryRequire(file: string, err: unknown): boolean { + if (envFlag("BT_EVAL_FORCE_ESM")) { + return false; + } + if (process.env.BT_EVAL_RUNNER_KIND === "vite-node") { + return false; + } + if (process.env.BT_EVAL_CJS === "1" || file.endsWith(".cjs")) { + return true; + } + if ( + (file.endsWith(".ts") || file.endsWith(".tsx")) && + isNodeErrorCode(err, "ERR_UNKNOWN_FILE_EXTENSION") + ) { + return true; + } + if (!(err instanceof Error)) { + return false; + } + const message = err.message || ""; + return ( + message.includes("require is not defined") || + message.includes("exports is not defined") || + message.includes("module is not defined") || + message.includes("Cannot use import statement outside a module") + ); +} + +function isNodeErrorCode(err: unknown, code: string): boolean { + if (!isObject(err) || !("code" in err)) { + return false; + } + return typeof err.code === "string" && err.code === code; +} + +export function getEvaluators(): EvaluatorEntry[] { + const evals = globalThis._evals; + if (!evals || !evals.evaluators) { + return []; + } + return Object.values(evals.evaluators) as EvaluatorEntry[]; +} + +export function getReporters(): Record { + const evals = globalThis._evals; + if (!evals || !evals.reporters) { + return {}; + } + return evals.reporters as Record; +} + +export function evaluateFilter( + object: Record, + filter: EvalFilter, +): boolean { + const key = filter.path.reduce((acc, part) => { + if (!isObject(acc)) { + return undefined; + } + return acc[part]; + }, object); + if (key === undefined) { + return false; + } + return filter.pattern.test(serializeJSONWithPlainString(key)); +} + +export function filterEvaluators( + evaluators: EvaluatorEntry[], + filters: EvalFilter[], +): EvaluatorEntry[] { + if (filters.length === 0) { + return evaluators; + } + return evaluators.filter((entry) => + filters.every((filter) => evaluateFilter(entry.evaluator, filter)), + ); +} + +export function callEvaluatorData( + data: unknown, +): { data: unknown; baseExperiment: string | undefined } { + const dataResult = typeof data === "function" ? (data as () => unknown)() : data; + let baseExperiment: string | undefined = undefined; + if ( + isObject(dataResult) && + Reflect.get(dataResult, "_type") === "BaseExperiment" && + typeof Reflect.get(dataResult, "name") === "string" + ) { + baseExperiment = Reflect.get(dataResult, "name") as string; + } + return { data: dataResult, baseExperiment }; +} + +export function toAsyncIterable(value: unknown): AsyncIterable { + if ( + typeof value === "object" && + value !== null && + Symbol.asyncIterator in value && + typeof (value as AsyncIterable)[Symbol.asyncIterator] === "function" + ) { + return value as AsyncIterable; + } + if ( + typeof value === "object" && + value !== null && + Symbol.iterator in value && + typeof (value as Iterable)[Symbol.iterator] === "function" + ) { + const iterable = value as Iterable; + return (async function* () { + for (const item of iterable) { + yield item; + } + })(); + } + throw new Error( + "Evaluator data must be an array, iterable, or async iterable", + ); +} diff --git a/scripts/runner_common.py b/scripts/runner_common.py new file mode 100644 index 0000000..f079a8a --- /dev/null +++ b/scripts/runner_common.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import asyncio +import importlib.util +import inspect +import json +import os +import re +import sys +from dataclasses import dataclass +from typing import Any, AsyncIterator + +try: + from braintrust.framework import ( + BaseExperiment, + EvaluatorInstance, + _evals, + _set_lazy_load, + ) + from braintrust.logger import Dataset +except Exception: + raise + + +@dataclass(frozen=True) +class EvalFilter: + path: list[str] + pattern: re.Pattern[str] + + +def env_flag(name: str) -> bool: + value = os.getenv(name) + if value is None: + return False + return value.lower() not in {"0", "false", "no", "off", ""} + + +def parse_serialized_filters(serialized: str | None) -> list[EvalFilter]: + if not serialized: + return [] + + parsed = json.loads(serialized) + if not isinstance(parsed, list): + raise ValueError("BT_EVAL_FILTER_PARSED must be a JSON array") + + filters: list[EvalFilter] = [] + for i, entry in enumerate(parsed): + if not isinstance(entry, dict): + raise ValueError("BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}") + key_path = entry.get("path") + pattern = entry.get("pattern") + if not isinstance(key_path, list) or not all(isinstance(part, str) for part in key_path): + raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} path must be an array of strings") + if not isinstance(pattern, str): + raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} pattern must be a string") + filters.append(EvalFilter(path=key_path, pattern=re.compile(pattern))) + return filters + + +def _to_mapping(value: Any) -> Any: + if isinstance(value, dict): + return {k: _to_mapping(v) for k, v in value.items()} + if isinstance(value, list): + return [_to_mapping(v) for v in value] + if hasattr(value, "__dict__"): + return { + key: _to_mapping(val) + for key, val in vars(value).items() + if not key.startswith("_") + } + return value + + +def serialize_json_with_plain_string(value: Any) -> str: + if isinstance(value, str): + return value + return json.dumps(value) + + +def evaluate_filter(value: Any, filt: EvalFilter) -> bool: + current = _to_mapping(value) + for part in filt.path: + if not isinstance(current, dict) or part not in current: + return False + current = current[part] + return bool(filt.pattern.search(serialize_json_with_plain_string(current))) + + +def filter_evaluators( + evaluators: list[EvaluatorInstance], filters: list[EvalFilter] +) -> list[EvaluatorInstance]: + if not filters: + return evaluators + return [ + evaluator + for evaluator in evaluators + if all(evaluate_filter(evaluator.evaluator, filt) for filt in filters) + ] + + +async def call_evaluator_data(data: Any) -> tuple[Any, str | None]: + data_result = data + if inspect.isclass(data_result): + data_result = data_result() + if inspect.isfunction(data_result) or inspect.isroutine(data_result): + data_result = data_result() + if inspect.isawaitable(data_result): + data_result = await data_result + + base_experiment_name = None + if isinstance(data_result, BaseExperiment): + base_experiment_name = data_result.name + + return data_result, base_experiment_name + + +def to_async_iterator(value: Any) -> AsyncIterator[Any]: + if inspect.isasyncgen(value): + return value + + async def to_async(it): + for item in it: + yield item + + return to_async(value) + + +def collect_files(input_path: str) -> list[str]: + if os.path.isdir(input_path): + matches: list[str] = [] + for root, _, files in os.walk(input_path): + for filename in files: + matches.append(os.path.join(root, filename)) + return matches + return [input_path] + + +def resolve_module_info(in_file: str) -> tuple[str, list[str]]: + in_file = os.path.abspath(in_file) + module_dir = os.path.dirname(in_file) + module_name = os.path.splitext(os.path.basename(in_file))[0] + + package_parts: list[str] = [] + current = module_dir + while os.path.isfile(os.path.join(current, "__init__.py")): + package_parts.insert(0, os.path.basename(current)) + current = os.path.dirname(current) + + extra_paths = [module_dir] + if package_parts: + module_name = ".".join(package_parts + [module_name]) + if current not in extra_paths: + extra_paths.append(current) + + return module_name, extra_paths + + +def load_evaluators(files: list[str]) -> tuple[list[EvaluatorInstance], dict[str, Any]]: + evaluator_instances: list[EvaluatorInstance] = [] + reporters: dict[str, Any] = {} + cwd = os.getcwd() + if cwd not in sys.path: + sys.path.insert(0, cwd) + + for f in files: + d = os.path.dirname(os.path.abspath(f)) + while d and d != os.path.dirname(d): + if os.path.isfile(os.path.join(d, "register.py")): + if d not in sys.path: + sys.path.insert(0, d) + break + d = os.path.dirname(d) + + unique_files: set[str] = set() + for file_path in files: + for candidate in collect_files(file_path): + unique_files.add(os.path.abspath(candidate)) + + for file_path in sorted(unique_files): + module_name, extra_paths = resolve_module_info(file_path) + with _set_lazy_load(True): + _evals.clear() + try: + for extra_path in reversed(extra_paths): + if extra_path not in sys.path: + sys.path.insert(0, extra_path) + + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load module spec for {file_path}") + + sys.modules.pop(module_name, None) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + evaluator_instances.extend( + [ + instance + for instance in _evals.evaluators.values() + if isinstance(instance, EvaluatorInstance) + ] + ) + for reporter_name, reporter in _evals.reporters.items(): + if reporter_name not in reporters: + reporters[reporter_name] = reporter + finally: + _evals.clear() + + return evaluator_instances, reporters + + +__all__ = [ + "BaseExperiment", + "Dataset", + "EvalFilter", + "EvaluatorInstance", + "call_evaluator_data", + "env_flag", + "filter_evaluators", + "load_evaluators", + "parse_serialized_filters", + "serialize_json_with_plain_string", + "to_async_iterator", +] diff --git a/src/eval.rs b/src/eval.rs index bb67433..8c88d03 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -15,6 +15,7 @@ use actix_web::http::header::{ }; use actix_web::{guard, web, App, HttpRequest, HttpResponse, HttpServer}; use anyhow::{Context, Result}; +use chrono::{SecondsFormat, Utc}; use clap::{Args, ValueEnum}; use crossterm::queue; use crossterm::style::{ @@ -22,11 +23,13 @@ use crossterm::style::{ Stylize, }; use futures_util::stream; +use futures_util::StreamExt; use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use strip_ansi_escapes::strip; +use tokio::io::AsyncWriteExt; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::net::UnixListener; use tokio::process::Command; @@ -41,7 +44,12 @@ use ratatui::widgets::{Cell, Row, Table}; use ratatui::Terminal; use crate::args::BaseArgs; +use crate::auth::login; use crate::auth::resolved_auth_env; +use crate::experiments::api::create_experiment; +use crate::functions::publish_eval_sandbox_functions; +use crate::http::ApiClient; +use crate::source_language::SourceLanguage; use crate::ui::{animations_enabled, is_quiet}; const MAX_NAME_LENGTH: usize = 40; @@ -161,6 +169,62 @@ struct ResolvedDatasetEvalData { _internal_btql: Option, } +#[derive(Debug, Serialize, Deserialize)] +struct EvalPullRequest { + name: String, + #[serde(default)] + parameters: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum EvalPullClientMessage { + Start { name: String }, + Next, + Close, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum EvalPullResponse { + Ready { + evaluator_name: String, + max_concurrency: usize, + experiment_name: String, + }, + Row { + datum: Value, + trial_index: usize, + }, + Eof, + Error { + message: String, + }, +} + +#[derive(Debug)] +struct EvalDataPuller { + child: tokio::process::Child, + writer: tokio::net::unix::OwnedWriteHalf, + reader: BufReader, + _socket_cleanup_guard: SocketCleanupGuard, +} + +#[derive(Debug, Clone)] +struct EvalSandboxPlan { + evaluator_name: String, + function_id: String, + project_id: String, +} + +#[derive(Debug, Clone, Deserialize)] +struct SandboxSummaryRow { + #[serde(default)] + scores: HashMap>, + #[serde(default)] + metrics: HashMap, +} + #[derive(Clone)] struct DevServerState { base: BaseArgs, @@ -190,10 +254,19 @@ struct RunnerFilter { } const JS_RUNNER_FILE: &str = "eval-runner.ts"; +const JS_DATA_RUNNER_FILE: &str = "data-runner.ts"; +const JS_RUNNER_COMMON_FILE: &str = "runner-common.ts"; const PY_RUNNER_FILE: &str = "eval-runner.py"; +const PY_DATA_RUNNER_FILE: &str = "data-runner.py"; +const PY_RUNNER_COMMON_FILE: &str = "runner_common.py"; const JS_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.ts"); +const JS_DATA_RUNNER_SOURCE: &str = include_str!("../scripts/data-runner.ts"); +const JS_RUNNER_COMMON_SOURCE: &str = include_str!("../scripts/runner-common.ts"); const PY_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.py"); +const PY_DATA_RUNNER_SOURCE: &str = include_str!("../scripts/data-runner.py"); +const PY_RUNNER_COMMON_SOURCE: &str = include_str!("../scripts/runner_common.py"); +#[derive(Debug)] struct SocketCleanupGuard { path: PathBuf, } @@ -218,6 +291,12 @@ pub enum EvalLanguage { Python, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +pub enum EvalSandbox { + Local, + Lambda, +} + #[derive(Debug, Clone, Args)] #[command(after_help = "\ Examples: @@ -245,6 +324,10 @@ pub struct EvalArgs { )] pub language: Option, + /// Execute evals locally or in a remote sandbox. + #[arg(long, env = "BT_EVAL_SANDBOX", value_enum, default_value = "local")] + pub sandbox: EvalSandbox, + /// Run evals locally (do not send logs to Braintrust). #[arg( long, @@ -389,6 +472,31 @@ pub async fn run(base: BaseArgs, args: EvalArgs) -> Result<()> { extra_args: args.extra_args, }; + if args.sandbox != EvalSandbox::Local { + if args.dev { + anyhow::bail!("--sandbox is not supported with --dev."); + } + if args.watch { + anyhow::bail!("--sandbox is not supported with --watch."); + } + if args.list { + anyhow::bail!("--sandbox is not supported with --list."); + } + if files.len() != 1 { + anyhow::bail!("`bt eval --sandbox lambda` currently supports exactly one eval file."); + } + return run_eval_files_sandbox( + &base, + args.sandbox, + args.language, + args.runner.as_deref(), + &files, + args.no_send_logs, + &options, + ) + .await; + } + if args.dev { let language = detect_eval_language(&files, args.language)?; let app_url = resolve_app_url(&base); @@ -500,6 +608,688 @@ async fn run_eval_files_watch( } } +async fn run_eval_files_sandbox( + base: &BaseArgs, + sandbox: EvalSandbox, + language_override: Option, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, +) -> Result<()> { + if sandbox != EvalSandbox::Lambda { + anyhow::bail!("unsupported sandbox mode"); + } + if no_send_logs { + anyhow::bail!("--sandbox lambda is not supported with --no-send-logs."); + } + let language = detect_eval_language(files, language_override)?; + let source_language = match language { + EvalLanguage::JavaScript => SourceLanguage::JsLike, + EvalLanguage::Python => SourceLanguage::Python, + }; + + let source_file = PathBuf::from( + files + .first() + .ok_or_else(|| anyhow::anyhow!("missing sandbox source file"))?, + ); + let published = + publish_eval_sandbox_functions(base, &source_file, runner_override, source_language) + .await?; + let evaluator_names = list_sandbox_evaluator_names( + base, + language, + runner_override, + files, + no_send_logs, + options, + ) + .await?; + if evaluator_names.is_empty() { + anyhow::bail!("No evaluators found. Did you call Eval() in the file?"); + } + + let mut plans = Vec::new(); + for evaluator_name in evaluator_names { + let slug = sandbox_slug_from_source(&source_file, &evaluator_name); + let published_entry = published + .iter() + .find(|entry| entry.slug == slug) + .ok_or_else(|| { + anyhow::anyhow!( + "sandbox function '{}' for evaluator '{}' was not published", + slug, + evaluator_name + ) + })?; + plans.push(EvalSandboxPlan { + evaluator_name, + function_id: published_entry.function_id.clone(), + project_id: published_entry.project_id.clone(), + }); + } + + let login_ctx = login(base).await?; + let client = ApiClient::new(&login_ctx)?; + let started_at = Utc::now().to_rfc3339_opts(SecondsFormat::Millis, true); + + for plan in plans { + let mut puller = spawn_eval_data_puller( + base, + language, + runner_override, + files, + no_send_logs, + options, + &EvalPullRequest { + name: plan.evaluator_name.clone(), + parameters: Some(json!({})), + }, + ) + .await?; + let ready = puller.read_message().await?; + let (max_concurrency, experiment_name) = match ready { + EvalPullResponse::Ready { + evaluator_name, + max_concurrency, + experiment_name, + } => { + if evaluator_name != plan.evaluator_name { + anyhow::bail!( + "sandbox runner selected unexpected evaluator '{}', expected '{}'", + evaluator_name, + plan.evaluator_name + ); + } + (max_concurrency.max(1), experiment_name) + } + EvalPullResponse::Error { message } => anyhow::bail!("{message}"), + other => anyhow::bail!("unexpected initial sandbox pull response: {other:?}"), + }; + + let experiment = create_experiment(&client, &plan.project_id, &experiment_name, true) + .await + .with_context(|| { + format!( + "failed to create sandbox parent experiment '{}' for evaluator '{}'", + experiment_name, plan.evaluator_name + ) + })?; + let mut in_flight: tokio::task::JoinSet>> = + tokio::task::JoinSet::new(); + let mut saw_eof = false; + let mut experiment_url: Option = None; + + while !saw_eof || !in_flight.is_empty() { + while !saw_eof && in_flight.len() < max_concurrency { + puller.send_message(&EvalPullClientMessage::Next).await?; + match puller.read_message().await? { + EvalPullResponse::Row { + datum, + trial_index: _trial_index, + } => { + let function_id = plan.function_id.clone(); + let evaluator_name = plan.evaluator_name.clone(); + let project_id = plan.project_id.clone(); + let body = json!({ + "api_version": 1, + "function_id": { "function_id": function_id }, + "name": evaluator_name, + "project_id": project_id, + "scores": [], + "stream": true, + "experiment_name": experiment.name, + "parent": { + "object_type": "experiment", + "object_id": experiment.id, + }, + "data": { "data": [datum] }, + }); + let client_cloned = client.clone(); + let org_name = login_ctx.login.org_name.clone(); + let project_id = plan.project_id.clone(); + in_flight.spawn(async move { + invoke_sandbox_eval(&client_cloned, &org_name, &project_id, body).await + }); + } + EvalPullResponse::Eof => saw_eof = true, + EvalPullResponse::Error { message } => anyhow::bail!("{message}"), + other => anyhow::bail!("unexpected sandbox pull response: {other:?}"), + } + } + + if let Some(joined) = in_flight.join_next().await { + if let Some(start) = joined?? { + if experiment_url.is_none() { + experiment_url = start.experiment_url.clone(); + } + } + } + } + + puller.send_message(&EvalPullClientMessage::Close).await?; + puller.wait().await?; + + let summary = summarize_sandbox_experiment( + &client, + &plan.project_id, + &plan.project_id, + &experiment.name, + &experiment.id, + experiment_url, + &started_at, + ) + .await?; + let rendered = format_experiment_summary(&summary); + println!("{rendered}"); + } + + Ok(()) +} + +impl EvalDataPuller { + async fn send_message(&mut self, message: &EvalPullClientMessage) -> Result<()> { + let mut payload = + serde_json::to_string(message).context("failed to serialize pull request")?; + payload.push('\n'); + self.writer + .write_all(payload.as_bytes()) + .await + .context("failed to write pull request")?; + self.writer + .flush() + .await + .context("failed to flush pull request")?; + Ok(()) + } + + async fn read_message(&mut self) -> Result { + let mut line = String::new(); + let read = self + .reader + .read_line(&mut line) + .await + .context("failed to read sandbox pull response")?; + if read == 0 { + let status = self + .child + .wait() + .await + .context("sandbox pull runner exited unexpectedly")?; + anyhow::bail!("sandbox pull runner exited with status {status}"); + } + serde_json::from_str(line.trim()).context("failed to parse sandbox pull response JSON") + } + + async fn wait(mut self) -> Result<()> { + let status = self + .child + .wait() + .await + .context("sandbox pull runner failed")?; + if !status.success() { + anyhow::bail!("sandbox pull runner exited with status {status}"); + } + Ok(()) + } +} + +fn sandbox_slugify(input: &str) -> String { + let mut out = String::with_capacity(input.len()); + let mut previous_dash = false; + for ch in input.chars() { + let lower = ch.to_ascii_lowercase(); + if lower.is_ascii_alphanumeric() { + out.push(lower); + previous_dash = false; + } else if !previous_dash { + out.push('-'); + previous_dash = true; + } + } + out.trim_matches('-').to_string() +} + +fn sandbox_slug_from_source(source_file: &Path, eval_name: &str) -> String { + let stem = source_file + .file_stem() + .and_then(|value| value.to_str()) + .map(|value| value.strip_suffix(".eval").unwrap_or(value)) + .unwrap_or("eval"); + sandbox_slugify(&format!("{stem}-{eval_name}-sandbox")) +} + +async fn list_sandbox_evaluator_names( + base: &BaseArgs, + language: EvalLanguage, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, +) -> Result> { + let output = run_eval_runner_command_to_completion( + base, + language, + runner_override, + files, + no_send_logs, + options, + &[("BT_EVAL_DEV_MODE".to_string(), "list".to_string())], + JsMode::Auto, + ) + .await?; + + let parsed: Value = + serde_json::from_slice(&output.stdout).context("failed to parse sandbox evaluator list")?; + let object = parsed + .as_object() + .ok_or_else(|| anyhow::anyhow!("sandbox evaluator list was not a JSON object"))?; + Ok(object.keys().cloned().collect()) +} + +async fn spawn_eval_data_puller( + base: &BaseArgs, + language: EvalLanguage, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, + request: &EvalPullRequest, +) -> Result { + let (listener, socket_path, socket_cleanup_guard) = + bind_unix_listener("bt-eval-pull").context("failed to bind sandbox pull socket")?; + let child = match language { + EvalLanguage::JavaScript => { + let js_runner = prepare_js_data_runner()?; + let mut extra_env = vec![( + "BT_EVAL_PULL_SOCK".to_string(), + socket_path.to_string_lossy().to_string(), + )]; + let mut plan = build_js_plan_with_entrypoint( + runner_override, + &js_runner, + files, + JS_DATA_RUNNER_FILE, + JS_DATA_RUNNER_SOURCE, + )?; + if should_set_node_heap_size(plan.kind) { + set_node_heap_size_env(&mut plan.cmd); + } + plan.cmd.envs(build_env(base).await?); + for (key, value) in extra_env.drain(..) { + plan.cmd.env(key, value); + } + if no_send_logs { + plan.cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); + plan.cmd.env("BT_EVAL_LOCAL", "1"); + } + if options.jsonl { + plan.cmd.env("BT_EVAL_JSONL", "1"); + } + if options.terminate_on_failure { + plan.cmd.env("BT_EVAL_TERMINATE_ON_FAILURE", "1"); + } + if options.list { + plan.cmd.env("BT_EVAL_LIST", "1"); + } + if let Some(num_workers) = options.num_workers { + plan.cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + plan.cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } + if !options.extra_args.is_empty() { + let serialized = serde_json::to_string(&options.extra_args) + .context("failed to serialize eval extra args")?; + plan.cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + } + let runner_name = match plan.kind { + RunnerKind::Tsx => "tsx", + RunnerKind::ViteNode => "vite-node", + RunnerKind::Deno => "deno", + RunnerKind::Bun => "bun", + RunnerKind::Other => "other", + }; + plan.cmd.env("BT_EVAL_RUNNER_KIND", runner_name); + plan.cmd + .stdin(Stdio::null()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()); + plan.cmd + .spawn() + .context("failed to spawn sandbox pull runner")? + } + EvalLanguage::Python => { + let py_runner = prepare_py_data_runner()?; + let mut cmd = build_python_command(runner_override, &py_runner, files)?; + cmd.envs(build_env(base).await?); + cmd.env( + "BT_EVAL_PULL_SOCK", + socket_path.to_string_lossy().to_string(), + ); + if no_send_logs { + cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); + cmd.env("BT_EVAL_LOCAL", "1"); + } + if options.jsonl { + cmd.env("BT_EVAL_JSONL", "1"); + } + if options.terminate_on_failure { + cmd.env("BT_EVAL_TERMINATE_ON_FAILURE", "1"); + } + if options.list { + cmd.env("BT_EVAL_LIST", "1"); + } + if let Some(num_workers) = options.num_workers { + cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } + if !options.extra_args.is_empty() { + let serialized = serde_json::to_string(&options.extra_args) + .context("failed to serialize eval extra args")?; + cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + } + cmd.stdin(Stdio::null()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()); + cmd.spawn().context("failed to spawn sandbox pull runner")? + } + }; + + let (stream, _) = tokio::time::timeout(Duration::from_secs(30), listener.accept()) + .await + .context("timed out waiting for sandbox pull runner to connect")? + .context("sandbox pull runner failed to connect")?; + let (read_half, write_half) = stream.into_split(); + let mut puller = EvalDataPuller { + child, + writer: write_half, + reader: BufReader::new(read_half), + _socket_cleanup_guard: socket_cleanup_guard, + }; + if matches!(language, EvalLanguage::JavaScript | EvalLanguage::Python) { + puller + .send_message(&EvalPullClientMessage::Start { + name: request.name.clone(), + }) + .await?; + } + Ok(puller) +} + +async fn run_eval_runner_command_to_completion( + base: &BaseArgs, + language: EvalLanguage, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, + extra_env: &[(String, String)], + js_mode: JsMode, +) -> Result { + let (js_runner, py_runner) = prepare_eval_runners()?; + let force_esm = matches!(js_mode, JsMode::ForceEsm); + let (mut cmd, runner_kind) = match language { + EvalLanguage::Python => ( + build_python_command(runner_override, &py_runner, files)?, + RunnerKind::Other, + ), + EvalLanguage::JavaScript => { + if force_esm { + ( + build_vite_node_fallback_command(&js_runner, files)?, + RunnerKind::ViteNode, + ) + } else { + let plan = build_js_plan(runner_override, &js_runner, files)?; + (plan.cmd, plan.kind) + } + } + }; + if language == EvalLanguage::JavaScript && should_set_node_heap_size(runner_kind) { + set_node_heap_size_env(&mut cmd); + } + cmd.envs(build_env(base).await?); + for (key, value) in extra_env { + cmd.env(key, value); + } + if no_send_logs { + cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); + cmd.env("BT_EVAL_LOCAL", "1"); + } + if let Some(num_workers) = options.num_workers { + cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } + if !options.extra_args.is_empty() { + let serialized = + serde_json::to_string(&options.extra_args).context("failed to serialize extra args")?; + cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + } + let output = cmd + .output() + .await + .context("failed to run eval support runner")?; + if !output.status.success() { + anyhow::bail!( + "eval support runner exited with status {}: {}", + output.status, + String::from_utf8_lossy(&output.stderr).trim() + ); + } + Ok(output) +} + +async fn invoke_sandbox_eval( + client: &ApiClient, + org_name: &str, + project_id: &str, + body: Value, +) -> Result> { + let response = client + .post_with_headers_raw( + "/function/sandbox", + &body, + &[("x-bt-org-name", org_name), ("x-bt-project-id", project_id)], + ) + .await?; + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("sandbox invoke failed ({status}): {body}"); + } + + let mut bytes = response.bytes_stream(); + let mut buffer = String::new(); + let mut current_event: Option = None; + let mut data_lines: Vec = Vec::new(); + let mut start: Option = None; + let mut saw_done = false; + + while let Some(chunk) = bytes.next().await { + let chunk = chunk.context("failed to read sandbox SSE response")?; + buffer.push_str(&String::from_utf8_lossy(&chunk)); + while let Some(pos) = buffer.find('\n') { + let mut line: String = buffer.drain(..=pos).collect(); + if line.ends_with('\n') { + line.pop(); + } + if line.ends_with('\r') { + line.pop(); + } + if line.is_empty() { + if current_event.is_some() || !data_lines.is_empty() { + let event_name = current_event.take().unwrap_or_default(); + let data = data_lines.join("\n"); + data_lines.clear(); + match event_name.as_str() { + "start" => { + if let Ok(parsed) = serde_json::from_str::(&data) { + if start.is_none() { + start = Some(parsed); + } + } + } + "error" => { + if let Ok(payload) = serde_json::from_str::(&data) { + let message = payload + .get("message") + .or_else(|| payload.get("error")) + .and_then(Value::as_str) + .unwrap_or("sandbox eval failed"); + anyhow::bail!("{message}"); + } + anyhow::bail!("{data}"); + } + "done" => { + saw_done = true; + } + _ => {} + } + } + continue; + } + if let Some(value) = line.strip_prefix("event:") { + current_event = Some(value.trim().to_string()); + } else if let Some(value) = line.strip_prefix("data:") { + data_lines.push(value.trim_start().to_string()); + } + } + } + + if !saw_done { + anyhow::bail!("sandbox SSE stream ended before a done event"); + } + Ok(start) +} + +async fn summarize_sandbox_experiment( + client: &ApiClient, + project_name: &str, + _project_id: &str, + experiment_name: &str, + experiment_id: &str, + experiment_url: Option, + started_at: &str, +) -> Result { + let query = build_sandbox_summary_query(experiment_id, started_at); + let response = client.btql::(&query).await?; + Ok(aggregate_sandbox_summary( + project_name, + experiment_name, + experiment_id, + experiment_url, + &response.data, + )) +} + +fn build_sandbox_summary_query(experiment_id: &str, started_at: &str) -> String { + format!( + "select: scores, metrics | from: experiment('{}') summary | filter: created >= '{}' | limit: 1000", + experiment_id.replace('\'', "''"), + started_at.replace('\'', "''") + ) +} + +fn aggregate_sandbox_summary( + project_name: &str, + experiment_name: &str, + experiment_id: &str, + experiment_url: Option, + rows: &[SandboxSummaryRow], +) -> ExperimentSummary { + let mut scores: HashMap = HashMap::new(); + let mut metrics: HashMap = HashMap::new(); + for row in rows { + for (name, value) in &row.scores { + if let Some(value) = value { + let entry = scores.entry(name.clone()).or_insert((0.0, 0)); + entry.0 += value; + entry.1 += 1; + } + } + for (name, value) in &row.metrics { + let Some(number) = value.as_f64() else { + continue; + }; + let entry = metrics.entry(name.clone()).or_insert((0.0, 0)); + entry.0 += number; + entry.1 += 1; + } + } + + ExperimentSummary { + project_name: project_name.to_string(), + experiment_name: experiment_name.to_string(), + project_id: None, + experiment_id: Some(experiment_id.to_string()), + project_url: None, + experiment_url, + comparison_experiment_name: None, + scores: scores + .into_iter() + .map(|(name, (total, count))| { + let average = if count == 0 { + 0.0 + } else { + total / count as f64 + }; + ( + name.clone(), + ScoreSummary { + name, + score: average, + diff: None, + improvements: 0, + regressions: 0, + }, + ) + }) + .collect(), + metrics: if metrics.is_empty() { + None + } else { + Some( + metrics + .into_iter() + .map(|(name, (total, count))| { + let average = if count == 0 { + 0.0 + } else { + total / count as f64 + }; + ( + name.clone(), + MetricSummary { + name, + metric: average, + unit: String::new(), + diff: None, + improvements: 0, + regressions: 0, + }, + ) + }) + .collect(), + ) + }, + } +} + struct EvalPlan<'a> { language: EvalLanguage, files: &'a [String], @@ -2098,18 +2888,40 @@ fn build_js_plan( runner_override: Option<&str>, runner: &Path, files: &[String], +) -> Result { + build_js_plan_with_entrypoint( + runner_override, + runner, + files, + JS_RUNNER_FILE, + JS_RUNNER_SOURCE, + ) +} + +fn build_js_plan_with_entrypoint( + runner_override: Option<&str>, + runner: &Path, + files: &[String], + embedded_file_name: &str, + embedded_source: &str, ) -> Result { if let Some(explicit) = runner_override { let resolved_runner = resolve_js_runner_command(explicit, files); if is_deno_runner(explicit) || is_deno_runner_path(resolved_runner.as_ref()) { - let runner_script = prepare_js_runner_in_cwd()?; + let runner_script = + prepare_js_embedded_runner_in_cwd(embedded_file_name, embedded_source)?; return Ok(JsRunnerPlan { cmd: build_deno_js_command(resolved_runner.as_os_str(), &runner_script, files), kind: RunnerKind::Deno, }); } let kind = runner_kind_for_bin(resolved_runner.as_ref()); - let runner_script = select_js_runner_entrypoint(runner, resolved_runner.as_ref())?; + let runner_script = select_js_runner_entrypoint_with_source( + runner, + resolved_runner.as_ref(), + embedded_file_name, + embedded_source, + )?; let mut command = Command::new(resolved_runner); command.arg(runner_script).args(files); return Ok(JsRunnerPlan { cmd: command, kind }); @@ -2117,14 +2929,20 @@ fn build_js_plan( if let Some(auto_runner) = find_js_runner_binary(files) { if is_deno_runner_path(&auto_runner) { - let runner_script = prepare_js_runner_in_cwd()?; + let runner_script = + prepare_js_embedded_runner_in_cwd(embedded_file_name, embedded_source)?; return Ok(JsRunnerPlan { cmd: build_deno_js_command(auto_runner.as_os_str(), &runner_script, files), kind: RunnerKind::Deno, }); } let kind = runner_kind_for_bin(auto_runner.as_ref()); - let runner_script = select_js_runner_entrypoint(runner, auto_runner.as_ref())?; + let runner_script = select_js_runner_entrypoint_with_source( + runner, + auto_runner.as_ref(), + embedded_file_name, + embedded_source, + )?; let mut command = Command::new(auto_runner); command.arg(runner_script).args(files); return Ok(JsRunnerPlan { cmd: command, kind }); @@ -2279,14 +3097,19 @@ fn is_deno_runner_path(runner: &Path) -> bool { .unwrap_or(false) } -fn select_js_runner_entrypoint(default_runner: &Path, runner_command: &Path) -> Result { +fn select_js_runner_entrypoint_with_source( + default_runner: &Path, + runner_command: &Path, + embedded_file_name: &str, + embedded_source: &str, +) -> Result { if is_ts_node_runner(runner_command) { - return prepare_js_runner_in_cwd(); + return prepare_js_embedded_runner_in_cwd(embedded_file_name, embedded_source); } Ok(default_runner.to_path_buf()) } -fn prepare_js_runner_in_cwd() -> Result { +fn prepare_js_embedded_runner_in_cwd(file_name: &str, source: &str) -> Result { let cwd = std::env::current_dir().context("failed to resolve current working directory")?; let cache_dir = cwd .join(".bt") @@ -2298,7 +3121,8 @@ fn prepare_js_runner_in_cwd() -> Result { cache_dir.display() ) })?; - materialize_runner_script(&cache_dir, JS_RUNNER_FILE, JS_RUNNER_SOURCE) + materialize_runner_script(&cache_dir, JS_RUNNER_COMMON_FILE, JS_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(&cache_dir, file_name, source) } fn runner_bin_name(runner_command: &Path) -> Option { @@ -2386,20 +3210,20 @@ fn find_binary_in_path(candidates: &[&str]) -> Option { None } -fn build_sse_socket_path() -> Result { +fn build_socket_path(prefix: &str) -> Result { let pid = std::process::id(); let serial = SSE_SOCKET_COUNTER.fetch_add(1, Ordering::Relaxed); let now = SystemTime::now() .duration_since(UNIX_EPOCH) .context("failed to read system time")? .as_nanos(); - Ok(std::env::temp_dir().join(format!("bt-eval-{pid}-{now}-{serial}.sock"))) + Ok(std::env::temp_dir().join(format!("{prefix}-{pid}-{now}-{serial}.sock"))) } -fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { +fn bind_unix_listener(prefix: &str) -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { let mut last_bind_err: Option = None; for _ in 0..SSE_SOCKET_BIND_MAX_ATTEMPTS { - let socket_path = build_sse_socket_path()?; + let socket_path = build_socket_path(prefix)?; let socket_cleanup_guard = SocketCleanupGuard::new(socket_path.clone()); let _ = std::fs::remove_file(&socket_path); match UnixListener::bind(&socket_path) { @@ -2425,10 +3249,14 @@ fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { ) }); Err(err).context(format!( - "failed to bind SSE unix socket after {SSE_SOCKET_BIND_MAX_ATTEMPTS} attempts" + "failed to bind unix socket after {SSE_SOCKET_BIND_MAX_ATTEMPTS} attempts" )) } +fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { + bind_unix_listener("bt-eval") +} + fn eval_runner_cache_dir() -> PathBuf { let root = std::env::var_os("XDG_CACHE_HOME") .map(PathBuf::from) @@ -2444,6 +3272,14 @@ fn prepare_eval_runners() -> Result<(PathBuf, PathBuf)> { prepare_eval_runners_in_dir(&eval_runner_cache_dir()) } +fn prepare_js_data_runner() -> Result { + prepare_js_data_runner_in_dir(&eval_runner_cache_dir()) +} + +fn prepare_py_data_runner() -> Result { + prepare_py_data_runner_in_dir(&eval_runner_cache_dir()) +} + fn prepare_eval_runners_in_dir(cache_dir: &Path) -> Result<(PathBuf, PathBuf)> { std::fs::create_dir_all(cache_dir).with_context(|| { format!( @@ -2452,11 +3288,35 @@ fn prepare_eval_runners_in_dir(cache_dir: &Path) -> Result<(PathBuf, PathBuf)> { ) })?; + materialize_runner_script(cache_dir, JS_RUNNER_COMMON_FILE, JS_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(cache_dir, PY_RUNNER_COMMON_FILE, PY_RUNNER_COMMON_SOURCE)?; let js_runner = materialize_runner_script(cache_dir, JS_RUNNER_FILE, JS_RUNNER_SOURCE)?; let py_runner = materialize_runner_script(cache_dir, PY_RUNNER_FILE, PY_RUNNER_SOURCE)?; Ok((js_runner, py_runner)) } +fn prepare_js_data_runner_in_dir(cache_dir: &Path) -> Result { + std::fs::create_dir_all(cache_dir).with_context(|| { + format!( + "failed to create eval runner cache dir {}", + cache_dir.display() + ) + })?; + materialize_runner_script(cache_dir, JS_RUNNER_COMMON_FILE, JS_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(cache_dir, JS_DATA_RUNNER_FILE, JS_DATA_RUNNER_SOURCE) +} + +fn prepare_py_data_runner_in_dir(cache_dir: &Path) -> Result { + std::fs::create_dir_all(cache_dir).with_context(|| { + format!( + "failed to create eval runner cache dir {}", + cache_dir.display() + ) + })?; + materialize_runner_script(cache_dir, PY_RUNNER_COMMON_FILE, PY_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(cache_dir, PY_DATA_RUNNER_FILE, PY_DATA_RUNNER_SOURCE) +} + fn materialize_runner_script(cache_dir: &Path, file_name: &str, source: &str) -> Result { let path = cache_dir.join(file_name); let current = std::fs::read_to_string(&path).ok(); @@ -3396,6 +4256,46 @@ mod tests { eval: EvalArgs, } + fn base_args() -> BaseArgs { + BaseArgs { + json: false, + quiet: false, + no_color: false, + profile: None, + org_name: None, + project: None, + api_key: None, + prefer_profile: false, + no_input: false, + api_url: None, + app_url: None, + env_file: None, + } + } + + fn make_eval_args(files: Vec) -> EvalArgs { + EvalArgs { + files, + runner: None, + language: None, + sandbox: EvalSandbox::Local, + no_send_logs: false, + jsonl: false, + terminate_on_failure: false, + num_workers: None, + list: false, + filter: Vec::new(), + verbose: false, + watch: false, + extra_args: Vec::new(), + dev: false, + dev_host: "localhost".to_string(), + dev_port: 8300, + dev_org_name: None, + dev_allowed_origin: Vec::new(), + } + } + fn env_test_lock() -> &'static Mutex<()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) @@ -3451,6 +4351,12 @@ mod tests { path } + fn write_eval_file(dir: &Path, name: &str) -> String { + let path = dir.join(name); + fs::write(&path, "export {};").expect("eval file should be written"); + path.to_string_lossy().to_string() + } + #[test] fn join_app_url_normalizes_slashes() { let joined = @@ -3514,11 +4420,27 @@ mod tests { let dir = make_temp_dir("embedded"); let (js_runner, py_runner) = prepare_eval_runners_in_dir(&dir).expect("embedded runners should be materialized"); + let js_data_runner = prepare_js_data_runner_in_dir(&dir) + .expect("embedded data runner should be materialized"); + let py_data_runner = prepare_py_data_runner_in_dir(&dir) + .expect("embedded py data runner should be materialized"); let js = fs::read_to_string(js_runner).expect("js runner should be readable"); + let js_data = + fs::read_to_string(js_data_runner).expect("js data runner should be readable"); + let js_common = fs::read_to_string(dir.join(JS_RUNNER_COMMON_FILE)) + .expect("js common should be readable"); let py = fs::read_to_string(py_runner).expect("python runner should be readable"); + let py_data = + fs::read_to_string(py_data_runner).expect("python data runner should be readable"); + let py_common = fs::read_to_string(dir.join(PY_RUNNER_COMMON_FILE)) + .expect("python common should be readable"); assert_eq!(js, JS_RUNNER_SOURCE); + assert_eq!(js_data, JS_DATA_RUNNER_SOURCE); + assert_eq!(js_common, JS_RUNNER_COMMON_SOURCE); assert_eq!(py, PY_RUNNER_SOURCE); + assert_eq!(py_data, PY_DATA_RUNNER_SOURCE); + assert_eq!(py_common, PY_RUNNER_COMMON_SOURCE); let _ = fs::remove_dir_all(&dir); } @@ -4016,8 +4938,8 @@ mod tests { #[test] fn build_sse_socket_path_is_unique_for_consecutive_calls() { - let first = build_sse_socket_path().expect("first socket path"); - let second = build_sse_socket_path().expect("second socket path"); + let first = build_socket_path("bt-eval").expect("first socket path"); + let second = build_socket_path("bt-eval").expect("second socket path"); assert_ne!(first, second); } @@ -4193,6 +5115,7 @@ mod tests { "BT_EVAL_TERMINATE_ON_FAILURE", "BT_EVAL_NUM_WORKERS", "BT_EVAL_LIST", + "BT_EVAL_SANDBOX", "BT_EVAL_FILTER", "BT_EVAL_VERBOSE", "BT_EVAL_WATCH", @@ -4207,6 +5130,7 @@ mod tests { set_env_var("BT_EVAL_TERMINATE_ON_FAILURE", "1"); set_env_var("BT_EVAL_NUM_WORKERS", "4"); set_env_var("BT_EVAL_LIST", "yes"); + set_env_var("BT_EVAL_SANDBOX", "lambda"); set_env_var("BT_EVAL_FILTER", "metadata.case=smoke.*,metadata.kind=fast"); set_env_var("BT_EVAL_VERBOSE", "1"); set_env_var("BT_EVAL_WATCH", "on"); @@ -4221,6 +5145,7 @@ mod tests { assert!(parsed.eval.terminate_on_failure); assert_eq!(parsed.eval.num_workers, Some(4)); assert!(parsed.eval.list); + assert_eq!(parsed.eval.sandbox, EvalSandbox::Lambda); assert_eq!( parsed.eval.filter, vec![ @@ -4239,4 +5164,117 @@ mod tests { restore_env_var(key, value); } } + + #[test] + fn eval_args_parse_sandbox_flag() { + let parsed = + EvalArgsHarness::try_parse_from(["bt", "--sandbox", "lambda", "sample.eval.ts"]) + .expect("sandbox flag should parse"); + assert_eq!(parsed.eval.sandbox, EvalSandbox::Lambda); + assert_eq!(parsed.eval.files, vec!["sample.eval.ts".to_string()]); + } + + #[tokio::test] + async fn sandbox_eval_rejects_dev_mode() { + let dir = make_temp_dir("sandbox-dev"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.dev = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+dev should fail"); + assert!(err + .to_string() + .contains("--sandbox is not supported with --dev.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_watch_mode() { + let dir = make_temp_dir("sandbox-watch"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.watch = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+watch should fail"); + assert!(err + .to_string() + .contains("--sandbox is not supported with --watch.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_list_mode() { + let dir = make_temp_dir("sandbox-list"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.list = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+list should fail"); + assert!(err + .to_string() + .contains("--sandbox is not supported with --list.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_no_send_logs() { + let dir = make_temp_dir("sandbox-local"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.no_send_logs = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+no-send-logs should fail"); + assert!(err + .to_string() + .contains("--sandbox lambda is not supported with --no-send-logs.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_multiple_files() { + let dir = make_temp_dir("sandbox-multi"); + let first = write_eval_file(&dir, "first.eval.ts"); + let second = write_eval_file(&dir, "second.eval.ts"); + let mut args = make_eval_args(vec![first, second]); + args.sandbox = EvalSandbox::Lambda; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+multiple files should fail"); + assert!(err + .to_string() + .contains("`bt eval --sandbox lambda` currently supports exactly one eval file.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn build_sandbox_summary_query_includes_timestamp_filter() { + let query = build_sandbox_summary_query("exp'123", "2026-03-19T12:00:00.000Z"); + assert!(query.contains("from: experiment('exp''123') summary")); + assert!(query.contains("filter: created >= '2026-03-19T12:00:00.000Z'")); + assert!(query.contains("select: scores, metrics")); + } + + #[test] + fn sandbox_slug_from_source_uses_source_stem_and_eval_name() { + let slug = sandbox_slug_from_source(Path::new("/tmp/My Eval.ts"), "Demo Eval"); + assert_eq!(slug, "my-eval-demo-eval-sandbox"); + } } diff --git a/src/experiments/api.rs b/src/experiments/api.rs index 3507a7e..bce8a00 100644 --- a/src/experiments/api.rs +++ b/src/experiments/api.rs @@ -70,11 +70,13 @@ pub async fn create_experiment( client: &ApiClient, project_id: &str, name: &str, + ensure_new: bool, ) -> Result { let body = serde_json::json!({ "name": name, "project_id": project_id, "org_name": client.org_name(), + "ensure_new": ensure_new, }); client.post("/v1/experiment", &body).await } diff --git a/src/functions/mod.rs b/src/functions/mod.rs index eadc100..934d0f7 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -17,11 +17,12 @@ mod delete; mod invoke; mod list; mod pull; -mod push; +pub(crate) mod push; pub(crate) mod report; mod view; use api::Function; +pub(crate) use push::publish_eval_sandbox_functions; #[derive(Debug, Clone, Copy, ValueEnum)] pub enum FunctionTypeFilter { diff --git a/src/functions/push.rs b/src/functions/push.rs index 0ca6109..e3979f7 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -640,6 +640,139 @@ struct FileSuccess { bundle_id: Option, } +#[derive(Debug, Clone)] +pub(crate) struct PublishedSandboxFunction { + pub slug: String, + pub project_id: String, + pub function_id: String, +} + +pub(crate) async fn publish_eval_sandbox_functions( + base: &BaseArgs, + source_file: &Path, + runner_override: Option<&str>, + language: SourceLanguage, +) -> Result> { + let available_orgs = list_available_orgs(base) + .await + .context("failed to list available orgs")?; + validate_explicit_org_selection(base, &available_orgs)?; + let auth_ctx = resolve_auth_context(base) + .await + .context("failed to resolve auth context")?; + + let args = PushArgs { + files: vec![source_file.to_path_buf()], + file_flag: Vec::new(), + if_exists: super::IfExistsMode::Replace, + terminate_on_failure: true, + create_missing_projects: true, + runner: runner_override.map(ToOwned::to_owned), + language: match language { + SourceLanguage::JsLike => PushLanguage::JavaScript, + SourceLanguage::Python => PushLanguage::Python, + }, + requirements: None, + tsconfig: None, + external_packages: Vec::new(), + yes: true, + }; + + let input_files = args.resolved_files(); + let classified = collect_classified_files(&input_files)?; + let files = classified.files_for_language(language); + if files.is_empty() { + bail!( + "no eligible {} files found for sandbox publish: {}", + language_label(language), + source_file.display() + ); + } + + let mut manifest = run_functions_runner(&args, &files, language, auth_ctx.client.api_key()) + .map_err(|failure| anyhow!(failure.message))?; + + for file in &mut manifest.files { + file.entries.retain(|entry| match entry { + ManifestEntry::Code(code) => code.function_type.as_deref() == Some("sandbox"), + ManifestEntry::FunctionEvent(_) => false, + }); + } + manifest + .files + .retain(|file| !file.entries.is_empty() || file.python_bundle.is_some()); + + if manifest.files.is_empty() { + bail!("no sandbox evaluators found in {}", source_file.display()); + } + + validate_manifest_paths(&manifest, &files, language, &classified.allowed_roots) + .map_err(|failure| anyhow!(failure.message))?; + + let preflight = collect_project_preflight(base, &manifest)?; + let mut project_name_cache = + resolve_named_projects(&auth_ctx, &preflight.named_projects, true).await?; + validate_direct_project_ids(&auth_ctx, &preflight.direct_project_ids).await?; + let default_project_id = resolve_default_project_id(&preflight, &project_name_cache)?; + let resolved_targets = resolve_manifest_targets( + &auth_ctx, + default_project_id.as_deref(), + &manifest, + &mut project_name_cache, + true, + ) + .await?; + validate_duplicate_slugs(&resolved_targets.entries)?; + + let mut published = Vec::new(); + for (manifest_file, resolved_file) in + manifest.files.iter().zip(resolved_targets.per_file.iter()) + { + let source_path = PathBuf::from(&manifest_file.source_file); + push_file( + &auth_ctx, + default_project_id.as_deref(), + &manifest.runtime_context, + &source_path, + manifest_file, + &resolved_file.entry_project_ids, + &args, + language, + None, + &classified.allowed_roots, + &mut project_name_cache, + ) + .await + .map_err(|failure| anyhow!(failure.message))?; + + for (entry_index, entry) in manifest_file.entries.iter().enumerate() { + let ManifestEntry::Code(code) = entry else { + continue; + }; + let project_id = resolved_file + .entry_project_ids + .get(entry_index) + .cloned() + .ok_or_else(|| anyhow!("missing resolved project id for sandbox entry"))?; + let function = api::get_function_by_slug(&auth_ctx.client, &project_id, &code.slug) + .await? + .ok_or_else(|| { + anyhow!( + "sandbox function '{}' was not found after publish", + code.slug + ) + })?; + published.push(PublishedSandboxFunction { + slug: code.slug.clone(), + project_id, + function_id: function.id, + }); + } + } + + Ok(published) +} + fn default_code_location(index: usize) -> Value { json!({ "type": "function", @@ -3294,12 +3427,8 @@ mod tests { "scores": [{ "name": "accuracy" }] } }); - let value = build_code_function_data( - &runtime, - sandbox_location.clone(), - "bundle-sandbox-1", - None, - ); + let value = + build_code_function_data(&runtime, sandbox_location.clone(), "bundle-sandbox-1", None); assert_eq!(value["type"], "code"); assert_eq!(value["data"]["type"], "bundle"); @@ -3319,12 +3448,8 @@ mod tests { "eval_name": "my-eval", "position": { "type": "task" } }); - let value = build_code_function_data( - &runtime, - experiment_location.clone(), - "bundle-task-1", - None, - ); + let value = + build_code_function_data(&runtime, experiment_location.clone(), "bundle-task-1", None); assert_eq!(value["type"], "code"); assert_eq!(value["data"]["location"], experiment_location); diff --git a/src/sync.rs b/src/sync.rs index e62a856..76e60ce 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -3116,7 +3116,7 @@ async fn resolve_push_experiment_target( ); } - let created = create_experiment(client, &project.id, experiment_selector) + let created = create_experiment(client, &project.id, experiment_selector, false) .await .with_context(|| { format!("experiment '{experiment_selector}' not found, and creating it failed") diff --git a/tests/eval_fixtures.rs b/tests/eval_fixtures.rs index 84c63a6..b64485f 100644 --- a/tests/eval_fixtures.rs +++ b/tests/eval_fixtures.rs @@ -1,6 +1,12 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fs; +#[cfg(unix)] +use std::io::Write; use std::io::{BufRead, BufReader, Read}; +#[cfg(unix)] +use std::os::unix::fs::symlink; +#[cfg(unix)] +use std::os::unix::net::{UnixListener, UnixStream}; use std::path::{Path, PathBuf}; use std::process::{Child, Command, Stdio}; use std::sync::{Arc, Mutex, MutexGuard, OnceLock}; @@ -8,7 +14,11 @@ use std::thread; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use serde::Deserialize; +#[cfg(unix)] +use serde_json::json; use serde_json::Value; +#[cfg(unix)] +use tempfile::tempdir; #[derive(Debug, Deserialize, Clone)] struct FixtureConfig { @@ -435,11 +445,338 @@ fn eval_runner_list_mode_serializes_parameter_defaults() { ); } +#[cfg(unix)] +#[test] +fn eval_runner_rows_mode_streams_js_rows_and_trials() { + let _guard = test_lock(); + if !command_exists("node") { + if required_runtimes().contains("node") { + panic!("node runtime is required but unavailable for rows-mode test"); + } + eprintln!( + "Skipping eval_runner_rows_mode_streams_js_rows_and_trials (node not installed)." + ); + return; + } + + let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let fixture_dir = root + .join("tests") + .join("evals") + .join("js") + .join("eval-ts-cjs"); + ensure_dependencies(&fixture_dir); + let runner = local_tsx_path(&fixture_dir).expect("resolve tsx runner"); + let runner_script = root.join("scripts").join("data-runner.ts"); + let temp_fixture_dir = tempdir().expect("create js rows tempdir"); + fs::copy( + fixture_dir.join("package.json"), + temp_fixture_dir.path().join("package.json"), + ) + .expect("copy js fixture package.json"); + symlink( + fixture_dir.join("node_modules"), + temp_fixture_dir.path().join("node_modules"), + ) + .expect("symlink js fixture node_modules"); + let fixture_name = "sandbox_rows.eval.ts"; + let fixture_path = temp_fixture_dir.path().join(fixture_name); + let fixture_source = r#"import { Eval } from "braintrust"; + +async function* rows() { + yield { input: { case_id: "row-1" }, expected: "alpha" }; + yield { input: { case_id: "row-2" }, expected: "bravo" }; +} + +Eval("sandbox-rows-js", { + data: rows, + task: async (input) => (input.case_id === "row-1" ? "alpha" : "bravo"), + scores: [ + ({ output, expected }) => ({ + name: "match", + score: output === expected ? 1 : 0, + }), + ], + maxConcurrency: 3, + trialCount: 2, +}); +"#; + fs::write(&fixture_path, fixture_source).expect("write js rows fixture"); + + let socket_path = unique_socket_path("bt-eval-js-rows"); + let listener = UnixListener::bind(&socket_path).expect("bind unix listener"); + listener + .set_nonblocking(true) + .expect("set listener nonblocking"); + + let mut child = Command::new(&runner) + .arg(&runner_script) + .arg(fixture_name) + .current_dir(temp_fixture_dir.path()) + .env("BT_EVAL_PULL_SOCK", &socket_path) + .env("BT_EVAL_LOCAL", "1") + .env("BT_EVAL_NO_SEND_LOGS", "1") + .env("BRAINTRUST_API_KEY", "local") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("spawn js eval runner"); + + let stream = accept_pull_stream(&listener, &mut child, Duration::from_secs(10)); + let mut reader = BufReader::new(stream.try_clone().expect("clone pull stream")); + let mut writer = stream; + + write_pull_message( + &mut writer, + &json!({ + "type": "start", + "name": "sandbox-rows-js", + }), + ); + let ready = read_pull_message(&mut reader); + assert_eq!(ready["type"], "ready"); + assert_eq!(ready["evaluator_name"], "sandbox-rows-js"); + assert_eq!(ready["max_concurrency"], 3); + assert!(ready["experiment_name"] + .as_str() + .is_some_and(|value| value.starts_with("sandbox-rows-js-"))); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let first = read_pull_message(&mut reader); + assert_eq!(first["type"], "row"); + assert_eq!(first["trial_index"], 0); + assert_eq!(first["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let second = read_pull_message(&mut reader); + assert_eq!(second["type"], "row"); + assert_eq!(second["trial_index"], 1); + assert_eq!(second["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let third = read_pull_message(&mut reader); + assert_eq!(third["type"], "row"); + assert_eq!(third["trial_index"], 0); + assert_eq!(third["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let fourth = read_pull_message(&mut reader); + assert_eq!(fourth["type"], "row"); + assert_eq!(fourth["trial_index"], 1); + assert_eq!(fourth["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let eof = read_pull_message(&mut reader); + assert_eq!(eof["type"], "eof"); + + write_pull_message(&mut writer, &json!({ "type": "close" })); + drop(writer); + drop(reader); + + let output = child.wait_with_output().expect("wait for js rows runner"); + if !output.status.success() { + panic!( + "js rows runner failed with status {}\nstdout:\n{}\nstderr:\n{}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + + let _ = fs::remove_file(&socket_path); +} + +#[cfg(unix)] +#[test] +fn eval_runner_rows_mode_streams_python_rows_and_trials() { + let _guard = test_lock(); + let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let fixtures_root = root.join("tests").join("evals"); + let python = match ensure_python_env(&fixtures_root.join("py")) { + Some(python) => python, + None => { + if required_runtimes().contains("python") { + panic!("python runtime unavailable for rows-mode test"); + } + eprintln!( + "Skipping eval_runner_rows_mode_streams_python_rows_and_trials (python runtime unavailable)." + ); + return; + } + }; + + let runner_script = root.join("scripts").join("data-runner.py"); + let temp_fixture_dir = tempdir().expect("create python rows tempdir"); + let fixture_name = "sandbox_rows.py"; + let fixture_path = temp_fixture_dir.path().join(fixture_name); + let fixture_source = r#"from braintrust import Eval + +def rows(): + yield {"input": {"case_id": "row-1"}, "expected": "alpha"} + yield {"input": {"case_id": "row-2"}, "expected": "bravo"} + +def task(input): + return "alpha" if input["case_id"] == "row-1" else "bravo" + +def match(output, expected): + return {"name": "match", "score": 1 if output == expected else 0} + +Eval( + "sandbox-rows-py", + data=rows, + task=task, + scores=[match], + max_concurrency=4, + trial_count=2, +) +"#; + fs::write(&fixture_path, fixture_source).expect("write python rows fixture"); + + let socket_path = unique_socket_path("bt-eval-py-rows"); + let listener = UnixListener::bind(&socket_path).expect("bind unix listener"); + listener + .set_nonblocking(true) + .expect("set listener nonblocking"); + + let mut child = Command::new(&python) + .arg(&runner_script) + .arg(fixture_name) + .current_dir(temp_fixture_dir.path()) + .env("BT_EVAL_PULL_SOCK", &socket_path) + .env("BT_EVAL_LOCAL", "1") + .env("BT_EVAL_NO_SEND_LOGS", "1") + .env("BRAINTRUST_API_KEY", "local") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("spawn python eval runner"); + + let stream = accept_pull_stream(&listener, &mut child, Duration::from_secs(10)); + let mut reader = BufReader::new(stream.try_clone().expect("clone pull stream")); + let mut writer = stream; + + write_pull_message( + &mut writer, + &json!({ + "type": "start", + "name": "sandbox-rows-py", + }), + ); + let ready = read_pull_message(&mut reader); + assert_eq!(ready["type"], "ready"); + assert_eq!(ready["evaluator_name"], "sandbox-rows-py"); + assert_eq!(ready["max_concurrency"], 4); + assert!(ready["experiment_name"] + .as_str() + .is_some_and(|value| value.starts_with("sandbox-rows-py-"))); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let first = read_pull_message(&mut reader); + assert_eq!(first["type"], "row"); + assert_eq!(first["trial_index"], 0); + assert_eq!(first["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let second = read_pull_message(&mut reader); + assert_eq!(second["type"], "row"); + assert_eq!(second["trial_index"], 1); + assert_eq!(second["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let third = read_pull_message(&mut reader); + assert_eq!(third["type"], "row"); + assert_eq!(third["trial_index"], 0); + assert_eq!(third["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let fourth = read_pull_message(&mut reader); + assert_eq!(fourth["type"], "row"); + assert_eq!(fourth["trial_index"], 1); + assert_eq!(fourth["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let eof = read_pull_message(&mut reader); + assert_eq!(eof["type"], "eof"); + + write_pull_message(&mut writer, &json!({ "type": "close" })); + drop(writer); + drop(reader); + + let output = child + .wait_with_output() + .expect("wait for python rows runner"); + if !output.status.success() { + panic!( + "python rows runner failed with status {}\nstdout:\n{}\nstderr:\n{}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + + let _ = fs::remove_file(&socket_path); +} + fn read_fixture_config(path: &Path) -> FixtureConfig { let raw = fs::read_to_string(path).expect("read fixture.json"); serde_json::from_str(&raw).expect("parse fixture.json") } +#[cfg(unix)] +fn unique_test_suffix() -> String { + format!( + "{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system clock before epoch") + .as_nanos() + ) +} + +#[cfg(unix)] +fn unique_socket_path(prefix: &str) -> PathBuf { + std::env::temp_dir().join(format!("{prefix}-{}.sock", unique_test_suffix())) +} + +#[cfg(unix)] +fn accept_pull_stream(listener: &UnixListener, child: &mut Child, timeout: Duration) -> UnixStream { + let started = Instant::now(); + loop { + match listener.accept() { + Ok((stream, _)) => return stream, + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {} + Err(err) => panic!("accept pull stream: {err}"), + } + + if let Some(status) = child.try_wait().expect("try_wait runner") { + panic!("runner exited early with status {status}"); + } + + if started.elapsed() > timeout { + panic!("timed out waiting for runner pull socket connection"); + } + + thread::sleep(Duration::from_millis(25)); + } +} + +#[cfg(unix)] +fn read_pull_message(reader: &mut BufReader) -> Value { + let mut line = String::new(); + let read = reader.read_line(&mut line).expect("read pull message"); + assert!(read > 0, "pull channel closed unexpectedly"); + serde_json::from_str(line.trim()).expect("parse pull message json") +} + +#[cfg(unix)] +fn write_pull_message(writer: &mut UnixStream, payload: &Value) { + writer + .write_all(format!("{payload}\n").as_bytes()) + .expect("write pull message"); + writer.flush().expect("flush pull message"); +} + fn collect_deno_eval_diagnostics(dir: &Path, files: &[String]) -> Option { if !command_exists("deno") { return None; diff --git a/tests/functions.rs b/tests/functions.rs index fa3de7c..419a351 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -2067,7 +2067,9 @@ exit 24 "expected 1 inserted function (sandbox only)" ); - let sandbox_obj = inserted[0].as_object().expect("sandbox should be an object"); + let sandbox_obj = inserted[0] + .as_object() + .expect("sandbox should be an object"); assert_eq!( sandbox_obj.get("slug").and_then(Value::as_str), Some("my-eval-my-eval-sandbox") @@ -2118,7 +2120,6 @@ exit 24 .and_then(Value::as_str), Some("my-eval") ); - } #[tokio::test(flavor = "multi_thread", worker_threads = 2)]