@@ -36,6 +36,7 @@ class ExpectCallMatch:
3636 assertion_chain : str
3737 has_trailing_semicolon : bool
3838 object_prefix : str = "" # Object prefix like "calc." or "this." or ""
39+ this_arg : str = "" # For .call() patterns: the thisArg value
3940
4041
4142@dataclass
@@ -49,6 +50,7 @@ class StandaloneCallMatch:
4950 prefix : str # "await " or ""
5051 object_prefix : str # Object prefix like "calc." or "this." or ""
5152 has_trailing_semicolon : bool
53+ this_arg : str = "" # For .call() patterns: the thisArg value
5254
5355
5456codeflash_import_pattern = re .compile (
@@ -96,6 +98,52 @@ def is_inside_string(code: str, pos: int) -> bool:
9698 return in_string
9799
98100
101+ def split_call_args (args_str : str ) -> tuple [str , str ]:
102+ """Split .call() arguments into (thisArg, remaining_args).
103+
104+ The first argument to .call() is thisArg. Remaining arguments are the
105+ actual function arguments. Handles nested parens, brackets, braces, and
106+ string literals when finding the first top-level comma.
107+
108+ Returns:
109+ Tuple of (thisArg, remaining_args_str). remaining_args_str may be empty.
110+
111+ """
112+ args_str = args_str .strip ()
113+ if not args_str :
114+ return "" , ""
115+
116+ depth = 0
117+ in_string = False
118+ string_char = None
119+ s = args_str
120+ s_len = len (s )
121+
122+ for i in range (s_len ):
123+ char = s [i ]
124+
125+ if char in "\" '`" and (i == 0 or s [i - 1 ] != "\\ " ):
126+ if not in_string :
127+ in_string = True
128+ string_char = char
129+ elif char == string_char :
130+ in_string = False
131+ string_char = None
132+ continue
133+
134+ if in_string :
135+ continue
136+
137+ if char in "([{" :
138+ depth += 1
139+ elif char in ")]}" :
140+ depth -= 1
141+ elif char == "," and depth == 0 :
142+ return s [:i ].strip (), s [i + 1 :].strip ()
143+
144+ return s .strip (), ""
145+
146+
99147class StandaloneCallTransformer :
100148 """Transforms standalone func(...) calls in JavaScript test code.
101149
@@ -127,36 +175,37 @@ def __init__(self, function_to_optimize: FunctionToOptimize, capture_func: str)
127175 self ._bracket_call_pattern = re .compile (
128176 rf"(\s*)(await\s+)?(\w+)\[['\"]({ re .escape (self .func_name )} )['\"]]\s*\("
129177 )
178+ # Pattern to match .call() invocation: func_name.call( or obj.func_name.call(
179+ # Captures: (whitespace)(await )?(object.)*func_name.call(
180+ self ._dot_call_pattern = re .compile (rf"(\s*)(await\s+)?((?:\w+\.)*){ re .escape (self .func_name )} \.call\s*\(" )
130181
131182 def transform (self , code : str ) -> str :
132183 """Transform all standalone calls in the code."""
133184 result : list [str ] = []
134185 pos = 0
135186
136187 while pos < len (code ):
137- # Try both dot notation and bracket notation patterns
188+ # Try all patterns: dot notation, bracket notation, and .call() notation
138189 dot_match = self ._call_pattern .search (code , pos )
139190 bracket_match = self ._bracket_call_pattern .search (code , pos )
140-
141- # Choose the first match (by position)
142- match = None
143- is_bracket_notation = False
144- if dot_match and bracket_match :
145- if dot_match .start () <= bracket_match .start ():
146- match = dot_match
147- else :
148- match = bracket_match
149- is_bracket_notation = True
150- elif dot_match :
151- match = dot_match
152- elif bracket_match :
153- match = bracket_match
154- is_bracket_notation = True
155-
156- if not match :
191+ call_match = self ._dot_call_pattern .search (code , pos )
192+
193+ # Choose the earliest match by position
194+ candidates : list [tuple [str , re .Match ]] = []
195+ if dot_match :
196+ candidates .append (("dot" , dot_match ))
197+ if bracket_match :
198+ candidates .append (("bracket" , bracket_match ))
199+ if call_match :
200+ candidates .append (("call" , call_match ))
201+
202+ if not candidates :
157203 result .append (code [pos :])
158204 break
159205
206+ candidates .sort (key = lambda x : x [1 ].start ())
207+ match_type , match = candidates [0 ]
208+
160209 match_start = match .start ()
161210
162211 # Check if this call is inside an expect() or already transformed
@@ -169,7 +218,9 @@ def transform(self, code: str) -> str:
169218 result .append (code [pos :match_start ])
170219
171220 # Try to parse the full standalone call
172- if is_bracket_notation :
221+ if match_type == "call" :
222+ standalone_match = self ._parse_dot_call_standalone (code , match )
223+ elif match_type == "bracket" :
173224 standalone_match = self ._parse_bracket_standalone_call (code , match )
174225 else :
175226 standalone_match = self ._parse_standalone_call (code , match )
@@ -182,7 +233,9 @@ def transform(self, code: str) -> str:
182233
183234 # Generate the transformed code
184235 self .invocation_counter += 1
185- transformed = self ._generate_transformed_call (standalone_match , is_bracket_notation )
236+ transformed = self ._generate_transformed_call (
237+ standalone_match , is_bracket_notation = (match_type == "bracket" ), is_dot_call = (match_type == "call" )
238+ )
186239 result .append (transformed )
187240 pos = standalone_match .end_pos
188241
@@ -394,12 +447,72 @@ def _parse_bracket_standalone_call(self, code: str, match: re.Match) -> Standalo
394447 has_trailing_semicolon = has_trailing_semicolon ,
395448 )
396449
397- def _generate_transformed_call (self , match : StandaloneCallMatch , is_bracket_notation : bool = False ) -> str :
450+ def _parse_dot_call_standalone (self , code : str , match : re .Match ) -> StandaloneCallMatch | None :
451+ """Parse a funcName.call(thisArg, args) or obj.funcName.call(thisArg, args) call."""
452+ leading_ws = match .group (1 )
453+ prefix = match .group (2 ) or "" # "await " or ""
454+ object_prefix = match .group (3 ) or "" # "obj." or ""
455+
456+ # Find the opening paren position
457+ match_text = match .group (0 )
458+ paren_offset = match_text .rfind ("(" )
459+ open_paren_pos = match .start () + paren_offset
460+
461+ # Find all arguments inside .call(...)
462+ all_args , close_pos = self ._find_balanced_parens (code , open_paren_pos )
463+ if all_args is None :
464+ return None
465+
466+ # Split into thisArg and remaining args
467+ this_arg , remaining_args = split_call_args (all_args )
468+ if not this_arg :
469+ return None # .call() with no arguments is invalid
470+
471+ # Check for trailing semicolon
472+ end_pos = close_pos
473+ while end_pos < len (code ) and code [end_pos ] in " \t " :
474+ end_pos += 1
475+ has_trailing_semicolon = end_pos < len (code ) and code [end_pos ] == ";"
476+ if has_trailing_semicolon :
477+ end_pos += 1
478+
479+ return StandaloneCallMatch (
480+ start_pos = match .start (),
481+ end_pos = end_pos ,
482+ leading_whitespace = leading_ws ,
483+ func_args = remaining_args ,
484+ prefix = prefix ,
485+ object_prefix = object_prefix ,
486+ has_trailing_semicolon = has_trailing_semicolon ,
487+ this_arg = this_arg ,
488+ )
489+
490+ def _generate_transformed_call (
491+ self , match : StandaloneCallMatch , is_bracket_notation : bool = False , is_dot_call : bool = False
492+ ) -> str :
398493 """Generate the transformed code for a standalone call."""
399494 line_id = str (self .invocation_counter )
400495 args_str = match .func_args .strip ()
401496 semicolon = ";" if match .has_trailing_semicolon else ""
402497
498+ # Handle .call() pattern: funcName.call(thisArg, args) -> codeflash.capture(..., funcName.bind(thisArg), args)
499+ if is_dot_call :
500+ if match .object_prefix :
501+ obj = match .object_prefix .rstrip ("." )
502+ func_ref = f"{ obj } .{ self .func_name } "
503+ else :
504+ func_ref = self .func_name
505+ bind_expr = f"{ func_ref } .bind({ match .this_arg } )"
506+ if args_str :
507+ return (
508+ f"{ match .leading_whitespace } { match .prefix } codeflash.{ self .capture_func } ('{ self .qualified_name } ', "
509+ f"'{ line_id } ', { bind_expr } , { args_str } ){ semicolon } "
510+ )
511+ return (
512+ f"{ match .leading_whitespace } { match .prefix } codeflash.{ self .capture_func } ('{ self .qualified_name } ', "
513+ f"'{ line_id } ', { bind_expr } ){ semicolon } "
514+ )
515+
403516 # Handle method calls on objects (e.g., calc.fibonacci, this.method, instance['method'])
404517 if match .object_prefix :
405518 # Remove trailing dot from object prefix for the bind call
@@ -481,14 +594,35 @@ def __init__(
481594 # Pattern to match start of expect((object.)*func_name(
482595 # Captures: (whitespace), (object prefix like calc. or this.)
483596 self ._expect_pattern = re .compile (rf"(\s*)expect\s*\(\s*((?:\w+\.)*){ re .escape (self .func_name )} \s*\(" )
597+ # Pattern to match expect((object.)*func_name.call(
598+ self ._expect_dot_call_pattern = re .compile (
599+ rf"(\s*)expect\s*\(\s*((?:\w+\.)*){ re .escape (self .func_name )} \.call\s*\("
600+ )
484601
485602 def transform (self , code : str ) -> str :
486603 """Transform all expect calls in the code."""
487604 result : list [str ] = []
488605 pos = 0
489606
490607 while pos < len (code ):
491- match = self ._expect_pattern .search (code , pos )
608+ expect_match = self ._expect_pattern .search (code , pos )
609+ call_match = self ._expect_dot_call_pattern .search (code , pos )
610+
611+ # Pick the earliest match
612+ match = None
613+ is_dot_call = False
614+ if expect_match and call_match :
615+ if call_match .start () <= expect_match .start ():
616+ match = call_match
617+ is_dot_call = True
618+ else :
619+ match = expect_match
620+ elif expect_match :
621+ match = expect_match
622+ elif call_match :
623+ match = call_match
624+ is_dot_call = True
625+
492626 if not match :
493627 result .append (code [pos :])
494628 break
@@ -503,18 +637,21 @@ def transform(self, code: str) -> str:
503637 result .append (code [pos : match .start ()])
504638
505639 # Try to parse the full expect call
506- expect_match = self ._parse_expect_call (code , match )
507- if expect_match is None :
640+ if is_dot_call :
641+ parsed_match = self ._parse_expect_dot_call (code , match )
642+ else :
643+ parsed_match = self ._parse_expect_call (code , match )
644+ if parsed_match is None :
508645 # Couldn't parse, skip this match
509646 result .append (code [match .start () : match .end ()])
510647 pos = match .end ()
511648 continue
512649
513650 # Generate the transformed code
514651 self .invocation_counter += 1
515- transformed = self ._generate_transformed_call (expect_match )
652+ transformed = self ._generate_transformed_call (parsed_match )
516653 result .append (transformed )
517- pos = expect_match .end_pos
654+ pos = parsed_match .end_pos
518655
519656 return "" .join (result )
520657
@@ -567,6 +704,53 @@ def _parse_expect_call(self, code: str, match: re.Match) -> ExpectCallMatch | No
567704 object_prefix = object_prefix ,
568705 )
569706
707+ def _parse_expect_dot_call (self , code : str , match : re .Match ) -> ExpectCallMatch | None :
708+ """Parse expect(funcName.call(thisArg, args)).assertion()."""
709+ leading_ws = match .group (1 )
710+ object_prefix = match .group (2 ) or ""
711+
712+ if "." not in self .qualified_name and object_prefix :
713+ return None
714+
715+ # Find arguments inside .call(...)
716+ args_start = match .end ()
717+ all_args , call_close_pos = self ._find_balanced_parens (code , args_start - 1 )
718+ if all_args is None :
719+ return None
720+
721+ # Split thisArg from remaining args
722+ this_arg , remaining_args = split_call_args (all_args )
723+ if not this_arg :
724+ return None
725+
726+ # Find closing ) of expect(
727+ expect_close_pos = call_close_pos
728+ while expect_close_pos < len (code ) and code [expect_close_pos ].isspace ():
729+ expect_close_pos += 1
730+ if expect_close_pos >= len (code ) or code [expect_close_pos ] != ")" :
731+ return None
732+ expect_close_pos += 1
733+
734+ # Parse assertion chain
735+ assertion_chain , chain_end_pos = self ._parse_assertion_chain (code , expect_close_pos )
736+ if assertion_chain is None :
737+ return None
738+
739+ has_trailing_semicolon = chain_end_pos < len (code ) and code [chain_end_pos ] == ";"
740+ if has_trailing_semicolon :
741+ chain_end_pos += 1
742+
743+ return ExpectCallMatch (
744+ start_pos = match .start (),
745+ end_pos = chain_end_pos ,
746+ leading_whitespace = leading_ws ,
747+ func_args = remaining_args ,
748+ assertion_chain = assertion_chain ,
749+ has_trailing_semicolon = has_trailing_semicolon ,
750+ object_prefix = object_prefix ,
751+ this_arg = this_arg ,
752+ )
753+
570754 def _find_balanced_parens (self , code : str , open_paren_pos : int ) -> tuple [str | None , int ]:
571755 """Find content within balanced parentheses.
572756
@@ -698,7 +882,14 @@ def _generate_transformed_call(self, match: ExpectCallMatch) -> str:
698882 args_str = match .func_args .strip ()
699883
700884 # Determine the function reference to use
701- if match .object_prefix :
885+ if match .this_arg :
886+ # .call() pattern: funcName.call(thisArg, ...) -> funcName.bind(thisArg)
887+ if match .object_prefix :
888+ obj = match .object_prefix .rstrip ("." )
889+ func_ref = f"{ obj } .{ self .func_name } .bind({ match .this_arg } )"
890+ else :
891+ func_ref = f"{ self .func_name } .bind({ match .this_arg } )"
892+ elif match .object_prefix :
702893 # Method call on object: calc.fibonacci -> calc.fibonacci.bind(calc)
703894 obj = match .object_prefix .rstrip ("." )
704895 func_ref = f"{ obj } .{ self .func_name } .bind({ obj } )"
@@ -831,6 +1022,11 @@ def _is_function_used_in_test(code: str, func_name: str) -> bool:
8311022 if re .search (default_import , code ):
8321023 return True
8331024
1025+ # Check for .call() pattern: funcName.call( or obj.funcName.call(
1026+ dot_call_pattern = rf"(?:\w+\.)*{ re .escape (func_name )} \.call\s*\("
1027+ if re .search (dot_call_pattern , code ):
1028+ return True
1029+
8341030 # Check for method calls: obj.funcName( or this.funcName(
8351031 # This handles class methods called on instances
8361032 method_call_pattern = rf"\w+\.{ re .escape (func_name )} \s*\("
0 commit comments