@@ -1439,9 +1439,12 @@ def __init__(self, function: FunctionToOptimize, mode: TestingMode = TestingMode
14391439 self .added_decorator = False
14401440
14411441 # Choose decorator based on mode
1442- self .decorator_name = (
1443- "codeflash_behavior_async" if mode == TestingMode .BEHAVIOR else "codeflash_performance_async"
1444- )
1442+ if mode == TestingMode .BEHAVIOR :
1443+ self .decorator_name = "codeflash_behavior_async"
1444+ elif mode == TestingMode .CONCURRENCY :
1445+ self .decorator_name = "codeflash_concurrency_async"
1446+ else :
1447+ self .decorator_name = "codeflash_performance_async"
14451448
14461449 def visit_ClassDef (self , node : cst .ClassDef ) -> None :
14471450 # Track when we enter a class
@@ -1484,12 +1487,14 @@ def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Ca
14841487 "codeflash_trace_async" ,
14851488 "codeflash_behavior_async" ,
14861489 "codeflash_performance_async" ,
1490+ "codeflash_concurrency_async" ,
14871491 }
14881492 if isinstance (decorator_node , cst .Call ) and isinstance (decorator_node .func , cst .Name ):
14891493 return decorator_node .func .value in {
14901494 "codeflash_trace_async" ,
14911495 "codeflash_behavior_async" ,
14921496 "codeflash_performance_async" ,
1497+ "codeflash_concurrency_async" ,
14931498 }
14941499 return False
14951500
@@ -1501,6 +1506,14 @@ def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
15011506 self .mode = mode
15021507 self .has_import = False
15031508
1509+ def _get_decorator_name (self ) -> str :
1510+ """Get the decorator name based on the testing mode."""
1511+ if self .mode == TestingMode .BEHAVIOR :
1512+ return "codeflash_behavior_async"
1513+ if self .mode == TestingMode .CONCURRENCY :
1514+ return "codeflash_concurrency_async"
1515+ return "codeflash_performance_async"
1516+
15041517 def visit_ImportFrom (self , node : cst .ImportFrom ) -> None :
15051518 # Check if the async decorator import is already present
15061519 if (
@@ -1512,9 +1525,7 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
15121525 and node .module .attr .value == "codeflash_wrap_decorator"
15131526 and not isinstance (node .names , cst .ImportStar )
15141527 ):
1515- decorator_name = (
1516- "codeflash_behavior_async" if self .mode == TestingMode .BEHAVIOR else "codeflash_performance_async"
1517- )
1528+ decorator_name = self ._get_decorator_name ()
15181529 for import_alias in node .names :
15191530 if import_alias .name .value == decorator_name :
15201531 self .has_import = True
@@ -1525,9 +1536,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
15251536 return updated_node
15261537
15271538 # Choose import based on mode
1528- decorator_name = (
1529- "codeflash_behavior_async" if self .mode == TestingMode .BEHAVIOR else "codeflash_performance_async"
1530- )
1539+ decorator_name = self ._get_decorator_name ()
15311540
15321541 # Parse the import statement into a CST node
15331542 import_node = cst .parse_statement (f"from codeflash.code_utils.codeflash_wrap_decorator import { decorator_name } " )
0 commit comments