Skip to content

Commit 74c29b2

Browse files
committed
fix: update tests for multi-round benchmark plugin
The benchmark plugin now runs multiple rounds with calibrated iterations. Tests need SELECT DISTINCT for row counts and must extract median_ns from BenchmarkStats before validation.
1 parent 7005fa0 commit 74c29b2

3 files changed

Lines changed: 20 additions & 16 deletions

File tree

codeflash/benchmarking/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
def validate_and_format_benchmark_table(
19-
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int]
19+
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]], total_benchmark_timings: dict[BenchmarkKey, float]
2020
) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]:
2121
function_to_result = {}
2222
# Process each function's benchmark data
@@ -77,8 +77,8 @@ def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey
7777

7878
def process_benchmark_data(
7979
replay_performance_gain: dict[BenchmarkKey, float],
80-
fto_benchmark_timings: dict[BenchmarkKey, int],
81-
total_benchmark_timings: dict[BenchmarkKey, int],
80+
fto_benchmark_timings: dict[BenchmarkKey, float],
81+
total_benchmark_timings: dict[BenchmarkKey, float],
8282
) -> Optional[ProcessedBenchmarkInfo]:
8383
"""Process benchmark data and generate detailed benchmark information.
8484

tests/test_pickle_patcher.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,14 +253,15 @@ def test_run_and_parse_picklepatch() -> None:
253253
cursor = conn.cursor()
254254

255255
cursor.execute(
256-
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
256+
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
257257
)
258258
function_calls = cursor.fetchall()
259259

260260
# Assert the length of function calls
261261
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
262262
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
263-
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
263+
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
264+
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
264265
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
265266
assert (
266267
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"
@@ -401,7 +402,7 @@ def test_run_and_parse_picklepatch() -> None:
401402
pytest_max_loops=1,
402403
testing_time=1.0,
403404
)
404-
assert len(test_results_unused_socket) == 1
405+
assert len(test_results_unused_socket) >= 1
405406
assert (
406407
test_results_unused_socket.test_results[0].id.test_module_path
407408
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
@@ -410,7 +411,7 @@ def test_run_and_parse_picklepatch() -> None:
410411
test_results_unused_socket.test_results[0].id.test_function_name
411412
== "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
412413
)
413-
assert test_results_unused_socket.test_results[0].did_pass == True
414+
assert test_results_unused_socket.test_results[0].did_pass is True
414415

415416
# Replace with optimized candidate
416417
fto_unused_socket_path.write_text("""
@@ -432,7 +433,7 @@ def bubble_sort_with_unused_socket(data_container):
432433
pytest_max_loops=1,
433434
testing_time=1.0,
434435
)
435-
assert len(optimized_test_results_unused_socket) == 1
436+
assert len(optimized_test_results_unused_socket) >= 1
436437
match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
437438
assert match
438439

@@ -487,7 +488,7 @@ def bubble_sort_with_unused_socket(data_container):
487488
pytest_max_loops=1,
488489
testing_time=1.0,
489490
)
490-
assert len(test_results_used_socket) == 1
491+
assert len(test_results_used_socket) >= 1
491492
assert (
492493
test_results_used_socket.test_results[0].id.test_module_path
493494
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
@@ -522,7 +523,7 @@ def bubble_sort_with_used_socket(data_container):
522523
pytest_max_loops=1,
523524
testing_time=1.0,
524525
)
525-
assert len(test_results_used_socket) == 1
526+
assert len(test_results_used_socket) >= 1
526527
assert (
527528
test_results_used_socket.test_results[0].id.test_module_path
528529
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"

tests/test_trace_benchmarks.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_trace_benchmarks() -> None:
2929
# Get the count of records
3030
# Get all records
3131
cursor.execute(
32-
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
32+
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
3333
)
3434
function_calls = cursor.fetchall()
3535

@@ -220,7 +220,8 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
220220
if conn is not None:
221221
conn.close()
222222
output_file.unlink(missing_ok=True)
223-
shutil.rmtree(replay_tests_dir)
223+
if replay_tests_dir.exists():
224+
shutil.rmtree(replay_tests_dir)
224225

225226

226227
# Skip the test in CI as the machine may not be multithreaded
@@ -242,14 +243,15 @@ def test_trace_multithreaded_benchmark() -> None:
242243
# Get the count of records
243244
# Get all records
244245
cursor.execute(
245-
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
246+
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
246247
)
247248
function_calls = cursor.fetchall()
248249

249250
# Assert the length of function calls
250251
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
251252
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
252-
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
253+
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
254+
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
253255
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
254256
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
255257

@@ -304,14 +306,15 @@ def test_trace_benchmark_decorator() -> None:
304306
# Get the count of records
305307
# Get all records
306308
cursor.execute(
307-
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
309+
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
308310
)
309311
function_calls = cursor.fetchall()
310312

311313
# Assert the length of function calls
312314
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
313315
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
314-
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
316+
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
317+
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
315318
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
316319
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
317320

0 commit comments

Comments
 (0)