diff --git a/docs/examples/enrich_with_tracking.md b/docs/examples/enrich_with_tracking.md new file mode 100644 index 000000000..307a28500 --- /dev/null +++ b/docs/examples/enrich_with_tracking.md @@ -0,0 +1,197 @@ +# Enriching Event Data with Tracking Data + +This guide demonstrates how to enrich an `EventDataset` with `PressureEvent`s derived from a `TrackingDataset`. We will use Sportec data from the IDSSE dataset to identify sequences of frames where a player is within a radius $r$ of the player in possession, create `PressureEvent`s for these sequences, and insert them into the `EventDataset`. + +## Loading Data + +First, we load both the event and tracking data for a specific match from the IDSSE dataset. When working with both datasets together, it is important to transform them to a common orientation and coordinate system to ensure the coordinates are comparable. We will use the `HOME_AWAY` orientation, where the home team plays from left to right in the first half. + +```python exec="true" source="above" session="enrich-tracking" +from kloppy import sportec +from kloppy.domain import Orientation + +match_id = "J03WPY" # Fortuna Düsseldorf vs. 1. FC Nürnberg + +# Load event data and transform to HOME_AWAY orientation +event_dataset = sportec.load_open_event_data(match_id=match_id) +event_dataset = event_dataset.transform( + to_orientation=Orientation.HOME_AWAY +) + +# Load tracking data and transform to the same orientation +tracking_dataset = sportec.load_open_tracking_data(match_id=match_id, sample_rate=1/10) +tracking_dataset = tracking_dataset.transform( + to_orientation=Orientation.HOME_AWAY +) + +print(f"Loaded {len(event_dataset)} events and {len(tracking_dataset)} tracking frames.") +``` + +## Identifying Pressure Sequences + +We define "pressure" as an opponent being within 3 meters of the player in possession. To find the player in possession in each frame, we look for the player from the ball-owning team who is closest to the ball. We use the `distance_between` method from the `PitchDimensions` class to calculate distances in meters. + +```python exec="true" source="above" session="enrich-tracking" +from kloppy.domain import BallState + +pitch_dimensions = tracking_dataset.metadata.pitch_dimensions + +RADIUS = 3.0 # meters +BALL_DIST_THRESHOLD = 1.0 # meters to consider a player has possession + +pressure_sequences = [] # List of (player, start_frame, end_frame) +active_pressures = {} # Mapping player -> start_frame + +for frame in tracking_dataset: + # Skip frames where the ball is dead or coordinates are missing + if not frame.ball_coordinates or frame.ball_state != BallState.ALIVE: + # Close all active pressures when the ball is dead + for player, start_frame in active_pressures.items(): + pressure_sequences.append((player, start_frame, frame)) + active_pressures = {} + continue + + # Find player in possession (closest to ball from the owning team) + possessor = None + min_ball_dist = float('inf') + for player, player_data in frame.players_data.items(): + if player.team != frame.ball_owning_team: + continue + dist = pitch_dimensions.distance_between(player_data.coordinates, frame.ball_coordinates) + if dist < min_ball_dist: + min_ball_dist = dist + possessor = player + + if not possessor or min_ball_dist > BALL_DIST_THRESHOLD: + # No clear possession, close all active pressures + for player, start_frame in active_pressures.items(): + pressure_sequences.append((player, start_frame, frame)) + active_pressures = {} + continue + + # Find opponents within RADIUS of the possessor + possessor_coords = frame.players_data[possessor].coordinates + for player, player_data in frame.players_data.items(): + if player.team == possessor.team: + continue + + dist = pitch_dimensions.distance_between(player_data.coordinates, possessor_coords) + if dist < RADIUS: + if player not in active_pressures: + active_pressures[player] = frame + else: + if player in active_pressures: + pressure_sequences.append((player, active_pressures.pop(player), frame)) + +# Close any remaining active pressures at the end of the dataset +for player, start_frame in active_pressures.items(): + pressure_sequences.append((player, start_frame, tracking_dataset[-1])) + +print(f"Identified {len(pressure_sequences)} potential pressure sequences.") +``` + +## Creating and Inserting PressureEvents + +Now we create `PressureEvent` objects from the identified sequences and insert them into our `event_dataset`. We use the `EventFactory` to simplify event creation and the `insert` method to ensure they are placed in the correct chronological order. + +```python exec="true" source="above" session="enrich-tracking" +from kloppy.domain import EventFactory, PressureEvent + +factory = EventFactory() + +inserted_count = 0 +for player, start_frame, end_frame in pressure_sequences: + # Only consider sequences that lasted at least 0.5 seconds to filter out noise + duration = (end_frame.timestamp - start_frame.timestamp).total_seconds() + if duration < 0.5: + continue + + pressure_event = factory.build_pressure_event( + event_id=f"pressure-{player.player_id}-{start_frame.frame_id}", + period=start_frame.period, + timestamp=start_frame.timestamp, + end_timestamp=end_frame.time, # PressureEvent requires end_timestamp as a Time object + team=player.team, + player=player, + coordinates=start_frame.players_data[player].coordinates, + ball_owning_team=start_frame.ball_owning_team, + ball_state=start_frame.ball_state, + result=None, + qualifiers=[], + raw_event=None + ) + + # Insert the event into the dataset based on its timestamp + event_dataset.insert(pressure_event, timestamp=pressure_event.timestamp) + inserted_count += 1 + +print(f"Inserted {inserted_count} PressureEvents into the EventDataset.") +``` + +## Contextualizing Events: Pressured vs. Unpressured Passes + +We can now use these `PressureEvent`s to contextualize existing events. For example, we can flag every `PassEvent` that occurred while an opponent was applying pressure to the passer. Instead of searching the entire dataset for each pass, we can efficiently use the `prev` and `next` records. + +```python exec="true" source="above" session="enrich-tracking" +from kloppy.domain import UnderPressureQualifier, PassEvent, PassResult +from datetime import timedelta + +# Flag passes as being under pressure +for event in event_dataset: + if not isinstance(event, PassEvent): + continue + + is_pressured = False + + # Check predecessors: since PressureEvents are inserted by start time, + # the containing pressure must be at or before this event in the sequence. + other = event.prev(lambda x: isinstance(x, PressureEvent)) + while other and event.time - other.time < timedelta(seconds=20): + if other.team != event.team and other.time <= event.time <= other.end_timestamp: + is_pressured = True + break + other = other.prev(lambda x: isinstance(x, PressureEvent)) + + # Check successors: only needed if they have the exact same timestamp + if not is_pressured: + other = event.next(lambda x: isinstance(x, PressureEvent)) + while other and other.time == event.time: + if other.team != event.team and other.time <= event.time <= other.end_timestamp: + is_pressured = True + break + other = other.next(lambda x: isinstance(x, PressureEvent)) + + if is_pressured: + event.qualifiers.append(UnderPressureQualifier(value=True)) + +# Analysis: Completion Rate +pressured_passes = [e for e in event_dataset if isinstance(e, PassEvent) and e.get_qualifier_value(UnderPressureQualifier)] +unpressured_passes = [e for e in event_dataset if isinstance(e, PassEvent) and not e.get_qualifier_value(UnderPressureQualifier)] + +def get_completion_rate(passes): + if not passes: + return 0 + complete = sum(1 for p in passes if p.result == PassResult.COMPLETE) + return (complete / len(passes)) * 100 + +print(f"Pressured Pass Completion: {get_completion_rate(pressured_passes):.1f}% ({len(pressured_passes)} passes)") +print(f"Unpressured Pass Completion: {get_completion_rate(unpressured_passes):.1f}% ({len(unpressured_passes)} passes)") +``` + +## Analysis: Who Pressed Most? + +Finally, we can analyze the enriched dataset to find out which player performed the most pressure actions. + +```python exec="true" source="above" session="enrich-tracking" +from collections import Counter + +pressure_counts = Counter( + event.player.full_name + for event in event_dataset + if isinstance(event, PressureEvent) +) + +print("Top 10 players by number of pressure actions:") +for name, count in pressure_counts.most_common(10): + print(f"{name}: {count}") +``` diff --git a/docs/user-guide/transformations/coordinates/index.md b/docs/user-guide/transformations/coordinates/index.md index 6d7ea08b2..878f57795 100644 --- a/docs/user-guide/transformations/coordinates/index.md +++ b/docs/user-guide/transformations/coordinates/index.md @@ -1,4 +1,4 @@ -# Dataset transformations +# Changing coordinate systems Kloppy's [`.transform()`][kloppy.domain.Dataset.transform] method allows you to adapt the [spatial representation](../../concepts/coordinates/index.md) of a dataset. This can be useful if you need to align data from different providers or to run analyses that assume a standard pitch size or attacking direction. diff --git a/docs/user-guide/transformations/insert/index.md b/docs/user-guide/transformations/insert/index.md new file mode 100644 index 000000000..20a7e0998 --- /dev/null +++ b/docs/user-guide/transformations/insert/index.md @@ -0,0 +1,108 @@ +# Inserting events + +Sometimes event data is incomplete, and you need to manually +inject events into a dataset. The [`.insert()`][kloppy.domain.EventDataset.insert] +method allows you to add [`Event`][kloppy.domain.Event] objects to an existing +[`EventDataset`][kloppy.domain.EventDataset]. + +Common use cases include: + +- Deduce and insert synthetic events that providers don't annotate (e.g., generating "Carry" events). +- Insert events that are provided in metadata rather than the event stream (e.g., substitutions with approximate timestamps). +- Insert events derived from a tracking dataset (e.g., adding "Pressing" events). + +The method automatically handles the re-linking of the dataset (updating +`prev_record` and `next_record` references) to ensure the integrity of the +event stream. + +## Basic setup + +To insert an event, you first need to create the +[`Event`][kloppy.domain.Event] object. You can use +the [`EventFactory`][kloppy.domain.EventFactory] to build specific event types. + +```python +from datetime import timedelta +from kloppy.domain import EventFactory, CarryResult + +# Create a new event +new_event = EventFactory().build_carry( + event_id="added-carry-1", + timestamp=timedelta(seconds=700), + result=CarryResult.COMPLETE, + period=dataset.metadata.periods[0], + ball_owning_team=dataset.metadata.teams[0], + team=dataset.metadata.teams[0], + player=dataset.metadata.teams[0].players[0], + coordinates=(0.2, 0.3), + end_coordinates=(0.22, 0.33) +) +``` + +## Insertion methods + +There are four ways to determine where the new event is placed in the dataset. + +### By position (index) + +If you know the exact index where the event should be located, you can use the +`position` argument. This works exactly like a standard Python list insertion. + +```python +# Insert the event at index 3 +dataset.insert(new_event, position=3) +``` + +### By event ID + +If you do not know the index but know the context (e.g., the event should +happen immediately before or after a specific action), you can use the +`before_event_id` or `after_event_id` arguments. + +```python +# Insert immediately before a specific event +dataset.insert(new_event, before_event_id="event-id-100") + +# Insert immediately after a specific event +dataset.insert(new_event, after_event_id="event-id-305") +``` + +### By timestamp + +To insert the event chronologically, provide the `timestamp` argument. The +dataset will be searched to find the correct location based on the time +provided. + +```python +# Insert based on the timestamp defined in the new_event +dataset.insert(new_event, timestamp=new_event.timestamp) +``` + +### Using a scoring function + +For complex insertion logic—such as ''insert after the closest event belonging +to Team A''—you can provide a `scoring_function`. + +The function iterates over events in the dataset. It must accept an `event` +and the `dataset` as arguments and return a number: + +- **Positive Score:** Insert **after** the event with the highest score. +- **Negative Score:** Insert **before** the event with the highest absolute score. +- **Zero:** No match. + +**Example: Insert after the closest timestamp** + +```python +def insert_after_closest_match(event, dataset): + # Filter logic: only check events for the same team and period + if event.ball_owning_team != dataset.metadata.teams[0]: + return 0 + if event.period != new_event.period: + return 0 + + # Scoring logic: The smaller the time difference, the higher the score + time_diff = abs(event.timestamp.total_seconds() - new_event.timestamp.total_seconds()) + return 1 / time_diff if time_diff != 0 else 0 + +dataset.insert(new_event, scoring_function=insert_after_closest_match) +``` diff --git a/kloppy/domain/models/common.py b/kloppy/domain/models/common.py index 3449a3559..f2dd08d68 100644 --- a/kloppy/domain/models/common.py +++ b/kloppy/domain/models/common.py @@ -1787,8 +1787,29 @@ def filter(self, filter_: Union[str, Callable[[T], bool]]): new_cls_name = f"Filtered{current_class.__name__}" # We inherit from FilteredDataset first, then the original class + disabled_methods_map = { + DatasetType.EVENT: ["insert"], + } + methods_to_disable = disabled_methods_map.get( + self.dataset_type, [] + ) + + attrs = {} + for method_name in methods_to_disable: + + def make_disabled_method(name): + def disabled_method(self, *args, **kwargs): + raise NotImplementedError( + f"Method '{name}' is not supported on filtered datasets." + ) + + disabled_method.__name__ = name + return disabled_method + + attrs[method_name] = make_disabled_method(method_name) + _FILTERED_CLASS_CACHE[current_class] = type( - new_cls_name, (FilteredDataset, current_class), {} + new_cls_name, (FilteredDataset, current_class), attrs ) target_class = _FILTERED_CLASS_CACHE[current_class] diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index 5c597c425..bf2fb49ea 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -1469,6 +1469,144 @@ def _update_formations_and_positions(self): else: event.team.formations.set(event.time, event.formation_type) + def insert( + self, + event: Event, + position: Optional[int] = None, + before_event_id: Optional[str] = None, + after_event_id: Optional[str] = None, + timestamp: Optional[timedelta] = None, + scoring_function: Optional[ + Callable[[Event, "EventDataset"], float] + ] = None, + ): + """Inserts an event into the dataset at the appropriate position. + + Args: + event (Event): The event to be inserted into the dataset. + position (Optional[int]): The exact index where the event should be inserted. + If provided, overrides all other positioning parameters. Defaults to None. + before_event_id (Optional[str]): The ID of the event before which the new event + should be inserted. Ignored if `position` is provided. Defaults to None. + after_event_id (Optional[str]): The ID of the event after which the new event + should be inserted. Ignored if `position` or `before_event_id` is provided. + Defaults to None. + timestamp (Optional[timedelta]): The timestamp of the event, used to determine + its position based on chronological order if no other positional parameters + are specified. Defaults to None. + scoring_function (Optional[Callable[[Event, EventDataset], float]]): A custom + function that takes an event from the dataset and the dataset itself as + arguments and returns a score. Negative scores mean insertion should happen + **before** the highest-scoring event, while positive scores mean insertion + should happen **after** the highest-scoring event. If all scores are zero, + the insertion will fail with a ValueError. + + Raises: + ValueError: If the insertion position cannot be determined or is invalid. + + Examples: + Insert an event at a specific index: + >>> dataset.insert(new_event, position=10) + + Insert an event based on its timestamp: + >>> dataset.insert(new_event, timestamp=timedelta(seconds=120)) + + Insert an event relative to another event: + >>> dataset.insert(new_event, after_event_id="event-789") + + Insert an event using a custom scoring function: + >>> def score_fn(existing_event, ds): + ... # Score based on proximity to a specific timestamp + ... return 1.0 / (1.0 + abs(existing_event.timestamp - target_ts).total_seconds()) + >>> dataset.insert(new_event, scoring_function=score_fn) + + Notes: + - If multiple parameters are provided to specify the position, the precedence is: + 1. `position` + 2. `before_event_id` + 3. `after_event_id` + 4. `timestamp` + 5. `scoring_function` + - If none of the above parameters are specified, the method raises a `ValueError`. + """ + if position is not None: + # If position is provided, use it directly + insert_position = position + + elif before_event_id is not None: + # Find the event with the matching `before_event_id` and insert before it + try: + insert_position = next( + ( + i + for i, e in enumerate(self.records) + if e.event_id == before_event_id + ), + ) + except StopIteration: + raise ValueError(f"No event found with ID {before_event_id}.") + + elif after_event_id is not None: + # Find the event with the matching `after_event_id` and insert after it + try: + insert_position = next( + ( + i + 1 + for i, e in enumerate(self.records) + if e.event_id == after_event_id + ), + ) + except StopIteration: + raise ValueError(f"No event found with ID {after_event_id}.") + + elif timestamp is not None: + # If no position or event IDs are specified, insert based on timestamp + insert_position = next( + ( + i + for i, e in enumerate(self.records) + if e.timestamp > timestamp + ), + len(self.records), + ) + + elif scoring_function is not None: + # Evaluate all possible positions using the constraint function + scores = [ + (i, scoring_function(event, self)) + for i, event in enumerate(self.records) + ] + # Select the best position with the highest score + best_index, best_score = max( + scores, key=lambda x: abs(x[1]), default=(0, -1) + ) + if best_score == 0: + raise ValueError( + "No valid insertion position found based on the provided scoring function." + ) + + # Insert after if score is positive, before if score is negative + insert_position = best_index + 1 if best_score > 0 else best_index + + else: + raise ValueError( + "Unable to determine insertion position for the event." + ) + + # Insert the event at the determined position + self.records.insert(insert_position, event) + + # Update the event's references + self.records[insert_position].dataset = self + for i in range( + max(0, insert_position - 1), + min(insert_position + 2, len(self.records)), + ): + self.records[i].prev_record = self.records[i - 1] if i > 0 else None + self.records[i].next_record = ( + self.records[i + 1] if i + 1 < len(self.records) else None + ) + @property def events(self): return self.records diff --git a/kloppy/domain/models/time.py b/kloppy/domain/models/time.py index 6609f1dee..d7650b50c 100644 --- a/kloppy/domain/models/time.py +++ b/kloppy/domain/models/time.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta +from functools import total_ordering from typing import ( Generic, Literal, @@ -14,6 +15,7 @@ from kloppy.exceptions import KloppyError +@total_ordering @dataclass class Period: """ @@ -82,9 +84,6 @@ def __eq__(self, other): def __lt__(self, other: "Period"): return self.id < other.id - def __ge__(self, other): - return self == other or other < self - def __hash__(self): return id(self.id) @@ -104,6 +103,7 @@ def set_refs( self.next_period = next_ +@total_ordering @dataclass class Time: """ diff --git a/kloppy/domain/services/mutators/__init__.py b/kloppy/domain/services/mutators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kloppy/domain/services/mutators/base.py b/kloppy/domain/services/mutators/base.py new file mode 100644 index 000000000..836ccecdf --- /dev/null +++ b/kloppy/domain/services/mutators/base.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +from kloppy.domain import Dataset + +D = TypeVar("D", bound=Dataset) + + +class DatasetMutator(ABC, Generic[D]): + def __init__(self, *, inplace: bool = False): + self.inplace = inplace + + def mutate(self, dataset: D) -> D: + if self.inplace: + return self._mutate_inplace(dataset) + else: + return self._mutate_inplace(self._copy_dataset(dataset)) + + @abstractmethod + def _mutate_inplace(self, dataset: D) -> D: + raise NotImplementedError + + def _copy_dataset(self, dataset: D) -> D: + from dataclasses import replace + + return replace(dataset, records=list(dataset.records)) diff --git a/kloppy/domain/services/mutators/helpers/insert.py b/kloppy/domain/services/mutators/helpers/insert.py new file mode 100644 index 000000000..3c0690239 --- /dev/null +++ b/kloppy/domain/services/mutators/helpers/insert.py @@ -0,0 +1,81 @@ +from datetime import timedelta +from typing import Callable, Optional, TypeVar + +from kloppy.domain import DataRecord, Dataset + +D = TypeVar("D", bound=Dataset) # any Dataset subclass +R = TypeVar("R", bound=DataRecord) # record type within the dataset + + +def insert_record( + dataset: D, + record: R, + *, + position: Optional[int] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, + timestamp: Optional[timedelta] = None, + scoring_function: Optional[Callable[[R, D], float]] = None, +) -> int: + """ + Generic insertion function for any Dataset subclass. + + Returns the index where the record was inserted. + """ + records = dataset.records # type: ignore + + # Determine insert position + if position is not None: + insert_position = position + elif before_id is not None: + insert_position = next( + i + for i, r in enumerate(records) + if getattr(r, "record_id", getattr(r, "event_id", None)) + == before_id + ) + elif after_id is not None: + insert_position = next( + i + 1 + for i, r in enumerate(records) + if getattr(r, "record_id", getattr(r, "event_id", None)) == after_id + ) + elif timestamp is not None: + insert_position = next( + ( + i + for i, r in enumerate(records) + if getattr(r, "timestamp", None) > timestamp + ), + len(records), + ) + elif scoring_function is not None: + scores = [ + (i, scoring_function(record, dataset)) + for i, r in enumerate(records) + ] + best_index, best_score = max( + scores, key=lambda x: abs(x[1]), default=(0, 0) + ) + if best_score == 0: + raise ValueError("No valid insertion position found.") + insert_position = best_index + 1 if best_score > 0 else best_index + else: + raise ValueError("Cannot determine insertion position") + + # Insert record + records.insert(insert_position, record) + record.dataset = dataset # type: ignore + + # Update references if they exist (prev/next) + for i in range( + max(0, insert_position - 1), min(insert_position + 2, len(records)) + ): + if hasattr(records[i], "prev_record"): + records[i].prev_record = records[i - 1] if i > 0 else None + if hasattr(records[i], "next_record"): + records[i].next_record = ( + records[i + 1] if i + 1 < len(records) else None + ) + + return insert_position diff --git a/kloppy/domain/services/mutators/insert_carries.py b/kloppy/domain/services/mutators/insert_carries.py new file mode 100644 index 000000000..8bcdaa9d7 --- /dev/null +++ b/kloppy/domain/services/mutators/insert_carries.py @@ -0,0 +1,159 @@ +from datetime import timedelta +from typing import Optional + +from kloppy.domain import ( + BodyPart, + CarryResult, + ClearanceEvent, + DuelEvent, + EventDataset, + EventFactory, + GenericEvent, + GoalkeeperEvent, + InterceptionEvent, + MiscontrolEvent, + PassEvent, + PressureEvent, + RecoveryEvent, + SetPieceQualifier, + ShotEvent, + TakeOnEvent, + Unit, +) + +from .base import DatasetMutator +from .helpers.insert import insert_record + +VALID_EVENT = ( + PassEvent, + ShotEvent, + TakeOnEvent, + ClearanceEvent, + InterceptionEvent, + DuelEvent, + RecoveryEvent, + MiscontrolEvent, + GoalkeeperEvent, +) + + +class SyntheticCarryMutator(DatasetMutator[EventDataset]): + def __init__( + self, + *, + event_factory: Optional[EventFactory] = None, + min_length_meters: float = 3, + max_length_meters: float = 60, + max_duration: timedelta = timedelta(seconds=10), + inplace: bool = False, + ): + super().__init__(inplace=inplace) + self.event_factory = event_factory or EventFactory() + self.min_length_meters = min_length_meters + self.max_length_meters = max_length_meters + self.max_duration = max_duration + + def _mutate_inplace(self, dataset: EventDataset) -> EventDataset: + pitch = dataset.metadata.pitch_dimensions + events = dataset.records + + i = 0 + while i < len(events) - 1: + event = events[i] + + if not isinstance(event, VALID_EVENT): + i += 1 + continue + + j = i + 1 + while j < len(events) and isinstance( + events[j], (GenericEvent, PressureEvent) + ): + j += 1 + + if j >= len(events): + break + + next_event = events[j] + + if not isinstance(next_event, VALID_EVENT): + i += 1 + continue + if event.team.team_id != next_event.team.team_id: + i += 1 + continue + if next_event.get_qualifier_value(SetPieceQualifier) is not None: + i += 1 + continue + + if hasattr(event, "end_coordinates"): + last_coord = event.end_coordinates + elif hasattr(event, "receiver_coordinates"): + last_coord = event.receiver_coordinates + else: + last_coord = event.coordinates + + new_coord = next_event.coordinates + distance = pitch.distance_between( + new_coord, last_coord, Unit.METERS + ) + + if not ( + self.min_length_meters <= distance <= self.max_length_meters + ): + i += 1 + continue + + dt = next_event.timestamp - event.timestamp + if ( + dt > self.max_duration + or event.period.id != next_event.period.id + ): + i += 1 + continue + + if isinstance(next_event, ShotEvent) and any( + q.name == "body_part" + and q.value in (BodyPart.HEAD, BodyPart.HEAD_OTHER) + for q in next_event.qualifiers or [] + ): + i += 1 + continue + + if getattr(event, "end_timestamp", None): + last_timestamp = event.end_timestamp + elif getattr(event, "receive_timestamp", None): + last_timestamp = event.receive_timestamp + else: + last_timestamp = ( + event.timestamp + + (next_event.timestamp - event.timestamp) / 10 + ) + + generic_args = { + "event_id": f"carry-{event.event_id}", + "coordinates": last_coord, + "team": next_event.team, + "player": next_event.player, + "ball_owning_team": next_event.ball_owning_team, + "ball_state": event.ball_state, + "period": next_event.period, + "timestamp": last_timestamp, + "raw_event": None, + } + + carry_args = { + "result": CarryResult.COMPLETE, + "qualifiers": None, + "end_coordinates": new_coord, + "end_timestamp": next_event.timestamp, + } + + carry = self.event_factory.build_carry(**carry_args, **generic_args) + + # use generic insert_record helper + insert_record(dataset, carry, position=j) + + i = j + 1 + + return dataset diff --git a/kloppy/domain/services/mutators/insert_event.py b/kloppy/domain/services/mutators/insert_event.py new file mode 100644 index 000000000..4a0489627 --- /dev/null +++ b/kloppy/domain/services/mutators/insert_event.py @@ -0,0 +1,40 @@ +from datetime import timedelta +from typing import Callable + +from kloppy.domain import Event, EventDataset + +from .base import DatasetMutator +from .helpers.insert import insert_record + + +class EventDatasetInsertMutator(DatasetMutator[EventDataset]): + def __init__( + self, + event: Event, + *, + position: int | None = None, + before_event_id: str | None = None, + after_event_id: str | None = None, + timestamp: timedelta | None = None, + scoring_function: Callable[[Event, EventDataset], float] | None = None, + inplace: bool = False, + ): + super().__init__(inplace=inplace) + self.event = event + self.position = position + self.before_event_id = before_event_id + self.after_event_id = after_event_id + self.timestamp = timestamp + self.scoring_function = scoring_function + + def _mutate_inplace(self, dataset: EventDataset) -> EventDataset: + insert_record( + dataset, + self.event, + position=self.position, + before_event_id=self.before_event_id, + after_event_id=self.after_event_id, + timestamp=self.timestamp, + scoring_function=self.scoring_function, + ) + return dataset diff --git a/kloppy/infra/serializers/event/sportec/deserializer.py b/kloppy/infra/serializers/event/sportec/deserializer.py index ca84895bd..5b4f1b693 100644 --- a/kloppy/infra/serializers/event/sportec/deserializer.py +++ b/kloppy/infra/serializers/event/sportec/deserializer.py @@ -686,7 +686,7 @@ def _deserialize(self, inputs: SportecEventDataInputs) -> EventDataset: score=sportec_metadata.score, frame_rate=None, orientation=orientation, - flags=~(DatasetFlag.BALL_STATE | DatasetFlag.BALL_OWNING_TEAM), + flags=DatasetFlag(0), provider=Provider.SPORTEC, coordinate_system=transformer.get_to_coordinate_system(), date=date, diff --git a/kloppy/infra/serializers/event/wyscout/deserializer_v2.py b/kloppy/infra/serializers/event/wyscout/deserializer_v2.py index 2542fa589..33e6f480c 100644 --- a/kloppy/infra/serializers/event/wyscout/deserializer_v2.py +++ b/kloppy/infra/serializers/event/wyscout/deserializer_v2.py @@ -10,6 +10,7 @@ CardQualifier, CardType, CounterAttackQualifier, + DatasetFlag, DuelQualifier, DuelResult, DuelType, @@ -719,7 +720,7 @@ def _deserialize(self, inputs: WyscoutInputs) -> EventDataset: score=None, frame_rate=None, orientation=Orientation.ACTION_EXECUTING_TEAM, - flags=None, + flags=DatasetFlag(0), provider=Provider.WYSCOUT, coordinate_system=transformer.get_to_coordinate_system(), game_id=game_id, diff --git a/kloppy/infra/serializers/event/wyscout/deserializer_v3.py b/kloppy/infra/serializers/event/wyscout/deserializer_v3.py index 546747278..3009a42ef 100644 --- a/kloppy/infra/serializers/event/wyscout/deserializer_v3.py +++ b/kloppy/infra/serializers/event/wyscout/deserializer_v3.py @@ -12,6 +12,7 @@ CardType, CarryResult, CounterAttackQualifier, + DatasetFlag, DuelQualifier, DuelResult, DuelType, @@ -1037,7 +1038,7 @@ def _deserialize(self, inputs: WyscoutInputs) -> EventDataset: score=None, frame_rate=None, orientation=Orientation.ACTION_EXECUTING_TEAM, - flags=None, + flags=DatasetFlag(0), provider=Provider.WYSCOUT, coordinate_system=transformer.get_to_coordinate_system(), date=date, diff --git a/kloppy/infra/serializers/tracking/hawkeye/deserializer.py b/kloppy/infra/serializers/tracking/hawkeye/deserializer.py index 197d6fead..cac668323 100644 --- a/kloppy/infra/serializers/tracking/hawkeye/deserializer.py +++ b/kloppy/infra/serializers/tracking/hawkeye/deserializer.py @@ -17,6 +17,7 @@ from kloppy.domain import ( AttackingDirection, + DatasetFlag, Frame, Ground, Metadata, @@ -435,7 +436,7 @@ def deserialize(self, inputs: HawkEyeInputs) -> TrackingDataset: frame_rate=frame_rate, orientation=orientation, provider=Provider.HAWKEYE, - flags=None, + flags=DatasetFlag(0), coordinate_system=transformer.get_to_coordinate_system(), game_id=self._game_id, date=self._game_date, diff --git a/kloppy/infra/serializers/tracking/signality.py b/kloppy/infra/serializers/tracking/signality.py index 4aa602975..1d334ac07 100644 --- a/kloppy/infra/serializers/tracking/signality.py +++ b/kloppy/infra/serializers/tracking/signality.py @@ -8,6 +8,7 @@ from kloppy.domain import ( AttackingDirection, BallState, + DatasetFlag, DatasetTransformer, Ground, Metadata, @@ -264,7 +265,7 @@ def deserialize(self, inputs: SignalityInputs) -> TrackingDataset: orientation=orientation, pitch_dimensions=transformer.get_to_coordinate_system().pitch_dimensions, coordinate_system=transformer.get_to_coordinate_system(), - flags=None, + flags=DatasetFlag(0), provider=Provider.SIGNALITY, ) diff --git a/kloppy/tests/test_event.py b/kloppy/tests/test_event.py index 2a492ea7a..b2c54b982 100644 --- a/kloppy/tests/test_event.py +++ b/kloppy/tests/test_event.py @@ -1,7 +1,15 @@ +from datetime import timedelta + import pytest from kloppy import statsbomb -from kloppy.domain import EventDataset, FilteredDataset +from kloppy.domain import ( + CarryResult, + Event, + EventDataset, + EventFactory, + FilteredDataset, +) class TestEvent: @@ -98,3 +106,118 @@ def test_find_all(self, dataset: EventDataset): assert goals[0].next("shot.goal") == goals[1] assert goals[0].next("shot.goal") == goals[2].prev("shot.goal") assert goals[2].next("shot.goal") is None + + def test_insert(self, dataset: EventDataset): + new_event = EventFactory().build_carry( + qualifiers=None, + timestamp=timedelta(seconds=700), + end_timestamp=timedelta(seconds=701), + result=CarryResult.COMPLETE, + period=dataset.metadata.periods[0], + ball_owning_team=dataset.metadata.teams[0], + ball_state="alive", + event_id="test-insert-1234", + team=dataset.metadata.teams[0], + player=dataset.metadata.teams[0].players[0], + coordinates=(0.2, 0.3), + end_coordinates=(0.22, 0.33), + raw_event=None, + ) + + # insert by position + dataset.insert(new_event, position=3) + assert dataset.events[3].event_id == "test-insert-1234" + del dataset.events[3] # Remove by index to restore the dataset + + # insert by before_event_id + dataset.insert(new_event, before_event_id=dataset.events[100].event_id) + assert dataset.events[100].event_id == "test-insert-1234" + del dataset.events[100] # Remove by index to restore the dataset + + # insert by after_event_id + dataset.insert(new_event, after_event_id=dataset.events[305].event_id) + assert dataset.events[306].event_id == "test-insert-1234" + del dataset.events[306] # Remove by index to restore the dataset + + # insert by timestamp + dataset.insert(new_event, timestamp=new_event.timestamp) + assert dataset.events[609].event_id == "test-insert-1234" + del dataset.events[609] # Remove by index to restore the dataset + + # insert using scoring function + def insert_after_scoring_function(event: Event, dataset: EventDataset): + if event.ball_owning_team != dataset.metadata.teams[0]: + return 0 + if event.period != new_event.period: + return 0 + return 1 / abs( + event.timestamp.total_seconds() + - new_event.timestamp.total_seconds() + ) + + dataset.insert( + new_event, scoring_function=insert_after_scoring_function + ) + assert dataset.events[608].event_id == "test-insert-1234" + del dataset.events[608] # Remove by index to restore the dataset + + # insert using scoring function + def insert_before_scoring_function(event: Event, dataset: EventDataset): + if event.ball_owning_team != dataset.metadata.teams[0]: + return 0 + if event.period != new_event.period: + return 0 + return -1 / abs( + event.timestamp.total_seconds() + - new_event.timestamp.total_seconds() + ) + + dataset.insert( + new_event, scoring_function=insert_before_scoring_function + ) + assert dataset.events[607].event_id == "test-insert-1234" + del dataset.events[607] # Remove by index to restore the dataset + + def no_match_scoring_function(event: Event, dataset: EventDataset): + return 0 + + with pytest.raises(ValueError): + dataset.insert( + new_event, scoring_function=no_match_scoring_function + ) + + # update references + dataset.insert(new_event, position=1) + assert dataset.events[0].next_record.event_id == "test-insert-1234" + assert ( + dataset.events[1].prev_record.event_id == dataset.events[0].event_id + ) + assert dataset.events[1].event_id == "test-insert-1234" + assert ( + dataset.events[1].next_record.event_id == dataset.events[2].event_id + ) + assert dataset.events[2].prev_record.event_id == "test-insert-1234" + + dataset.insert(new_event, position=0) + assert dataset.events[0].prev_record is None + assert dataset.events[0].event_id == "test-insert-1234" + assert ( + dataset.events[0].next_record.event_id == dataset.events[1].event_id + ) + assert dataset.events[1].prev_record.event_id == "test-insert-1234" + + dataset.insert(new_event, position=len(dataset)) + assert dataset.events[-2].next_record.event_id == "test-insert-1234" + assert ( + dataset.events[-1].prev_record.event_id + == dataset.events[-2].event_id + ) + assert dataset.events[-1].event_id == "test-insert-1234" + assert dataset.events[-1].next_record is None + + def test_filtered_insert(self, dataset: EventDataset): + goals_dataset = dataset.filter("shot.goal") + assert hasattr(goals_dataset, "insert") + + with pytest.raises(NotImplementedError): + goals_dataset.insert(dataset.records[0], position=0) diff --git a/mkdocs.yml b/mkdocs.yml index d8b62bf64..7b3370bfa 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -60,7 +60,9 @@ nav: # - Selections: user-guide/querying-data/selections/index.md # - Pattern matching: user-guide/querying-data/pattern-matching/index.md - Transforming data: - - user-guide/transformations/coordinates/index.md + - user-guide/transformations/index.md + - Changing coordinates: user-guide/transformations/coordinates/index.md + - Inserting events: user-guide/transformations/insert/index.md # - Aggregating data: # - user-guide/aggregating-data/index.md # - Enriching data: @@ -77,6 +79,7 @@ nav: - examples/aggregations.ipynb - examples/event_factory.ipynb - examples/adapter.ipynb + - Enriching with tracking data: examples/enrich_with_tracking.md - Reference: - reference/index.md - Data Loaders: