diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index 88758455e..bd4726a0b 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -14,6 +14,13 @@ DEFAULT_IMPORTANCE_THRESHOLD = 0.001 N_CANDIDATES_LP = 6 +# pytest loop stability +STABILITY_WARMUP_LOOPS = 4 +STABILITY_WINDOW_SIZE = 6 +STABILITY_CENTER_TOLERANCE = 0.01 # ±1% around median +STABILITY_SPREAD_TOLERANCE = 0.02 # 2% window spread +STABILITY_SLOPE_TOLERANCE = 0.01 # 1% improvement allowed + # Refinement REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2 diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index 450c023b3..76621327c 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -19,7 +19,6 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa if not formatter_cmds or formatter_cmds[0] == "disabled": return True - first_cmd = formatter_cmds[0] cmd_tokens = shlex.split(first_cmd) if isinstance(first_cmd, str) else [first_cmd] diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 1db09bc12..cec3c88b1 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -10,6 +10,7 @@ from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table from codeflash.lsp.lsp_message import LspMarkdownMessage from codeflash.models.test_type import TestType +from codeflash.result.best_summed_runtime import calculate_best_summed_runtime if TYPE_CHECKING: from collections.abc import Iterator @@ -817,9 +818,7 @@ def total_passed_runtime(self) -> int: :return: The runtime in nanoseconds. """ # TODO this doesn't look at the intersection of tests of baseline and original - return sum( - [min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()] - ) + return calculate_best_summed_runtime(self.usable_runtime_data_by_test_case()) def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]: map_gen_test_file_to_no_of_tests = Counter() diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 416bdc8df..e7ef0829f 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1887,7 +1887,6 @@ def establish_original_code_baseline( benchmarking_results, self.function_to_optimize.function_name ) logger.debug(f"Original async function throughput: {async_throughput} calls/second") - console.rule() if self.args.benchmark: replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks( diff --git a/codeflash/result/best_summed_runtime.py b/codeflash/result/best_summed_runtime.py new file mode 100644 index 000000000..29972b394 --- /dev/null +++ b/codeflash/result/best_summed_runtime.py @@ -0,0 +1,2 @@ +def calculate_best_summed_runtime(grouped_runtime_info: dict[any, list[int]]) -> int: + return sum([min(usable_runtime_data) for _, usable_runtime_data in grouped_runtime_info.items()]) diff --git a/codeflash/verification/pytest_plugin.py b/codeflash/verification/pytest_plugin.py index 20ef8624a..693e0d8b3 100644 --- a/codeflash/verification/pytest_plugin.py +++ b/codeflash/verification/pytest_plugin.py @@ -2,8 +2,6 @@ import contextlib import inspect - -# System Imports import logging import os import platform @@ -11,14 +9,25 @@ import sys import time as _time_module import warnings + +# System Imports from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional from unittest import TestCase # PyTest Imports import pytest from pluggy import HookspecMarker +from codeflash.code_utils.config_consts import ( + STABILITY_CENTER_TOLERANCE, + STABILITY_SLOPE_TOLERANCE, + STABILITY_SPREAD_TOLERANCE, + STABILITY_WARMUP_LOOPS, + STABILITY_WINDOW_SIZE, +) +from codeflash.result.best_summed_runtime import calculate_best_summed_runtime + if TYPE_CHECKING: from _pytest.config import Config, Parser from _pytest.main import Session @@ -77,6 +86,7 @@ class UnexpectedError(Exception): # Store references to original functions before any patching _ORIGINAL_TIME_TIME = _time_module.time _ORIGINAL_PERF_COUNTER = _time_module.perf_counter +_ORIGINAL_PERF_COUNTER_NS = _time_module.perf_counter_ns _ORIGINAL_TIME_SLEEP = _time_module.sleep @@ -260,6 +270,71 @@ def pytest_configure(config: Config) -> None: _apply_deterministic_patches() +def get_runtime_from_stdout(stdout: str) -> Optional[int]: + marker_start = "!######" + marker_end = "######!" + + if not stdout: + return None + + end = stdout.rfind(marker_end) + if end == -1: + return None + + start = stdout.rfind(marker_start, 0, end) + if start == -1: + return None + + payload = stdout[start + len(marker_start) : end] + last_colon = payload.rfind(":") + if last_colon == -1: + return None + + return int(payload[last_colon + 1 :]) + + +_NODEID_BRACKET_PATTERN = re.compile(r"\s*\[\s*\d+\s*\]\s*$") + + +def should_stop( + runtimes: list[int], + warmup: int = STABILITY_WARMUP_LOOPS, + window: int = STABILITY_WINDOW_SIZE, + center_rel_tol: float = STABILITY_CENTER_TOLERANCE, + spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE, + slope_rel_tol: float = STABILITY_SLOPE_TOLERANCE, +) -> bool: + if len(runtimes) < warmup + window: + return False + + recent = runtimes[-window:] + + # Use sorted array for faster median and min/max operations + recent_sorted = sorted(recent) + mid = window // 2 + m = recent_sorted[mid] if window % 2 else (recent_sorted[mid - 1] + recent_sorted[mid]) / 2 + + # 1) All recent points close to the median + centered = True + for r in recent: + if abs(r - m) / m > center_rel_tol: + centered = False + break + + # 2) Window spread is small + r_min, r_max = recent_sorted[0], recent_sorted[-1] + spread_ok = (r_max - r_min) / r_min <= spread_rel_tol + + # 3) No strong downward trend (still improving) + # Compare first half vs second half + half = window // 2 + first = sum(recent[:half]) / half + second = sum(recent[half:]) / (window - half) + slope_ok = (first - second) / first <= slope_rel_tol + + return centered and spread_ok and slope_ok + + class PytestLoops: name: str = "pytest-loops" @@ -268,6 +343,15 @@ def __init__(self, config: Config) -> None: level = logging.DEBUG if config.option.verbose > 1 else logging.INFO logging.basicConfig(level=level) self.logger = logging.getLogger(self.name) + self.usable_runtime_data_by_test_case: dict[str, list[int]] = {} + + @pytest.hookimpl + def pytest_runtest_logreport(self, report: pytest.TestReport) -> None: + if report.when == "call" and report.passed: + duration_ns = get_runtime_from_stdout(report.capstdout) + if duration_ns: + clean_id = _NODEID_BRACKET_PATTERN.sub("", report.nodeid) + self.usable_runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns) @hookspec(firstresult=True) def pytest_runtestloop(self, session: Session) -> bool: @@ -283,11 +367,9 @@ def pytest_runtestloop(self, session: Session) -> bool: total_time: float = self._get_total_time(session) count: int = 0 - - while total_time >= SHORTEST_AMOUNT_OF_TIME: # need to run at least one for normal tests + runtimes = [] + while total_time >= SHORTEST_AMOUNT_OF_TIME: count += 1 - total_time = self._get_total_time(session) - for index, item in enumerate(session.items): item: pytest.Item = item # noqa: PLW0127, PLW2901 item._report_sections.clear() # clear reports for new test # noqa: SLF001 @@ -304,8 +386,17 @@ def pytest_runtestloop(self, session: Session) -> bool: raise session.Failed(session.shouldfail) if session.shouldstop: raise session.Interrupted(session.shouldstop) + + best_runtime_until_now = calculate_best_summed_runtime(self.usable_runtime_data_by_test_case) + if best_runtime_until_now > 0: + runtimes.append(best_runtime_until_now) + + if should_stop(runtimes): + break + if self._timed_out(session, start_time, count): - break # exit loop + break + _ORIGINAL_TIME_SLEEP(self._get_delay_time(session)) return True