Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/eva/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
response_speed_latencies: list[float] | None = None,
assistant_interrupted_turns: set[int] | None = None,
user_interrupted_turns: set[int] | None = None,
assistant_responding_to_user_turn: dict[int, int] | None = None,
is_audio_native: bool = False,
):
self.record_id = record_id
Expand Down Expand Up @@ -134,6 +135,7 @@ def __init__(
self.response_speed_latencies = response_speed_latencies or []
self.assistant_interrupted_turns = assistant_interrupted_turns or set()
self.user_interrupted_turns = user_interrupted_turns or set()
self.assistant_responding_to_user_turn = assistant_responding_to_user_turn or {}
self.is_audio_native = is_audio_native

def to_dict(self) -> dict[str, Any]:
Expand Down
39 changes: 31 additions & 8 deletions src/eva/metrics/experience/turn_taking.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,21 @@ class TurnTakingMetric(TextJudgeMetric):

@staticmethod
def _get_turn_ids_with_turn_taking(context: MetricContext) -> list[int]:
"""Return sorted turn IDs for user-assistant exchange pairs (excludes greeting)."""
return sorted(context.transcribed_user_turns.keys() & context.transcribed_assistant_turns.keys() - {0})
"""Return sorted assistant turn IDs to evaluate (excludes greeting).

Includes assistant turns that have a responding_to user turn mapping,
even if there's no user transcript at the same turn ID (e.g., after tool calls).
"""
# All assistant turns except greeting
assistant_turns = set(context.transcribed_assistant_turns.keys()) - {0}
# Filter to those with a valid user turn to measure latency against
responding_to = context.assistant_responding_to_user_turn
valid_turns = {
t
for t in assistant_turns
if t in context.transcribed_user_turns or responding_to.get(t) in context.transcribed_user_turns
}
return sorted(valid_turns)

def _format_conversation_context(
self,
Expand All @@ -52,14 +65,17 @@ def _format_conversation_context(
all_starts = [segs[0][0] for segs in all_timestamps.values() if segs]
t0 = min(all_starts) if all_starts else 0

responding_to = context.assistant_responding_to_user_turn
blocks = []
for turn_id in turn_keys:
user_heard = context.transcribed_user_turns.get(turn_id, "")
# Use responding_to user turn for user data (handles tool call advances)
user_turn_id = responding_to.get(turn_id, turn_id)
user_heard = context.transcribed_user_turns.get(user_turn_id, "")
asst_heard = context.transcribed_assistant_turns.get(turn_id, "")
user_expected = context.intended_user_turns.get(turn_id, "")
user_expected = context.intended_user_turns.get(user_turn_id, "")
asst_expected = context.intended_assistant_turns.get(turn_id, "")

u_segments = context.audio_timestamps_user_turns.get(turn_id)
u_segments = context.audio_timestamps_user_turns.get(user_turn_id)
a_segments = context.audio_timestamps_assistant_turns.get(turn_id)
latency = per_turn_latency.get(turn_id)

Expand Down Expand Up @@ -136,22 +152,26 @@ def _compute_per_turn_latency_and_timing_labels(
Turns with missing timestamps get None values.

Latency is user_end -> asst_start in seconds.
For assistant turns after tool calls, uses the responding_to user turn's timestamps.
Timing label thresholds:
latency < 200 ms -> "Early / Interrupting"
200 ms <= latency < 4000 ms -> "On-Time"
latency >= 4000 ms -> "Late"
"""
user_ts = context.audio_timestamps_user_turns
asst_ts = context.audio_timestamps_assistant_turns
responding_to = context.assistant_responding_to_user_turn
latencies: dict[int, float | None] = {}
labels: dict[int, str | None] = {}
for turn_id in turn_keys:
u = user_ts.get(turn_id)
# Use the responding_to user turn for latency calculation (handles tool call advances)
user_turn_id = responding_to.get(turn_id, turn_id)
u = user_ts.get(user_turn_id)
a = asst_ts.get(turn_id)
if not u or not a:
if turn_id != len(turn_keys):
self.logger.warning(
f"[{context.record_id}] Missing audio timestamps at turn {turn_id}/{len(turn_keys)} (user={u}, assistant={a}); skipping turn taking for this turn."
f"[{context.record_id}] Missing audio timestamps at turn {turn_id}/{len(turn_keys)} (user={u} at turn {user_turn_id}, assistant={a}); skipping turn taking for this turn."
)
latencies[turn_id] = None
labels[turn_id] = None
Expand All @@ -176,8 +196,11 @@ async def compute(self, context: MetricContext) -> MetricScore:
# Identify turn keys where either timestamp is missing.
_user_ts = context.audio_timestamps_user_turns
_asst_ts = context.audio_timestamps_assistant_turns
_responding_to = context.assistant_responding_to_user_turn
skipped_turn_ids = {
turn_id for turn_id in turn_keys if not _user_ts.get(turn_id) or not _asst_ts.get(turn_id)
turn_id
for turn_id in turn_keys
if not _user_ts.get(_responding_to.get(turn_id, turn_id)) or not _asst_ts.get(turn_id)
}
turns_missing_timestamps = sorted(skipped_turn_ids)
if skipped_turn_ids:
Expand Down
25 changes: 25 additions & 0 deletions src/eva/metrics/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class _TurnExtractionState:
assistant_interrupted_turns: set[int] = field(default_factory=set)
user_interrupted_turns: set[int] = field(default_factory=set)
pending_user_interrupts_label: bool = False # Next user entry should get [user interrupts] prefix
current_assistant_audio_is_interruption: bool = False # Current assistant audio session started while user speaking
# Track which turn each speaker's audio started at, so late-arriving speech transcripts land at the correct turn.
last_assistant_audio_turn: int = 0
last_user_audio_turn: int = 0
Expand Down Expand Up @@ -95,6 +96,7 @@ def advance_turn_if_needed(self) -> None:

Called on audio_start(elevenlabs_user) and audit_log/user events.
After an interruption, hold_turn consumes one advance without incrementing.
The turn will advance when the assistant interrupts again (repeat interruption).
"""
if self.hold_turn:
self.hold_turn = False
Expand Down Expand Up @@ -254,8 +256,22 @@ def _handle_pipecat_event(
"""
if event["event_type"] not in ("tts_text", "llm_response"):
return
# If a tool call happened in this turn, advance to a new turn for the tool response
# This creates a separate turn for measuring response speed accurately
if state.assistant_processed_in_turn:
# Only carry over the interrupted label if the CURRENT assistant audio session
# started while the user was speaking (i.e., this speech is the interruption)
is_interruption = state.current_assistant_audio_is_interruption
state.turn_num += 1
if is_interruption:
state.assistant_interrupted_turns.add(state.turn_num)
state.assistant_processed_in_turn = False
state.user_audio_started_in_turn = False
state.assistant_spoke_in_turn = True
turn = state.turn_num
# Record which user turn this assistant turn is responding to (for latency calculation)
if turn not in context.assistant_responding_to_user_turn:
context.assistant_responding_to_user_turn[turn] = state.last_user_audio_turn
existing = context.intended_assistant_turns.get(turn, "")

if existing:
Expand All @@ -268,6 +284,8 @@ def _handle_pipecat_event(
sep = ""
elif turn in state.user_interrupted_turns:
sep = f" {AnnotationLabel.CUT_OFF_BY_USER} "
elif turn in state.assistant_interrupted_turns:
sep = f" {AnnotationLabel.ASSISTANT_INTERRUPTS} "
elif state.assistant_processed_in_turn:
sep = f" {AnnotationLabel.PAUSE_TOOL_CALL} "
else:
Expand Down Expand Up @@ -354,6 +372,9 @@ def _handle_audio_start(
if state.user_audio_open and state.user_audio_started_in_turn and not state.assistant_processed_in_turn:
state.assistant_interrupted_turns.add(state.turn_num)
state.hold_turn = True
state.current_assistant_audio_is_interruption = True
else:
state.current_assistant_audio_is_interruption = False

turn_idx = state.turn_num
key = (role, turn_idx)
Expand Down Expand Up @@ -678,6 +699,10 @@ def __init__(self):
self.assistant_interrupted_turns: set[int] = set()
self.user_interrupted_turns: set[int] = set()

# Maps each assistant turn to the user turn it's responding to
# (for latency calculation when turns advance after tool calls)
self.assistant_responding_to_user_turn: dict[int, int] = {}

# Conversation metadata
self.conversation_finished: bool = False
self.conversation_ended_reason: Optional[str] = None
Expand Down
Loading