Skip to content

Commit 0da87cf

Browse files
Optimize get_async_inline_code
The optimization achieves a **32% runtime improvement** by eliminating redundant work on every function call through two key changes: ## Primary Optimization: Module-Level Constants + Caching **What changed:** 1. **Module-level string constants**: The large inline code strings (1000+ characters each) are now defined once as module-level constants (`_BEHAVIOR_ASYNC_INLINE_CODE`, `_PERFORMANCE_ASYNC_INLINE_CODE`, `_CONCURRENCY_ASYNC_INLINE_CODE`) instead of being reconstructed as string literals on every function call. 2. **Cached dispatcher with dictionary lookup**: The `get_async_inline_code()` function is decorated with `@cache` and uses a pre-built dictionary (`_INLINE_CODE_MAP`) for O(1) mode lookups, replacing the sequential if-statement chain. **Why this is faster:** - **Eliminates string allocation overhead**: In the original code, Python had to allocate and construct the multi-line string literal every time a function was called. String literals in function bodies are not automatically interned, so each call created a new string object. The optimized version references the same string object stored at module initialization. - **Reduces CPU instruction count**: The original sequential if-checks required evaluating up to 2 enum comparisons per call. The optimized dictionary lookup is a single hash table access (~O(1)) that's even faster with `@cache` memoization—subsequent calls with the same `TestingMode` return the cached result immediately without any dictionary lookup. - **Caching multiplier effect**: The `@cache` decorator means the first call with each `TestingMode` performs the dictionary lookup once, then all subsequent calls with that mode are nearly instant pointer returns from the cache. **How this impacts real workloads:** Based on the `function_references`, `get_async_inline_code()` is called during test instrumentation in hot paths like `test_async_bubble_sort_behavior_results()`, `test_async_function_performance_mode()`, and `test_async_function_error_handling()`. These test setup functions likely run many times during development and CI/CD pipelines. The optimization means: - **Test instrumentation is faster**: Setting up async decorators for behavior/performance testing completes 32% faster, reducing overall test suite setup time. - **Scales with test volume**: The annotated tests show improvements compound in loops—`test_mass_compilation_of_generated_codes_varied_modes` runs 38.6% faster (329μs → 237μs) when calling the function 1000 times. - **Best for repeated mode access**: Tests that call the same mode multiple times benefit most from caching (e.g., `test_get_async_inline_code_called_multiple_times_performance` shows 44.1% speedup for 100 calls). The optimization trades a negligible increase in module initialization time and memory (storing three strings at module level) for substantial per-call speedup, making it particularly effective for test instrumentation workflows that repeatedly access the same testing mode configurations.
1 parent 9f80ea6 commit 0da87cf

1 file changed

Lines changed: 199 additions & 5 deletions

File tree

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 199 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,208 @@
1212
from codeflash.code_utils.formatter import sort_imports
1313
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1414
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
15+
from functools import cache
1516

1617
if TYPE_CHECKING:
1718
from collections.abc import Iterable
1819

1920
from codeflash.models.models import CodePosition
2021

22+
_BEHAVIOR_ASYNC_INLINE_CODE = """import asyncio
23+
import gc
24+
import os
25+
import sqlite3
26+
from functools import wraps
27+
from pathlib import Path
28+
from tempfile import TemporaryDirectory
29+
30+
import dill as pickle
31+
32+
33+
def get_run_tmp_file(file_path):
34+
if not hasattr(get_run_tmp_file, "tmpdir"):
35+
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
36+
return Path(get_run_tmp_file.tmpdir.name) / file_path
37+
38+
39+
def extract_test_context_from_env():
40+
test_module = os.environ["CODEFLASH_TEST_MODULE"]
41+
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
42+
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
43+
if test_module and test_function:
44+
return (test_module, test_class if test_class else None, test_function)
45+
raise RuntimeError(
46+
"Test context environment variables not set - ensure tests are run through codeflash test runner"
47+
)
48+
49+
50+
def codeflash_behavior_async(func):
51+
@wraps(func)
52+
async def async_wrapper(*args, **kwargs):
53+
loop = asyncio.get_running_loop()
54+
function_name = func.__name__
55+
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
56+
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
57+
test_module_name, test_class_name, test_name = extract_test_context_from_env()
58+
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
59+
if not hasattr(async_wrapper, "index"):
60+
async_wrapper.index = {}
61+
if test_id in async_wrapper.index:
62+
async_wrapper.index[test_id] += 1
63+
else:
64+
async_wrapper.index[test_id] = 0
65+
codeflash_test_index = async_wrapper.index[test_id]
66+
invocation_id = f"{line_id}_{codeflash_test_index}"
67+
class_prefix = (test_class_name + ".") if test_class_name else ""
68+
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
69+
print(f"!$######{test_stdout_tag}######$!")
70+
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
71+
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
72+
codeflash_con = sqlite3.connect(db_path)
73+
codeflash_cur = codeflash_con.cursor()
74+
codeflash_cur.execute(
75+
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
76+
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
77+
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
78+
)
79+
exception = None
80+
counter = loop.time()
81+
gc.disable()
82+
try:
83+
ret = func(*args, **kwargs)
84+
counter = loop.time()
85+
return_value = await ret
86+
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
87+
except Exception as e:
88+
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
89+
exception = e
90+
finally:
91+
gc.enable()
92+
print(f"!######{test_stdout_tag}######!")
93+
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
94+
codeflash_cur.execute(
95+
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
96+
(
97+
test_module_name,
98+
test_class_name,
99+
test_name,
100+
function_name,
101+
loop_index,
102+
invocation_id,
103+
codeflash_duration,
104+
pickled_return_value,
105+
"function_call",
106+
),
107+
)
108+
codeflash_con.commit()
109+
codeflash_con.close()
110+
if exception:
111+
raise exception
112+
return return_value
113+
return async_wrapper
114+
"""
115+
116+
_PERFORMANCE_ASYNC_INLINE_CODE = """import asyncio
117+
import gc
118+
import os
119+
from functools import wraps
120+
121+
122+
def extract_test_context_from_env():
123+
test_module = os.environ["CODEFLASH_TEST_MODULE"]
124+
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
125+
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
126+
if test_module and test_function:
127+
return (test_module, test_class if test_class else None, test_function)
128+
raise RuntimeError(
129+
"Test context environment variables not set - ensure tests are run through codeflash test runner"
130+
)
131+
132+
133+
def codeflash_performance_async(func):
134+
@wraps(func)
135+
async def async_wrapper(*args, **kwargs):
136+
loop = asyncio.get_running_loop()
137+
function_name = func.__name__
138+
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
139+
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
140+
test_module_name, test_class_name, test_name = extract_test_context_from_env()
141+
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
142+
if not hasattr(async_wrapper, "index"):
143+
async_wrapper.index = {}
144+
if test_id in async_wrapper.index:
145+
async_wrapper.index[test_id] += 1
146+
else:
147+
async_wrapper.index[test_id] = 0
148+
codeflash_test_index = async_wrapper.index[test_id]
149+
invocation_id = f"{line_id}_{codeflash_test_index}"
150+
class_prefix = (test_class_name + ".") if test_class_name else ""
151+
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
152+
print(f"!$######{test_stdout_tag}######$!")
153+
exception = None
154+
counter = loop.time()
155+
gc.disable()
156+
try:
157+
ret = func(*args, **kwargs)
158+
counter = loop.time()
159+
return_value = await ret
160+
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
161+
except Exception as e:
162+
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
163+
exception = e
164+
finally:
165+
gc.enable()
166+
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
167+
if exception:
168+
raise exception
169+
return return_value
170+
return async_wrapper
171+
"""
172+
173+
_CONCURRENCY_ASYNC_INLINE_CODE = """import asyncio
174+
import gc
175+
import os
176+
import time
177+
from functools import wraps
178+
179+
180+
def codeflash_concurrency_async(func):
181+
@wraps(func)
182+
async def async_wrapper(*args, **kwargs):
183+
function_name = func.__name__
184+
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
185+
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
186+
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
187+
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
188+
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
189+
gc.disable()
190+
try:
191+
seq_start = time.perf_counter_ns()
192+
for _ in range(concurrency_factor):
193+
result = await func(*args, **kwargs)
194+
sequential_time = time.perf_counter_ns() - seq_start
195+
finally:
196+
gc.enable()
197+
gc.disable()
198+
try:
199+
conc_start = time.perf_counter_ns()
200+
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
201+
await asyncio.gather(*tasks)
202+
concurrent_time = time.perf_counter_ns() - conc_start
203+
finally:
204+
gc.enable()
205+
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
206+
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
207+
return result
208+
return async_wrapper
209+
"""
210+
211+
_INLINE_CODE_MAP = {
212+
TestingMode.BEHAVIOR: _BEHAVIOR_ASYNC_INLINE_CODE,
213+
TestingMode.PERFORMANCE: _PERFORMANCE_ASYNC_INLINE_CODE,
214+
TestingMode.CONCURRENCY: _CONCURRENCY_ASYNC_INLINE_CODE,
215+
}
216+
21217

22218
@dataclass(frozen=True)
23219
class FunctionCallNodeArguments:
@@ -1692,12 +1888,10 @@ async def async_wrapper(*args, **kwargs):
16921888
"""
16931889

16941890

1891+
@cache
16951892
def get_async_inline_code(mode: TestingMode) -> str:
1696-
if mode == TestingMode.BEHAVIOR:
1697-
return get_behavior_async_inline_code()
1698-
if mode == TestingMode.CONCURRENCY:
1699-
return get_concurrency_async_inline_code()
1700-
return get_performance_async_inline_code()
1893+
# Return the inline code for the requested mode. Default to performance mode if not matched.
1894+
return _INLINE_CODE_MAP.get(mode, _PERFORMANCE_ASYNC_INLINE_CODE)
17011895

17021896

17031897
class AsyncInlineCodeInjector(cst.CSTTransformer):

0 commit comments

Comments
 (0)