Skip to content

Commit 50bde7e

Browse files
committed
remote scorers for evals + remote evals
1 parent e70508a commit 50bde7e

51 files changed

Lines changed: 1189 additions & 261 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

examples/src/main/java/dev/braintrust/examples/ExperimentExample.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ public static void main(String[] args) throws Exception {
5151
// .dataset(braintrust.fetchDataset("my-dataset-name"))
5252
.taskFunction(getFoodType)
5353
.scorers(
54+
// to fetch a remote scorer:
55+
// braintrust.fetchScorer("my-remote-scorer-6d9f"),
5456
Scorer.of(
5557
"exact_match",
5658
(expected, result) -> expected.equals(result) ? 1.0 : 0.0))

src/main/java/dev/braintrust/Braintrust.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import dev.braintrust.config.BraintrustConfig;
55
import dev.braintrust.eval.Dataset;
66
import dev.braintrust.eval.Eval;
7+
import dev.braintrust.eval.Scorer;
78
import dev.braintrust.prompt.BraintrustPromptLoader;
89
import dev.braintrust.trace.BraintrustTracing;
910
import io.opentelemetry.api.OpenTelemetry;
@@ -173,4 +174,40 @@ public <INPUT, OUTPUT> Dataset<INPUT, OUTPUT> fetchDataset(
173174
var projectName = apiClient.getOrCreateProjectAndOrgInfo(config).project().name();
174175
return Dataset.fetchFromBraintrust(apiClient(), projectName, datasetName, datasetVersion);
175176
}
177+
178+
/**
179+
* Fetch a scorer from Braintrust by slug, using the default project from configuration.
180+
*
181+
* @param scorerSlug the unique slug identifier for the scorer
182+
* @return a Scorer that invokes the remote function
183+
*/
184+
public <INPUT, OUTPUT> Scorer<INPUT, OUTPUT> fetchScorer(String scorerSlug) {
185+
return fetchScorer(scorerSlug, null);
186+
}
187+
188+
/**
189+
* Fetch a scorer from Braintrust by slug, using the default project from configuration.
190+
*
191+
* @param scorerSlug the unique slug identifier for the scorer
192+
* @param version optional version of the scorer to fetch
193+
* @return a Scorer that invokes the remote function
194+
*/
195+
public <INPUT, OUTPUT> Scorer<INPUT, OUTPUT> fetchScorer(
196+
String scorerSlug, @Nullable String version) {
197+
var projectName = apiClient.getOrCreateProjectAndOrgInfo(config).project().name();
198+
return Scorer.fetchFromBraintrust(apiClient, projectName, scorerSlug, version);
199+
}
200+
201+
/**
202+
* Fetch a scorer from Braintrust by project name and slug.
203+
*
204+
* @param projectName the name of the project containing the scorer
205+
* @param scorerSlug the unique slug identifier for the scorer
206+
* @param version optional version of the scorer to fetch
207+
* @return a Scorer that invokes the remote function
208+
*/
209+
public <INPUT, OUTPUT> Scorer<INPUT, OUTPUT> fetchScorer(
210+
String projectName, String scorerSlug, @Nullable String version) {
211+
return Scorer.fetchFromBraintrust(apiClient, projectName, scorerSlug, version);
212+
}
176213
}

src/main/java/dev/braintrust/api/BraintrustApiClient.java

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,26 @@ Optional<Prompt> getPrompt(
6666
/** Query datasets by project name and dataset name */
6767
List<Dataset> queryDatasets(String projectName, String datasetName);
6868

69+
/**
70+
* Get a function by project name and slug, with optional version.
71+
*
72+
* @param projectName the name of the project containing the function
73+
* @param slug the unique slug identifier for the function
74+
* @param version optional version identifier (transaction id or version string)
75+
* @return the function if found
76+
*/
77+
Optional<Function> getFunction(
78+
@Nonnull String projectName, @Nonnull String slug, @Nullable String version);
79+
80+
/**
81+
* Invoke a function (scorer, prompt, or tool) by its ID.
82+
*
83+
* @param functionId the ID of the function to invoke
84+
* @param request the invocation request containing input, expected output, etc.
85+
* @return the result of the function invocation
86+
*/
87+
Object invokeFunction(@Nonnull String functionId, @Nonnull FunctionInvokeRequest request);
88+
6989
static BraintrustApiClient of(BraintrustConfig config) {
7090
return new HttpImpl(config);
7191
}
@@ -296,6 +316,54 @@ public List<Dataset> queryDatasets(String projectName, String datasetName) {
296316
}
297317
}
298318

319+
@Override
320+
public Optional<Function> getFunction(
321+
@Nonnull String projectName, @Nonnull String slug, @Nullable String version) {
322+
Objects.requireNonNull(projectName, "projectName must not be null");
323+
Objects.requireNonNull(slug, "slug must not be null");
324+
try {
325+
var uriBuilder = new StringBuilder("/v1/function?");
326+
uriBuilder.append("slug=").append(slug);
327+
uriBuilder.append("&project_name=").append(projectName);
328+
329+
if (version != null && !version.isEmpty()) {
330+
uriBuilder.append("&version=").append(version);
331+
}
332+
333+
FunctionListResponse response =
334+
getAsync(uriBuilder.toString(), FunctionListResponse.class).get();
335+
336+
if (response.objects() == null || response.objects().isEmpty()) {
337+
return Optional.empty();
338+
}
339+
340+
if (response.objects().size() > 1) {
341+
throw new ApiException(
342+
"Multiple functions found for slug: "
343+
+ slug
344+
+ ", projectName: "
345+
+ projectName);
346+
}
347+
348+
return Optional.of(response.objects().get(0));
349+
} catch (InterruptedException | ExecutionException e) {
350+
throw new RuntimeException(e);
351+
}
352+
}
353+
354+
@Override
355+
public Object invokeFunction(
356+
@Nonnull String functionId, @Nonnull FunctionInvokeRequest request) {
357+
Objects.requireNonNull(functionId, "functionId must not be null");
358+
Objects.requireNonNull(request, "request must not be null");
359+
try {
360+
String path = "/v1/function/" + functionId + "/invoke";
361+
return postAsync(path, request, Object.class).get();
362+
} catch (InterruptedException | ExecutionException e) {
363+
throw new ApiException("Failed to invoke function: " + functionId, e);
364+
}
365+
}
366+
299367
private <T> CompletableFuture<T> getAsync(String path, Class<T> responseType) {
300368
var request =
301369
HttpRequest.newBuilder()
@@ -399,6 +467,9 @@ class InMemoryImpl implements BraintrustApiClient {
399467
private final Set<Experiment> experiments =
400468
Collections.newSetFromMap(new ConcurrentHashMap<>());
401469
private final List<Prompt> prompts = new ArrayList<>();
470+
private final List<Function> functions = new ArrayList<>();
471+
private final Map<String, java.util.function.Function<FunctionInvokeRequest, Object>>
472+
functionInvokers = new ConcurrentHashMap<>();
402473

403474
public InMemoryImpl(OrganizationAndProjectInfo... organizationAndProjectInfos) {
404475
this.organizationAndProjectInfos =
@@ -583,6 +654,18 @@ public Optional<Dataset> getDataset(String datasetId) {
583654
public List<Dataset> queryDatasets(String projectName, String datasetName) {
584655
return List.of();
585656
}
657+
658+
@Override
659+
public Optional<Function> getFunction(
660+
@Nonnull String projectName, @Nonnull String slug, @Nullable String version) {
661+
throw new RuntimeException("will not be invoked");
662+
}
663+
664+
@Override
665+
public Object invokeFunction(
666+
@Nonnull String functionId, @Nonnull FunctionInvokeRequest request) {
667+
throw new RuntimeException("will not be invoked");
668+
}
586669
}
587670

588671
// Request/Response DTOs
@@ -681,4 +764,59 @@ record Prompt(
681764
Optional<Object> metadata) {}
682765

683766
record PromptListResponse(List<Prompt> objects) {}
767+
768+
// Function models for remote scorers/prompts/tools
769+
770+
/**
771+
* Represents a Braintrust function (scorer, prompt, tool, or task). Functions can be invoked
772+
* remotely via the API.
773+
*/
774+
record Function(
775+
String id,
776+
String projectId,
777+
String orgId,
778+
String name,
779+
String slug,
780+
Optional<String> description,
781+
String created,
782+
Optional<Object> functionData,
783+
Optional<Object> promptData,
784+
Optional<List<String>> tags,
785+
Optional<Object> metadata,
786+
Optional<String> functionType,
787+
Optional<Object> origin,
788+
Optional<Object> functionSchema) {}
789+
790+
record FunctionListResponse(List<Function> objects) {}
791+
792+
/**
793+
* Request body for invoking a function. The input field wraps the function arguments.
794+
*
795+
* <p>For remote Python/TypeScript scorers, the scorer handler parameters (input, output,
796+
* expected, metadata) must be wrapped in the outer input field.
797+
*/
798+
record FunctionInvokeRequest(@Nullable Object input) {
799+
800+
/** Create a simple invoke request with just input */
801+
public static FunctionInvokeRequest of(Object input) {
802+
return new FunctionInvokeRequest(input);
803+
}
804+
805+
/**
806+
* Create an invoke request for a scorer with input, output, expected, and metadata. This
807+
* maps to the standard scorer handler signature: handler(input, output, expected, metadata)
808+
*
809+
* <p>The scorer args are wrapped in the outer input field as required by the invoke API.
810+
*/
811+
public static FunctionInvokeRequest forScorer(
812+
Object input, Object output, Object expected, Object metadata) {
813+
// Wrap scorer args in an inner map that becomes the outer "input" field
814+
var scorerArgs = new java.util.LinkedHashMap<String, Object>();
815+
scorerArgs.put("input", input);
816+
scorerArgs.put("output", output);
817+
scorerArgs.put("expected", expected);
818+
scorerArgs.put("metadata", metadata);
819+
return new FunctionInvokeRequest(scorerArgs);
820+
}
821+
}
684822
}

src/main/java/dev/braintrust/devserver/Devserver.java

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,18 @@ private void handleEval(HttpExchange exchange) throws IOException {
288288
return;
289289
}
290290

291-
// TODO: support remote scorers
291+
// Resolve remote scorers from the request
292+
List<Scorer<Object, Object>> remoteScorers = new ArrayList<>();
293+
if (request.getScores() != null) {
294+
var apiClient = context.getBraintrust().apiClient();
295+
for (var remoteScorer : request.getScores()) {
296+
remoteScorers.add(resolveRemoteScorer(remoteScorer, apiClient));
297+
}
298+
log.debug(
299+
"Resolved {} remote scorer(s): {}",
300+
remoteScorers.size(),
301+
remoteScorers.stream().map(Scorer::getName).toList());
302+
}
292303

293304
String datasetDescription =
294305
hasInlineData
@@ -308,7 +319,7 @@ private void handleEval(HttpExchange exchange) throws IOException {
308319
if (isStreaming) {
309320
// SSE streaming response - errors handled inside
310321
log.debug("Starting streaming evaluation for '{}'", request.getName());
311-
handleStreamingEval(exchange, eval, request, context);
322+
handleStreamingEval(exchange, eval, request, context, remoteScorers);
312323
} else {
313324
throw new NotSupportedYetException("non-streaming responses");
314325
}
@@ -325,7 +336,11 @@ private void handleEval(HttpExchange exchange) throws IOException {
325336

326337
@SuppressWarnings({"unchecked", "rawtypes"})
327338
private void handleStreamingEval(
328-
HttpExchange exchange, RemoteEval eval, EvalRequest request, RequestContext context)
339+
HttpExchange exchange,
340+
RemoteEval eval,
341+
EvalRequest request,
342+
RequestContext context,
343+
List<Scorer<Object, Object>> remoteScorers)
329344
throws Exception {
330345
// Set SSE headers
331346
exchange.getResponseHeaders().set("Content-Type", "text/event-stream");
@@ -423,7 +438,12 @@ private void handleStreamingEval(
423438
taskResult);
424439
}
425440
// run scorers - one score span per scorer
426-
for (var scorer : (List<Scorer<?, ?>>) eval.getScorers()) {
441+
// Combine local scorers from RemoteEval with remote scorers
442+
// from request
443+
List<Scorer<?, ?>> allScorers =
444+
new ArrayList<>(eval.getScorers());
445+
allScorers.addAll(remoteScorers);
446+
for (var scorer : allScorers) {
427447
var scoreSpan = tracer.spanBuilder("score").startSpan();
428448
try (var unused =
429449
Context.current()
@@ -1037,6 +1057,30 @@ private static ParentInfo extractParentInfo(EvalRequest request) {
10371057
}
10381058
}
10391059

1060+
/**
1061+
* Resolve a remote scorer from the eval request into a Scorer instance.
1062+
*
1063+
* @param remoteScorer the remote scorer specification from the request
1064+
* @param apiClient the API client to use for invoking the scorer function
1065+
* @return a Scorer that invokes the remote function
1066+
* @throws IllegalArgumentException if the function_id is missing
1067+
*/
1068+
private static Scorer<Object, Object> resolveRemoteScorer(
1069+
EvalRequest.RemoteScorer remoteScorer, BraintrustApiClient apiClient) {
1070+
var functionIdSpec = remoteScorer.getFunctionId();
1071+
1072+
if (functionIdSpec == null || functionIdSpec.getFunctionId() == null) {
1073+
throw new IllegalArgumentException(
1074+
"Remote scorer '" + remoteScorer.getName() + "' missing function_id");
1075+
}
1076+
1077+
return new ScorerBrainstoreImpl<>(
1078+
apiClient,
1079+
functionIdSpec.getFunctionId(),
1080+
remoteScorer.getName(),
1081+
functionIdSpec.getVersion());
1082+
}
1083+
10401084
public static class Builder {
10411085
private @Nullable BraintrustConfig config = null;
10421086
private String host = "localhost";

src/main/java/dev/braintrust/eval/Scorer.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package dev.braintrust.eval;
22

3+
import dev.braintrust.api.BraintrustApiClient;
34
import java.util.List;
45
import java.util.function.BiFunction;
56
import java.util.function.Function;
7+
import javax.annotation.Nullable;
68

79
/**
810
* A scorer evaluates the result of a test case with a score between 0 (inclusive) and 1
@@ -49,4 +51,33 @@ public List<Score> score(TaskResult<INPUT, OUTPUT> taskResult) {
4951
}
5052
};
5153
}
54+
55+
/**
56+
* Fetch a scorer from Braintrust by project name and slug.
57+
*
58+
* @param apiClient the API client to use
59+
* @param projectName the name of the project containing the scorer
60+
* @param scorerSlug the unique slug identifier for the scorer
61+
* @param version optional version of the scorer to fetch
62+
* @return a Scorer that invokes the remote function
63+
* @throws RuntimeException if the scorer is not found
64+
*/
65+
static <INPUT, OUTPUT> Scorer<INPUT, OUTPUT> fetchFromBraintrust(
66+
BraintrustApiClient apiClient,
67+
String projectName,
68+
String scorerSlug,
69+
@Nullable String version) {
70+
var function =
71+
apiClient
72+
.getFunction(projectName, scorerSlug, version)
73+
.orElseThrow(
74+
() ->
75+
new RuntimeException(
76+
"Scorer not found: project="
77+
+ projectName
78+
+ ", slug="
79+
+ scorerSlug));
80+
81+
return new ScorerBrainstoreImpl<>(apiClient, function.id(), function.name(), version);
82+
}
5283
}

0 commit comments

Comments
 (0)