diff --git a/codeflash/languages/java/parser.py b/codeflash/languages/java/parser.py index 52d10726a..c52817e58 100644 --- a/codeflash/languages/java/parser.py +++ b/codeflash/languages/java/parser.py @@ -745,15 +745,10 @@ def _has_test_annotation(self, method_node: Node, source_bytes: bytes) -> bool: for child in method_node.children: if child.type == "modifiers": for mod_child in child.children: - if mod_child.type in ("marker_annotation", "annotation"): - name_node = mod_child.child_by_field_name("name") - if name_node is None: - # Fallback: search direct children for identifier - for ann_child in mod_child.children: - if ann_child.type == "identifier": - name_node = ann_child - break - if name_node and self.get_node_text(name_node, source_bytes) == "Test": + mod_type = mod_child.type + if mod_type == "marker_annotation" or mod_type == "annotation": + name_node = self._find_test_annotation_name(mod_child) + if name_node is not None and self.get_node_text(name_node, source_bytes) == "Test": return True return False @@ -861,6 +856,25 @@ def find_import_insertion_point(self, source: str) -> int: return last_line + def _find_test_annotation_name(self, annotation_node: Node) -> Node | None: + """Find the name node of an annotation. + + Args: + annotation_node: The annotation or marker_annotation node. + + Returns: + The name node if found, None otherwise. + """ + name_node = annotation_node.child_by_field_name("name") + if name_node is not None: + return name_node + # Fallback: search direct children for identifier + for ann_child in annotation_node.children: + if ann_child.type == "identifier": + return ann_child + return None + + def get_java_analyzer() -> JavaAnalyzer: """Get a JavaAnalyzer instance.