|
23 | 23 | from codeflash.languages.java.parser import get_java_analyzer |
24 | 24 |
|
25 | 25 | if TYPE_CHECKING: |
| 26 | + from tree_sitter import Node |
| 27 | + |
26 | 28 | from codeflash.discovery.functions_to_optimize import FunctionToOptimize |
27 | 29 | from codeflash.languages.java.parser import JavaAnalyzer |
28 | 30 |
|
@@ -548,7 +550,7 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa |
548 | 550 |
|
549 | 551 | def _collect_target_invocations( |
550 | 552 | self, |
551 | | - node, |
| 553 | + node: Node, |
552 | 554 | wrapper_bytes: bytes, |
553 | 555 | content_bytes: bytes, |
554 | 556 | base_offset: int, |
@@ -601,32 +603,40 @@ def _collect_target_invocations( |
601 | 603 | self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out, seen_top_level) |
602 | 604 |
|
603 | 605 | def _build_target_call( |
604 | | - self, node, wrapper_bytes: bytes, content_bytes: bytes, start_byte: int, end_byte: int, base_offset: int |
| 606 | + self, node: Node, wrapper_bytes: bytes, content_bytes: bytes, start_byte: int, end_byte: int, base_offset: int |
605 | 607 | ) -> TargetCall: |
606 | 608 | """Build a TargetCall from a tree-sitter method_invocation node.""" |
607 | | - get_text = self.analyzer.get_node_text |
608 | | - |
609 | 609 | object_node = node.child_by_field_name("object") |
610 | 610 | args_node = node.child_by_field_name("arguments") |
611 | | - args_text = get_text(args_node, wrapper_bytes) if args_node else "" |
| 611 | + |
| 612 | + if args_node: |
| 613 | + args_text = wrapper_bytes[args_node.start_byte : args_node.end_byte].decode("utf8") |
| 614 | + else: |
| 615 | + args_text = "" |
612 | 616 | # argument_list node includes parens, strip them |
613 | | - if args_text.startswith("(") and args_text.endswith(")"): |
| 617 | + if args_text and args_text[0] == "(" and args_text[-1] == ")": |
614 | 618 | args_text = args_text[1:-1] |
615 | 619 |
|
616 | | - # Byte offsets -> char offsets for correct Python string indexing |
617 | | - start_char = len(content_bytes[:start_byte].decode("utf8")) |
618 | | - end_char = len(content_bytes[:end_byte].decode("utf8")) |
| 620 | + # Byte offsets -> char offsets for correct Python string indexing using analyzer mapping |
| 621 | + start_char = self.analyzer.byte_to_char_index(start_byte, content_bytes) |
| 622 | + end_char = self.analyzer.byte_to_char_index(end_byte, content_bytes) |
| 623 | + |
| 624 | + # Extract receiver and full call text from the wrapper bytes directly (fast for small wrappers) |
| 625 | + receiver_text = ( |
| 626 | + wrapper_bytes[object_node.start_byte : object_node.end_byte].decode("utf8") if object_node else None |
| 627 | + ) |
| 628 | + full_call_text = wrapper_bytes[node.start_byte : node.end_byte].decode("utf8") |
619 | 629 |
|
620 | 630 | return TargetCall( |
621 | | - receiver=get_text(object_node, wrapper_bytes) if object_node else None, |
| 631 | + receiver=receiver_text, |
622 | 632 | method_name=self.func_name, |
623 | 633 | arguments=args_text, |
624 | | - full_call=get_text(node, wrapper_bytes), |
| 634 | + full_call=full_call_text, |
625 | 635 | start_pos=base_offset + start_char, |
626 | 636 | end_pos=base_offset + end_char, |
627 | 637 | ) |
628 | 638 |
|
629 | | - def _find_top_level_arg_node(self, target_node, wrapper_bytes: bytes): |
| 639 | + def _find_top_level_arg_node(self, target_node: Node, wrapper_bytes: bytes) -> Node | None: |
630 | 640 | """Find the top-level argument expression containing a nested target call. |
631 | 641 |
|
632 | 642 | Walks up the AST from target_node to the wrapper _d() call's argument_list. |
|
0 commit comments