88import pytest
99
1010from codeflash .code_utils .instrument_existing_tests import (
11+ ASYNC_HELPER_FILENAME ,
1112 add_async_decorator_to_function ,
12- get_async_inline_code ,
13+ get_decorator_name_for_mode ,
1314 inject_profiling_into_existing_test ,
1415)
1516from codeflash .discovery .functions_to_optimize import FunctionToOptimize
@@ -56,19 +57,22 @@ async def test_async_sort():
5657 func = FunctionToOptimize (function_name = "async_sorter" , parents = [], file_path = Path (fto_path ), is_async = True )
5758
5859 # For async functions, instrument the source module directly with decorators
59- source_success = add_async_decorator_to_function (fto_path , func , TestingMode .BEHAVIOR )
60+ source_success = add_async_decorator_to_function (
61+ fto_path , func , TestingMode .BEHAVIOR , project_root = project_root_path
62+ )
6063
6164 assert source_success
6265
6366 # Verify the file was modified with exact expected output
6467 instrumented_source = fto_path .read_text ("utf-8" )
6568 from codeflash .code_utils .formatter import sort_imports
6669
67- inline_code = get_async_inline_code (TestingMode .BEHAVIOR )
70+ decorator_name = get_decorator_name_for_mode (TestingMode .BEHAVIOR )
6871 decorated_original = original_code .replace (
69- "async def async_sorter" , "@codeflash_behavior_async \n async def async_sorter"
72+ "async def async_sorter" , f"@ { decorator_name } \n async def async_sorter"
7073 )
71- expected = sort_imports (code = inline_code + decorated_original , float_to_top = True )
74+ code_with_import = f"from codeflash_async_wrapper import { decorator_name } \n { decorated_original } "
75+ expected = sort_imports (code = code_with_import , float_to_top = True )
7276 assert instrumented_source .strip () == expected .strip ()
7377
7478 # Add codeflash capture
@@ -147,6 +151,9 @@ async def test_async_sort():
147151 test_path .unlink ()
148152 if test_path_perf .exists ():
149153 test_path_perf .unlink ()
154+ helper_path = project_root_path / ASYNC_HELPER_FILENAME
155+ if helper_path .exists ():
156+ helper_path .unlink ()
150157
151158
152159@pytest .mark .skipif (sys .platform == "win32" , reason = "pending support for asyncio on windows" )
@@ -187,7 +194,9 @@ async def test_async_class_sort():
187194 is_async = True ,
188195 )
189196
190- source_success = add_async_decorator_to_function (fto_path , func , TestingMode .BEHAVIOR )
197+ source_success = add_async_decorator_to_function (
198+ fto_path , func , TestingMode .BEHAVIOR , project_root = project_root_path
199+ )
191200
192201 assert source_success
193202
@@ -269,6 +278,9 @@ async def test_async_class_sort():
269278 test_path .unlink ()
270279 if test_path_perf .exists ():
271280 test_path_perf .unlink ()
281+ helper_path = project_root_path / ASYNC_HELPER_FILENAME
282+ if helper_path .exists ():
283+ helper_path .unlink ()
272284
273285
274286@pytest .mark .skipif (sys .platform == "win32" , reason = "pending support for asyncio on windows" )
@@ -299,19 +311,22 @@ async def test_async_perf():
299311 func = FunctionToOptimize (function_name = "async_sorter" , parents = [], file_path = Path (fto_path ), is_async = True )
300312
301313 # Instrument the source module with async performance decorators
302- source_success = add_async_decorator_to_function (fto_path , func , TestingMode .PERFORMANCE )
314+ source_success = add_async_decorator_to_function (
315+ fto_path , func , TestingMode .PERFORMANCE , project_root = project_root_path
316+ )
303317
304318 assert source_success
305319
306320 # Verify the file was modified
307321 instrumented_source = fto_path .read_text ("utf-8" )
308322 from codeflash .code_utils .formatter import sort_imports
309323
310- inline_code = get_async_inline_code (TestingMode .PERFORMANCE )
324+ decorator_name = get_decorator_name_for_mode (TestingMode .PERFORMANCE )
311325 decorated_original = original_code .replace (
312- "async def async_sorter" , "@codeflash_performance_async \n async def async_sorter"
326+ "async def async_sorter" , f"@ { decorator_name } \n async def async_sorter"
313327 )
314- expected = sort_imports (code = inline_code + decorated_original , float_to_top = True )
328+ code_with_import = f"from codeflash_async_wrapper import { decorator_name } \n { decorated_original } "
329+ expected = sort_imports (code = code_with_import , float_to_top = True )
315330 assert instrumented_source .strip () == expected .strip ()
316331
317332 instrument_codeflash_capture (func , {}, tests_root )
@@ -368,6 +383,9 @@ async def test_async_perf():
368383 # Clean up test files
369384 if test_path .exists ():
370385 test_path .unlink ()
386+ helper_path = project_root_path / ASYNC_HELPER_FILENAME
387+ if helper_path .exists ():
388+ helper_path .unlink ()
371389
372390
373391@pytest .mark .skipif (sys .platform == "win32" , reason = "pending support for asyncio on windows" )
@@ -413,7 +431,9 @@ async def async_error_function(lst):
413431 function_name = "async_error_function" , parents = [], file_path = Path (fto_path ), is_async = True
414432 )
415433
416- source_success = add_async_decorator_to_function (fto_path , func , TestingMode .BEHAVIOR )
434+ source_success = add_async_decorator_to_function (
435+ fto_path , func , TestingMode .BEHAVIOR , project_root = project_root_path
436+ )
417437
418438 assert source_success
419439
@@ -422,11 +442,12 @@ async def async_error_function(lst):
422442
423443 from codeflash .code_utils .formatter import sort_imports
424444
425- inline_code = get_async_inline_code (TestingMode .BEHAVIOR )
445+ decorator_name = get_decorator_name_for_mode (TestingMode .BEHAVIOR )
426446 decorated_modified = modified_code .replace (
427- "async def async_error_function" , "@codeflash_behavior_async \n async def async_error_function"
447+ "async def async_error_function" , f"@ { decorator_name } \n async def async_error_function"
428448 )
429- expected = sort_imports (code = inline_code + decorated_modified , float_to_top = True )
449+ code_with_import = f"from codeflash_async_wrapper import { decorator_name } \n { decorated_modified } "
450+ expected = sort_imports (code = code_with_import , float_to_top = True )
430451 assert instrumented_source .strip () == expected .strip ()
431452 instrument_codeflash_capture (func , {}, tests_root )
432453
@@ -488,6 +509,9 @@ async def async_error_function(lst):
488509 test_path .unlink ()
489510 if test_path_perf .exists ():
490511 test_path_perf .unlink ()
512+ helper_path = project_root_path / ASYNC_HELPER_FILENAME
513+ if helper_path .exists ():
514+ helper_path .unlink ()
491515
492516
493517@pytest .mark .skipif (sys .platform == "win32" , reason = "pending support for asyncio on windows" )
@@ -525,7 +549,9 @@ async def test_async_multi():
525549
526550 func = FunctionToOptimize (function_name = "async_sorter" , parents = [], file_path = Path (fto_path ), is_async = True )
527551
528- source_success = add_async_decorator_to_function (fto_path , func , TestingMode .BEHAVIOR )
552+ source_success = add_async_decorator_to_function (
553+ fto_path , func , TestingMode .BEHAVIOR , project_root = project_root_path
554+ )
529555
530556 assert source_success
531557 instrument_codeflash_capture (func , {}, tests_root )
@@ -598,6 +624,9 @@ async def test_async_multi():
598624 test_path .unlink ()
599625 if test_path_perf .exists ():
600626 test_path_perf .unlink ()
627+ helper_path = project_root_path / ASYNC_HELPER_FILENAME
628+ if helper_path .exists ():
629+ helper_path .unlink ()
601630
602631
603632@pytest .mark .skipif (sys .platform == "win32" , reason = "pending support for asyncio on windows" )
@@ -640,7 +669,9 @@ async def test_async_edge_cases():
640669
641670 func = FunctionToOptimize (function_name = "async_sorter" , parents = [], file_path = Path (fto_path ), is_async = True )
642671
643- source_success = add_async_decorator_to_function (fto_path , func , TestingMode .BEHAVIOR )
672+ source_success = add_async_decorator_to_function (
673+ fto_path , func , TestingMode .BEHAVIOR , project_root = project_root_path
674+ )
644675
645676 assert source_success
646677 instrument_codeflash_capture (func , {}, tests_root )
@@ -715,6 +746,9 @@ async def test_async_edge_cases():
715746 test_path .unlink ()
716747 if test_path_perf .exists ():
717748 test_path_perf .unlink ()
749+ helper_path = project_root_path / ASYNC_HELPER_FILENAME
750+ if helper_path .exists ():
751+ helper_path .unlink ()
718752
719753
720754@pytest .mark .skipif (sys .platform == "win32" , reason = "pending support for asyncio on windows" )
@@ -949,7 +983,9 @@ async def test_mixed_sorting():
949983 function_name = "async_merge_sort" , parents = [], file_path = Path (mixed_fto_path ), is_async = True
950984 )
951985
952- source_success = add_async_decorator_to_function (mixed_fto_path , async_func , TestingMode .BEHAVIOR )
986+ source_success = add_async_decorator_to_function (
987+ mixed_fto_path , async_func , TestingMode .BEHAVIOR , project_root = project_root_path
988+ )
953989
954990 assert source_success
955991
@@ -1022,3 +1058,6 @@ async def test_mixed_sorting():
10221058 test_path .unlink ()
10231059 if test_path_perf .exists ():
10241060 test_path_perf .unlink ()
1061+ helper_path = project_root_path / ASYNC_HELPER_FILENAME
1062+ if helper_path .exists ():
1063+ helper_path .unlink ()
0 commit comments