Skip to content

Commit 7f1818b

Browse files
refactoring
1 parent 9cae2a1 commit 7f1818b

3 files changed

Lines changed: 24 additions & 29 deletions

File tree

codeflash/models/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table
1111
from codeflash.lsp.lsp_message import LspMarkdownMessage
1212
from codeflash.models.test_type import TestType
13-
from codeflash.result.best_summed_runtime import calculate_best_summed_runtime
1413

1514
if TYPE_CHECKING:
1615
from collections.abc import Iterator
@@ -818,7 +817,9 @@ def total_passed_runtime(self) -> int:
818817
:return: The runtime in nanoseconds.
819818
"""
820819
# TODO this doesn't look at the intersection of tests of baseline and original
821-
return calculate_best_summed_runtime(self.usable_runtime_data_by_test_case())
820+
return sum(
821+
[min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()]
822+
)
822823

823824
def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]:
824825
map_gen_test_file_to_no_of_tests = Counter()

codeflash/result/best_summed_runtime.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

codeflash/verification/pytest_plugin.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
STABILITY_SPREAD_TOLERANCE,
2525
STABILITY_WINDOW_SIZE,
2626
)
27-
from codeflash.result.best_summed_runtime import calculate_best_summed_runtime
2827

2928
if TYPE_CHECKING:
3029
from _pytest.config import Config, Parser
@@ -292,12 +291,11 @@ def get_runtime_from_stdout(stdout: str) -> Optional[int]:
292291
return None
293292

294293
payload = stdout[start + len(marker_start) : end]
295-
parts = payload.split(":")
296-
if len(parts) != 6:
294+
last_colon = payload.rfind(":")
295+
if last_colon == -1:
297296
return None
298-
299297
try:
300-
return int(parts[5])
298+
return int(payload[last_colon + 1 :])
301299
except ValueError:
302300
return None
303301

@@ -308,17 +306,22 @@ def get_runtime_from_stdout(stdout: str) -> Optional[int]:
308306
def should_stop(
309307
runtimes: list[int],
310308
window: int,
309+
min_window_size: int,
311310
center_rel_tol: float = STABILITY_CENTER_TOLERANCE,
312311
spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE,
313312
) -> bool:
314313
if len(runtimes) < window:
315314
return False
316315

317-
# runtimes is already sorted descending
316+
if len(runtimes) < min_window_size:
317+
return False
318+
318319
recent = runtimes[-window:]
319320

321+
# Use sorted array for faster median and min/max operations
322+
recent_sorted = sorted(recent)
320323
mid = window // 2
321-
m = recent[mid] if window % 2 else (recent[mid - 1] + recent[mid]) / 2
324+
m = recent_sorted[mid] if window % 2 else (recent_sorted[mid - 1] + recent_sorted[mid]) / 2
322325

323326
# 1) All recent points close to the median
324327
centered = True
@@ -328,8 +331,9 @@ def should_stop(
328331
break
329332

330333
# 2) Window spread is small
331-
r_max = recent[0]
332-
r_min = recent[-1]
334+
r_min, r_max = recent_sorted[0], recent_sorted[-1]
335+
if r_min == 0:
336+
return False
333337
spread_ok = (r_max - r_min) / r_min <= spread_rel_tol
334338

335339
return centered and spread_ok
@@ -343,7 +347,7 @@ def __init__(self, config: Config) -> None:
343347
level = logging.DEBUG if config.option.verbose > 1 else logging.INFO
344348
logging.basicConfig(level=level)
345349
self.logger = logging.getLogger(self.name)
346-
self.usable_runtime_data_by_test_case: dict[str, list[int]] = {}
350+
self.runtime_data_by_test_case: dict[str, list[int]] = {}
347351
self.enable_stability_check: bool = (
348352
str(getattr(config.option, "codeflash_stability_check", "false")).lower() == "true"
349353
)
@@ -356,7 +360,7 @@ def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
356360
duration_ns = get_runtime_from_stdout(report.capstdout)
357361
if duration_ns:
358362
clean_id = _NODEID_BRACKET_PATTERN.sub("", report.nodeid)
359-
self.usable_runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns)
363+
self.runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns)
360364

361365
@hookspec(firstresult=True)
362366
def pytest_runtestloop(self, session: Session) -> bool:
@@ -373,7 +377,7 @@ def pytest_runtestloop(self, session: Session) -> bool:
373377

374378
count: int = 0
375379
runtimes = []
376-
elapsed = 0.0
380+
elapsed_ns = 0
377381

378382
while total_time >= SHORTEST_AMOUNT_OF_TIME: # need to run at least one for normal tests
379383
count += 1
@@ -396,27 +400,19 @@ def pytest_runtestloop(self, session: Session) -> bool:
396400
raise session.Interrupted(session.shouldstop)
397401

398402
if self.enable_stability_check:
399-
loop_end = _ORIGINAL_PERF_COUNTER_NS()
400-
dt = loop_end - loop_start # nano-seconds
401-
402-
elapsed += dt
403-
404-
best_runtime_until_now = calculate_best_summed_runtime(self.usable_runtime_data_by_test_case)
403+
elapsed_ns += _ORIGINAL_PERF_COUNTER_NS() - loop_start
404+
best_runtime_until_now = sum([min(data) for data in self.runtime_data_by_test_case.values()])
405405
if best_runtime_until_now > 0:
406406
runtimes.append(best_runtime_until_now)
407407

408408
estimated_total_loops = 0
409-
if elapsed > 0:
410-
rate = count / elapsed # loops / nano-seconds
409+
if elapsed_ns > 0:
410+
rate = count / elapsed_ns
411411
total_time_ns = total_time * 1e9
412412
estimated_total_loops = int(rate * total_time_ns)
413413

414414
window_size = int(STABILITY_WINDOW_SIZE * estimated_total_loops + 0.5)
415-
if (
416-
count >= session.config.option.codeflash_min_loops
417-
and window_size > 1
418-
and should_stop(runtimes, window_size)
419-
):
415+
if should_stop(runtimes, window_size, session.config.option.codeflash_min_loops):
420416
break
421417

422418
if self._timed_out(session, start_time, count):

0 commit comments

Comments
 (0)