Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,4 @@ cython_debug/
*.jpeg

assets/data
benchmark_cache/
7 changes: 4 additions & 3 deletions examples/noreset_shewhart.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pysatl_cpd.benchmark.metrics.online.delay_metric import MeanDelayMetric, MedianDelayMetric
from pysatl_cpd.benchmark.noreset.noreset_benchmark_runner import NoResetBenchmarkRunner
from pysatl_cpd.benchmark.noreset.threshold_policy import EventBasedPolicy
from pysatl_cpd.core.algorithm_entry import AlgorithmEntry
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -169,7 +170,7 @@ def main() -> None:
print(f"Algorithm: ShewhartControlChart(learning_period={LEARNING_PERIOD}, window={WINDOW_SIZE})")
print(
f"Dataset (NoReset): {N_SERIES} series, length={SERIES_LENGTH}, change_point={CHANGE_POINT},"
"shift={MU_AFTER - MU_BEFORE:.1f}*sigma"
f"shift={MU_AFTER - MU_BEFORE:.1f}*sigma"
)
print(f"Dataset (ARL): {N_SERIES} series, length={SERIES_LENGTH}, no change points")
print(f"Error margin: {ERROR_MARGIN}")
Expand All @@ -192,7 +193,7 @@ def main() -> None:
policy = EventBasedPolicy(ERROR_MARGIN[1], strict_edge=False)

runner = NoResetBenchmarkRunner(
algorithms=[(algorithm, THRESHOLDS)],
entries=[AlgorithmEntry(algorithm, THRESHOLDS)],
providers=providers,
metrics=metrics,
solver=solver,
Expand All @@ -206,7 +207,7 @@ def main() -> None:
# RUN 2: Average Run Length (ARL)
# ==========================================
arl_runner = ARLBenchmarkRunner(
algorithms=[(algorithm, THRESHOLDS)],
entries=[AlgorithmEntry(algorithm, THRESHOLDS)],
providers=arl_providers,
solver=solver,
mode="noreset", # uses rapid point-based extraction behind the scenes
Expand Down
27 changes: 15 additions & 12 deletions pysatl_cpd/benchmark/arl_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pysatl_cpd.benchmark.noreset.threshold_policy import PointBasedPolicy
from pysatl_cpd.benchmark.online_benchmark_runner import OnlineBenchmarkRunner
from pysatl_cpd.benchmark.reset_benchmark_runner import ResetBenchmarkRunner
from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm
from pysatl_cpd.core.algorithm_entry import AlgorithmEntry
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace

Expand All @@ -44,8 +44,9 @@ class ARLBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledDa

Parameters
----------
algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
Sequence of (algorithm, thresholds) pairs to evaluate.
entries : Sequence[AlgorithmEntry]
Sequence of AlgorithmEntry objects containing algorithm, thresholds,
and an optional data transformer.
providers : list[ProviderT]
Labeled data providers to run against. Must have `change_points == []`.
solver : OnlineCpdSolver
Expand All @@ -55,6 +56,8 @@ class ARLBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledDa
dump_dir : Path | str | None, optional
Directory for caching results via BenchmarkExecutor.
If None, caching is disabled. Default is None.
verbose : bool, default=False
If True, displays progress bars during execution.

Raises
------
Expand All @@ -66,7 +69,7 @@ class ARLBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledDa

def __init__(
self,
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
entries: Sequence[AlgorithmEntry[Any, Any, Any]],
providers: list[ProviderT],
solver: OnlineCpdSolver,
mode: Literal["reset", "noreset"],
Expand All @@ -83,7 +86,7 @@ def __init__(
metrics = {"arl": ARLMetric[TraceT, ProviderT]()}

super().__init__(
algorithms=algorithms,
entries=entries,
providers=providers,
metrics=metrics, # type: ignore[arg-type]
solver=solver,
Expand All @@ -95,7 +98,7 @@ def __init__(
if mode == "reset":
# Delegate to standard ResetBenchmarkRunner
self._inner_runner: OnlineBenchmarkRunner[Any, ProviderT] = ResetBenchmarkRunner(
algorithms=algorithms,
entries=entries,
providers=providers,
metrics=cast(Any, metrics),
solver=solver,
Expand All @@ -104,7 +107,7 @@ def __init__(
elif mode == "noreset":
# Delegate to optimized NoResetBenchmarkRunner with PointBased policy
self._inner_runner = NoResetBenchmarkRunner(
algorithms=algorithms,
entries=entries,
providers=providers,
metrics=cast(Any, metrics),
solver=solver,
Expand All @@ -116,20 +119,20 @@ def __init__(

def _collect_runs(
self,
algorithm: OnlineAlgorithm[Any, Any, Any],
entry: AlgorithmEntry[Any, Any, Any],
threshold: float,
providers: Sequence[ProviderT],
) -> list[tuple[TraceT, ProviderT]]:
"""
Collect runs for a given algorithm and threshold using the configured mode.
Collect runs for a given algorithm entry and threshold using the configured mode.

Delegates the collection to either ResetBenchmarkRunner or
NoResetBenchmarkRunner depending on the initialized mode.

Parameters
----------
algorithm : OnlineAlgorithm[Any, Any, Any]
The algorithm to evaluate.
entry : AlgorithmEntry
The algorithm configuration entry to evaluate.
threshold : float
The detection threshold.
providers : Sequence[ProviderT]
Expand All @@ -140,5 +143,5 @@ def _collect_runs(
list[tuple[TraceT, ProviderT]]
Batch of (trace, provider) pairs.
"""
runs = self._inner_runner._collect_runs(algorithm, threshold, providers)
runs = self._inner_runner._collect_runs(entry, threshold, providers)
return cast(list[tuple[TraceT, ProviderT]], runs)
29 changes: 17 additions & 12 deletions pysatl_cpd/benchmark/core/benchmark_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from pathlib import Path
from typing import Any

from pysatl_cpd.core.algorithm_entry import AlgorithmEntry
from pysatl_cpd.core.data_providers.idata_provider import DataProvider
from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace

Expand All @@ -37,7 +37,7 @@ class BenchmarkRecord:
Parameters
----------
algorithm : str
The string identifier or name of the online algorithm.
The string identifier or name of the online algorithm (and transformer).
configuration_hash : str
A hash string representing the algorithm's configuration.
data : str
Expand Down Expand Up @@ -79,9 +79,9 @@ class BenchmarkExecutor[DataT]:

Parameters
----------
algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
A sequence of tuples, where each tuple contains an instantiated online
algorithm and a sequence of thresholds to test it against.
entries : Sequence[AlgorithmEntry]
A sequence of AlgorithmEntry objects, each grouping an algorithm,
its thresholds, and an optional data transformer.
providers : Sequence[DataProvider[DataT]]
A sequence of data providers to be fed into the algorithms.
solver : OnlineCpdSolver
Expand All @@ -94,12 +94,12 @@ class BenchmarkExecutor[DataT]:

def __init__(
self,
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
entries: Sequence[AlgorithmEntry[Any, Any, Any]],
providers: Sequence[DataProvider[DataT]],
solver: OnlineCpdSolver,
dump_dir: str | Path | None = None,
) -> None:
self.__algorithms = algorithms
self.__entries = entries
self.__providers = providers
self.__solver = solver
self.__dump_dir = Path(dump_dir) if dump_dir is not None else None
Expand Down Expand Up @@ -141,12 +141,17 @@ def execute(self) -> list[tuple[BenchmarkRecord, OnlineDetectionTrace[Any]]]:
)
registry[record.key] = record

for (algorithm, thresholds), provider in itertools.product(self.__algorithms, self.__providers):
algo_name = str(algorithm)
config_hash = hash(algorithm.configuration)
for entry, provider in itertools.product(self.__entries, self.__providers):
algo_name = entry.full_name
config_hash = entry.full_hash
data_name = provider.name

for threshold in thresholds:
# Apply data transformer if specified in the entry
active_provider = provider
if entry.transformer is not None:
active_provider = entry.transformer.transform(provider)

for threshold in entry.thresholds:
key = (algo_name, config_hash, data_name, float(threshold))

if key in registry:
Expand All @@ -159,7 +164,7 @@ def execute(self) -> list[tuple[BenchmarkRecord, OnlineDetectionTrace[Any]]]:
results.append((registry[key], trace))
continue

steps = list(self.__solver.run(algorithm, provider, threshold))
steps = list(self.__solver.run(entry.algorithm, active_provider, threshold))
trace = OnlineDetectionTrace.from_run(steps, algo_name, config_hash)

record = BenchmarkRecord(algo_name, config_hash, data_name, threshold, None)
Expand Down
2 changes: 1 addition & 1 deletion pysatl_cpd/benchmark/core/benchmark_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
from typing import Any

__author__ = "PySATL contributors"
__author__ = "Danil Totmyanin"
__copyright__ = "Copyright (c) 2026 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

Expand Down
34 changes: 21 additions & 13 deletions pysatl_cpd/benchmark/noreset/noreset_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
__copyright__ = "Copyright (c) 2026 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

import dataclasses
from collections.abc import Sequence
from pathlib import Path
from typing import Any
Expand All @@ -23,7 +24,7 @@
from pysatl_cpd.benchmark.noreset.noreset_detection_trace import NoResetDetectionTrace
from pysatl_cpd.benchmark.noreset.threshold_policy import ThresholdPolicy
from pysatl_cpd.benchmark.online_benchmark_runner import OnlineBenchmarkRunner
from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm
from pysatl_cpd.core.algorithm_entry import AlgorithmEntry
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace

Expand All @@ -32,16 +33,17 @@ class NoResetBenchmarkRunner[ProviderT: LabeledData[Any]](OnlineBenchmarkRunner[
"""
Optimised benchmark runner for series with a single change point.

For each (algorithm, provider) pair the solver is executed exactly
For each (algorithm entry, provider) pair the solver is executed exactly
once with threshold=inf, producing a full detection function trace.
All threshold evaluations are then simulated by applying a
ThresholdPolicy to that cached trace, avoiding redundant solver runs.
Caching is handled entirely by BenchmarkExecutor.

Parameters
----------
algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
Sequence of (algorithm, thresholds) pairs to evaluate.
entries : Sequence[AlgorithmEntry]
Sequence of AlgorithmEntry objects containing algorithm, thresholds,
and an optional data transformer.
providers : Sequence[ProviderT]
Labeled data providers to run against.
metrics : dict[str, MultipleRunMetric[NoResetDetectionTrace[Any], ProviderT, Any]]
Expand All @@ -54,11 +56,13 @@ class NoResetBenchmarkRunner[ProviderT: LabeledData[Any]](OnlineBenchmarkRunner[
dump_dir : Path | str | None, optional
Directory for caching inf traces via BenchmarkExecutor.
If None, caching is disabled. Default is None.
verbose : bool, default=False
If True, displays progress bars during execution.
"""

def __init__(
self,
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
entries: Sequence[AlgorithmEntry[Any, Any, Any]],
providers: Sequence[ProviderT],
metrics: dict[str, MultipleRunMetric[NoResetDetectionTrace[Any], ProviderT, Any]],
solver: OnlineCpdSolver,
Expand All @@ -67,7 +71,7 @@ def __init__(
verbose: bool = False,
) -> None:
super().__init__(
algorithms=algorithms,
entries=entries,
providers=providers,
metrics=metrics,
solver=solver,
Expand All @@ -76,8 +80,11 @@ def __init__(
)
self._policy = policy

# Replace all thresholds with inf for initial pre-caching run
inf_entries = [dataclasses.replace(entry, thresholds=[float("inf")]) for entry in entries]

executor: BenchmarkExecutor[Any] = BenchmarkExecutor(
algorithms=[(algorithm, [float("inf")]) for algorithm, _ in algorithms],
entries=inf_entries,
providers=list(providers),
solver=self._solver,
dump_dir=self._dump_dir,
Expand All @@ -86,26 +93,27 @@ def __init__(
self._inf_trace_cache: dict[tuple[str, int, str], OnlineDetectionTrace[Any]] = {}

for record, trace in executor.execute():
# record.algorithm maps to entry.full_name, hash maps to entry.full_hash
key = (record.algorithm, record.configuration_hash, record.data)
self._inf_trace_cache[key] = trace

def _collect_runs(
self,
algorithm: OnlineAlgorithm[Any, Any, Any],
entry: AlgorithmEntry[Any, Any, Any],
threshold: float,
providers: Sequence[ProviderT],
) -> list[tuple[NoResetDetectionTrace[Any], ProviderT]]:
"""
Collect NoReset runs for a given algorithm and threshold.
Collect NoReset runs for a given algorithm entry and threshold.

For each provider, retrieves the inf trace via BenchmarkExecutor
and applies the ThresholdPolicy to produce a lightweight
NoResetDetectionTrace.

Parameters
----------
algorithm : OnlineAlgorithm[Any, Any, Any]
The algorithm to evaluate.
entry : AlgorithmEntry
The algorithm configuration entry to evaluate.
threshold : float
The detection threshold to simulate.
providers : Sequence[ProviderT]
Expand All @@ -119,8 +127,8 @@ def _collect_runs(
if not providers:
return []

algo_name = str(algorithm)
config_hash = hash(algorithm.configuration)
algo_name = entry.full_name
config_hash = entry.full_hash
runs: list[tuple[NoResetDetectionTrace[Any], ProviderT]] = []

for provider in providers:
Expand Down
Loading
Loading