Skip to content

Commit d7a4478

Browse files
fix: resolve mypy type errors in Java config and instrumentation
1 parent 1dc69c6 commit d7a4478

2 files changed

Lines changed: 32 additions & 23 deletions

File tree

codeflash/languages/java/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]:
183183
has_junit4 = False
184184
has_testng = False
185185

186-
def check_dependencies(deps_element, ns):
186+
def check_dependencies(deps_element: ET.Element | None, ns: dict[str, str]) -> None:
187187
"""Check dependencies element for test frameworks."""
188188
nonlocal has_junit5, has_junit4, has_testng
189189

codeflash/languages/java/instrumentation.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@
3333
def _get_function_name(func: Any) -> str:
3434
"""Get the function name from FunctionToOptimize."""
3535
if hasattr(func, "function_name"):
36-
return func.function_name
36+
return str(func.function_name)
3737
if hasattr(func, "name"):
38-
return func.name
38+
return str(func.name)
3939
msg = f"Cannot get function name from {type(func)}"
4040
raise AttributeError(msg)
4141

@@ -80,7 +80,7 @@ def _is_test_annotation(stripped_line: str) -> bool:
8080
return bool(_TEST_ANNOTATION_RE.match(stripped_line))
8181

8282

83-
def _is_inside_lambda(node) -> bool:
83+
def _is_inside_lambda(node: Any) -> bool:
8484
"""Check if a tree-sitter node is inside a lambda_expression."""
8585
current = node.parent
8686
while current is not None:
@@ -92,7 +92,7 @@ def _is_inside_lambda(node) -> bool:
9292
return False
9393

9494

95-
def _is_inside_complex_expression(node) -> bool:
95+
def _is_inside_complex_expression(node: Any) -> bool:
9696
"""Check if a tree-sitter node is inside a complex expression that shouldn't be instrumented directly.
9797
9898
This includes:
@@ -161,7 +161,7 @@ def wrap_target_calls_with_treesitter(
161161
tree = analyzer.parse(wrapper_bytes)
162162

163163
# Collect all matching calls with their metadata
164-
calls = []
164+
calls: list[dict[str, Any]] = []
165165
_collect_calls(tree.root_node, wrapper_bytes, body_bytes, prefix_len, func_name, analyzer, calls)
166166

167167
if not calls:
@@ -175,7 +175,7 @@ def wrap_target_calls_with_treesitter(
175175
offset += len(line.encode("utf8")) + 1 # +1 for \n from join
176176

177177
# Group non-lambda and non-complex-expression calls by their line index
178-
calls_by_line: dict[int, list] = {}
178+
calls_by_line: dict[int, list[dict[str, Any]]] = {}
179179
for call in calls:
180180
if call["in_lambda"] or call.get("in_complex", False):
181181
logger.debug("Skipping behavior instrumentation for call in lambda or complex expression")
@@ -261,7 +261,15 @@ def wrap_target_calls_with_treesitter(
261261
return wrapped, call_counter
262262

263263

264-
def _collect_calls(node, wrapper_bytes, body_bytes, prefix_len, func_name, analyzer, out):
264+
def _collect_calls(
265+
node: Any,
266+
wrapper_bytes: bytes,
267+
body_bytes: bytes,
268+
prefix_len: int,
269+
func_name: str,
270+
analyzer: JavaAnalyzer,
271+
out: list[dict[str, Any]],
272+
) -> None:
265273
"""Recursively collect method_invocation nodes matching func_name."""
266274
if node.type == "method_invocation":
267275
name_node = node.child_by_field_name("name")
@@ -328,7 +336,7 @@ def _infer_array_cast_type(line: str) -> str | None:
328336
def _get_qualified_name(func: Any) -> str:
329337
"""Get the qualified name from FunctionToOptimize."""
330338
if hasattr(func, "qualified_name"):
331-
return func.qualified_name
339+
return str(func.qualified_name)
332340
# Build qualified name from function_name and parents
333341
if hasattr(func, "function_name"):
334342
parts = []
@@ -699,7 +707,7 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
699707
analyzer = get_java_analyzer()
700708
tree = analyzer.parse(source_bytes)
701709

702-
def has_test_annotation(method_node) -> bool:
710+
def has_test_annotation(method_node: Any) -> bool:
703711
modifiers = None
704712
for child in method_node.children:
705713
if child.type == "modifiers":
@@ -718,15 +726,15 @@ def has_test_annotation(method_node) -> bool:
718726
return True
719727
return False
720728

721-
def collect_test_methods(node, out) -> None:
729+
def collect_test_methods(node: Any, out: list[tuple[Any, Any]]) -> None:
722730
if node.type == "method_declaration" and has_test_annotation(node):
723731
body_node = node.child_by_field_name("body")
724732
if body_node is not None:
725733
out.append((node, body_node))
726734
for child in node.children:
727735
collect_test_methods(child, out)
728736

729-
def collect_target_calls(node, wrapper_bytes: bytes, func: str, out) -> None:
737+
def collect_target_calls(node: Any, wrapper_bytes: bytes, func: str, out: list[Any]) -> None:
730738
if node.type == "method_invocation":
731739
name_node = node.child_by_field_name("name")
732740
if name_node and analyzer.get_node_text(name_node, wrapper_bytes) == func:
@@ -753,13 +761,13 @@ def reindent_block(text: str, target_indent: str) -> str:
753761
reindented.append(f"{target_indent}{line[min_leading:]}")
754762
return "\n".join(reindented)
755763

756-
def find_top_level_statement(node, body_node):
764+
def find_top_level_statement(node: Any, body_node: Any) -> Any:
757765
current = node
758766
while current is not None and current.parent is not None and current.parent != body_node:
759767
current = current.parent
760768
return current if current is not None and current.parent == body_node else None
761769

762-
def split_var_declaration(stmt_node, source_bytes_ref: bytes) -> tuple[str, str] | None:
770+
def split_var_declaration(stmt_node: Any, source_bytes_ref: bytes) -> tuple[str, str] | None:
763771
"""Split a local_variable_declaration into a hoisted declaration and an assignment.
764772
765773
When a target call is inside a variable declaration like:
@@ -831,7 +839,7 @@ def build_instrumented_body(body_text: str, next_wrapper_id: int, base_indent: s
831839
wrapped_body = wrapped_method.child_by_field_name("body")
832840
if wrapped_body is None:
833841
return body_text, next_wrapper_id
834-
calls = []
842+
calls: list[Any] = []
835843
collect_target_calls(wrapped_body, wrapper_bytes, func_name, calls)
836844

837845
indent = base_indent
@@ -930,14 +938,14 @@ def build_instrumented_body(body_text: str, next_wrapper_id: int, base_indent: s
930938
result_parts.append(suffix)
931939
return "".join(result_parts), current_id
932940

933-
result_parts: list[str] = []
941+
multi_result_parts: list[str] = []
934942
cursor = 0
935943
wrapper_id = next_wrapper_id
936944

937945
for stmt_start, stmt_end, stmt_ast_node in unique_ranges:
938946
prefix = body_text[cursor:stmt_start]
939947
target_stmt = body_text[stmt_start:stmt_end]
940-
result_parts.append(prefix.rstrip(" \t"))
948+
multi_result_parts.append(prefix.rstrip(" \t"))
941949

942950
wrapper_id += 1
943951
current_id = wrapper_id
@@ -979,14 +987,14 @@ def build_instrumented_body(body_text: str, next_wrapper_id: int, base_indent: s
979987
f"{indent}}}",
980988
]
981989

982-
result_parts.append("\n" + "\n".join(setup_lines))
983-
result_parts.append("\n".join(timing_lines))
990+
multi_result_parts.append("\n" + "\n".join(setup_lines))
991+
multi_result_parts.append("\n".join(timing_lines))
984992
cursor = stmt_end
985993

986-
result_parts.append(body_text[cursor:])
987-
return "".join(result_parts), wrapper_id
994+
multi_result_parts.append(body_text[cursor:])
995+
return "".join(multi_result_parts), wrapper_id
988996

989-
test_methods = []
997+
test_methods: list[tuple[Any, Any]] = []
990998
collect_test_methods(tree.root_node, test_methods)
991999
if not test_methods:
9921000
return source
@@ -1134,12 +1142,13 @@ def instrument_generated_java_test(
11341142
function_name,
11351143
)
11361144
elif mode == "behavior":
1137-
_, modified_code = instrument_existing_test(
1145+
_, behavior_code = instrument_existing_test(
11381146
test_string=test_code,
11391147
mode=mode,
11401148
function_to_optimize=function_to_optimize,
11411149
test_class_name=original_class_name,
11421150
)
1151+
modified_code = behavior_code or test_code
11431152
else:
11441153
modified_code = test_code
11451154

0 commit comments

Comments
 (0)