Skip to content

Commit 056cec7

Browse files
Merge pull request #1663 from codeflash-ai/fix/java-maven-test-execution-bugs
fix: resolve Maven test execution blockers for open-source Java repos
2 parents 919551e + cda56d1 commit 056cec7

7 files changed

Lines changed: 409 additions & 149 deletions

File tree

codeflash/languages/java/instrumentation.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,32 +1325,3 @@ def instrument_generated_java_test(
13251325

13261326
logger.debug("Instrumented generated Java test for %s (mode=%s)", function_name, mode)
13271327
return modified_code
1328-
1329-
1330-
def _add_import(source: str, import_statement: str) -> str:
1331-
"""Add an import statement to the source.
1332-
1333-
Args:
1334-
source: The source code.
1335-
import_statement: The import to add.
1336-
1337-
Returns:
1338-
Source with import added.
1339-
1340-
"""
1341-
lines = source.splitlines(keepends=True)
1342-
insert_idx = 0
1343-
1344-
# Find the last import or package statement
1345-
for i, line in enumerate(lines):
1346-
stripped = line.strip()
1347-
if stripped.startswith(("import ", "package ")):
1348-
insert_idx = i + 1
1349-
elif stripped and not stripped.startswith("//") and not stripped.startswith("/*"):
1350-
# First non-import, non-comment line
1351-
if insert_idx == 0:
1352-
insert_idx = i
1353-
break
1354-
1355-
lines.insert(insert_idx, import_statement + "\n")
1356-
return "".join(lines)

codeflash/languages/java/remove_asserts.py

Lines changed: 226 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,14 @@ def __init__(
198198
# Precompile regex to find next special character (quotes, parens, braces).
199199
self._special_re = re.compile(r"[\"'{}()]")
200200

201+
# Precompile literal/cast regexes to avoid recompilation on each literal check.
202+
self._LONG_LITERAL_RE = re.compile(r"^-?\d+[lL]$")
203+
self._INT_LITERAL_RE = re.compile(r"^-?\d+$")
204+
self._DOUBLE_LITERAL_RE = re.compile(r"^-?\d+\.\d*[dD]?$|^-?\d+[dD]$")
205+
self._FLOAT_LITERAL_RE = re.compile(r"^-?\d+\.?\d*[fF]$")
206+
self._CHAR_LITERAL_RE = re.compile(r"^'.'$|^'\\.'$")
207+
self._cast_re = re.compile(r"^\((\w+)\)")
208+
201209
def transform(self, source: str) -> str:
202210
"""Remove assertions from source code, preserving target function calls.
203211
@@ -894,6 +902,143 @@ def _find_balanced_braces(self, code: str, open_brace_pos: int) -> tuple[str | N
894902

895903
return code[open_brace_pos + 1 : pos - 1], pos
896904

905+
def _infer_return_type(self, assertion: AssertionMatch) -> str:
906+
"""Infer the Java return type from the assertion context.
907+
908+
For assertEquals(expected, actual) patterns, the expected literal determines the type.
909+
For assertTrue/assertFalse, the result is boolean.
910+
Falls back to Object when the type cannot be determined.
911+
"""
912+
method = assertion.assertion_method
913+
914+
# assertTrue/assertFalse always deal with boolean values
915+
if method in {"assertTrue", "assertFalse"}:
916+
return "boolean"
917+
918+
# assertNull/assertNotNull — keep Object (reference type)
919+
if method in {"assertNull", "assertNotNull"}:
920+
return "Object"
921+
922+
# For assertEquals/assertNotEquals/assertSame, try to infer from the expected literal
923+
if method in JUNIT5_VALUE_ASSERTIONS:
924+
return self._infer_type_from_assertion_args(assertion.original_text, method)
925+
926+
# For fluent assertions (assertThat), type inference is harder — keep Object
927+
return "Object"
928+
929+
# Regex patterns for Java literal type inference
930+
_LONG_LITERAL_RE = re.compile(r"^-?\d+[lL]$")
931+
_INT_LITERAL_RE = re.compile(r"^-?\d+$")
932+
_DOUBLE_LITERAL_RE = re.compile(r"^-?\d+\.\d*[dD]?$|^-?\d+[dD]$")
933+
_FLOAT_LITERAL_RE = re.compile(r"^-?\d+\.?\d*[fF]$")
934+
_CHAR_LITERAL_RE = re.compile(r"^'.'$|^'\\.'$")
935+
936+
def _infer_type_from_assertion_args(self, original_text: str, method: str) -> str:
937+
"""Infer the return type from assertEquals/assertNotEquals expected value."""
938+
# Extract the args portion from the assertion text
939+
# Pattern: assertXxx( args... )
940+
paren_idx = original_text.find("(")
941+
if paren_idx < 0:
942+
return "Object"
943+
944+
args_str = original_text[paren_idx + 1 :]
945+
# Remove trailing ");", whitespace
946+
args_str = args_str.rstrip()
947+
if args_str.endswith(");"):
948+
args_str = args_str[:-2]
949+
elif args_str.endswith(")"):
950+
args_str = args_str[:-1]
951+
952+
# Fast-path: only extract the first top-level argument instead of splitting all arguments.
953+
first_arg = self._extract_first_arg(args_str)
954+
if not first_arg:
955+
return "Object"
956+
957+
expected = first_arg.strip()
958+
959+
# JUnit 4 has assertEquals(String message, expected, actual) where the first arg is a message.
960+
# If the first arg is a string literal, check if there are 3+ args — if so, the real expected
961+
# value is the second argument, not the message string.
962+
if expected.startswith('"') and method in ("assertEquals", "assertNotEquals"):
963+
all_args = self._split_top_level_args(args_str)
964+
if len(all_args) >= 3:
965+
expected = all_args[1].strip()
966+
967+
return self._type_from_literal(expected)
968+
969+
def _type_from_literal(self, value: str) -> str:
970+
"""Determine the Java type of a literal value."""
971+
if value in ("true", "false"):
972+
return "boolean"
973+
if value == "null":
974+
return "Object"
975+
if self._FLOAT_LITERAL_RE.match(value):
976+
return "float"
977+
if self._DOUBLE_LITERAL_RE.match(value):
978+
return "double"
979+
if self._LONG_LITERAL_RE.match(value):
980+
return "long"
981+
if self._INT_LITERAL_RE.match(value):
982+
return "int"
983+
if self._CHAR_LITERAL_RE.match(value):
984+
return "char"
985+
if value.startswith('"'):
986+
return "String"
987+
# Cast expression like (byte)0, (short)1
988+
cast_match = self._cast_re.match(value)
989+
if cast_match:
990+
return cast_match.group(1)
991+
return "Object"
992+
993+
def _split_top_level_args(self, args_str: str) -> list[str]:
994+
"""Split assertion arguments at top-level commas, respecting parens/strings/generics."""
995+
# Fast-path: if there are no special delimiters that require parsing,
996+
# we can use a simple split which is much faster for common simple cases.
997+
if not self._special_re.search(args_str):
998+
# Preserve original behavior of returning a list with the single unstripped string
999+
# when there are no commas, otherwise split on commas.
1000+
if "," in args_str:
1001+
return args_str.split(",")
1002+
return [args_str]
1003+
1004+
args: list[str] = []
1005+
depth = 0
1006+
current: list[str] = []
1007+
i = 0
1008+
in_string = False
1009+
string_char = ""
1010+
1011+
while i < len(args_str):
1012+
ch = args_str[i]
1013+
1014+
if in_string:
1015+
current.append(ch)
1016+
if ch == "\\" and i + 1 < len(args_str):
1017+
i += 1
1018+
current.append(args_str[i])
1019+
elif ch == string_char:
1020+
in_string = False
1021+
elif ch in ('"', "'"):
1022+
in_string = True
1023+
string_char = ch
1024+
current.append(ch)
1025+
elif ch in ("(", "<", "[", "{"):
1026+
depth += 1
1027+
current.append(ch)
1028+
elif ch in (")", ">", "]", "}"):
1029+
depth -= 1
1030+
current.append(ch)
1031+
elif ch == "," and depth == 0:
1032+
args.append("".join(current))
1033+
current = []
1034+
else:
1035+
current.append(ch)
1036+
i += 1
1037+
1038+
if current:
1039+
args.append("".join(current))
1040+
return args
1041+
8971042
def _generate_replacement(self, assertion: AssertionMatch) -> str:
8981043
"""Generate replacement code for an assertion.
8991044
@@ -912,18 +1057,34 @@ def _generate_replacement(self, assertion: AssertionMatch) -> str:
9121057
if not assertion.target_calls:
9131058
return ""
9141059

1060+
# Infer the return type from assertion context to avoid Object→primitive cast errors
1061+
return_type = self._infer_return_type(assertion)
1062+
9151063
# Generate capture statements for each target call
916-
replacements = []
1064+
replacements: list[str] = []
9171065
# For the first replacement, use the full leading whitespace
9181066
# For subsequent ones, strip leading newlines to avoid extra blank lines
919-
base_indent = assertion.leading_whitespace.lstrip("\n\r")
920-
for i, call in enumerate(assertion.target_calls):
921-
self.invocation_counter += 1
922-
var_name = f"_cf_result{self.invocation_counter}"
923-
if i == 0:
924-
replacements.append(f"{assertion.leading_whitespace}Object {var_name} = {call.full_call};")
925-
else:
926-
replacements.append(f"{base_indent}Object {var_name} = {call.full_call};")
1067+
leading_ws = assertion.leading_whitespace
1068+
base_indent = leading_ws.lstrip("\n\r")
1069+
1070+
# Use a local counter to minimize attribute write overhead in the loop.
1071+
inv = self.invocation_counter
1072+
1073+
calls = assertion.target_calls
1074+
# Handle first call explicitly to avoid a per-iteration branch
1075+
if calls:
1076+
inv += 1
1077+
var_name = "_cf_result" + str(inv)
1078+
replacements.append(f"{leading_ws}{return_type} {var_name} = {calls[0].full_call};")
1079+
1080+
# Handle remaining calls
1081+
for call in calls[1:]:
1082+
inv += 1
1083+
var_name = "_cf_result" + str(inv)
1084+
replacements.append(f"{base_indent}{return_type} {var_name} = {call.full_call};")
1085+
1086+
# Write back the counter
1087+
self.invocation_counter = inv
9271088

9281089
return "\n".join(replacements)
9291090

@@ -942,8 +1103,10 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
9421103
try { code(); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
9431104
9441105
"""
945-
self.invocation_counter += 1
946-
counter = self.invocation_counter
1106+
# Increment invocation counter once for this exception handling
1107+
inv = self.invocation_counter + 1
1108+
self.invocation_counter = inv
1109+
counter = inv
9471110
ws = assertion.leading_whitespace
9481111
base_indent = ws.lstrip("\n\r")
9491112

@@ -982,6 +1145,58 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
9821145
# Fallback: comment out the assertion
9831146
return f"{ws}// Removed assertThrows: could not extract callable"
9841147

1148+
def _extract_first_arg(self, args_str: str) -> str | None:
1149+
"""Extract the first top-level argument from args_str.
1150+
1151+
This is a lightweight alternative to splitting all top-level arguments;
1152+
it stops at the first top-level comma, respects nested delimiters and strings,
1153+
and avoids constructing the full argument list for better performance.
1154+
"""
1155+
n = len(args_str)
1156+
i = 0
1157+
1158+
# skip leading whitespace
1159+
while i < n and args_str[i].isspace():
1160+
i += 1
1161+
if i >= n:
1162+
return None
1163+
1164+
depth = 0
1165+
in_string = False
1166+
string_char = ""
1167+
cur: list[str] = []
1168+
1169+
while i < n:
1170+
ch = args_str[i]
1171+
1172+
if in_string:
1173+
cur.append(ch)
1174+
if ch == "\\" and i + 1 < n:
1175+
i += 1
1176+
cur.append(args_str[i])
1177+
elif ch == string_char:
1178+
in_string = False
1179+
elif ch in ('"', "'"):
1180+
in_string = True
1181+
string_char = ch
1182+
cur.append(ch)
1183+
elif ch in ("(", "<", "[", "{"):
1184+
depth += 1
1185+
cur.append(ch)
1186+
elif ch in (")", ">", "]", "}"):
1187+
depth -= 1
1188+
cur.append(ch)
1189+
elif ch == "," and depth == 0:
1190+
break
1191+
else:
1192+
cur.append(ch)
1193+
i += 1
1194+
1195+
# Trim trailing whitespace from the extracted argument
1196+
if not cur:
1197+
return None
1198+
return "".join(cur).rstrip()
1199+
9851200

9861201
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:
9871202
"""Transform Java test code by removing assertions and capturing function calls.

codeflash/languages/java/test_runner.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@
4747
# Allows: letters, digits, underscores, dots, and dollar signs (inner classes)
4848
_VALID_JAVA_CLASS_NAME = re.compile(r"^[a-zA-Z_$][a-zA-Z0-9_$.]*$")
4949

50+
# Skip validation/analysis plugins that reject generated instrumented files
51+
# (e.g. Apache Rat rejects missing license headers, Checkstyle rejects naming, etc.)
52+
_MAVEN_VALIDATION_SKIP_FLAGS = [
53+
"-Drat.skip=true",
54+
"-Dcheckstyle.skip=true",
55+
"-Dspotbugs.skip=true",
56+
"-Dpmd.skip=true",
57+
"-Denforcer.skip=true",
58+
"-Djapicmp.skip=true",
59+
]
60+
5061

5162
def _run_cmd_kill_pg_on_timeout(
5263
cmd: list[str],
@@ -85,9 +96,7 @@ def _run_cmd_kill_pg_on_timeout(
8596
# Windows does not have POSIX process groups / killpg. Fall back to
8697
# the standard subprocess.run() behaviour (kills parent only).
8798
try:
88-
return subprocess.run(
89-
cmd, cwd=cwd, env=env, capture_output=True, text=text, timeout=timeout, check=False
90-
)
99+
return subprocess.run(cmd, cwd=cwd, env=env, capture_output=True, text=text, timeout=timeout, check=False)
91100
except subprocess.TimeoutExpired:
92101
return subprocess.CompletedProcess(
93102
args=cmd, returncode=-2, stdout="", stderr=f"Process timed out after {timeout}s"
@@ -509,8 +518,9 @@ def run_behavioral_tests(
509518
add_jacoco_plugin_to_pom(pom_path)
510519
coverage_xml_path = get_jacoco_xml_path(project_root)
511520

512-
# Use a minimum timeout of 60s for Java builds (120s when coverage is enabled due to verify phase)
513-
min_timeout = 120 if enable_coverage else 60
521+
# Use a minimum timeout of 60s for Java builds (300s when coverage is enabled due to verify phase
522+
# which runs full compilation + instrumentation + test execution in multi-module projects)
523+
min_timeout = 300 if enable_coverage else 60
514524
effective_timeout = max(timeout or 300, min_timeout)
515525

516526
if enable_coverage:
@@ -591,6 +601,7 @@ def _compile_tests(
591601
return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
592602

593603
cmd = [mvn, "test-compile", "-e", "-B"] # Show errors but not verbose output; -B for batch mode (no ANSI colors)
604+
cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS)
594605

595606
if test_module:
596607
cmd.extend(["-pl", test_module, "-am"])
@@ -1526,6 +1537,7 @@ def _run_maven_tests(
15261537
# JaCoCo's report goal is bound to the verify phase to get post-test execution data
15271538
maven_goal = "verify" if enable_coverage else "test"
15281539
cmd = [mvn, maven_goal, "-fae", "-B"] # Fail at end to run all tests; -B for batch mode (no ANSI colors)
1540+
cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS)
15291541

15301542
# Add --add-opens flags for Java 16+ module system compatibility.
15311543
# The codeflash-runtime Serializer uses Kryo which needs reflective access to
@@ -1562,7 +1574,16 @@ def _run_maven_tests(
15621574
# -am = also make dependencies
15631575
# -DfailIfNoTests=false allows dependency modules without tests to pass
15641576
# -DskipTests=false overrides any skipTests=true in pom.xml
1565-
cmd.extend(["-pl", test_module, "-am", "-DfailIfNoTests=false", "-DskipTests=false"])
1577+
cmd.extend(
1578+
[
1579+
"-pl",
1580+
test_module,
1581+
"-am",
1582+
"-DfailIfNoTests=false",
1583+
"-Dsurefire.failIfNoSpecifiedTests=false",
1584+
"-DskipTests=false",
1585+
]
1586+
)
15661587

15671588
if test_filter:
15681589
# Validate test filter to prevent command injection

codeflash/verification/parse_test_output.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,13 @@ def parse_test_xml(
786786
if class_name is not None and class_name.startswith(test_module_path):
787787
test_class = class_name[len(test_module_path) + 1 :] # +1 for the dot, gets Unittest class name
788788

789-
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if testcase.name and "[" in testcase.name else 1
789+
loop_index = 1
790+
if testcase.name and "[" in testcase.name:
791+
bracket_content = testcase.name.rsplit("[", 1)[-1].rstrip("]").strip()
792+
try:
793+
loop_index = int(bracket_content)
794+
except ValueError:
795+
loop_index = 1
790796

791797
timed_out = False
792798
if len(testcase.result) > 1:

0 commit comments

Comments
 (0)