Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions server/reflector/hatchet/workflows/daily_multitrack_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
chunk_text=chunk["text"],
timestamp=chunk["timestamp"],
duration=chunk["duration"],
words=chunk["words"],
)
)
for chunk in chunks
Expand All @@ -732,31 +731,41 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
TopicChunkResult(**result[TaskName.DETECT_CHUNK_TOPIC]) for result in results
]

# Build index-to-words map from local chunks (words not in child workflow results)
chunks_by_index = {chunk["index"]: chunk["words"] for chunk in chunks}

async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)
if not transcript:
raise ValueError(f"Transcript {input.transcript_id} not found")

# Clear topics for idempotency on retry (each topic gets a fresh UUID,
# so upsert_topic would append duplicates without this)
await transcripts_controller.update(transcript, {"topics": []})

for chunk in topic_chunks:
chunk_words = chunks_by_index[chunk.chunk_index]
topic = TranscriptTopic(
title=chunk.title,
summary=chunk.summary,
timestamp=chunk.timestamp,
transcript=" ".join(w.text for w in chunk.words),
words=chunk.words,
transcript=" ".join(w.text for w in chunk_words),
words=chunk_words,
)
await transcripts_controller.upsert_topic(transcript, topic)
await append_event_and_broadcast(
input.transcript_id, transcript, "TOPIC", topic, logger=logger
)

# Words omitted from TopicsResult — already persisted to DB above.
# Downstream tasks that need words refetch from DB.
topics_list = [
TitleSummary(
title=chunk.title,
summary=chunk.summary,
timestamp=chunk.timestamp,
duration=chunk.duration,
transcript=TranscriptType(words=chunk.words),
transcript=TranscriptType(words=[]),
)
for chunk in topic_chunks
]
Expand Down Expand Up @@ -842,9 +851,8 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult
ctx.log(f"extract_subjects: starting for transcript_id={input.transcript_id}")

topics_result = ctx.task_output(detect_topics)
topics = topics_result.topics

if not topics:
if not topics_result.topics:
ctx.log("extract_subjects: no topics, returning empty subjects")
return SubjectsResult(
subjects=[],
Expand All @@ -857,11 +865,13 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult
# sharing DB connections and LLM HTTP pools across forks
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
from reflector.llm import LLM # noqa: PLC0415
from reflector.processors.types import words_to_segments # noqa: PLC0415

async with fresh_db_connection():
transcript = await transcripts_controller.get_by_id(input.transcript_id)

# Build transcript text from topics (same logic as TranscriptFinalSummaryProcessor)
# Build transcript text from DB topics (words omitted from task output
# to reduce Hatchet payload size — refetch from DB where they were persisted)
speakermap = {}
if transcript and transcript.participants:
speakermap = {
Expand All @@ -871,8 +881,8 @@ async def extract_subjects(input: PipelineInput, ctx: Context) -> SubjectsResult
}

text_lines = []
for topic in topics:
for segment in topic.transcript.as_segments():
for db_topic in transcript.topics:
for segment in words_to_segments(db_topic.words):
name = speakermap.get(segment.speaker, f"Speaker {segment.speaker}")
text_lines.append(f"{name}: {segment.text}")

Expand Down
1 change: 0 additions & 1 deletion server/reflector/hatchet/workflows/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ class TopicChunkResult(BaseModel):
summary: str
timestamp: float
duration: float
words: list[Word]


class TopicsResult(BaseModel):
Expand Down
3 changes: 0 additions & 3 deletions server/reflector/hatchet/workflows/topic_chunk_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from reflector.hatchet.workflows.models import TopicChunkResult
from reflector.logger import logger
from reflector.processors.prompts import TOPIC_PROMPT
from reflector.processors.types import Word


class TopicChunkInput(BaseModel):
Expand All @@ -30,7 +29,6 @@ class TopicChunkInput(BaseModel):
chunk_text: str
timestamp: float
duration: float
words: list[Word]


hatchet = HatchetClientManager.get_client()
Expand Down Expand Up @@ -99,5 +97,4 @@ async def detect_chunk_topic(input: TopicChunkInput, ctx: Context) -> TopicChunk
summary=response.summary,
timestamp=input.timestamp,
duration=input.duration,
words=input.words,
)
185 changes: 185 additions & 0 deletions server/tests/test_hatchet_payload_thinning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""
Tests for Hatchet payload thinning optimizations.

Verifies that:
1. TopicChunkInput no longer carries words
2. TopicChunkResult no longer carries words
3. words_to_segments() matches Transcript.as_segments(is_multitrack=False) — behavioral equivalence
for the extract_subjects refactoring
4. TopicsResult can be constructed with empty transcript words
"""

from reflector.hatchet.workflows.models import TopicChunkResult
from reflector.hatchet.workflows.topic_chunk_processing import TopicChunkInput
from reflector.processors.types import Word


def _make_words(speaker: int = 0, start: float = 0.0) -> list[Word]:
return [
Word(text="Hello", start=start, end=start + 0.5, speaker=speaker),
Word(text=" world.", start=start + 0.5, end=start + 1.0, speaker=speaker),
]


class TestTopicChunkInputNoWords:
"""TopicChunkInput must not have a words field."""

def test_no_words_field(self):
assert "words" not in TopicChunkInput.model_fields

def test_construction_without_words(self):
inp = TopicChunkInput(
chunk_index=0, chunk_text="Hello world.", timestamp=0.0, duration=1.0
)
assert inp.chunk_index == 0
assert inp.chunk_text == "Hello world."

def test_rejects_words_kwarg(self):
"""Passing words= should raise a validation error (field doesn't exist)."""
import pydantic

try:
TopicChunkInput(
chunk_index=0,
chunk_text="text",
timestamp=0.0,
duration=1.0,
words=_make_words(),
)
# If pydantic is configured to ignore extra, this won't raise.
# Verify the field is still absent from the model.
assert "words" not in TopicChunkInput.model_fields
except pydantic.ValidationError:
pass # Expected


class TestTopicChunkResultNoWords:
"""TopicChunkResult must not have a words field."""

def test_no_words_field(self):
assert "words" not in TopicChunkResult.model_fields

def test_construction_without_words(self):
result = TopicChunkResult(
chunk_index=0,
title="Test",
summary="Summary",
timestamp=0.0,
duration=1.0,
)
assert result.title == "Test"
assert result.chunk_index == 0

def test_serialization_roundtrip(self):
"""Serialized TopicChunkResult has no words key."""
result = TopicChunkResult(
chunk_index=0,
title="Test",
summary="Summary",
timestamp=0.0,
duration=1.0,
)
data = result.model_dump()
assert "words" not in data
reconstructed = TopicChunkResult(**data)
assert reconstructed == result


class TestWordsToSegmentsBehavioralEquivalence:
"""words_to_segments() must produce same output as Transcript.as_segments(is_multitrack=False).

This ensures the extract_subjects refactoring (from task output topic.transcript.as_segments()
to words_to_segments(db_topic.words)) preserves identical behavior.
"""

def test_single_speaker(self):
from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import words_to_segments

words = _make_words(speaker=0)
direct = words_to_segments(words)
via_transcript = TranscriptType(words=words).as_segments(is_multitrack=False)

assert len(direct) == len(via_transcript)
for d, v in zip(direct, via_transcript):
assert d.text == v.text
assert d.speaker == v.speaker
assert d.start == v.start
assert d.end == v.end

def test_multiple_speakers(self):
from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import words_to_segments

words = [
Word(text="Hello", start=0.0, end=0.5, speaker=0),
Word(text=" world.", start=0.5, end=1.0, speaker=0),
Word(text=" How", start=1.0, end=1.5, speaker=1),
Word(text=" are", start=1.5, end=2.0, speaker=1),
Word(text=" you?", start=2.0, end=2.5, speaker=1),
]

direct = words_to_segments(words)
via_transcript = TranscriptType(words=words).as_segments(is_multitrack=False)

assert len(direct) == len(via_transcript)
for d, v in zip(direct, via_transcript):
assert d.text == v.text
assert d.speaker == v.speaker

def test_empty_words(self):
from reflector.processors.types import Transcript as TranscriptType
from reflector.processors.types import words_to_segments

assert words_to_segments([]) == []
assert TranscriptType(words=[]).as_segments(is_multitrack=False) == []


class TestTopicsResultEmptyWords:
"""TopicsResult can carry topics with empty transcript words."""

def test_construction_with_empty_words(self):
from reflector.hatchet.workflows.models import TopicsResult
from reflector.processors.types import TitleSummary
from reflector.processors.types import Transcript as TranscriptType

topics = [
TitleSummary(
title="Topic A",
summary="Summary A",
timestamp=0.0,
duration=5.0,
transcript=TranscriptType(words=[]),
),
TitleSummary(
title="Topic B",
summary="Summary B",
timestamp=5.0,
duration=5.0,
transcript=TranscriptType(words=[]),
),
]
result = TopicsResult(topics=topics)
assert len(result.topics) == 2
for t in result.topics:
assert t.transcript.words == []

def test_serialization_roundtrip(self):
from reflector.hatchet.workflows.models import TopicsResult
from reflector.processors.types import TitleSummary
from reflector.processors.types import Transcript as TranscriptType

topics = [
TitleSummary(
title="Topic",
summary="Summary",
timestamp=0.0,
duration=1.0,
transcript=TranscriptType(words=[]),
)
]
result = TopicsResult(topics=topics)
data = result.model_dump()
reconstructed = TopicsResult(**data)
assert len(reconstructed.topics) == 1
assert reconstructed.topics[0].transcript.words == []
Loading