Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion py/autoevals/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
```
"""

import asyncio
import json
import os
import re
Expand Down Expand Up @@ -219,7 +220,94 @@ def _build_args(self, output, expected, **kwargs):
tool_choice={"type": "function", "function": {"name": "select_choice"}},
)

async def _render_messages_async(self, **kwargs):
"""Async version of _render_messages that can fetch thread data."""
# Check if we have a trace and if any message uses thread variables
trace = kwargs.get("trace")
thread_vars = {}

if trace:
# Check if any message template uses thread variables
uses_thread_vars = False
for m in self.messages:
content = m.get("content", "")
if isinstance(content, str):
# Import here to avoid circular dependency
from braintrust.thread_utils import template_uses_thread_variables

if template_uses_thread_variables(content):
uses_thread_vars = True
break

if uses_thread_vars:
try:
# Get the thread from the trace
thread = await trace.get_thread()

# Compute thread template variables
from braintrust.thread_utils import (
THREAD_VARIABLE_NAMES,
compute_thread_template_vars,
smart_escape_value,
)

computed = compute_thread_template_vars(thread)

# Build thread vars dict
for name in THREAD_VARIABLE_NAMES:
thread_vars[name] = getattr(computed, name, None)
except Exception as e:
# Log warning but continue - don't fail the evaluation
import warnings

warnings.warn(f"Failed to compute thread variables: {e}")

# Merge all kwargs, with thread_vars first so explicit args can override
merged_kwargs = {**self.render_args, **thread_vars, **kwargs}

# Set up custom escape function for chevron to handle LLM messages
from braintrust.thread_utils import smart_escape_value

original_escape = chevron.renderer._html_escape
chevron.renderer._html_escape = smart_escape_value

try:
return [
{
**m,
"content": chevron.render(m["content"].strip(), merged_kwargs, warn=True),
}
for m in self.messages
]
finally:
# Restore original escape function
chevron.renderer._html_escape = original_escape

def _render_messages(self, **kwargs):
"""Sync version of _render_messages that handles thread variables."""
# For sync version, we need to check if we have async context
trace = kwargs.get("trace")

if trace:
# Try to run the async version if we can
try:
# Check if we're already in an event loop
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context, create a task
import concurrent.futures

with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self._render_messages_async(**kwargs))
return future.result()
else:
# No running loop, we can create one
return asyncio.run(self._render_messages_async(**kwargs))
except RuntimeError:
# Can't get event loop, fall back to sync version without thread vars
pass

# Fallback: render without thread variables
kwargs.update(self.render_args)
return [
{
Expand Down Expand Up @@ -268,7 +356,29 @@ def _postprocess_response(self, resp):
raise ValueError("Empty response from OpenAI")

async def _run_eval_async(self, output, expected, **kwargs):
return self._postprocess_response(await arun_cached_request(**self._request_args(output, expected, **kwargs)))
# For async evaluation, we need to render messages asynchronously
# to properly handle thread variables
trace = kwargs.get("trace")
if trace:
# Render messages with async support for thread variables
messages = await self._render_messages_async(output=output, expected=expected, **kwargs)
# Build args with pre-rendered messages
request_args = {
"client": self.client,
**self.extra_args,
"model": self.model,
"messages": messages,
"tools": self.classification_tools,
"tool_choice": {"type": "function", "function": {"name": "select_choice"}},
}
if self.engine is not None:
request_args["engine"] = self.engine
return self._postprocess_response(await arun_cached_request(**request_args))
else:
# Fall back to sync rendering when no trace
return self._postprocess_response(
await arun_cached_request(**self._request_args(output, expected, **kwargs))
)

def _run_eval_sync(self, output, expected, **kwargs):
return self._postprocess_response(run_cached_request(**self._request_args(output, expected, **kwargs)))
Expand Down