Skip to content

Commit 7cf7bcd

Browse files
authored
Merge pull request #1472 from codeflash-ai/fix/java-jpms-and-js-import-guard
fix(java): fix Java instrumentation, JPMS, and timeout bugs for QuestDB
2 parents 292edae + 5751d3b commit 7cf7bcd

5 files changed

Lines changed: 185 additions & 65 deletions

File tree

codeflash/languages/java/instrumentation.py

Lines changed: 138 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import logging
1818
import re
19-
from functools import lru_cache
2019
from typing import TYPE_CHECKING
2120

2221
if TYPE_CHECKING:
@@ -43,6 +42,102 @@ def _get_function_name(func: Any) -> str:
4342
# Pattern to detect primitive array types in assertions
4443
_PRIMITIVE_ARRAY_PATTERN = re.compile(r"new\s+(int|long|double|float|short|byte|char|boolean)\s*\[\s*\]")
4544

45+
# Pattern to match @Test annotation exactly (not @TestOnly, @TestFactory, etc.)
46+
_TEST_ANNOTATION_RE = re.compile(r"^@Test(?:\s*\(.*\))?(?:\s.*)?$")
47+
48+
49+
def _is_test_annotation(stripped_line: str) -> bool:
50+
"""Check if a stripped line is an @Test annotation (not @TestOnly, @TestFactory, etc.).
51+
52+
Matches:
53+
@Test
54+
@Test(expected = ...)
55+
@Test(timeout = 5000)
56+
Does NOT match:
57+
@TestOnly
58+
@TestFactory
59+
@TestTemplate
60+
"""
61+
return bool(_TEST_ANNOTATION_RE.match(stripped_line))
62+
63+
64+
def _find_balanced_end(text: str, start: int) -> int:
65+
"""Find the position after the closing paren that balances the opening paren at start.
66+
67+
Args:
68+
text: The source text.
69+
start: Index of the opening parenthesis '('.
70+
71+
Returns:
72+
Index one past the matching closing ')', or -1 if not found.
73+
74+
"""
75+
if start >= len(text) or text[start] != "(":
76+
return -1
77+
depth = 1
78+
pos = start + 1
79+
in_string = False
80+
string_char = None
81+
in_char = False
82+
while pos < len(text) and depth > 0:
83+
ch = text[pos]
84+
prev = text[pos - 1] if pos > 0 else ""
85+
if ch == "'" and not in_string and prev != "\\":
86+
in_char = not in_char
87+
elif ch == '"' and not in_char and prev != "\\":
88+
if not in_string:
89+
in_string = True
90+
string_char = ch
91+
elif ch == string_char:
92+
in_string = False
93+
string_char = None
94+
elif not in_string and not in_char:
95+
if ch == "(":
96+
depth += 1
97+
elif ch == ")":
98+
depth -= 1
99+
pos += 1
100+
return pos if depth == 0 else -1
101+
102+
103+
def _find_method_calls_balanced(line: str, func_name: str):
104+
"""Find method calls to func_name with properly balanced parentheses.
105+
106+
Handles nested parentheses in arguments correctly, unlike a pure regex approach.
107+
Returns a list of (start, end, full_call) tuples where start/end are positions
108+
in the line and full_call is the matched text (receiver.funcName(args)).
109+
110+
Args:
111+
line: A single line of Java source code.
112+
func_name: The method name to look for.
113+
114+
Returns:
115+
List of (start_pos, end_pos, full_call_text) tuples.
116+
117+
"""
118+
# First find all occurrences of .funcName( in the line using regex
119+
# to locate the method name, then use balanced paren finding for args
120+
prefix_pattern = re.compile(
121+
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*{re.escape(func_name)}\s*\("
122+
)
123+
results = []
124+
search_start = 0
125+
while search_start < len(line):
126+
m = prefix_pattern.search(line, search_start)
127+
if not m:
128+
break
129+
# m.end() - 1 is the position of the opening paren
130+
open_paren_pos = m.end() - 1
131+
close_pos = _find_balanced_end(line, open_paren_pos)
132+
if close_pos == -1:
133+
# Unbalanced parens - skip this match
134+
search_start = m.end()
135+
continue
136+
full_call = line[m.start():close_pos]
137+
results.append((m.start(), close_pos, full_call))
138+
search_start = close_pos
139+
return results
140+
46141

47142
def _infer_array_cast_type(line: str) -> str | None:
48143
"""Infer the array cast type needed for assertion methods.
@@ -182,11 +277,13 @@ def instrument_existing_test(
182277
else:
183278
new_class_name = f"{original_class_name}__perfonlyinstrumented"
184279

185-
# Rename the class declaration in the source
186-
# Pattern: "public class ClassName" or "class ClassName"
187-
pattern = rf"\b(public\s+)?class\s+{re.escape(original_class_name)}\b"
188-
replacement = rf"\1class {new_class_name}"
189-
modified_source = re.sub(pattern, replacement, source)
280+
# Rename all references to the original class name in the source.
281+
# This includes the class declaration, return types, constructor calls,
282+
# variable declarations, etc. We use word-boundary matching to avoid
283+
# replacing substrings of other identifiers.
284+
modified_source = re.sub(
285+
rf"\b{re.escape(original_class_name)}\b", new_class_name, source
286+
)
190287

191288
# Add timing instrumentation to test methods
192289
# Use original class name (without suffix) in timing markers for consistency with Python
@@ -277,15 +374,12 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
277374
iteration_counter = 0
278375
helper_added = False
279376

280-
# Pre-compile the regex pattern once
281-
method_call_pattern = _get_method_call_pattern(func_name)
282-
283377
while i < len(lines):
284378
line = lines[i]
285379
stripped = line.strip()
286380

287-
# Look for @Test annotation
288-
if stripped.startswith("@Test"):
381+
# Look for @Test annotation (not @TestOnly, @TestFactory, etc.)
382+
if _is_test_annotation(stripped):
289383
if not helper_added:
290384
helper_added = True
291385
result.append(line)
@@ -342,27 +436,20 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
342436
call_counter = 0
343437
wrapped_body_lines = []
344438

345-
# Use regex to find method calls with the target function
346-
# Pattern matches: receiver.funcName(args) where receiver can be:
347-
# - identifier (counter, calc, etc.)
348-
# - new ClassName()
349-
# - new ClassName(args)
350-
# - this
351-
method_call_pattern = re.compile(
352-
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
353-
)
354-
355439
# Track lambda block nesting depth to avoid wrapping calls inside lambda bodies.
356440
# assertThrows/assertDoesNotThrow expect an Executable (void functional interface),
357441
# and wrapping the call in a variable assignment would turn the void-compatible
358442
# lambda into a value-returning lambda, causing a compilation error.
359-
# Handles both expression lambdas: () -> func()
360-
# and block lambdas: () -> { func(); }
443+
# Also, variables declared outside lambdas cannot be reassigned inside them
444+
# (Java requires effectively final variables in lambda captures).
445+
# Handles both no-arg lambdas: () -> { func(); }
446+
# and parameterized lambdas: (a, b, c) -> { func(); }
361447
lambda_brace_depth = 0
362448

363449
for body_line in body_lines:
364-
# Detect new block lambda openings: () -> {
365-
is_lambda_open = bool(re.search(r"\(\s*\)\s*->\s*\{", body_line))
450+
# Detect block lambda openings: (...) -> { or () -> {
451+
# Matches both () -> { and (a, b, c) -> {
452+
is_lambda_open = bool(re.search(r"->\s*\{", body_line))
366453

367454
# Update lambda brace depth tracking for block lambdas
368455
if is_lambda_open or lambda_brace_depth > 0:
@@ -376,7 +463,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
376463
# Ensure depth doesn't go below 0
377464
lambda_brace_depth = max(0, lambda_brace_depth)
378465

379-
inside_lambda = lambda_brace_depth > 0 or bool(re.search(r"\(\s*\)\s*->", body_line))
466+
inside_lambda = lambda_brace_depth > 0 or bool(re.search(r"->\s+\S", body_line))
380467

381468
# Check if this line contains a call to the target function
382469
if func_name in body_line and "(" in body_line:
@@ -388,30 +475,41 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
388475
line_indent = len(body_line) - len(body_line.lstrip())
389476
line_indent_str = " " * line_indent
390477

391-
# Find all matches in the line
392-
matches = list(method_call_pattern.finditer(body_line))
478+
# Find all matches using balanced parenthesis matching
479+
# This correctly handles nested parens like:
480+
# obj.func(a, Rows.toRowID(frame.getIndex(), row))
481+
matches = _find_method_calls_balanced(body_line, func_name)
393482
if matches:
394483
# Process matches in reverse order to maintain correct positions
395484
new_line = body_line
396-
for match in reversed(matches):
485+
for start_pos, end_pos, full_call in reversed(matches):
397486
call_counter += 1
398487
var_name = f"_cf_result{iter_id}_{call_counter}"
399-
full_call = match.group(0) # e.g., "new StringUtils().reverse(\"hello\")"
400488

401489
# Check if we need to cast the result for assertions with primitive arrays
402490
# This handles assertArrayEquals(int[], int[]) etc.
403491
cast_type = _infer_array_cast_type(body_line)
404492
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
405493

406494
# Replace this occurrence with the variable (with cast if needed)
407-
new_line = new_line[: match.start()] + var_with_cast + new_line[match.end() :]
495+
new_line = new_line[:start_pos] + var_with_cast + new_line[end_pos:]
408496

409497
# Use 'var' instead of 'Object' to preserve the exact return type.
410498
# This avoids boxing mismatches (e.g., assertEquals(int, Object) where
411499
# Object is boxed Long but expected is boxed Integer). Requires Java 10+.
412500
capture_line = f"{line_indent_str}var {var_name} = {full_call};"
413501
wrapped_body_lines.append(capture_line)
414502

503+
# Immediately serialize the captured result while the variable
504+
# is still in scope. This is necessary because the variable may
505+
# be declared inside a nested block (while/for/if/try) and would
506+
# be out of scope at the end of the method body.
507+
serialize_line = (
508+
f"{line_indent_str}_cf_serializedResult{iter_id} = "
509+
f"com.codeflash.Serializer.serialize((Object) {var_name});"
510+
)
511+
wrapped_body_lines.append(serialize_line)
512+
415513
# Check if the line is now just a variable reference (invalid statement)
416514
# This happens when the original line was just a void method call
417515
# e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
@@ -423,15 +521,6 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
423521
else:
424522
wrapped_body_lines.append(body_line)
425523

426-
# Build the serialized return value expression
427-
# If we captured any calls, serialize the last one via Kryo; otherwise null bytes
428-
# The (Object) cast ensures primitives get autoboxed before being passed to the method.
429-
if call_counter > 0:
430-
result_var = f"_cf_result{iter_id}_{call_counter}"
431-
serialize_expr = f"com.codeflash.Serializer.serialize((Object) {result_var})"
432-
else:
433-
serialize_expr = "null"
434-
435524
# Add behavior instrumentation code
436525
behavior_start_code = [
437526
f"{indent}// Codeflash behavior instrumentation",
@@ -450,13 +539,13 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
450539
]
451540
result.extend(behavior_start_code)
452541

453-
# Add the wrapped body lines with extra indentation
542+
# Add the wrapped body lines with extra indentation.
543+
# Serialization of captured results is already done inline (immediately
544+
# after each capture) so the _cf_serializedResult variable is always
545+
# assigned while the captured variable is still in scope.
454546
for bl in wrapped_body_lines:
455547
result.append(" " + bl)
456548

457-
# Add serialization after the body (before finally)
458-
result.append(f"{indent} _cf_serializedResult{iter_id} = {serialize_expr};")
459-
460549
# Add finally block with SQLite write
461550
method_close_indent = " " * base_indent
462551
behavior_end_code = [
@@ -543,8 +632,8 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
543632
line = lines[i]
544633
stripped = line.strip()
545634

546-
# Look for @Test annotation
547-
if stripped.startswith("@Test"):
635+
# Look for @Test annotation (not @TestOnly, @TestFactory, etc.)
636+
if _is_test_annotation(stripped):
548637
result.append(line)
549638
i += 1
550639

@@ -751,9 +840,10 @@ def instrument_generated_java_test(
751840
else:
752841
new_class_name = f"{original_class_name}__perfonlyinstrumented"
753842

754-
# Rename the class in the source
843+
# Rename all references to the original class name in the source.
844+
# This includes the class declaration, return types, constructor calls, etc.
755845
modified_code = re.sub(
756-
rf"\b(public\s+)?class\s+{re.escape(original_class_name)}\b", rf"\1class {new_class_name}", test_code
846+
rf"\b{re.escape(original_class_name)}\b", new_class_name, test_code
757847
)
758848

759849
# For performance mode, add timing instrumentation
@@ -798,9 +888,3 @@ def _add_import(source: str, import_statement: str) -> str:
798888
return "".join(lines)
799889

800890

801-
@lru_cache(maxsize=128)
802-
def _get_method_call_pattern(func_name: str):
803-
"""Cache compiled regex patterns for method call matching."""
804-
return re.compile(
805-
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
806-
)

codeflash/languages/java/line_profiler.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import json
1111
import logging
12+
import re
1213
from pathlib import Path
1314
from typing import TYPE_CHECKING
1415

@@ -89,17 +90,19 @@ def instrument_source(
8990
end_idx = func.ending_line
9091
lines = lines[:start_idx] + func_lines + lines[end_idx:]
9192

92-
instrumented_source = "".join(lines)
93-
9493
# Add profiler class and initialization
9594
profiler_class_code = self._generate_profiler_class()
9695

9796
# Insert profiler class before the package's first class
98-
# Find the first class declaration
97+
# Find the first class/interface/enum/record declaration
98+
# Must handle any combination of modifiers: public final class, abstract class, etc.
99+
class_pattern = re.compile(
100+
r"^(?:(?:public|private|protected|final|abstract|static|sealed|non-sealed)\s+)*"
101+
r"(?:class|interface|enum|record)\s+"
102+
)
99103
import_end_idx = 0
100104
for i, line in enumerate(lines):
101-
stripped = line.strip()
102-
if stripped.startswith("public class ") or stripped.startswith("class "):
105+
if class_pattern.match(line.strip()):
103106
import_end_idx = i
104107
break
105108

codeflash/languages/java/test_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def _run_benchmarking_tests_maven(
655655
loop_count = 0
656656
last_result = None
657657

658-
per_loop_timeout = timeout or max(120, 60 + inner_iterations)
658+
per_loop_timeout = max(timeout or 0, 120, 60 + inner_iterations)
659659

660660
logger.debug("Using Maven-based benchmarking (fallback mode)")
661661

0 commit comments

Comments
 (0)