Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pysatl_cpd/core/data_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
__license__ = "SPDX-License-Identifier: MIT"


from pysatl_cpd.core.data_providers.dataset import Annotation, Dataset, PandasLabeledDataProvider, SegmentInfo
from pysatl_cpd.core.data_providers.dataset import Annotation, Dataset, PandasLabeledDataProvider, RealDatasetLoader, SegmentInfo
from pysatl_cpd.core.data_providers.idata_provider import DataProvider
from pysatl_cpd.core.data_providers.numpy_data_provider import (
NDArrayMultivariateProvider,
Expand All @@ -25,6 +25,7 @@
"SegmentInfo",
"PandasLabeledDataProvider",
"Dataset",
"RealDatasetLoader",
"NDArrayMultivariateProvider",
"NDArrayUnivariateProvider",
]
129 changes: 80 additions & 49 deletions pysatl_cpd/core/data_providers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,22 @@
Если берется подтаблица, индекс должен быть приведен к непрерывному.
"""

__author__ = "Andrey"
__author__ = "Andrey Isakov"
__copyright__ = "Copyright (c) 2026 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from collections.abc import Callable, Iterator, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from typing import Any, Protocol

import numpy as np
import pandas as pd

from pysatl_cpd.analysis.labeled_data import LabeledData
from pysatl_cpd.core.typedefs import NumericArray

SEGMENT_COLUMN = "segments"
SEGMENT_COLUMN = "segment"
SEGMENT_ID_COLUMN = "segment"
SEGMENT_START_COLUMN = "start"
SEGMENT_END_COLUMN = "end"
Expand Down Expand Up @@ -144,43 +144,42 @@ def __init__(
missing_columns = required_segment_columns.difference(segment_info.columns)
raise ValueError(f"Segment info is missing required columns: {sorted(missing_columns)}")

# TODO: one underscore for private attributes
self.__dataset = dataset.copy().reset_index(drop=True)
self.__segment_info = segment_info.copy().reset_index(drop=True)
self.__annotation = annotation
self._dataset = dataset.copy().reset_index(drop=True)
self._segment_info = segment_info.copy().reset_index(drop=True)
self._annotation = annotation

self.__segment_info = self._normalize_segment_info()
self._segment_info = self._normalize_segment_info()
self._validate_segment_ranges()

raw_data = self.__dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False)
raw_data = self._dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False)
super().__init__(raw_data=raw_data, change_points=self.change_point, name=name)

def __iter__(self) -> Iterator[NumericArray]:
feature_values = self.__dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False)
feature_values = self._dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False)
return iter(feature_values)

def __len__(self) -> int:
return len(self.__dataset)
return len(self._dataset)

@property
def dataset(self) -> pd.DataFrame:
return self.__dataset.copy()
return self._dataset.copy()

@property
def segment_info(self) -> DatasetSegmentInfo:
return self.__segment_info.copy()
return self._segment_info.copy()

@property
def annotation(self) -> Annotation:
return self.__annotation
return self._annotation

@property
def feature_columns(self) -> list[str]:
return [column for column in self.__dataset.columns if column != SEGMENT_COLUMN]
return [column for column in self._dataset.columns if column != SEGMENT_COLUMN]

@property
def change_point(self) -> tuple[int, ...]:
segments = self.__dataset[SEGMENT_COLUMN].to_numpy(copy=False)
segments = self._dataset[SEGMENT_COLUMN].to_numpy(copy=False)
if len(segments) <= 1:
return ()

Expand All @@ -195,13 +194,13 @@ def select_columns(self, columns: Sequence[str]) -> "PandasLabeledDataProvider":
if not requested_columns:
raise ValueError("At least one feature column must be selected")

selected_dataset = self.__dataset.loc[:, [*requested_columns, SEGMENT_COLUMN]].copy().reset_index(drop=True)
selected_segment_info = self.__segment_info.copy().reset_index(drop=True)
selected_dataset = self._dataset.loc[:, [*requested_columns, SEGMENT_COLUMN]].copy().reset_index(drop=True)
selected_segment_info = self._segment_info.copy().reset_index(drop=True)

return PandasLabeledDataProvider(
dataset=selected_dataset,
segment_info=selected_segment_info,
annotation=self.__annotation,
annotation=self._annotation,
name=self.name,
)

Expand All @@ -215,7 +214,7 @@ def query_bisegments_indexes(self, filter_fn: SegmentFilter | None = None) -> li
def query_bisegments(self, filter_fn: SegmentFilter | None = None) -> list["PandasLabeledDataProvider"]:
result: list[PandasLabeledDataProvider] = []
for current, next_segment in self._iter_segment_pairs(filter_fn):
sliced_dataset = self.__dataset.iloc[current.start : next_segment.end + 1].copy().reset_index(drop=True)
sliced_dataset = self._dataset.iloc[current.start : next_segment.end + 1].copy().reset_index(drop=True)
split_index = next_segment.start - current.start

sliced_segment_info = pd.DataFrame(
Expand All @@ -239,16 +238,16 @@ def query_bisegments(self, filter_fn: SegmentFilter | None = None) -> list["Pand
PandasLabeledDataProvider(
dataset=sliced_dataset,
segment_info=sliced_segment_info,
annotation=self.__annotation,
annotation=self._annotation,
name=f"{self.name}:{current.segment}->{next_segment.segment}",
)
)

return result

def _normalize_segment_info(self) -> DatasetSegmentInfo:
unique_segments = self.__dataset[SEGMENT_COLUMN].drop_duplicates().tolist()
normalized_info = self.__segment_info.copy()
unique_segments = self._dataset[SEGMENT_COLUMN].drop_duplicates().tolist()
normalized_info = self._segment_info.copy()

if SEGMENT_ID_COLUMN in normalized_info.columns:
normalized_rows: list[pd.Series[Any]] = []
Expand All @@ -270,11 +269,11 @@ def _normalize_segment_info(self) -> DatasetSegmentInfo:
return normalized_info

def _validate_segment_ranges(self) -> None:
data_length = len(self.__dataset)
data_length = len(self._dataset)
if data_length == 0:
return

for _, segment_row in self.__segment_info.iterrows():
for _, segment_row in self._segment_info.iterrows():
start = int(segment_row[SEGMENT_START_COLUMN])
end = int(segment_row[SEGMENT_END_COLUMN])
if start < 0:
Expand All @@ -297,7 +296,7 @@ def _iter_segment_pairs(self, filter_fn: SegmentFilter | None) -> list[tuple[Seg

def _segment_infos(self) -> list[SegmentInfo]:
segment_infos: list[SegmentInfo] = []
for _, row in self.__segment_info.iterrows():
for _, row in self._segment_info.iterrows():
row_dict = row.to_dict()
start = int(row_dict.pop(SEGMENT_START_COLUMN))
end = int(row_dict.pop(SEGMENT_END_COLUMN))
Expand All @@ -322,43 +321,75 @@ class Dataset(Sequence[PandasLabeledDataProvider]):
def __init__(
self,
timeserieses: Sequence[PandasLabeledDataProvider],
timeseries_preprocessor: TimeseriesPreprocessor | None = None,
) -> None:
self.__timeserieses = list(timeserieses)
self.__timeseries_preprocessor = timeseries_preprocessor if timeseries_preprocessor is not None else _identity

@classmethod
def load_from_dir(
cls,
dir_path: Path,
timeseries_preprocessor: TimeseriesPreprocessor | None = None,
) -> "Dataset":
raise NotImplementedError(f"{cls.__name__}.load_from_dir is dataset-source specific. dir_path={dir_path}")
self._timeserieses = list(timeserieses)

def __getitem__(self, index: int) -> PandasLabeledDataProvider:
return self.__timeserieses[index]
return self._timeserieses[index]

def __len__(self) -> int:
return len(self.__timeserieses)
return len(self._timeserieses)

@property
def timeserieses(self) -> list[PandasLabeledDataProvider]:
return list(self.__timeserieses)

@property
def timeseries_preprocessor(self) -> TimeseriesPreprocessor:
return self.__timeseries_preprocessor
return list(self._timeserieses)

def filter_by_annotation(self, annotation_filter: AnnotationFilter) -> "Dataset":
filtered_timeserieses = [provider for provider in self.__timeserieses if annotation_filter(provider.annotation)]
return Dataset(filtered_timeserieses, timeseries_preprocessor=self.__timeseries_preprocessor)
filtered_timeserieses = [provider for provider in self._timeserieses if annotation_filter(provider.annotation)]
return Dataset(filtered_timeserieses, timeseries_preprocessor=self._timeseries_preprocessor)

def select_bisegments_by_filter(self, filter_fn: SegmentFilter | None = None) -> list[PandasLabeledDataProvider]:
bisegments: list[PandasLabeledDataProvider] = []
for provider in self.__timeserieses:
for provider in self._timeserieses:
bisegments.extend(provider.query_bisegments(filter_fn))
return bisegments


def _identity(frame: pd.DataFrame) -> pd.DataFrame:
return frame
class DatasetLoader(Protocol):
def load(self, path: Path, timeseries_preprocessor: TimeseriesPreprocessor | None = None) -> Dataset: ...


class RealDatasetLoader(DatasetLoader):
@staticmethod
def prepare_annotation(file: str | Path) -> Annotation:
return Annotation(path=file)

@staticmethod
def prepare_segment_info(timeseries: pd.DataFrame) -> DatasetSegmentInfo:
if SEGMENT_COLUMN not in timeseries.columns:
raise ValueError(f"Timeseries must contain '{SEGMENT_COLUMN}' column")

segments = timeseries[SEGMENT_COLUMN].to_numpy(copy=False)
if len(segments) == 0:
return pd.DataFrame(columns=[SEGMENT_ID_COLUMN, SEGMENT_START_COLUMN, SEGMENT_END_COLUMN])

change_points = np.flatnonzero(segments[1:] != segments[:-1]) + 1
starts = np.concatenate([[0], change_points])
ends = np.concatenate([change_points - 1, [len(segments) - 1]])
segment_ids = segments[starts]

return pd.DataFrame(
{
SEGMENT_ID_COLUMN: segment_ids,
SEGMENT_START_COLUMN: starts,
SEGMENT_END_COLUMN: ends,
}
)

@classmethod
def load(cls, path: Path, timeseries_preprocessor: TimeseriesPreprocessor | None = None) -> Dataset:
timeseries_files = list(path.glob("**/timeseries*.csv"))
if not timeseries_files:
raise FileNotFoundError(f"No timeseries files found in {path}")

timeserieses = []
for file in timeseries_files:
timeseries = pd.read_csv(file)
if timeseries_preprocessor is not None:
timeseries = timeseries_preprocessor(timeseries)

segment_info = cls.prepare_segment_info(timeseries)
annotation = cls.prepare_annotation(file)
timeserieses.append(PandasLabeledDataProvider(timeseries, segment_info=segment_info, annotation=annotation))

return Dataset(timeserieses)
Loading