Skip to content

Commit 46701c7

Browse files
window percentage
1 parent b57fa1a commit 46701c7

2 files changed

Lines changed: 65 additions & 22 deletions

File tree

codeflash/code_utils/config_consts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
N_CANDIDATES_LP = 6
1616

1717
# pytest loop stability
18-
STABILITY_WARMUP_LOOPS = 4
19-
STABILITY_WINDOW_SIZE = 6
20-
STABILITY_CENTER_TOLERANCE = 0.01 # ±1% around median
21-
STABILITY_SPREAD_TOLERANCE = 0.02 # 2% window spread
22-
STABILITY_SLOPE_TOLERANCE = 0.01 # 1% improvement allowed
18+
# For now, we use strict thresholds (large windows and tight tolerances), since this is still experimental.
19+
STABILITY_WARMUP_LOOPS = 0.05 # 5% of total window
20+
STABILITY_WINDOW_SIZE = 0.35 # 35% of total window
21+
STABILITY_CENTER_TOLERANCE = 0.005 # ±0.5% around median
22+
STABILITY_SPREAD_TOLERANCE = 0.005 # 0.5% window spread
2323

2424
# Refinement
2525
REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations

codeflash/verification/pytest_plugin.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import inspect
55
import logging
6+
import math
67
import os
78
import platform
89
import re
@@ -21,11 +22,11 @@
2122

2223
from codeflash.code_utils.config_consts import (
2324
STABILITY_CENTER_TOLERANCE,
24-
STABILITY_SLOPE_TOLERANCE,
2525
STABILITY_SPREAD_TOLERANCE,
2626
STABILITY_WARMUP_LOOPS,
2727
STABILITY_WINDOW_SIZE,
2828
)
29+
from codeflash.code_utils.time_utils import humanize_runtime
2930
from codeflash.result.best_summed_runtime import calculate_best_summed_runtime
3031

3132
if TYPE_CHECKING:
@@ -298,11 +299,10 @@ def get_runtime_from_stdout(stdout: str) -> Optional[int]:
298299

299300
def should_stop(
300301
runtimes: list[int],
301-
warmup: int = STABILITY_WARMUP_LOOPS,
302-
window: int = STABILITY_WINDOW_SIZE,
302+
warmup: int,
303+
window: int,
303304
center_rel_tol: float = STABILITY_CENTER_TOLERANCE,
304305
spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE,
305-
slope_rel_tol: float = STABILITY_SLOPE_TOLERANCE,
306306
) -> bool:
307307
if len(runtimes) < warmup + window:
308308
return False
@@ -325,14 +325,7 @@ def should_stop(
325325
r_min, r_max = recent_sorted[0], recent_sorted[-1]
326326
spread_ok = (r_max - r_min) / r_min <= spread_rel_tol
327327

328-
# 3) No strong downward trend (still improving)
329-
# Compare first half vs second half
330-
half = window // 2
331-
first = sum(recent[:half]) / half
332-
second = sum(recent[half:]) / (window - half)
333-
slope_ok = (first - second) / first <= slope_rel_tol
334-
335-
return centered and spread_ok and slope_ok
328+
return centered and spread_ok
336329

337330

338331
class PytestLoops:
@@ -344,6 +337,7 @@ def __init__(self, config: Config) -> None:
344337
logging.basicConfig(level=level)
345338
self.logger = logging.getLogger(self.name)
346339
self.usable_runtime_data_by_test_case: dict[str, list[int]] = {}
340+
self.total_loop_runtimes: list[int] = []
347341

348342
@pytest.hookimpl
349343
def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
@@ -359,6 +353,7 @@ def pytest_runtestloop(self, session: Session) -> bool:
359353
if session.testsfailed and not session.config.option.continue_on_collection_errors:
360354
msg = "{} error{} during collection".format(session.testsfailed, "s" if session.testsfailed != 1 else "")
361355
raise session.Interrupted(msg)
356+
is_perf_test = bool(session.config.option.codeflash_max_loops > 1)
362357

363358
if session.config.option.collectonly:
364359
return True
@@ -368,8 +363,12 @@ def pytest_runtestloop(self, session: Session) -> bool:
368363

369364
count: int = 0
370365
runtimes = []
366+
break_at = -1
367+
elapsed = 0.0
368+
371369
while total_time >= SHORTEST_AMOUNT_OF_TIME:
372370
count += 1
371+
loop_start = _ORIGINAL_PERF_COUNTER_NS()
373372
for index, item in enumerate(session.items):
374373
item: pytest.Item = item # noqa: PLW0127, PLW2901
375374
item._report_sections.clear() # clear reports for new test # noqa: SLF001
@@ -387,14 +386,58 @@ def pytest_runtestloop(self, session: Session) -> bool:
387386
if session.shouldstop:
388387
raise session.Interrupted(session.shouldstop)
389388

390-
best_runtime_until_now = calculate_best_summed_runtime(self.usable_runtime_data_by_test_case)
391-
if best_runtime_until_now > 0:
392-
runtimes.append(best_runtime_until_now)
389+
if is_perf_test:
390+
loop_end = _ORIGINAL_PERF_COUNTER_NS()
391+
dt = loop_end - loop_start # nano-seconds
393392

394-
if should_stop(runtimes):
395-
break
393+
elapsed += dt
394+
self.total_loop_runtimes.append(dt)
395+
396+
best_runtime_until_now = calculate_best_summed_runtime(self.usable_runtime_data_by_test_case)
397+
if best_runtime_until_now > 0:
398+
runtimes.append(best_runtime_until_now)
399+
400+
estimated_total_loops = 0
401+
if elapsed > 0:
402+
rate = count / elapsed # loops / nano-seconds
403+
estimated_total_loops = int(rate * total_time * 1e9)
404+
405+
warmup_loops = math.floor(STABILITY_WARMUP_LOOPS * estimated_total_loops)
406+
window_size = math.floor(STABILITY_WINDOW_SIZE * estimated_total_loops)
407+
if ( # noqa: SIM102
408+
warmup_loops > 1 and window_size > 1 and should_stop(runtimes, warmup_loops, window_size)
409+
):
410+
if break_at == -1:
411+
break_at = count
412+
# break
396413

397414
if self._timed_out(session, start_time, count):
415+
if is_perf_test:
416+
did_break = "true" if break_at != -1 else "false"
417+
best_of_all = min(runtimes)
418+
best_before_break = min(runtimes[:break_at])
419+
420+
runtimes_after_break = self.total_loop_runtimes[break_at:]
421+
total_after_break = str(sum(runtimes_after_break)) if did_break == "true" else "NA"
422+
total_runtime = str(sum(self.total_loop_runtimes))
423+
424+
accuracy_str = "NA"
425+
if did_break == "true":
426+
accuracy = best_of_all / best_before_break * 100
427+
accuracy_str = f"{accuracy:.2f}"
428+
Path(
429+
f"/home/mohammed/Documents/test-results/optimize-me-exp/exp-{int(_ORIGINAL_TIME_TIME())}.json"
430+
).write_text(f"""{{
431+
"runtimes": {runtimes},
432+
"did_break": {did_break},
433+
"accuracy": "{accuracy_str}%",
434+
"time_saved": "{total_after_break}",
435+
"total_runtime": "{total_runtime}",
436+
"total_loops": {count},
437+
"breaked_at": {break_at},
438+
"best_of_all": "{humanize_runtime(best_of_all)}",
439+
"best_before_break": "{humanize_runtime(best_before_break)}"
440+
}}""")
398441
break
399442

400443
_ORIGINAL_TIME_SLEEP(self._get_delay_time(session))

0 commit comments

Comments
 (0)