Skip to content

Commit e086121

Browse files
committed
fix: clean up async helper file and combine all decorators into single file
Write all three async decorator implementations into one helper file to avoid overwrite issues when switching modes. Clean up the helper file in revert_code_and_helpers and early-exit paths so it doesn't persist in the user's project root after optimization.
1 parent 64a18c9 commit e086121

2 files changed

Lines changed: 15 additions & 46 deletions

File tree

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,11 +1497,11 @@ def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Ca
14971497
return False
14981498

14991499

1500-
def get_behavior_async_inline_code() -> str:
1501-
return """import asyncio
1500+
ASYNC_HELPER_INLINE_CODE = """import asyncio
15021501
import gc
15031502
import os
15041503
import sqlite3
1504+
import time
15051505
from functools import wraps
15061506
from pathlib import Path
15071507
from tempfile import TemporaryDirectory
@@ -1590,25 +1590,6 @@ async def async_wrapper(*args, **kwargs):
15901590
raise exception
15911591
return return_value
15921592
return async_wrapper
1593-
"""
1594-
1595-
1596-
def get_performance_async_inline_code() -> str:
1597-
return """import asyncio
1598-
import gc
1599-
import os
1600-
from functools import wraps
1601-
1602-
1603-
def extract_test_context_from_env():
1604-
test_module = os.environ["CODEFLASH_TEST_MODULE"]
1605-
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
1606-
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
1607-
if test_module and test_function:
1608-
return (test_module, test_class if test_class else None, test_function)
1609-
raise RuntimeError(
1610-
"Test context environment variables not set - ensure tests are run through codeflash test runner"
1611-
)
16121593
16131594
16141595
def codeflash_performance_async(func):
@@ -1649,15 +1630,6 @@ async def async_wrapper(*args, **kwargs):
16491630
raise exception
16501631
return return_value
16511632
return async_wrapper
1652-
"""
1653-
1654-
1655-
def get_concurrency_async_inline_code() -> str:
1656-
return """import asyncio
1657-
import gc
1658-
import os
1659-
import time
1660-
from functools import wraps
16611633
16621634
16631635
def codeflash_concurrency_async(func):
@@ -1691,15 +1663,6 @@ async def async_wrapper(*args, **kwargs):
16911663
return async_wrapper
16921664
"""
16931665

1694-
1695-
def get_async_inline_code(mode: TestingMode) -> str:
1696-
if mode == TestingMode.BEHAVIOR:
1697-
return get_behavior_async_inline_code()
1698-
if mode == TestingMode.CONCURRENCY:
1699-
return get_concurrency_async_inline_code()
1700-
return get_performance_async_inline_code()
1701-
1702-
17031666
ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"
17041667

17051668

@@ -1711,14 +1674,11 @@ def get_decorator_name_for_mode(mode: TestingMode) -> str:
17111674
return "codeflash_performance_async"
17121675

17131676

1714-
def write_async_helper_file(target_dir: Path, mode: TestingMode) -> Path:
1677+
def write_async_helper_file(target_dir: Path) -> Path:
17151678
"""Write the async decorator helper file to the target directory."""
17161679
helper_path = target_dir / ASYNC_HELPER_FILENAME
1717-
if helper_path.exists():
1718-
decorator_name = get_decorator_name_for_mode(mode)
1719-
if f"def {decorator_name}" in helper_path.read_text("utf-8"):
1720-
return helper_path
1721-
helper_path.write_text(get_async_inline_code(mode), "utf-8")
1680+
if not helper_path.exists():
1681+
helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8")
17221682
return helper_path
17231683

17241684

@@ -1750,7 +1710,7 @@ def add_async_decorator_to_function(
17501710
if decorator_transformer.added_decorator:
17511711
# Write the helper file to project_root (on sys.path) or source dir as fallback
17521712
helper_dir = project_root if project_root is not None else source_path.parent
1753-
write_async_helper_file(helper_dir, mode)
1713+
write_async_helper_file(helper_dir)
17541714
# Add the import via CST so sort_imports can place it correctly
17551715
decorator_name = get_decorator_name_for_mode(mode)
17561716
import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}")

codeflash/optimization/function_optimizer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,6 +1897,7 @@ def setup_and_establish_baseline(
18971897
if self.args.override_fixtures:
18981898
restore_conftest(original_conftest_content)
18991899
cleanup_paths(paths_to_cleanup)
1900+
self.cleanup_async_helper_file()
19001901
return Failure(baseline_result.failure())
19011902

19021903
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
@@ -1908,6 +1909,7 @@ def setup_and_establish_baseline(
19081909
if self.args.override_fixtures:
19091910
restore_conftest(original_conftest_content)
19101911
cleanup_paths(paths_to_cleanup)
1912+
self.cleanup_async_helper_file()
19111913
return Failure("The threshold for test confidence was not met.")
19121914

19131915
return Success(
@@ -2279,6 +2281,13 @@ def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None
22792281
self.write_code_and_helpers(
22802282
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
22812283
)
2284+
self.cleanup_async_helper_file()
2285+
2286+
def cleanup_async_helper_file(self) -> None:
2287+
from codeflash.code_utils.instrument_existing_tests import ASYNC_HELPER_FILENAME
2288+
2289+
helper_path = self.project_root / ASYNC_HELPER_FILENAME
2290+
helper_path.unlink(missing_ok=True)
22822291

22832292
def establish_original_code_baseline(
22842293
self,

0 commit comments

Comments
 (0)