3333def _get_function_name (func : Any ) -> str :
3434 """Get the function name from FunctionToOptimize."""
3535 if hasattr (func , "function_name" ):
36- return func .function_name
36+ return str ( func .function_name )
3737 if hasattr (func , "name" ):
38- return func .name
38+ return str ( func .name )
3939 msg = f"Cannot get function name from { type (func )} "
4040 raise AttributeError (msg )
4141
@@ -80,7 +80,7 @@ def _is_test_annotation(stripped_line: str) -> bool:
8080 return bool (_TEST_ANNOTATION_RE .match (stripped_line ))
8181
8282
83- def _is_inside_lambda (node ) -> bool :
83+ def _is_inside_lambda (node : Any ) -> bool :
8484 """Check if a tree-sitter node is inside a lambda_expression."""
8585 current = node .parent
8686 while current is not None :
@@ -92,7 +92,7 @@ def _is_inside_lambda(node) -> bool:
9292 return False
9393
9494
95- def _is_inside_complex_expression (node ) -> bool :
95+ def _is_inside_complex_expression (node : Any ) -> bool :
9696 """Check if a tree-sitter node is inside a complex expression that shouldn't be instrumented directly.
9797
9898 This includes:
@@ -161,7 +161,7 @@ def wrap_target_calls_with_treesitter(
161161 tree = analyzer .parse (wrapper_bytes )
162162
163163 # Collect all matching calls with their metadata
164- calls = []
164+ calls : list [ dict [ str , Any ]] = []
165165 _collect_calls (tree .root_node , wrapper_bytes , body_bytes , prefix_len , func_name , analyzer , calls )
166166
167167 if not calls :
@@ -175,7 +175,7 @@ def wrap_target_calls_with_treesitter(
175175 offset += len (line .encode ("utf8" )) + 1 # +1 for \n from join
176176
177177 # Group non-lambda and non-complex-expression calls by their line index
178- calls_by_line : dict [int , list ] = {}
178+ calls_by_line : dict [int , list [ dict [ str , Any ]] ] = {}
179179 for call in calls :
180180 if call ["in_lambda" ] or call .get ("in_complex" , False ):
181181 logger .debug ("Skipping behavior instrumentation for call in lambda or complex expression" )
@@ -261,7 +261,15 @@ def wrap_target_calls_with_treesitter(
261261 return wrapped , call_counter
262262
263263
264- def _collect_calls (node , wrapper_bytes , body_bytes , prefix_len , func_name , analyzer , out ):
264+ def _collect_calls (
265+ node : Any ,
266+ wrapper_bytes : bytes ,
267+ body_bytes : bytes ,
268+ prefix_len : int ,
269+ func_name : str ,
270+ analyzer : JavaAnalyzer ,
271+ out : list [dict [str , Any ]],
272+ ) -> None :
265273 """Recursively collect method_invocation nodes matching func_name."""
266274 if node .type == "method_invocation" :
267275 name_node = node .child_by_field_name ("name" )
@@ -328,7 +336,7 @@ def _infer_array_cast_type(line: str) -> str | None:
328336def _get_qualified_name (func : Any ) -> str :
329337 """Get the qualified name from FunctionToOptimize."""
330338 if hasattr (func , "qualified_name" ):
331- return func .qualified_name
339+ return str ( func .qualified_name )
332340 # Build qualified name from function_name and parents
333341 if hasattr (func , "function_name" ):
334342 parts = []
@@ -699,7 +707,7 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
699707 analyzer = get_java_analyzer ()
700708 tree = analyzer .parse (source_bytes )
701709
702- def has_test_annotation (method_node ) -> bool :
710+ def has_test_annotation (method_node : Any ) -> bool :
703711 modifiers = None
704712 for child in method_node .children :
705713 if child .type == "modifiers" :
@@ -718,15 +726,15 @@ def has_test_annotation(method_node) -> bool:
718726 return True
719727 return False
720728
721- def collect_test_methods (node , out ) -> None :
729+ def collect_test_methods (node : Any , out : list [ tuple [ Any , Any ]] ) -> None :
722730 if node .type == "method_declaration" and has_test_annotation (node ):
723731 body_node = node .child_by_field_name ("body" )
724732 if body_node is not None :
725733 out .append ((node , body_node ))
726734 for child in node .children :
727735 collect_test_methods (child , out )
728736
729- def collect_target_calls (node , wrapper_bytes : bytes , func : str , out ) -> None :
737+ def collect_target_calls (node : Any , wrapper_bytes : bytes , func : str , out : list [ Any ] ) -> None :
730738 if node .type == "method_invocation" :
731739 name_node = node .child_by_field_name ("name" )
732740 if name_node and analyzer .get_node_text (name_node , wrapper_bytes ) == func :
@@ -753,13 +761,13 @@ def reindent_block(text: str, target_indent: str) -> str:
753761 reindented .append (f"{ target_indent } { line [min_leading :]} " )
754762 return "\n " .join (reindented )
755763
756- def find_top_level_statement (node , body_node ) :
764+ def find_top_level_statement (node : Any , body_node : Any ) -> Any :
757765 current = node
758766 while current is not None and current .parent is not None and current .parent != body_node :
759767 current = current .parent
760768 return current if current is not None and current .parent == body_node else None
761769
762- def split_var_declaration (stmt_node , source_bytes_ref : bytes ) -> tuple [str , str ] | None :
770+ def split_var_declaration (stmt_node : Any , source_bytes_ref : bytes ) -> tuple [str , str ] | None :
763771 """Split a local_variable_declaration into a hoisted declaration and an assignment.
764772
765773 When a target call is inside a variable declaration like:
@@ -831,7 +839,7 @@ def build_instrumented_body(body_text: str, next_wrapper_id: int, base_indent: s
831839 wrapped_body = wrapped_method .child_by_field_name ("body" )
832840 if wrapped_body is None :
833841 return body_text , next_wrapper_id
834- calls = []
842+ calls : list [ Any ] = []
835843 collect_target_calls (wrapped_body , wrapper_bytes , func_name , calls )
836844
837845 indent = base_indent
@@ -930,14 +938,14 @@ def build_instrumented_body(body_text: str, next_wrapper_id: int, base_indent: s
930938 result_parts .append (suffix )
931939 return "" .join (result_parts ), current_id
932940
933- result_parts : list [str ] = []
941+ multi_result_parts : list [str ] = []
934942 cursor = 0
935943 wrapper_id = next_wrapper_id
936944
937945 for stmt_start , stmt_end , stmt_ast_node in unique_ranges :
938946 prefix = body_text [cursor :stmt_start ]
939947 target_stmt = body_text [stmt_start :stmt_end ]
940- result_parts .append (prefix .rstrip (" \t " ))
948+ multi_result_parts .append (prefix .rstrip (" \t " ))
941949
942950 wrapper_id += 1
943951 current_id = wrapper_id
@@ -979,14 +987,14 @@ def build_instrumented_body(body_text: str, next_wrapper_id: int, base_indent: s
979987 f"{ indent } }}" ,
980988 ]
981989
982- result_parts .append ("\n " + "\n " .join (setup_lines ))
983- result_parts .append ("\n " .join (timing_lines ))
990+ multi_result_parts .append ("\n " + "\n " .join (setup_lines ))
991+ multi_result_parts .append ("\n " .join (timing_lines ))
984992 cursor = stmt_end
985993
986- result_parts .append (body_text [cursor :])
987- return "" .join (result_parts ), wrapper_id
994+ multi_result_parts .append (body_text [cursor :])
995+ return "" .join (multi_result_parts ), wrapper_id
988996
989- test_methods = []
997+ test_methods : list [ tuple [ Any , Any ]] = []
990998 collect_test_methods (tree .root_node , test_methods )
991999 if not test_methods :
9921000 return source
@@ -1134,12 +1142,13 @@ def instrument_generated_java_test(
11341142 function_name ,
11351143 )
11361144 elif mode == "behavior" :
1137- _ , modified_code = instrument_existing_test (
1145+ _ , behavior_code = instrument_existing_test (
11381146 test_string = test_code ,
11391147 mode = mode ,
11401148 function_to_optimize = function_to_optimize ,
11411149 test_class_name = original_class_name ,
11421150 )
1151+ modified_code = behavior_code or test_code
11431152 else :
11441153 modified_code = test_code
11451154
0 commit comments