Skip to content

Commit 41814cd

Browse files
HeshamHM28claude
andcommitted
feat: support void method optimization in Java pipeline
Discover void methods, instrument them by serializing the receiver instead of a return value, and treat all-null comparisons as equivalent. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2d2bdc7 commit 41814cd

6 files changed

Lines changed: 347 additions & 32 deletions

File tree

codeflash/discovery/functions_to_optimize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list
195195

196196
try:
197197
lang_support = get_language_support(file_path)
198-
criteria = FunctionFilterCriteria(require_return=True)
198+
require_return = lang_support.language != Language.JAVA
199+
criteria = FunctionFilterCriteria(require_return=require_return)
199200
functions[file_path] = lang_support.discover_functions(file_path, criteria)
200201
except Exception as e:
201202
logger.debug(f"Failed to discover functions in {file_path}: {e}")
@@ -454,7 +455,8 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
454455
from codeflash.languages.base import FunctionFilterCriteria
455456

456457
lang_support = get_language_support(file_path)
457-
criteria = FunctionFilterCriteria(require_return=True)
458+
require_return = lang_support.language != Language.JAVA
459+
criteria = FunctionFilterCriteria(require_return=require_return)
458460
source = file_path.read_text(encoding="utf-8")
459461
return {file_path: lang_support.discover_functions(source, file_path, criteria)}
460462
except Exception as e:

codeflash/languages/java/comparator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,16 @@ def compare_test_results(
299299
skipped_deser_errors = comparison.get("skippedDeserializationErrors", 0)
300300

301301
if actual_comparisons == 0:
302+
if skipped_placeholders > 0 and skipped_deser_errors == 0 and not comparison.get("diffs"):
303+
# For void methods, all return values are null → all are "placeholder" skips.
304+
# If no diffs and no deser errors, treat as equivalent (pass/fail verification).
305+
logger.info(
306+
"Java comparison: void method — all return values null, treating as equivalent "
307+
"(total=%s, skipped_placeholders=%s)",
308+
comparison.get("totalInvocations", 0),
309+
skipped_placeholders,
310+
)
311+
return True, []
302312
logger.warning(
303313
"Java comparison: no actual comparisons performed "
304314
"(total=%s, skipped_placeholders=%s, skipped_deser_errors=%s). "

codeflash/languages/java/instrumentation.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -337,54 +337,94 @@ def wrap_target_calls_with_treesitter(
337337
orig_line = body_lines[line_idx]
338338
line_indent_str = " " * (len(orig_line) - len(orig_line.lstrip()))
339339

340+
is_void = target_return_type == "void"
340341
var_name = f"_cf_result{iter_id}_{call_counter}"
342+
receiver = call.get("receiver", "this")
341343
cast_type = _infer_array_cast_type(orig_line)
342-
if not cast_type and target_return_type and target_return_type != "void":
344+
if not cast_type and target_return_type and not is_void:
343345
cast_type = target_return_type
344346
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
345347

346-
capture_stmt_with_decl = f"var {var_name} = {call['full_call']};"
347-
capture_stmt_assign = f"{var_name} = {call['full_call']};"
348-
if precise_call_timing:
349-
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});"
350-
start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();"
351-
end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();"
348+
if is_void:
349+
bare_call_stmt = f"{call['full_call']};"
350+
if precise_call_timing:
351+
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {receiver});"
352+
start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();"
353+
end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();"
354+
else:
355+
serialize_stmt = (
356+
f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {receiver});"
357+
)
358+
start_stmt = f"_cf_start{iter_id} = System.nanoTime();"
359+
end_stmt = f"_cf_end{iter_id} = System.nanoTime();"
352360
else:
353-
serialize_stmt = f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});"
354-
start_stmt = f"_cf_start{iter_id} = System.nanoTime();"
355-
end_stmt = f"_cf_end{iter_id} = System.nanoTime();"
361+
capture_stmt_with_decl = f"var {var_name} = {call['full_call']};"
362+
capture_stmt_assign = f"{var_name} = {call['full_call']};"
363+
if precise_call_timing:
364+
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});"
365+
start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();"
366+
end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();"
367+
else:
368+
serialize_stmt = (
369+
f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});"
370+
)
371+
start_stmt = f"_cf_start{iter_id} = System.nanoTime();"
372+
end_stmt = f"_cf_end{iter_id} = System.nanoTime();"
356373

357374
if call["parent_type"] == "expression_statement":
358375
es_start = call["_es_start_char"]
359376
es_end = call["_es_end_char"]
360377
if precise_call_timing:
361378
# No indent on first line — body_text[:es_start] already has leading whitespace.
362379
# Subsequent lines get line_indent_str.
363-
var_decls = [
364-
f"Object {var_name} = null;",
365-
f"long _cf_end{iter_id}_{call_counter} = -1;",
366-
f"long _cf_start{iter_id}_{call_counter} = 0;",
367-
f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;",
368-
]
380+
if is_void:
381+
var_decls = [
382+
f"long _cf_end{iter_id}_{call_counter} = -1;",
383+
f"long _cf_start{iter_id}_{call_counter} = 0;",
384+
f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;",
385+
]
386+
else:
387+
var_decls = [
388+
f"Object {var_name} = null;",
389+
f"long _cf_end{iter_id}_{call_counter} = -1;",
390+
f"long _cf_start{iter_id}_{call_counter} = 0;",
391+
f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;",
392+
]
369393
start_marker = f'System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");'
370-
try_block = [
371-
"try {",
372-
f" {start_stmt}",
373-
f" {capture_stmt_assign}",
374-
f" {end_stmt}",
375-
f" {serialize_stmt}",
376-
]
394+
if is_void:
395+
try_block = [
396+
"try {",
397+
f" {start_stmt}",
398+
f" {bare_call_stmt}",
399+
f" {end_stmt}",
400+
f" {serialize_stmt}",
401+
]
402+
else:
403+
try_block = [
404+
"try {",
405+
f" {start_stmt}",
406+
f" {capture_stmt_assign}",
407+
f" {end_stmt}",
408+
f" {serialize_stmt}",
409+
]
377410
finally_block = _generate_sqlite_write_code(
378411
iter_id, call_counter, "", class_name, func_name, test_method_name, invocation_id=inv_id
379412
)
380413
all_lines = [*var_decls, start_marker, *try_block, *finally_block]
381414
replacement = (
382415
all_lines[0] + "\n" + "\n".join(f"{line_indent_str}{repl_line}" for repl_line in all_lines[1:])
383416
)
417+
elif is_void:
418+
replacement = f"{bare_call_stmt} {serialize_stmt}"
384419
else:
385420
replacement = f"{capture_stmt_with_decl} {serialize_stmt}"
386421
body_text = body_text[:es_start] + replacement + body_text[es_end:]
387422
else:
423+
if is_void:
424+
# Void calls cannot be embedded in expressions in valid Java — skip instrumentation
425+
logger.warning("Skipping instrumentation of embedded void call: %s", call["full_call"])
426+
continue
427+
388428
# Embedded call: replace call with variable, then insert capture lines before the line
389429
call_start = call["_call_start_char"]
390430
call_end = call["_call_end_char"]
@@ -451,6 +491,8 @@ def _collect_calls(
451491
if parent_type == "expression_statement":
452492
es_start = parent.start_byte - prefix_len
453493
es_end = parent.end_byte - prefix_len
494+
object_node = node.child_by_field_name("object")
495+
receiver = analyzer.get_node_text(object_node, wrapper_bytes) if object_node else "this"
454496
out.append(
455497
{
456498
"start_byte": start,
@@ -461,6 +503,7 @@ def _collect_calls(
461503
"in_complex": _is_inside_complex_expression(node),
462504
"es_start_byte": es_start,
463505
"es_end_byte": es_end,
506+
"receiver": receiver,
464507
}
465508
)
466509
for child in node.children:

codeflash/languages/java/remove_asserts.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,15 @@ def __init__(
189189
qualified_name: str | None = None,
190190
analyzer: JavaAnalyzer | None = None,
191191
mode: str = "capture",
192+
target_return_type: str = "",
192193
) -> None:
193194
self.analyzer = analyzer or get_java_analyzer()
194195
self.func_name = function_name
195196
self.qualified_name = qualified_name or function_name
196197
self.invocation_counter = 0
197198
self._detected_framework: str | None = None
198199
self.mode = mode # "capture" (default, instrumentation) or "strip" (clean display)
200+
self.target_return_type = target_return_type
199201

200202
# Precompile the assignment-detection regex to avoid recompiling on each call.
201203
self._assign_re = re.compile(r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$")
@@ -1062,7 +1064,7 @@ def _generate_replacement(self, assertion: AssertionMatch) -> str:
10621064
if not assertion.target_calls:
10631065
return ""
10641066

1065-
if self.mode == "strip":
1067+
if self.mode == "strip" or self.target_return_type == "void":
10661068
return self._generate_strip_replacement(assertion)
10671069

10681070
# Infer the return type from assertion context to avoid Object→primitive cast errors
@@ -1244,7 +1246,9 @@ def _extract_first_arg(self, args_str: str) -> str | None:
12441246
return "".join(cur).rstrip()
12451247

12461248

1247-
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:
1249+
def transform_java_assertions(
1250+
source: str, function_name: str, qualified_name: str | None = None, target_return_type: str = ""
1251+
) -> str:
12481252
"""Transform Java test code by removing assertions and capturing function calls.
12491253
12501254
This is the main entry point for Java assertion transformation.
@@ -1253,12 +1257,15 @@ def transform_java_assertions(source: str, function_name: str, qualified_name: s
12531257
source: The Java test source code.
12541258
function_name: Name of the function being tested.
12551259
qualified_name: Optional fully qualified name of the function.
1260+
target_return_type: Return type of the target function (e.g., "void", "int").
12561261
12571262
Returns:
12581263
Transformed source code with assertions replaced by capture statements.
12591264
12601265
"""
1261-
transformer = JavaAssertTransformer(function_name=function_name, qualified_name=qualified_name)
1266+
transformer = JavaAssertTransformer(
1267+
function_name=function_name, qualified_name=qualified_name, target_return_type=target_return_type
1268+
)
12621269
return transformer.transform(source)
12631270

12641271

tests/test_languages/fixtures/java_tracer_e2e/pom.xml

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,18 @@
1111
<maven.compiler.source>11</maven.compiler.source>
1212
<maven.compiler.target>11</maven.compiler.target>
1313
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
14-
</properties>
14+
<checkstyle.skip>true</checkstyle.skip>
15+
<disable.checks>true</disable.checks>
16+
<spotbugs.skip>true</spotbugs.skip>
17+
<pmd.skip>true</pmd.skip>
18+
<rat.skip>true</rat.skip>
19+
<enforcer.skip>true</enforcer.skip>
20+
<japicmp.skip>true</japicmp.skip>
21+
<checkstyle.failOnViolation>false</checkstyle.failOnViolation>
22+
<checkstyle.failsOnError>false</checkstyle.failsOnError>
23+
<maven-checkstyle-plugin.failsOnError>false</maven-checkstyle-plugin.failsOnError>
24+
<maven-checkstyle-plugin.failOnViolation>false</maven-checkstyle-plugin.failOnViolation>
25+
</properties>
1526

1627
<dependencies>
1728
<dependency>
@@ -62,6 +73,26 @@
6273
</execution>
6374
</executions>
6475
</plugin>
65-
</plugins>
76+
<!-- codeflash-validation-skip -->
77+
<plugin>
78+
<groupId>org.apache.maven.plugins</groupId>
79+
<artifactId>maven-checkstyle-plugin</artifactId>
80+
<configuration>
81+
<skip>true</skip>
82+
<failOnViolation>false</failOnViolation>
83+
<failsOnError>false</failsOnError>
84+
</configuration>
85+
</plugin>
86+
<plugin>
87+
<groupId>com.github.spotbugs</groupId>
88+
<artifactId>spotbugs-maven-plugin</artifactId>
89+
<configuration><skip>true</skip></configuration>
90+
</plugin>
91+
<plugin>
92+
<groupId>org.apache.maven.plugins</groupId>
93+
<artifactId>maven-pmd-plugin</artifactId>
94+
<configuration><skip>true</skip></configuration>
95+
</plugin>
96+
</plugins>
6697
</build>
6798
</project>

0 commit comments

Comments
 (0)