diff --git a/py/autoevals/llm.py b/py/autoevals/llm.py index 0bbc7d4..d3f713c 100644 --- a/py/autoevals/llm.py +++ b/py/autoevals/llm.py @@ -45,6 +45,7 @@ ``` """ +import asyncio import json import os import re @@ -219,7 +220,94 @@ def _build_args(self, output, expected, **kwargs): tool_choice={"type": "function", "function": {"name": "select_choice"}}, ) + async def _render_messages_async(self, **kwargs): + """Async version of _render_messages that can fetch thread data.""" + # Check if we have a trace and if any message uses thread variables + trace = kwargs.get("trace") + thread_vars = {} + + if trace: + # Check if any message template uses thread variables + uses_thread_vars = False + for m in self.messages: + content = m.get("content", "") + if isinstance(content, str): + # Import here to avoid circular dependency + from braintrust.thread_utils import template_uses_thread_variables + + if template_uses_thread_variables(content): + uses_thread_vars = True + break + + if uses_thread_vars: + try: + # Get the thread from the trace + thread = await trace.get_thread() + + # Compute thread template variables + from braintrust.thread_utils import ( + THREAD_VARIABLE_NAMES, + compute_thread_template_vars, + smart_escape_value, + ) + + computed = compute_thread_template_vars(thread) + + # Build thread vars dict + for name in THREAD_VARIABLE_NAMES: + thread_vars[name] = getattr(computed, name, None) + except Exception as e: + # Log warning but continue - don't fail the evaluation + import warnings + + warnings.warn(f"Failed to compute thread variables: {e}") + + # Merge all kwargs, with thread_vars first so explicit args can override + merged_kwargs = {**self.render_args, **thread_vars, **kwargs} + + # Set up custom escape function for chevron to handle LLM messages + from braintrust.thread_utils import smart_escape_value + + original_escape = chevron.renderer._html_escape + chevron.renderer._html_escape = smart_escape_value + + try: + return [ + { + **m, + "content": chevron.render(m["content"].strip(), merged_kwargs, warn=True), + } + for m in self.messages + ] + finally: + # Restore original escape function + chevron.renderer._html_escape = original_escape + def _render_messages(self, **kwargs): + """Sync version of _render_messages that handles thread variables.""" + # For sync version, we need to check if we have async context + trace = kwargs.get("trace") + + if trace: + # Try to run the async version if we can + try: + # Check if we're already in an event loop + loop = asyncio.get_event_loop() + if loop.is_running(): + # We're in an async context, create a task + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self._render_messages_async(**kwargs)) + return future.result() + else: + # No running loop, we can create one + return asyncio.run(self._render_messages_async(**kwargs)) + except RuntimeError: + # Can't get event loop, fall back to sync version without thread vars + pass + + # Fallback: render without thread variables kwargs.update(self.render_args) return [ { @@ -268,7 +356,29 @@ def _postprocess_response(self, resp): raise ValueError("Empty response from OpenAI") async def _run_eval_async(self, output, expected, **kwargs): - return self._postprocess_response(await arun_cached_request(**self._request_args(output, expected, **kwargs))) + # For async evaluation, we need to render messages asynchronously + # to properly handle thread variables + trace = kwargs.get("trace") + if trace: + # Render messages with async support for thread variables + messages = await self._render_messages_async(output=output, expected=expected, **kwargs) + # Build args with pre-rendered messages + request_args = { + "client": self.client, + **self.extra_args, + "model": self.model, + "messages": messages, + "tools": self.classification_tools, + "tool_choice": {"type": "function", "function": {"name": "select_choice"}}, + } + if self.engine is not None: + request_args["engine"] = self.engine + return self._postprocess_response(await arun_cached_request(**request_args)) + else: + # Fall back to sync rendering when no trace + return self._postprocess_response( + await arun_cached_request(**self._request_args(output, expected, **kwargs)) + ) def _run_eval_sync(self, output, expected, **kwargs): return self._postprocess_response(run_cached_request(**self._request_args(output, expected, **kwargs)))