@@ -684,27 +684,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
684684 )
685685
686686
687- def instrument_source_module_with_async_decorators (
688- source_path : Path , function_to_optimize : FunctionToOptimize , mode : TestingMode = TestingMode .BEHAVIOR
689- ) -> tuple [bool , str | None ]:
690- if not function_to_optimize .is_async :
691- return False , None
692-
693- try :
694- with source_path .open (encoding = "utf8" ) as f :
695- source_code = f .read ()
696-
697- modified_code , decorator_added = add_async_decorator_to_function (source_code , function_to_optimize , mode )
698-
699- if decorator_added :
700- return True , modified_code
701-
702- except Exception :
703- return False , None
704- else :
705- return False , None
706-
707-
708687def inject_async_profiling_into_existing_test (
709688 test_path : Path ,
710689 call_positions : list [CodePosition ],
@@ -1288,25 +1267,29 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
12881267
12891268
12901269def add_async_decorator_to_function (
1291- source_code : str , function : FunctionToOptimize , mode : TestingMode = TestingMode .BEHAVIOR
1292- ) -> tuple [ str , bool ] :
1293- """Add async decorator to an async function definition.
1270+ source_path : Path , function : FunctionToOptimize , mode : TestingMode = TestingMode .BEHAVIOR
1271+ ) -> bool :
1272+ """Add async decorator to an async function definition and write back to file .
12941273
12951274 Args:
12961275 ----
1297- source_code: The source code to modify.
1276+ source_path: Path to the source file to modify in-place .
12981277 function: The FunctionToOptimize object representing the target async function.
12991278 mode: The testing mode to determine which decorator to apply.
13001279
13011280 Returns:
13021281 -------
1303- Tuple of (modified_source_code, was_decorator_added) .
1282+ Boolean indicating whether the decorator was successfully added .
13041283
13051284 """
13061285 if not function .is_async :
1307- return source_code , False
1286+ return False
13081287
13091288 try :
1289+ # Read source code
1290+ with source_path .open (encoding = "utf8" ) as f :
1291+ source_code = f .read ()
1292+
13101293 module = cst .parse_module (source_code )
13111294
13121295 # Add the decorator to the function
@@ -1318,10 +1301,17 @@ def add_async_decorator_to_function(
13181301 import_transformer = AsyncDecoratorImportAdder (mode )
13191302 module = module .visit (import_transformer )
13201303
1321- return sort_imports (code = module .code , float_to_top = True ), decorator_transformer . added_decorator
1304+ modified_code = sort_imports (code = module .code , float_to_top = True )
13221305 except Exception as e :
13231306 logger .exception (f"Error adding async decorator to function { function .qualified_name } : { e } " )
1324- return source_code , False
1307+ return False
1308+ else :
1309+ if decorator_transformer .added_decorator :
1310+ with source_path .open ("w" , encoding = "utf8" ) as f :
1311+ f .write (modified_code )
1312+ logger .debug (f"Applied async { mode .value } instrumentation to { source_path } " )
1313+ return True
1314+ return False
13251315
13261316
13271317def create_instrumented_source_module_path (source_path : Path , temp_dir : Path ) -> Path :
0 commit comments