Skip to content

Commit 3aa1c0e

Browse files
feat: add .call() pattern support to JS test instrumentation
Instrument funcName.call(thisArg, args) patterns in both standalone and expect-wrapped contexts, transforming them to codeflash.capture() with func.bind(thisArg). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8cb7209 commit 3aa1c0e

2 files changed

Lines changed: 605 additions & 28 deletions

File tree

codeflash/languages/javascript/instrument.py

Lines changed: 223 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class ExpectCallMatch:
3636
assertion_chain: str
3737
has_trailing_semicolon: bool
3838
object_prefix: str = "" # Object prefix like "calc." or "this." or ""
39+
this_arg: str = "" # For .call() patterns: the thisArg value
3940

4041

4142
@dataclass
@@ -49,6 +50,7 @@ class StandaloneCallMatch:
4950
prefix: str # "await " or ""
5051
object_prefix: str # Object prefix like "calc." or "this." or ""
5152
has_trailing_semicolon: bool
53+
this_arg: str = "" # For .call() patterns: the thisArg value
5254

5355

5456
codeflash_import_pattern = re.compile(
@@ -96,6 +98,52 @@ def is_inside_string(code: str, pos: int) -> bool:
9698
return in_string
9799

98100

101+
def split_call_args(args_str: str) -> tuple[str, str]:
102+
"""Split .call() arguments into (thisArg, remaining_args).
103+
104+
The first argument to .call() is thisArg. Remaining arguments are the
105+
actual function arguments. Handles nested parens, brackets, braces, and
106+
string literals when finding the first top-level comma.
107+
108+
Returns:
109+
Tuple of (thisArg, remaining_args_str). remaining_args_str may be empty.
110+
111+
"""
112+
args_str = args_str.strip()
113+
if not args_str:
114+
return "", ""
115+
116+
depth = 0
117+
in_string = False
118+
string_char = None
119+
s = args_str
120+
s_len = len(s)
121+
122+
for i in range(s_len):
123+
char = s[i]
124+
125+
if char in "\"'`" and (i == 0 or s[i - 1] != "\\"):
126+
if not in_string:
127+
in_string = True
128+
string_char = char
129+
elif char == string_char:
130+
in_string = False
131+
string_char = None
132+
continue
133+
134+
if in_string:
135+
continue
136+
137+
if char in "([{":
138+
depth += 1
139+
elif char in ")]}":
140+
depth -= 1
141+
elif char == "," and depth == 0:
142+
return s[:i].strip(), s[i + 1 :].strip()
143+
144+
return s.strip(), ""
145+
146+
99147
class StandaloneCallTransformer:
100148
"""Transforms standalone func(...) calls in JavaScript test code.
101149
@@ -127,36 +175,37 @@ def __init__(self, function_to_optimize: FunctionToOptimize, capture_func: str)
127175
self._bracket_call_pattern = re.compile(
128176
rf"(\s*)(await\s+)?(\w+)\[['\"]({re.escape(self.func_name)})['\"]]\s*\("
129177
)
178+
# Pattern to match .call() invocation: func_name.call( or obj.func_name.call(
179+
# Captures: (whitespace)(await )?(object.)*func_name.call(
180+
self._dot_call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(self.func_name)}\.call\s*\(")
130181

131182
def transform(self, code: str) -> str:
132183
"""Transform all standalone calls in the code."""
133184
result: list[str] = []
134185
pos = 0
135186

136187
while pos < len(code):
137-
# Try both dot notation and bracket notation patterns
188+
# Try all patterns: dot notation, bracket notation, and .call() notation
138189
dot_match = self._call_pattern.search(code, pos)
139190
bracket_match = self._bracket_call_pattern.search(code, pos)
140-
141-
# Choose the first match (by position)
142-
match = None
143-
is_bracket_notation = False
144-
if dot_match and bracket_match:
145-
if dot_match.start() <= bracket_match.start():
146-
match = dot_match
147-
else:
148-
match = bracket_match
149-
is_bracket_notation = True
150-
elif dot_match:
151-
match = dot_match
152-
elif bracket_match:
153-
match = bracket_match
154-
is_bracket_notation = True
155-
156-
if not match:
191+
call_match = self._dot_call_pattern.search(code, pos)
192+
193+
# Choose the earliest match by position
194+
candidates: list[tuple[str, re.Match]] = []
195+
if dot_match:
196+
candidates.append(("dot", dot_match))
197+
if bracket_match:
198+
candidates.append(("bracket", bracket_match))
199+
if call_match:
200+
candidates.append(("call", call_match))
201+
202+
if not candidates:
157203
result.append(code[pos:])
158204
break
159205

206+
candidates.sort(key=lambda x: x[1].start())
207+
match_type, match = candidates[0]
208+
160209
match_start = match.start()
161210

162211
# Check if this call is inside an expect() or already transformed
@@ -169,7 +218,9 @@ def transform(self, code: str) -> str:
169218
result.append(code[pos:match_start])
170219

171220
# Try to parse the full standalone call
172-
if is_bracket_notation:
221+
if match_type == "call":
222+
standalone_match = self._parse_dot_call_standalone(code, match)
223+
elif match_type == "bracket":
173224
standalone_match = self._parse_bracket_standalone_call(code, match)
174225
else:
175226
standalone_match = self._parse_standalone_call(code, match)
@@ -182,7 +233,9 @@ def transform(self, code: str) -> str:
182233

183234
# Generate the transformed code
184235
self.invocation_counter += 1
185-
transformed = self._generate_transformed_call(standalone_match, is_bracket_notation)
236+
transformed = self._generate_transformed_call(
237+
standalone_match, is_bracket_notation=(match_type == "bracket"), is_dot_call=(match_type == "call")
238+
)
186239
result.append(transformed)
187240
pos = standalone_match.end_pos
188241

@@ -394,12 +447,72 @@ def _parse_bracket_standalone_call(self, code: str, match: re.Match) -> Standalo
394447
has_trailing_semicolon=has_trailing_semicolon,
395448
)
396449

397-
def _generate_transformed_call(self, match: StandaloneCallMatch, is_bracket_notation: bool = False) -> str:
450+
def _parse_dot_call_standalone(self, code: str, match: re.Match) -> StandaloneCallMatch | None:
451+
"""Parse a funcName.call(thisArg, args) or obj.funcName.call(thisArg, args) call."""
452+
leading_ws = match.group(1)
453+
prefix = match.group(2) or "" # "await " or ""
454+
object_prefix = match.group(3) or "" # "obj." or ""
455+
456+
# Find the opening paren position
457+
match_text = match.group(0)
458+
paren_offset = match_text.rfind("(")
459+
open_paren_pos = match.start() + paren_offset
460+
461+
# Find all arguments inside .call(...)
462+
all_args, close_pos = self._find_balanced_parens(code, open_paren_pos)
463+
if all_args is None:
464+
return None
465+
466+
# Split into thisArg and remaining args
467+
this_arg, remaining_args = split_call_args(all_args)
468+
if not this_arg:
469+
return None # .call() with no arguments is invalid
470+
471+
# Check for trailing semicolon
472+
end_pos = close_pos
473+
while end_pos < len(code) and code[end_pos] in " \t":
474+
end_pos += 1
475+
has_trailing_semicolon = end_pos < len(code) and code[end_pos] == ";"
476+
if has_trailing_semicolon:
477+
end_pos += 1
478+
479+
return StandaloneCallMatch(
480+
start_pos=match.start(),
481+
end_pos=end_pos,
482+
leading_whitespace=leading_ws,
483+
func_args=remaining_args,
484+
prefix=prefix,
485+
object_prefix=object_prefix,
486+
has_trailing_semicolon=has_trailing_semicolon,
487+
this_arg=this_arg,
488+
)
489+
490+
def _generate_transformed_call(
491+
self, match: StandaloneCallMatch, is_bracket_notation: bool = False, is_dot_call: bool = False
492+
) -> str:
398493
"""Generate the transformed code for a standalone call."""
399494
line_id = str(self.invocation_counter)
400495
args_str = match.func_args.strip()
401496
semicolon = ";" if match.has_trailing_semicolon else ""
402497

498+
# Handle .call() pattern: funcName.call(thisArg, args) -> codeflash.capture(..., funcName.bind(thisArg), args)
499+
if is_dot_call:
500+
if match.object_prefix:
501+
obj = match.object_prefix.rstrip(".")
502+
func_ref = f"{obj}.{self.func_name}"
503+
else:
504+
func_ref = self.func_name
505+
bind_expr = f"{func_ref}.bind({match.this_arg})"
506+
if args_str:
507+
return (
508+
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
509+
f"'{line_id}', {bind_expr}, {args_str}){semicolon}"
510+
)
511+
return (
512+
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
513+
f"'{line_id}', {bind_expr}){semicolon}"
514+
)
515+
403516
# Handle method calls on objects (e.g., calc.fibonacci, this.method, instance['method'])
404517
if match.object_prefix:
405518
# Remove trailing dot from object prefix for the bind call
@@ -481,14 +594,35 @@ def __init__(
481594
# Pattern to match start of expect((object.)*func_name(
482595
# Captures: (whitespace), (object prefix like calc. or this.)
483596
self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(self.func_name)}\s*\(")
597+
# Pattern to match expect((object.)*func_name.call(
598+
self._expect_dot_call_pattern = re.compile(
599+
rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(self.func_name)}\.call\s*\("
600+
)
484601

485602
def transform(self, code: str) -> str:
486603
"""Transform all expect calls in the code."""
487604
result: list[str] = []
488605
pos = 0
489606

490607
while pos < len(code):
491-
match = self._expect_pattern.search(code, pos)
608+
expect_match = self._expect_pattern.search(code, pos)
609+
call_match = self._expect_dot_call_pattern.search(code, pos)
610+
611+
# Pick the earliest match
612+
match = None
613+
is_dot_call = False
614+
if expect_match and call_match:
615+
if call_match.start() <= expect_match.start():
616+
match = call_match
617+
is_dot_call = True
618+
else:
619+
match = expect_match
620+
elif expect_match:
621+
match = expect_match
622+
elif call_match:
623+
match = call_match
624+
is_dot_call = True
625+
492626
if not match:
493627
result.append(code[pos:])
494628
break
@@ -503,18 +637,21 @@ def transform(self, code: str) -> str:
503637
result.append(code[pos : match.start()])
504638

505639
# Try to parse the full expect call
506-
expect_match = self._parse_expect_call(code, match)
507-
if expect_match is None:
640+
if is_dot_call:
641+
parsed_match = self._parse_expect_dot_call(code, match)
642+
else:
643+
parsed_match = self._parse_expect_call(code, match)
644+
if parsed_match is None:
508645
# Couldn't parse, skip this match
509646
result.append(code[match.start() : match.end()])
510647
pos = match.end()
511648
continue
512649

513650
# Generate the transformed code
514651
self.invocation_counter += 1
515-
transformed = self._generate_transformed_call(expect_match)
652+
transformed = self._generate_transformed_call(parsed_match)
516653
result.append(transformed)
517-
pos = expect_match.end_pos
654+
pos = parsed_match.end_pos
518655

519656
return "".join(result)
520657

@@ -567,6 +704,53 @@ def _parse_expect_call(self, code: str, match: re.Match) -> ExpectCallMatch | No
567704
object_prefix=object_prefix,
568705
)
569706

707+
def _parse_expect_dot_call(self, code: str, match: re.Match) -> ExpectCallMatch | None:
708+
"""Parse expect(funcName.call(thisArg, args)).assertion()."""
709+
leading_ws = match.group(1)
710+
object_prefix = match.group(2) or ""
711+
712+
if "." not in self.qualified_name and object_prefix:
713+
return None
714+
715+
# Find arguments inside .call(...)
716+
args_start = match.end()
717+
all_args, call_close_pos = self._find_balanced_parens(code, args_start - 1)
718+
if all_args is None:
719+
return None
720+
721+
# Split thisArg from remaining args
722+
this_arg, remaining_args = split_call_args(all_args)
723+
if not this_arg:
724+
return None
725+
726+
# Find closing ) of expect(
727+
expect_close_pos = call_close_pos
728+
while expect_close_pos < len(code) and code[expect_close_pos].isspace():
729+
expect_close_pos += 1
730+
if expect_close_pos >= len(code) or code[expect_close_pos] != ")":
731+
return None
732+
expect_close_pos += 1
733+
734+
# Parse assertion chain
735+
assertion_chain, chain_end_pos = self._parse_assertion_chain(code, expect_close_pos)
736+
if assertion_chain is None:
737+
return None
738+
739+
has_trailing_semicolon = chain_end_pos < len(code) and code[chain_end_pos] == ";"
740+
if has_trailing_semicolon:
741+
chain_end_pos += 1
742+
743+
return ExpectCallMatch(
744+
start_pos=match.start(),
745+
end_pos=chain_end_pos,
746+
leading_whitespace=leading_ws,
747+
func_args=remaining_args,
748+
assertion_chain=assertion_chain,
749+
has_trailing_semicolon=has_trailing_semicolon,
750+
object_prefix=object_prefix,
751+
this_arg=this_arg,
752+
)
753+
570754
def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]:
571755
"""Find content within balanced parentheses.
572756
@@ -698,7 +882,14 @@ def _generate_transformed_call(self, match: ExpectCallMatch) -> str:
698882
args_str = match.func_args.strip()
699883

700884
# Determine the function reference to use
701-
if match.object_prefix:
885+
if match.this_arg:
886+
# .call() pattern: funcName.call(thisArg, ...) -> funcName.bind(thisArg)
887+
if match.object_prefix:
888+
obj = match.object_prefix.rstrip(".")
889+
func_ref = f"{obj}.{self.func_name}.bind({match.this_arg})"
890+
else:
891+
func_ref = f"{self.func_name}.bind({match.this_arg})"
892+
elif match.object_prefix:
702893
# Method call on object: calc.fibonacci -> calc.fibonacci.bind(calc)
703894
obj = match.object_prefix.rstrip(".")
704895
func_ref = f"{obj}.{self.func_name}.bind({obj})"
@@ -831,6 +1022,11 @@ def _is_function_used_in_test(code: str, func_name: str) -> bool:
8311022
if re.search(default_import, code):
8321023
return True
8331024

1025+
# Check for .call() pattern: funcName.call( or obj.funcName.call(
1026+
dot_call_pattern = rf"(?:\w+\.)*{re.escape(func_name)}\.call\s*\("
1027+
if re.search(dot_call_pattern, code):
1028+
return True
1029+
8341030
# Check for method calls: obj.funcName( or this.funcName(
8351031
# This handles class methods called on instances
8361032
method_call_pattern = rf"\w+\.{re.escape(func_name)}\s*\("

0 commit comments

Comments
 (0)