diff --git a/kloppy/domain/services/state_builder/builders/sequence.py b/kloppy/domain/services/state_builder/builders/sequence.py index e851f4309..ecae12719 100644 --- a/kloppy/domain/services/state_builder/builders/sequence.py +++ b/kloppy/domain/services/state_builder/builders/sequence.py @@ -3,6 +3,7 @@ from kloppy.domain import ( Event, Team, + Time, EventDataset, PassEvent, CarryEvent, @@ -22,30 +23,55 @@ class Sequence: sequence_id: int team: Team + start: Time + end: Time class SequenceStateBuilder(StateBuilder): + # current_sequence is mutable by design so every event in the sequence can be updated with the correct times + current_sequence: Sequence + def initial_state(self, dataset: EventDataset) -> Sequence: - for event in dataset.events: - if isinstance(event, OPEN_SEQUENCE): - return Sequence(sequence_id=0, team=event.team) - return Sequence(sequence_id=0, team=None) + self.current_sequence = Sequence( + sequence_id=0, team=None, start=None, end=None + ) + return self.current_sequence def reduce_before(self, state: Sequence, event: Event) -> Sequence: + # Set the start time of the sequence if it is not set yet + if self.current_sequence.start is None: + self.current_sequence.start = event.time + if isinstance(event, OPEN_SEQUENCE) and ( state.team != event.team or event.get_qualifier_value(SetPieceQualifier) ): - state = replace( - state, sequence_id=state.sequence_id + 1, team=event.team + # Start a new sequence + self.current_sequence = replace( + state, + sequence_id=state.sequence_id + 1, + team=event.team, + start=event.time, + end=None, ) + state = self.current_sequence return state def reduce_after(self, state: Sequence, event: Event) -> Sequence: + # Always update the end time of the sequence + # This ensures sequences without CLOSE_SEQUENCE events still have the correct time + self.current_sequence.end = event.time + if isinstance(event, CLOSE_SEQUENCE): - state = replace( - state, sequence_id=state.sequence_id + 1, team=None + # Start a new sequence + self.current_sequence = replace( + state, + sequence_id=state.sequence_id + 1, + team=None, + start=None, + end=None, ) + state = self.current_sequence return state diff --git a/kloppy/tests/test_state_builder.py b/kloppy/tests/test_state_builder.py index d27f6190a..eab868089 100644 --- a/kloppy/tests/test_state_builder.py +++ b/kloppy/tests/test_state_builder.py @@ -50,6 +50,12 @@ def test_sequence_state_builder(self, base_dir): events = list(events) events_per_sequence[sequence_id] = len(events) + # Check if the sequence start and end times match the first and last event + sequence_start = events[0].state["sequence"].start + sequence_end = events[-1].state["sequence"].end + assert events[0].time == sequence_start + assert events[-1].time == sequence_end + assert events_per_sequence[0] == 4 assert events_per_sequence[51] == 7