diff --git a/src/eva/metrics/base.py b/src/eva/metrics/base.py index 174273ff..f2ecbbcf 100644 --- a/src/eva/metrics/base.py +++ b/src/eva/metrics/base.py @@ -84,6 +84,7 @@ def __init__( response_speed_latencies: list[float] | None = None, assistant_interrupted_turns: set[int] | None = None, user_interrupted_turns: set[int] | None = None, + assistant_responding_to_user_turn: dict[int, int] | None = None, is_audio_native: bool = False, ): self.record_id = record_id @@ -134,6 +135,7 @@ def __init__( self.response_speed_latencies = response_speed_latencies or [] self.assistant_interrupted_turns = assistant_interrupted_turns or set() self.user_interrupted_turns = user_interrupted_turns or set() + self.assistant_responding_to_user_turn = assistant_responding_to_user_turn or {} self.is_audio_native = is_audio_native def to_dict(self) -> dict[str, Any]: diff --git a/src/eva/metrics/experience/turn_taking.py b/src/eva/metrics/experience/turn_taking.py index 92f6e50b..c7eb99e1 100644 --- a/src/eva/metrics/experience/turn_taking.py +++ b/src/eva/metrics/experience/turn_taking.py @@ -27,8 +27,21 @@ class TurnTakingMetric(TextJudgeMetric): @staticmethod def _get_turn_ids_with_turn_taking(context: MetricContext) -> list[int]: - """Return sorted turn IDs for user-assistant exchange pairs (excludes greeting).""" - return sorted(context.transcribed_user_turns.keys() & context.transcribed_assistant_turns.keys() - {0}) + """Return sorted assistant turn IDs to evaluate (excludes greeting). + + Includes assistant turns that have a responding_to user turn mapping, + even if there's no user transcript at the same turn ID (e.g., after tool calls). + """ + # All assistant turns except greeting + assistant_turns = set(context.transcribed_assistant_turns.keys()) - {0} + # Filter to those with a valid user turn to measure latency against + responding_to = context.assistant_responding_to_user_turn + valid_turns = { + t + for t in assistant_turns + if t in context.transcribed_user_turns or responding_to.get(t) in context.transcribed_user_turns + } + return sorted(valid_turns) def _format_conversation_context( self, @@ -52,14 +65,17 @@ def _format_conversation_context( all_starts = [segs[0][0] for segs in all_timestamps.values() if segs] t0 = min(all_starts) if all_starts else 0 + responding_to = context.assistant_responding_to_user_turn blocks = [] for turn_id in turn_keys: - user_heard = context.transcribed_user_turns.get(turn_id, "") + # Use responding_to user turn for user data (handles tool call advances) + user_turn_id = responding_to.get(turn_id, turn_id) + user_heard = context.transcribed_user_turns.get(user_turn_id, "") asst_heard = context.transcribed_assistant_turns.get(turn_id, "") - user_expected = context.intended_user_turns.get(turn_id, "") + user_expected = context.intended_user_turns.get(user_turn_id, "") asst_expected = context.intended_assistant_turns.get(turn_id, "") - u_segments = context.audio_timestamps_user_turns.get(turn_id) + u_segments = context.audio_timestamps_user_turns.get(user_turn_id) a_segments = context.audio_timestamps_assistant_turns.get(turn_id) latency = per_turn_latency.get(turn_id) @@ -136,6 +152,7 @@ def _compute_per_turn_latency_and_timing_labels( Turns with missing timestamps get None values. Latency is user_end -> asst_start in seconds. + For assistant turns after tool calls, uses the responding_to user turn's timestamps. Timing label thresholds: latency < 200 ms -> "Early / Interrupting" 200 ms <= latency < 4000 ms -> "On-Time" @@ -143,15 +160,18 @@ def _compute_per_turn_latency_and_timing_labels( """ user_ts = context.audio_timestamps_user_turns asst_ts = context.audio_timestamps_assistant_turns + responding_to = context.assistant_responding_to_user_turn latencies: dict[int, float | None] = {} labels: dict[int, str | None] = {} for turn_id in turn_keys: - u = user_ts.get(turn_id) + # Use the responding_to user turn for latency calculation (handles tool call advances) + user_turn_id = responding_to.get(turn_id, turn_id) + u = user_ts.get(user_turn_id) a = asst_ts.get(turn_id) if not u or not a: if turn_id != len(turn_keys): self.logger.warning( - f"[{context.record_id}] Missing audio timestamps at turn {turn_id}/{len(turn_keys)} (user={u}, assistant={a}); skipping turn taking for this turn." + f"[{context.record_id}] Missing audio timestamps at turn {turn_id}/{len(turn_keys)} (user={u} at turn {user_turn_id}, assistant={a}); skipping turn taking for this turn." ) latencies[turn_id] = None labels[turn_id] = None @@ -176,8 +196,11 @@ async def compute(self, context: MetricContext) -> MetricScore: # Identify turn keys where either timestamp is missing. _user_ts = context.audio_timestamps_user_turns _asst_ts = context.audio_timestamps_assistant_turns + _responding_to = context.assistant_responding_to_user_turn skipped_turn_ids = { - turn_id for turn_id in turn_keys if not _user_ts.get(turn_id) or not _asst_ts.get(turn_id) + turn_id + for turn_id in turn_keys + if not _user_ts.get(_responding_to.get(turn_id, turn_id)) or not _asst_ts.get(turn_id) } turns_missing_timestamps = sorted(skipped_turn_ids) if skipped_turn_ids: diff --git a/src/eva/metrics/processor.py b/src/eva/metrics/processor.py index ad4d5d6c..5cbf121b 100644 --- a/src/eva/metrics/processor.py +++ b/src/eva/metrics/processor.py @@ -66,6 +66,7 @@ class _TurnExtractionState: assistant_interrupted_turns: set[int] = field(default_factory=set) user_interrupted_turns: set[int] = field(default_factory=set) pending_user_interrupts_label: bool = False # Next user entry should get [user interrupts] prefix + current_assistant_audio_is_interruption: bool = False # Current assistant audio session started while user speaking # Track which turn each speaker's audio started at, so late-arriving speech transcripts land at the correct turn. last_assistant_audio_turn: int = 0 last_user_audio_turn: int = 0 @@ -95,6 +96,7 @@ def advance_turn_if_needed(self) -> None: Called on audio_start(elevenlabs_user) and audit_log/user events. After an interruption, hold_turn consumes one advance without incrementing. + The turn will advance when the assistant interrupts again (repeat interruption). """ if self.hold_turn: self.hold_turn = False @@ -254,8 +256,22 @@ def _handle_pipecat_event( """ if event["event_type"] not in ("tts_text", "llm_response"): return + # If a tool call happened in this turn, advance to a new turn for the tool response + # This creates a separate turn for measuring response speed accurately + if state.assistant_processed_in_turn: + # Only carry over the interrupted label if the CURRENT assistant audio session + # started while the user was speaking (i.e., this speech is the interruption) + is_interruption = state.current_assistant_audio_is_interruption + state.turn_num += 1 + if is_interruption: + state.assistant_interrupted_turns.add(state.turn_num) + state.assistant_processed_in_turn = False + state.user_audio_started_in_turn = False state.assistant_spoke_in_turn = True turn = state.turn_num + # Record which user turn this assistant turn is responding to (for latency calculation) + if turn not in context.assistant_responding_to_user_turn: + context.assistant_responding_to_user_turn[turn] = state.last_user_audio_turn existing = context.intended_assistant_turns.get(turn, "") if existing: @@ -268,6 +284,8 @@ def _handle_pipecat_event( sep = "" elif turn in state.user_interrupted_turns: sep = f" {AnnotationLabel.CUT_OFF_BY_USER} " + elif turn in state.assistant_interrupted_turns: + sep = f" {AnnotationLabel.ASSISTANT_INTERRUPTS} " elif state.assistant_processed_in_turn: sep = f" {AnnotationLabel.PAUSE_TOOL_CALL} " else: @@ -354,6 +372,9 @@ def _handle_audio_start( if state.user_audio_open and state.user_audio_started_in_turn and not state.assistant_processed_in_turn: state.assistant_interrupted_turns.add(state.turn_num) state.hold_turn = True + state.current_assistant_audio_is_interruption = True + else: + state.current_assistant_audio_is_interruption = False turn_idx = state.turn_num key = (role, turn_idx) @@ -678,6 +699,10 @@ def __init__(self): self.assistant_interrupted_turns: set[int] = set() self.user_interrupted_turns: set[int] = set() + # Maps each assistant turn to the user turn it's responding to + # (for latency calculation when turns advance after tool calls) + self.assistant_responding_to_user_turn: dict[int, int] = {} + # Conversation metadata self.conversation_finished: bool = False self.conversation_ended_reason: Optional[str] = None