From b22aebfbd22b2982290ba5ee2a68d33885146285 Mon Sep 17 00:00:00 2001 From: Andrey Isakov Date: Fri, 17 Apr 2026 00:34:15 +0300 Subject: [PATCH] feat: Protocol for DataLoader Added --- pysatl_cpd/core/data_providers/__init__.py | 3 +- pysatl_cpd/core/data_providers/dataset.py | 129 +++++++++++++-------- 2 files changed, 82 insertions(+), 50 deletions(-) diff --git a/pysatl_cpd/core/data_providers/__init__.py b/pysatl_cpd/core/data_providers/__init__.py index 85214ba..2d4cd2e 100644 --- a/pysatl_cpd/core/data_providers/__init__.py +++ b/pysatl_cpd/core/data_providers/__init__.py @@ -12,7 +12,7 @@ __license__ = "SPDX-License-Identifier: MIT" -from pysatl_cpd.core.data_providers.dataset import Annotation, Dataset, PandasLabeledDataProvider, SegmentInfo +from pysatl_cpd.core.data_providers.dataset import Annotation, Dataset, PandasLabeledDataProvider, RealDatasetLoader, SegmentInfo from pysatl_cpd.core.data_providers.idata_provider import DataProvider from pysatl_cpd.core.data_providers.numpy_data_provider import ( NDArrayMultivariateProvider, @@ -25,6 +25,7 @@ "SegmentInfo", "PandasLabeledDataProvider", "Dataset", + "RealDatasetLoader", "NDArrayMultivariateProvider", "NDArrayUnivariateProvider", ] diff --git a/pysatl_cpd/core/data_providers/dataset.py b/pysatl_cpd/core/data_providers/dataset.py index d2a37ea..0d56330 100644 --- a/pysatl_cpd/core/data_providers/dataset.py +++ b/pysatl_cpd/core/data_providers/dataset.py @@ -69,14 +69,14 @@ Если берется подтаблица, индекс должен быть приведен к непрерывному. """ -__author__ = "Andrey" +__author__ = "Andrey Isakov" __copyright__ = "Copyright (c) 2026 PySATL project" __license__ = "SPDX-License-Identifier: MIT" from collections.abc import Callable, Iterator, Sequence from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Any, Protocol import numpy as np import pandas as pd @@ -84,7 +84,7 @@ from pysatl_cpd.analysis.labeled_data import LabeledData from pysatl_cpd.core.typedefs import NumericArray -SEGMENT_COLUMN = "segments" +SEGMENT_COLUMN = "segment" SEGMENT_ID_COLUMN = "segment" SEGMENT_START_COLUMN = "start" SEGMENT_END_COLUMN = "end" @@ -144,43 +144,42 @@ def __init__( missing_columns = required_segment_columns.difference(segment_info.columns) raise ValueError(f"Segment info is missing required columns: {sorted(missing_columns)}") - # TODO: one underscore for private attributes - self.__dataset = dataset.copy().reset_index(drop=True) - self.__segment_info = segment_info.copy().reset_index(drop=True) - self.__annotation = annotation + self._dataset = dataset.copy().reset_index(drop=True) + self._segment_info = segment_info.copy().reset_index(drop=True) + self._annotation = annotation - self.__segment_info = self._normalize_segment_info() + self._segment_info = self._normalize_segment_info() self._validate_segment_ranges() - raw_data = self.__dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False) + raw_data = self._dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False) super().__init__(raw_data=raw_data, change_points=self.change_point, name=name) def __iter__(self) -> Iterator[NumericArray]: - feature_values = self.__dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False) + feature_values = self._dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False) return iter(feature_values) def __len__(self) -> int: - return len(self.__dataset) + return len(self._dataset) @property def dataset(self) -> pd.DataFrame: - return self.__dataset.copy() + return self._dataset.copy() @property def segment_info(self) -> DatasetSegmentInfo: - return self.__segment_info.copy() + return self._segment_info.copy() @property def annotation(self) -> Annotation: - return self.__annotation + return self._annotation @property def feature_columns(self) -> list[str]: - return [column for column in self.__dataset.columns if column != SEGMENT_COLUMN] + return [column for column in self._dataset.columns if column != SEGMENT_COLUMN] @property def change_point(self) -> tuple[int, ...]: - segments = self.__dataset[SEGMENT_COLUMN].to_numpy(copy=False) + segments = self._dataset[SEGMENT_COLUMN].to_numpy(copy=False) if len(segments) <= 1: return () @@ -195,13 +194,13 @@ def select_columns(self, columns: Sequence[str]) -> "PandasLabeledDataProvider": if not requested_columns: raise ValueError("At least one feature column must be selected") - selected_dataset = self.__dataset.loc[:, [*requested_columns, SEGMENT_COLUMN]].copy().reset_index(drop=True) - selected_segment_info = self.__segment_info.copy().reset_index(drop=True) + selected_dataset = self._dataset.loc[:, [*requested_columns, SEGMENT_COLUMN]].copy().reset_index(drop=True) + selected_segment_info = self._segment_info.copy().reset_index(drop=True) return PandasLabeledDataProvider( dataset=selected_dataset, segment_info=selected_segment_info, - annotation=self.__annotation, + annotation=self._annotation, name=self.name, ) @@ -215,7 +214,7 @@ def query_bisegments_indexes(self, filter_fn: SegmentFilter | None = None) -> li def query_bisegments(self, filter_fn: SegmentFilter | None = None) -> list["PandasLabeledDataProvider"]: result: list[PandasLabeledDataProvider] = [] for current, next_segment in self._iter_segment_pairs(filter_fn): - sliced_dataset = self.__dataset.iloc[current.start : next_segment.end + 1].copy().reset_index(drop=True) + sliced_dataset = self._dataset.iloc[current.start : next_segment.end + 1].copy().reset_index(drop=True) split_index = next_segment.start - current.start sliced_segment_info = pd.DataFrame( @@ -239,7 +238,7 @@ def query_bisegments(self, filter_fn: SegmentFilter | None = None) -> list["Pand PandasLabeledDataProvider( dataset=sliced_dataset, segment_info=sliced_segment_info, - annotation=self.__annotation, + annotation=self._annotation, name=f"{self.name}:{current.segment}->{next_segment.segment}", ) ) @@ -247,8 +246,8 @@ def query_bisegments(self, filter_fn: SegmentFilter | None = None) -> list["Pand return result def _normalize_segment_info(self) -> DatasetSegmentInfo: - unique_segments = self.__dataset[SEGMENT_COLUMN].drop_duplicates().tolist() - normalized_info = self.__segment_info.copy() + unique_segments = self._dataset[SEGMENT_COLUMN].drop_duplicates().tolist() + normalized_info = self._segment_info.copy() if SEGMENT_ID_COLUMN in normalized_info.columns: normalized_rows: list[pd.Series[Any]] = [] @@ -270,11 +269,11 @@ def _normalize_segment_info(self) -> DatasetSegmentInfo: return normalized_info def _validate_segment_ranges(self) -> None: - data_length = len(self.__dataset) + data_length = len(self._dataset) if data_length == 0: return - for _, segment_row in self.__segment_info.iterrows(): + for _, segment_row in self._segment_info.iterrows(): start = int(segment_row[SEGMENT_START_COLUMN]) end = int(segment_row[SEGMENT_END_COLUMN]) if start < 0: @@ -297,7 +296,7 @@ def _iter_segment_pairs(self, filter_fn: SegmentFilter | None) -> list[tuple[Seg def _segment_infos(self) -> list[SegmentInfo]: segment_infos: list[SegmentInfo] = [] - for _, row in self.__segment_info.iterrows(): + for _, row in self._segment_info.iterrows(): row_dict = row.to_dict() start = int(row_dict.pop(SEGMENT_START_COLUMN)) end = int(row_dict.pop(SEGMENT_END_COLUMN)) @@ -322,43 +321,75 @@ class Dataset(Sequence[PandasLabeledDataProvider]): def __init__( self, timeserieses: Sequence[PandasLabeledDataProvider], - timeseries_preprocessor: TimeseriesPreprocessor | None = None, ) -> None: - self.__timeserieses = list(timeserieses) - self.__timeseries_preprocessor = timeseries_preprocessor if timeseries_preprocessor is not None else _identity - - @classmethod - def load_from_dir( - cls, - dir_path: Path, - timeseries_preprocessor: TimeseriesPreprocessor | None = None, - ) -> "Dataset": - raise NotImplementedError(f"{cls.__name__}.load_from_dir is dataset-source specific. dir_path={dir_path}") + self._timeserieses = list(timeserieses) def __getitem__(self, index: int) -> PandasLabeledDataProvider: - return self.__timeserieses[index] + return self._timeserieses[index] def __len__(self) -> int: - return len(self.__timeserieses) + return len(self._timeserieses) @property def timeserieses(self) -> list[PandasLabeledDataProvider]: - return list(self.__timeserieses) - - @property - def timeseries_preprocessor(self) -> TimeseriesPreprocessor: - return self.__timeseries_preprocessor + return list(self._timeserieses) def filter_by_annotation(self, annotation_filter: AnnotationFilter) -> "Dataset": - filtered_timeserieses = [provider for provider in self.__timeserieses if annotation_filter(provider.annotation)] - return Dataset(filtered_timeserieses, timeseries_preprocessor=self.__timeseries_preprocessor) + filtered_timeserieses = [provider for provider in self._timeserieses if annotation_filter(provider.annotation)] + return Dataset(filtered_timeserieses, timeseries_preprocessor=self._timeseries_preprocessor) def select_bisegments_by_filter(self, filter_fn: SegmentFilter | None = None) -> list[PandasLabeledDataProvider]: bisegments: list[PandasLabeledDataProvider] = [] - for provider in self.__timeserieses: + for provider in self._timeserieses: bisegments.extend(provider.query_bisegments(filter_fn)) return bisegments -def _identity(frame: pd.DataFrame) -> pd.DataFrame: - return frame +class DatasetLoader(Protocol): + def load(self, path: Path, timeseries_preprocessor: TimeseriesPreprocessor | None = None) -> Dataset: ... + + +class RealDatasetLoader(DatasetLoader): + @staticmethod + def prepare_annotation(file: str | Path) -> Annotation: + return Annotation(path=file) + + @staticmethod + def prepare_segment_info(timeseries: pd.DataFrame) -> DatasetSegmentInfo: + if SEGMENT_COLUMN not in timeseries.columns: + raise ValueError(f"Timeseries must contain '{SEGMENT_COLUMN}' column") + + segments = timeseries[SEGMENT_COLUMN].to_numpy(copy=False) + if len(segments) == 0: + return pd.DataFrame(columns=[SEGMENT_ID_COLUMN, SEGMENT_START_COLUMN, SEGMENT_END_COLUMN]) + + change_points = np.flatnonzero(segments[1:] != segments[:-1]) + 1 + starts = np.concatenate([[0], change_points]) + ends = np.concatenate([change_points - 1, [len(segments) - 1]]) + segment_ids = segments[starts] + + return pd.DataFrame( + { + SEGMENT_ID_COLUMN: segment_ids, + SEGMENT_START_COLUMN: starts, + SEGMENT_END_COLUMN: ends, + } + ) + + @classmethod + def load(cls, path: Path, timeseries_preprocessor: TimeseriesPreprocessor | None = None) -> Dataset: + timeseries_files = list(path.glob("**/timeseries*.csv")) + if not timeseries_files: + raise FileNotFoundError(f"No timeseries files found in {path}") + + timeserieses = [] + for file in timeseries_files: + timeseries = pd.read_csv(file) + if timeseries_preprocessor is not None: + timeseries = timeseries_preprocessor(timeseries) + + segment_info = cls.prepare_segment_info(timeseries) + annotation = cls.prepare_annotation(file) + timeserieses.append(PandasLabeledDataProvider(timeseries, segment_info=segment_info, annotation=annotation)) + + return Dataset(timeserieses)