Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e4cbf95
exp
mohammedahmed18 Sep 8, 2025
01ad626
still experimenting
mohammedahmed18 Sep 9, 2025
2b01faf
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 9, 2025
ce89905
reset
mohammedahmed18 Dec 9, 2025
1f367be
dynamic tolerance
mohammedahmed18 Dec 9, 2025
0d819a8
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 11, 2025
5c4a6d9
get the duration from the pytest overriden methods
mohammedahmed18 Dec 12, 2025
ecd21d5
remove debug log
mohammedahmed18 Dec 12, 2025
30c89ce
respect the min loop count -just in case-
mohammedahmed18 Dec 12, 2025
ce2c05b
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 12, 2025
89fc939
more closer method
mohammedahmed18 Dec 15, 2025
cc94694
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 16, 2025
a67dad3
working version
mohammedahmed18 Dec 16, 2025
d52aae4
even better
mohammedahmed18 Dec 16, 2025
244f9ca
better stability algorithm
mohammedahmed18 Dec 17, 2025
a890d4f
should stop metrics
mohammedahmed18 Dec 19, 2025
3159eb6
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 22, 2025
95f22ee
better stability with sum the min of all prev loops
mohammedahmed18 Dec 22, 2025
9f311cd
Optimize should_stop
codeflash-ai[bot] Dec 22, 2025
83dff02
best summed runtime helper
mohammedahmed18 Dec 23, 2025
a8e93c7
Merge branch 'main' of github.com:codeflash-ai/codeflash into codefla…
mohammedahmed18 Dec 23, 2025
e49ba13
linting
mohammedahmed18 Dec 23, 2025
91cbc74
Merge pull request #984 from codeflash-ai/codeflash/optimize-pr967-20…
mohammedahmed18 Dec 23, 2025
0b3be3f
some enhancements from claude pr review
mohammedahmed18 Dec 23, 2025
b57fa1a
Merge branch 'exp/consistent-loop-break' of github.com:codeflash-ai/c…
mohammedahmed18 Dec 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions codeflash/code_utils/config_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
N_CANDIDATES_LP = 6

# pytest loop stability
STABILITY_WARMUP_LOOPS = 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did you determine these magic numbers @mohammedahmed18 ?

STABILITY_WINDOW_SIZE = 6
STABILITY_CENTER_TOLERANCE = 0.01 # ±1% around median
STABILITY_SPREAD_TOLERANCE = 0.02 # 2% window spread
STABILITY_SLOPE_TOLERANCE = 0.01 # 1% improvement allowed

# Refinement
REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations
REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2
Expand Down
1 change: 0 additions & 1 deletion codeflash/code_utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
if not formatter_cmds or formatter_cmds[0] == "disabled":
return True

first_cmd = formatter_cmds[0]
cmd_tokens = shlex.split(first_cmd) if isinstance(first_cmd, str) else [first_cmd]

Expand Down
5 changes: 2 additions & 3 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table
from codeflash.lsp.lsp_message import LspMarkdownMessage
from codeflash.models.test_type import TestType
from codeflash.result.best_summed_runtime import calculate_best_summed_runtime

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down Expand Up @@ -817,9 +818,7 @@ def total_passed_runtime(self) -> int:
:return: The runtime in nanoseconds.
"""
# TODO this doesn't look at the intersection of tests of baseline and original
return sum(
[min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()]
)
return calculate_best_summed_runtime(self.usable_runtime_data_by_test_case())

def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]:
map_gen_test_file_to_no_of_tests = Counter()
Expand Down
1 change: 0 additions & 1 deletion codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,7 +1887,6 @@ def establish_original_code_baseline(
benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
console.rule()

if self.args.benchmark:
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
Expand Down
2 changes: 2 additions & 0 deletions codeflash/result/best_summed_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def calculate_best_summed_runtime(grouped_runtime_info: dict[any, list[int]]) -> int:
return sum([min(usable_runtime_data) for _, usable_runtime_data in grouped_runtime_info.items()])
107 changes: 99 additions & 8 deletions codeflash/verification/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,32 @@

import contextlib
import inspect

# System Imports
import logging
import os
import platform
import re
import sys
import time as _time_module
import warnings

# System Imports
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional
from unittest import TestCase

# PyTest Imports
import pytest
from pluggy import HookspecMarker

from codeflash.code_utils.config_consts import (
STABILITY_CENTER_TOLERANCE,
STABILITY_SLOPE_TOLERANCE,
STABILITY_SPREAD_TOLERANCE,
STABILITY_WARMUP_LOOPS,
STABILITY_WINDOW_SIZE,
)
from codeflash.result.best_summed_runtime import calculate_best_summed_runtime

if TYPE_CHECKING:
from _pytest.config import Config, Parser
from _pytest.main import Session
Expand Down Expand Up @@ -77,6 +86,7 @@ class UnexpectedError(Exception):
# Store references to original functions before any patching
_ORIGINAL_TIME_TIME = _time_module.time
_ORIGINAL_PERF_COUNTER = _time_module.perf_counter
_ORIGINAL_PERF_COUNTER_NS = _time_module.perf_counter_ns
_ORIGINAL_TIME_SLEEP = _time_module.sleep


Expand Down Expand Up @@ -260,6 +270,71 @@ def pytest_configure(config: Config) -> None:
_apply_deterministic_patches()


def get_runtime_from_stdout(stdout: str) -> Optional[int]:
marker_start = "!######"
marker_end = "######!"

if not stdout:
return None

end = stdout.rfind(marker_end)
if end == -1:
return None

start = stdout.rfind(marker_start, 0, end)
if start == -1:
return None

payload = stdout[start + len(marker_start) : end]
last_colon = payload.rfind(":")
if last_colon == -1:
return None

return int(payload[last_colon + 1 :])


_NODEID_BRACKET_PATTERN = re.compile(r"\s*\[\s*\d+\s*\]\s*$")


def should_stop(
runtimes: list[int],
warmup: int = STABILITY_WARMUP_LOOPS,
window: int = STABILITY_WINDOW_SIZE,
center_rel_tol: float = STABILITY_CENTER_TOLERANCE,
spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE,
slope_rel_tol: float = STABILITY_SLOPE_TOLERANCE,
) -> bool:
if len(runtimes) < warmup + window:
return False

recent = runtimes[-window:]

# Use sorted array for faster median and min/max operations
recent_sorted = sorted(recent)
mid = window // 2
m = recent_sorted[mid] if window % 2 else (recent_sorted[mid - 1] + recent_sorted[mid]) / 2

# 1) All recent points close to the median
centered = True
for r in recent:
if abs(r - m) / m > center_rel_tol:
centered = False
break

# 2) Window spread is small
r_min, r_max = recent_sorted[0], recent_sorted[-1]
spread_ok = (r_max - r_min) / r_min <= spread_rel_tol

# 3) No strong downward trend (still improving)
# Compare first half vs second half
half = window // 2
first = sum(recent[:half]) / half
second = sum(recent[half:]) / (window - half)
slope_ok = (first - second) / first <= slope_rel_tol

return centered and spread_ok and slope_ok


class PytestLoops:
name: str = "pytest-loops"

Expand All @@ -268,6 +343,15 @@ def __init__(self, config: Config) -> None:
level = logging.DEBUG if config.option.verbose > 1 else logging.INFO
logging.basicConfig(level=level)
self.logger = logging.getLogger(self.name)
self.usable_runtime_data_by_test_case: dict[str, list[int]] = {}

@pytest.hookimpl
def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
if report.when == "call" and report.passed:
duration_ns = get_runtime_from_stdout(report.capstdout)
if duration_ns:
clean_id = _NODEID_BRACKET_PATTERN.sub("", report.nodeid)
self.usable_runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns)

@hookspec(firstresult=True)
def pytest_runtestloop(self, session: Session) -> bool:
Expand All @@ -283,11 +367,9 @@ def pytest_runtestloop(self, session: Session) -> bool:
total_time: float = self._get_total_time(session)

count: int = 0

while total_time >= SHORTEST_AMOUNT_OF_TIME: # need to run at least one for normal tests
runtimes = []
while total_time >= SHORTEST_AMOUNT_OF_TIME:
count += 1
total_time = self._get_total_time(session)

for index, item in enumerate(session.items):
item: pytest.Item = item # noqa: PLW0127, PLW2901
item._report_sections.clear() # clear reports for new test # noqa: SLF001
Expand All @@ -304,8 +386,17 @@ def pytest_runtestloop(self, session: Session) -> bool:
raise session.Failed(session.shouldfail)
if session.shouldstop:
raise session.Interrupted(session.shouldstop)

best_runtime_until_now = calculate_best_summed_runtime(self.usable_runtime_data_by_test_case)
if best_runtime_until_now > 0:
runtimes.append(best_runtime_until_now)

if should_stop(runtimes):
break

if self._timed_out(session, start_time, count):
break # exit loop
break

_ORIGINAL_TIME_SLEEP(self._get_delay_time(session))
return True

Expand Down
Loading