Skip to content

Commit 8028174

Browse files
committed
feat: Enhance wrap_target_calls_with_treesitter and _add_behavior_instrumentation to support return type handling
1 parent c8d4fd3 commit 8028174

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

codeflash/languages/java/instrumentation.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def wrap_target_calls_with_treesitter(
279279
class_name: str = "",
280280
test_method_name: str = "",
281281
is_void: bool = False,
282+
return_type: str | None = None,
282283
) -> tuple[list[str], int]:
283284
"""Replace target method calls in body_lines with capture + serialize using tree-sitter.
284285
@@ -348,6 +349,8 @@ def wrap_target_calls_with_treesitter(
348349
call_counter += 1
349350
var_name = f"_cf_result{iter_id}_{call_counter}"
350351
cast_type = _infer_array_cast_type(body_line)
352+
if not cast_type and return_type and return_type not in ("void", "Object"):
353+
cast_type = return_type
351354
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
352355

353356
# For void functions, we can't assign the return value to a variable
@@ -704,7 +707,8 @@ def instrument_existing_test(
704707
# replacing substrings of other identifiers.
705708
modified_source = re.sub(rf"\b{re.escape(original_class_name)}\b", new_class_name, source)
706709

707-
is_void = getattr(function_to_optimize, "return_type", None) == "void"
710+
return_type = getattr(function_to_optimize, "return_type", None)
711+
is_void = return_type == "void"
708712

709713
# Add timing instrumentation to test methods
710714
# Use original class name (without suffix) in timing markers for consistency with Python
@@ -716,14 +720,14 @@ def instrument_existing_test(
716720
)
717721
else:
718722
# Behavior mode: add timing instrumentation that also writes to SQLite
719-
modified_source = _add_behavior_instrumentation(modified_source, original_class_name, func_name, is_void=is_void)
723+
modified_source = _add_behavior_instrumentation(modified_source, original_class_name, func_name, is_void=is_void, return_type=return_type)
720724

721725
logger.debug("Java %s testing for %s: renamed class %s -> %s", mode, func_name, original_class_name, new_class_name)
722726
# Why return True here?
723727
return True, modified_source
724728

725729

726-
def _add_behavior_instrumentation(source: str, class_name: str, func_name: str, is_void: bool = False) -> str:
730+
def _add_behavior_instrumentation(source: str, class_name: str, func_name: str, is_void: bool = False, return_type: str | None = None) -> str:
727731
"""Add behavior instrumentation to test methods.
728732
729733
For behavior mode, this adds:
@@ -865,6 +869,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str,
865869
class_name=class_name,
866870
test_method_name=test_method_name,
867871
is_void=is_void,
872+
return_type=return_type,
868873
)
869874

870875
# Add behavior instrumentation setup code (shared variables for all calls in the method)

0 commit comments

Comments
 (0)