diff --git a/pysatl_cpd/core/algorithms/bayesian_linear_heuristic.py b/pysatl_cpd/core/algorithms/bayesian_linear_heuristic.py index add4cfbc..4f433ef5 100644 --- a/pysatl_cpd/core/algorithms/bayesian_linear_heuristic.py +++ b/pysatl_cpd/core/algorithms/bayesian_linear_heuristic.py @@ -89,7 +89,7 @@ def detect(self, observation: np.float64 | npt.NDArray[np.float64]) -> bool: :param observation: a new observation from a time series. Note: only univariate data is supported for now. :return: whether a change point was detected by a main algorithm. """ - if observation is npt.NDArray[np.float64]: + if isinstance(observation, np.ndarray): raise TypeError("Multivariate observations are not supported") assert self.__main_algorithm is not None, "Main algorithm must be initialized" @@ -111,7 +111,7 @@ def localize(self, observation: np.float64 | npt.NDArray[np.float64]) -> Optiona :param observation: a new observation from a time series. Note: only univariate data is supported for now. :return: a change point, if it was localized, None otherwise. """ - if observation is npt.NDArray[np.float64]: + if isinstance(observation, np.ndarray): raise TypeError("Multivariate observations are not supported") assert self.__main_algorithm is not None, "Main algorithm must be initialized" diff --git a/pysatl_cpd/core/algorithms/bayesian_online_algorithm.py b/pysatl_cpd/core/algorithms/bayesian_online_algorithm.py index a1a4ca73..cea6e5a2 100644 --- a/pysatl_cpd/core/algorithms/bayesian_online_algorithm.py +++ b/pysatl_cpd/core/algorithms/bayesian_online_algorithm.py @@ -188,7 +188,7 @@ def detect(self, observation: np.float64 | npt.NDArray[np.float64]) -> bool: :param observation: new observation of a time series. Note: multivariate time series aren't supported for now. :return: whether a change point was detected after processing the new observation. """ - if observation is npt.NDArray[np.float64]: + if isinstance(observation, np.ndarray): raise TypeError("Multivariate observations are not supported") self.__process_point(np.float64(observation), False) @@ -203,7 +203,7 @@ def localize(self, observation: np.float64 | npt.NDArray[np.float64]) -> Optiona :return: absolute location of a change point, acquired after processing the new observation, or None if there wasn't any. """ - if observation is npt.NDArray[np.float64]: + if isinstance(observation, np.ndarray): raise TypeError("Multivariate observations are not supported") self.__process_point(np.float64(observation), True) diff --git a/pysatl_cpd/core/algorithms/ssa/__init__.py b/pysatl_cpd/core/algorithms/ssa/__init__.py new file mode 100644 index 00000000..3c1bea4e --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/__init__.py @@ -0,0 +1,33 @@ +""" +Module for SSA CPD algorithm's customization blocks. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from pysatl_cpd.core.algorithms.ssa.abstracts import ( + SVD, + IDecomposition, + IDetectorSSA, + IEmbedding, + IGrouping, +) +from pysatl_cpd.core.algorithms.ssa.decomposition import BasicSVD +from pysatl_cpd.core.algorithms.ssa.detectors import DistanceThreshold +from pysatl_cpd.core.algorithms.ssa.embedding import BasicEmbedding +from pysatl_cpd.core.algorithms.ssa.grouping import ConstantGrouping +from pysatl_cpd.core.algorithms.ssa.ssa import SSA + +__all__ = [ + "SSA", + "SVD", + "BasicEmbedding", + "BasicSVD", + "ConstantGrouping", + "DistanceThreshold", + "IDecomposition", + "IDetectorSSA", + "IEmbedding", + "IGrouping", +] diff --git a/pysatl_cpd/core/algorithms/ssa/abstracts/__init__.py b/pysatl_cpd/core/algorithms/ssa/abstracts/__init__.py new file mode 100644 index 00000000..c03ff6aa --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/abstracts/__init__.py @@ -0,0 +1,14 @@ +""" +Module for abstract base classes for SSA CPD algorithm. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from pysatl_cpd.core.algorithms.ssa.abstracts.idecomposition import SVD, IDecomposition +from pysatl_cpd.core.algorithms.ssa.abstracts.idetector import IDetectorSSA +from pysatl_cpd.core.algorithms.ssa.abstracts.iembedding import IEmbedding +from pysatl_cpd.core.algorithms.ssa.abstracts.igrouping import IGrouping + +__all__ = ["SVD", "IDecomposition", "IDetectorSSA", "IEmbedding", "IGrouping"] diff --git a/pysatl_cpd/core/algorithms/ssa/abstracts/idecomposition.py b/pysatl_cpd/core/algorithms/ssa/abstracts/idecomposition.py new file mode 100644 index 00000000..11320865 --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/abstracts/idecomposition.py @@ -0,0 +1,37 @@ +""" +Module for the SSA decomposition step base class. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import numpy as np +import numpy.typing as npt + + +@dataclass +class SVD: + """ + Dataclass of the singular value decomposition. + """ + + U: npt.NDArray[np.float64] + sigma: npt.NDArray[np.float64] + + +class IDecomposition(ABC): + """ + Abstract class of the second step of SSA (Decomposition step). + """ + + @abstractmethod + def decompose(self, X: npt.NDArray[np.float64]) -> SVD: + """ + Decomposes the trajectory matrix into elementary matrices. + :param X: trajectory matrix from embedding step. + :return: matrix decomposition for grouping step. + """ diff --git a/pysatl_cpd/core/algorithms/ssa/abstracts/idetector.py b/pysatl_cpd/core/algorithms/ssa/abstracts/idetector.py new file mode 100644 index 00000000..2614edba --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/abstracts/idetector.py @@ -0,0 +1,35 @@ +""" +Module for SSA CPD algorithm detector's abstract base class. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from abc import ABC, abstractmethod + +import numpy as np +import numpy.typing as npt + + +class IDetectorSSA(ABC): + """ + Abstract class for detectors that detect a change point. + """ + + @abstractmethod + def detect( + self, subspace: npt.NDArray[np.float64], test_data: list[np.float64] + ) -> bool: + """ + Checks whether a changepoint has occurred at the start of the test data. + :param subspace: vectors of the training set subspace. + :param test_data: test sample for breakpoint detection. + :return: boolean indicating whether a changepoint occurred. + """ + + @abstractmethod + def clean(self) -> None: + """ + Clears the detector's state. + """ diff --git a/pysatl_cpd/core/algorithms/ssa/abstracts/iembedding.py b/pysatl_cpd/core/algorithms/ssa/abstracts/iembedding.py new file mode 100644 index 00000000..31900628 --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/abstracts/iembedding.py @@ -0,0 +1,29 @@ +""" +Module for the SSA embedding step base class. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from abc import ABC, abstractmethod + +import numpy as np +import numpy.typing as npt + + +class IEmbedding(ABC): + """ + Abstract class of the first step of SSA (Embedding step). + """ + + @abstractmethod + def transform( + self, segment: npt.NDArray[np.float64], L: int + ) -> npt.NDArray[np.float64]: + """ + Converts a segment of the time series into a trajectory matrix. + :param segment: segment of the time series. + :param L: window width. + :return: trajectory matrix for decomposition step. + """ diff --git a/pysatl_cpd/core/algorithms/ssa/abstracts/igrouping.py b/pysatl_cpd/core/algorithms/ssa/abstracts/igrouping.py new file mode 100644 index 00000000..947ff912 --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/abstracts/igrouping.py @@ -0,0 +1,28 @@ +""" +Module for the SSA grouping step base class. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from abc import ABC, abstractmethod + +import numpy as np +import numpy.typing as npt + +from pysatl_cpd.core.algorithms.ssa.abstracts.idecomposition import SVD + + +class IGrouping(ABC): + """ + Abstract class of the third step of SSA (Grouping step). + """ + + @abstractmethod + def group(self, svd: SVD) -> npt.NDArray[np.float64]: + """ + Groups vectors to define a subspace of the time series. + :param svd: decomposition of trajectory matrix. + :return: vector group characterizing a subspace. + """ diff --git a/pysatl_cpd/core/algorithms/ssa/decomposition/__init__.py b/pysatl_cpd/core/algorithms/ssa/decomposition/__init__.py new file mode 100644 index 00000000..ddc4729b --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/decomposition/__init__.py @@ -0,0 +1,11 @@ +""" +Module for implementations of SSA decomposition step +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from pysatl_cpd.core.algorithms.ssa.decomposition.standard_svd import BasicSVD + +__all__ = ["BasicSVD"] diff --git a/pysatl_cpd/core/algorithms/ssa/decomposition/standard_svd.py b/pysatl_cpd/core/algorithms/ssa/decomposition/standard_svd.py new file mode 100644 index 00000000..86c93750 --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/decomposition/standard_svd.py @@ -0,0 +1,27 @@ +""" +Module for implementation of basic SSA decomposition step. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +import numpy as np +import numpy.typing as npt + +from pysatl_cpd.core.algorithms.ssa.abstracts import SVD, IDecomposition + + +class BasicSVD(IDecomposition): + """ + Class of basic SSA decomposition step based on SVD. + """ + + def decompose(self, X: npt.NDArray[np.float64]) -> SVD: + """ + Decomposes the trajectory matrix using SVD. + :param X: trajectory matrix from embedding step. + :return: matrix decomposition for grouping step. + """ + U, sigma, _ = np.linalg.svd(X, full_matrices=False) + return SVD(U=U, sigma=sigma) diff --git a/pysatl_cpd/core/algorithms/ssa/detectors/__init__.py b/pysatl_cpd/core/algorithms/ssa/detectors/__init__.py new file mode 100644 index 00000000..5341698c --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/detectors/__init__.py @@ -0,0 +1,13 @@ +""" +Module for implementations of SSA CPD algorithm detectors. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from pysatl_cpd.core.algorithms.ssa.detectors.distance_threshold import ( + DistanceThreshold, +) + +__all__ = ["DistanceThreshold"] diff --git a/pysatl_cpd/core/algorithms/ssa/detectors/distance_threshold.py b/pysatl_cpd/core/algorithms/ssa/detectors/distance_threshold.py new file mode 100644 index 00000000..5b78cc49 --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/detectors/distance_threshold.py @@ -0,0 +1,59 @@ +""" +Module for implementations of SSA CPD algorithm distance detector using a threshold. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +import numpy as np +import numpy.typing as npt + +from pysatl_cpd.core.algorithms.ssa.abstracts.idetector import IDetectorSSA + + +class DistanceThreshold(IDetectorSSA): + """ + Class of implementations of SSA CPD algorithm detector using + the normalized Euclidean distance and a threshold. + """ + + def __init__(self, threshold: float) -> None: + """ + Initializes SSA CPD algorithm distance detector with given threshold. + :param threshold: threshold for distance calculation. + """ + if not(0 <= threshold <= 1): + raise ValueError("Threshold must be in [0.0, 1.0]") + self._threshold = threshold + + def detect( + self, subspace: npt.NDArray[np.float64], test_data: list[np.float64] + ) -> bool: + """ + Checks whether a changepoint has occurred at the start of the test data using + the normalized Euclidean distance and comparing it to a threshold. + :param subspace: vectors of the training set subspace. + :param test_data: test sample for breakpoint detection. + :return: boolean indicating whether a changepoint occurred. + """ + L = subspace.shape[0] + n_test = len(test_data) - L + 1 + X_test = np.zeros((L, n_test)) + + for i in range(n_test): + X_test[:, i] = test_data[i : i + L] + + proj_sum = np.sum((subspace.T @ X_test) ** 2) + total_sum = np.sum(X_test**2) + + if total_sum == 0: + return False + + return bool(1 - (proj_sum / total_sum) > self._threshold) + + def clean(self) -> None: + """ + Clears the detector's state. + """ + return diff --git a/pysatl_cpd/core/algorithms/ssa/embedding/__init__.py b/pysatl_cpd/core/algorithms/ssa/embedding/__init__.py new file mode 100644 index 00000000..39f9b5e8 --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/embedding/__init__.py @@ -0,0 +1,11 @@ +""" +Module for implementations of SSA embedding step +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from pysatl_cpd.core.algorithms.ssa.embedding.basic_embedding import BasicEmbedding + +__all__ = ["BasicEmbedding"] diff --git a/pysatl_cpd/core/algorithms/ssa/embedding/basic_embedding.py b/pysatl_cpd/core/algorithms/ssa/embedding/basic_embedding.py new file mode 100644 index 00000000..cc54e597 --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/embedding/basic_embedding.py @@ -0,0 +1,34 @@ +""" +Module for implementation of basic SSA embedding step. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +import numpy as np +import numpy.typing as npt + +from pysatl_cpd.core.algorithms.ssa.abstracts.iembedding import IEmbedding + + +class BasicEmbedding(IEmbedding): + """ + Class of basic SSA embedding step based on a sliding window. + """ + + def transform( + self, segment: npt.NDArray[np.float64], L: int + ) -> npt.NDArray[np.float64]: + """ + Converts a segment of the time series into a trajectory matrix based on a sliding window. + :param segment: segment of the time series. + :param L: window width. + :return: trajectory matrix for decomposition step. + """ + K = len(segment) - L + 1 + X = np.zeros((L, K)) + for i in range(K): + X[:, i] = segment[i : i + L] + + return X diff --git a/pysatl_cpd/core/algorithms/ssa/grouping/__init__.py b/pysatl_cpd/core/algorithms/ssa/grouping/__init__.py new file mode 100644 index 00000000..36a61742 --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/grouping/__init__.py @@ -0,0 +1,11 @@ +""" +Module for implementations of SSA grouping step +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from pysatl_cpd.core.algorithms.ssa.grouping.constant_grouping import ConstantGrouping + +__all__ = ["ConstantGrouping"] diff --git a/pysatl_cpd/core/algorithms/ssa/grouping/constant_grouping.py b/pysatl_cpd/core/algorithms/ssa/grouping/constant_grouping.py new file mode 100644 index 00000000..20536b9e --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/grouping/constant_grouping.py @@ -0,0 +1,34 @@ +""" +Module for implementation of SSA constant grouping step. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +import numpy as np +import numpy.typing as npt + +from pysatl_cpd.core.algorithms.ssa.abstracts import SVD, IGrouping + + +class ConstantGrouping(IGrouping): + """ + Class of SSA constant grouping step. + """ + + def __init__(self, M: int) -> None: + """ + Initializes SSA constant grouping step with given number of vectors. + :param M: number of vectors in the main group. + """ + self._M = M + + def group(self, svd: SVD) -> npt.NDArray[np.float64]: + """ + Groups vectors to define a subspace of the time series with constant number of vectors. + :param svd: decomposition of trajectory matrix. + :return: vector group characterizing a subspace. + """ + U = svd.U[np.argsort(svd.sigma)] + return U[:, : self._M] diff --git a/pysatl_cpd/core/algorithms/ssa/ssa.py b/pysatl_cpd/core/algorithms/ssa/ssa.py new file mode 100644 index 00000000..8119fed2 --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa/ssa.py @@ -0,0 +1,50 @@ +""" +Module for the SSA method. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +import numpy as np +import numpy.typing as npt + +from pysatl_cpd.core.algorithms.ssa.abstracts.idecomposition import IDecomposition +from pysatl_cpd.core.algorithms.ssa.abstracts.iembedding import IEmbedding +from pysatl_cpd.core.algorithms.ssa.abstracts.igrouping import IGrouping + + +class SSA: + """ + Class for the SSA method. + """ + + def __init__( + self, + embedding_step: IEmbedding, + decomposition_step: IDecomposition, + grouping_step: IGrouping, + ) -> None: + """ + Initializes the steps of the SSA method. + :param embedding_step: step for constructing the trajectory matrix. + :param decomposition_step: step of decomposing the trajectory matrix into elementary matrices. + :param grouping_step: step of grouping the main matrices. + """ + self.__embedding_step = embedding_step + self.__decomposition_step = decomposition_step + self.__grouping_step = grouping_step + + def subspace( + self, segment: npt.NDArray[np.float64], L: int + ) -> npt.NDArray[np.float64]: + """ + Determines the subspace vectors based on a segment of the time series. + :param segment: segment of the time series. + :param L: window width for the SSA method. + :return: matrix of subspace vectors, where the number of columns equals the number of subspace vectors + """ + X = self.__embedding_step.transform(segment, L) + svd = self.__decomposition_step.decompose(X) + + return self.__grouping_step.group(svd) diff --git a/pysatl_cpd/core/algorithms/ssa_online_algorithm.py b/pysatl_cpd/core/algorithms/ssa_online_algorithm.py new file mode 100644 index 00000000..802b1f83 --- /dev/null +++ b/pysatl_cpd/core/algorithms/ssa_online_algorithm.py @@ -0,0 +1,137 @@ +""" +Module for SSA online change point detection algorithm. +""" + +__author__ = "Mark Dubrovchenko" +__copyright__ = "Copyright (c) 2026 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +import numpy as np +import numpy.typing as npt + +from pysatl_cpd.core.algorithms.online_algorithm import OnlineAlgorithm +from pysatl_cpd.core.algorithms.ssa.abstracts import IDetectorSSA +from pysatl_cpd.core.algorithms.ssa.ssa import SSA + + +class SSAOnline(OnlineAlgorithm): + """ + Class for SSA online change point detection algorithm. + """ + + def __init__( + self, + ssa: SSA, + detector: IDetectorSSA, + N: int, + L: int | None = None, + p: int | None = None, + Q: int | None = None, + ) -> None: + self.__ssa = ssa + self.__detector = detector + self.__N = N + self.__L = int(N / 2) if L is None else L + self.__p = N if p is None else p + self.__Q = 1 if Q is None else Q + self.__q = self.__p + self.__Q + + self.__buffer: list[np.float64] = [] + self.__current_time = 0 + self.__required_len = max(N, self.__q + self.__L - 1) + + self.__is_ready = False + self.__was_changed = False + self.__change_point: int | None = None + + def clear(self) -> None: + """ + Clears the state of the algorithm's instance. + :return: + """ + self.__buffer = [] + self.__current_time = 0 + + self.__is_ready = False + self.__was_changed = False + self.__change_point = None + + def __update_buffer(self) -> None: + """ + Updates the buffer after detecting a breakpoint. + :return: + """ + self.__buffer = self.__buffer[self.__p :] + if len(self.__buffer) == self.__required_len: + self.__buffer.pop(0) + self.__is_ready = False + + def __detect_breakpoint(self, with_localization: bool) -> None: + """ + Checks for a breakoint and determines when. + :param with_localization: whether the method was called for localization of a change point. + :return: + """ + training_data: npt.NDArray[np.float64] = np.array(self.__buffer[: self.__N]) + test_data = self.__buffer[self.__p :] + + subspace = self.__ssa.subspace(training_data, self.__L) + detection = self.__detector.detect(subspace, test_data) + + if detection: + self.__was_changed = True + if with_localization: + self.__change_point = self.__current_time - ( + len(self.__buffer) - self.__p + ) + + self.__update_buffer() + return + + self.__buffer.pop(0) + + def __process_point(self, observation: np.float64, with_localization: bool) -> None: + """ + Universal method for processing of another observation of a time series. + :param observation: new observation of a time series. + :param with_localization: whether the method was called for localization of a change point. + :return: + """ + self.__buffer.append(observation) + self.__current_time += 1 + + if not self.__is_ready: + if len(self.__buffer) != self.__required_len: + return + + self.__is_ready = True + + self.__detect_breakpoint(with_localization) + + def detect(self, observation: np.float64 | npt.NDArray[np.float64]) -> bool: + """ + Performs a change point detection after processing another observation of a time series. + :param observation: new observation of a time series. Note: multivariate time series aren't supported for now. + :return: whether a change point was detected after processing the new observation. + """ + if isinstance(observation, np.ndarray): + raise TypeError("Multivariate observations are not supported") + self.__process_point(observation, False) + result = self.__was_changed + self.__was_changed = False + return result + + def localize(self, observation: np.float64 | npt.NDArray[np.float64]) -> int | None: + """ + Performs a change point localization after processing another observation of a time series. + :param observation: new observation of a time series. + :return: absolute location of a change point, acquired after processing the new observation, + or None if there wasn't any. + """ + if isinstance(observation, np.ndarray): + raise TypeError("Multivariate observations are not supported") + self.__process_point(observation, True) + result = self.__change_point + self.__was_changed = False + self.__change_point = None + return result diff --git a/tests/test_core/test_algorithms/test_ssa_online_algorithm.py b/tests/test_core/test_algorithms/test_ssa_online_algorithm.py new file mode 100644 index 00000000..201d8890 --- /dev/null +++ b/tests/test_core/test_algorithms/test_ssa_online_algorithm.py @@ -0,0 +1,109 @@ +import numpy as np +import pytest + +from pysatl_cpd.core.algorithms.ssa.decomposition import BasicSVD +from pysatl_cpd.core.algorithms.ssa.detectors import DistanceThreshold +from pysatl_cpd.core.algorithms.ssa.embedding import BasicEmbedding +from pysatl_cpd.core.algorithms.ssa.grouping import ConstantGrouping +from pysatl_cpd.core.algorithms.ssa.ssa import SSA +from pysatl_cpd.core.algorithms.ssa_online_algorithm import SSAOnline + + +@pytest.fixture +def experimental_params(distribution_type): + params = { + "size": 500, + "change_point": 250, + "tolerable_deviation": 25, + } + return params + + +@pytest.fixture +def confiure_algorithm(distribution_type): + match distribution_type: + case "normal": + m = 12 + case "uniform": + m = 2 + case _: + raise ValueError("Unsupported likelihood") + + ssa = SSA( + embedding_step=BasicEmbedding(), + decomposition_step=BasicSVD(), + grouping_step=ConstantGrouping(m), + ) + detector = DistanceThreshold(threshold=0.75) + + ssa_online = SSAOnline( + ssa=ssa, + detector=detector, + N=40, + ) + return ssa_online + + +@pytest.fixture(scope="function") +def generate_data(distribution_type, experimental_params): + def _generate(): + np.random.seed(42) + cp = experimental_params["change_point"] + size = experimental_params["size"] + + match distribution_type: + case "normal": + return np.concatenate([np.random.normal(0, 1, cp), np.random.normal(5, 2, size - cp)]) + case "uniform": + return np.concatenate( + [ + np.random.uniform(0.0, 0.1, cp), + np.random.uniform(2.0, 2.1, size - cp) + ] + ) + case _: + raise ValueError("Unsupported likelihood") + + return _generate + + +@pytest.mark.parametrize("distribution_type", ["normal", "uniform"]) +class TestSSAOnlineAlgorithm: + def test_consecutive_detection(self, generate_data, confiure_algorithm, experimental_params): + online_ssa = confiure_algorithm + data = generate_data() + was_change_point = False + for value in data: + result = online_ssa.detect(value) + if result: + was_change_point = True + + assert was_change_point, "There was undetected change point in data" + online_ssa.clear() + + def test_consecutive_localization(self, generate_data, confiure_algorithm, experimental_params): + online_ssa = confiure_algorithm + data = generate_data() + was_change_point = False + for value in data: + result = online_ssa.localize(value) + if result: + was_change_point = True + assert ( + experimental_params["change_point"] - experimental_params["tolerable_deviation"] + <= result + <= experimental_params["change_point"] + experimental_params["tolerable_deviation"] + ), "Incorrect change point localization" + + assert was_change_point, "There was undetected change point in data" + online_ssa.clear() + + def test_online_localization_correctness(self, generate_data, confiure_algorithm, experimental_params): + online_ssa = confiure_algorithm + data = generate_data() + for time, value in np.ndenumerate(data): + result = online_ssa.detect(value) + if result: + assert experimental_params["change_point"] <= time[0], "Change point cannot be detected beforehand" + + online_ssa.clear()