Skip to content

Commit a979c45

Browse files
Merge pull request #1597 from codeflash-ai/codeflash/optimize-pr1199-2026-02-20T10.16.02
⚡️ Speed up method `JavaAssertTransformer._build_target_call` by 1,947% in PR #1199 (`omni-java`)
2 parents adcc9b8 + 8d74b26 commit a979c45

2 files changed

Lines changed: 65 additions & 12 deletions

File tree

codeflash/languages/java/parser.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
import logging
10+
from bisect import bisect_right
1011
from dataclasses import dataclass
1112
from typing import TYPE_CHECKING
1213

@@ -111,6 +112,13 @@ def __init__(self) -> None:
111112
"""Initialize the Java analyzer."""
112113
self._parser: Parser | None = None
113114

115+
# Caches for the last decoded source to avoid repeated decodes.
116+
self._cached_source_bytes: bytes | None = None
117+
self._cached_source_str: str | None = None
118+
# cumulative byte counts per character: cum_bytes[i] == total bytes for first i characters
119+
# length is number_of_chars + 1, cum_bytes[0] == 0
120+
self._cached_cum_bytes: list[int] | None = None
121+
114122
@property
115123
def parser(self) -> Parser:
116124
"""Get the parser, creating it lazily."""
@@ -678,6 +686,41 @@ def get_package_name(self, source: str) -> str | None:
678686

679687
return None
680688

689+
def _ensure_decoded(self, source: bytes) -> None:
690+
"""Ensure the provided source bytes are decoded and cumulative byte mapping is built.
691+
692+
Caches the decoded string and cumulative byte-lengths for the last-seen `source` bytes
693+
to make slicing by node byte offsets into string slices much cheaper.
694+
"""
695+
if source is self._cached_source_bytes:
696+
return
697+
698+
decoded = source.decode("utf8")
699+
# Build cumulative bytes per character. cum[0] = 0, cum[i] = bytes for first i chars.
700+
cum: list[int] = [0]
701+
# Building the cumulative mapping is done once per distinct source and is faster than
702+
# repeatedly decoding prefixes for many nodes.
703+
# A local variable for append and encode reduces attribute lookups.
704+
append = cum.append
705+
for ch in decoded:
706+
append(cum[-1] + len(ch.encode("utf8")))
707+
708+
self._cached_source_bytes = source
709+
self._cached_source_str = decoded
710+
self._cached_cum_bytes = cum
711+
712+
def byte_to_char_index(self, byte_offset: int, source: bytes) -> int:
713+
"""Convert a byte offset into a character index for the given source bytes.
714+
715+
This uses a cached cumulative byte-length mapping so repeated conversions are O(log n)
716+
(binary search) instead of re-decoding prefixes O(n).
717+
"""
718+
self._ensure_decoded(source)
719+
# cum is a non-decreasing list: find largest k where cum[k] <= byte_offset
720+
cum = self._cached_cum_bytes # type: ignore[assignment]
721+
# bisect_right returns insertion point; subtract 1 to get character count
722+
return bisect_right(cum, byte_offset) - 1
723+
681724

682725
def get_java_analyzer() -> JavaAnalyzer:
683726
"""Get a JavaAnalyzer instance.

codeflash/languages/java/remove_asserts.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from codeflash.languages.java.parser import get_java_analyzer
2424

2525
if TYPE_CHECKING:
26+
from tree_sitter import Node
27+
2628
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
2729
from codeflash.languages.java.parser import JavaAnalyzer
2830

@@ -548,7 +550,7 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa
548550

549551
def _collect_target_invocations(
550552
self,
551-
node,
553+
node: Node,
552554
wrapper_bytes: bytes,
553555
content_bytes: bytes,
554556
base_offset: int,
@@ -601,32 +603,40 @@ def _collect_target_invocations(
601603
self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out, seen_top_level)
602604

603605
def _build_target_call(
604-
self, node, wrapper_bytes: bytes, content_bytes: bytes, start_byte: int, end_byte: int, base_offset: int
606+
self, node: Node, wrapper_bytes: bytes, content_bytes: bytes, start_byte: int, end_byte: int, base_offset: int
605607
) -> TargetCall:
606608
"""Build a TargetCall from a tree-sitter method_invocation node."""
607-
get_text = self.analyzer.get_node_text
608-
609609
object_node = node.child_by_field_name("object")
610610
args_node = node.child_by_field_name("arguments")
611-
args_text = get_text(args_node, wrapper_bytes) if args_node else ""
611+
612+
if args_node:
613+
args_text = wrapper_bytes[args_node.start_byte : args_node.end_byte].decode("utf8")
614+
else:
615+
args_text = ""
612616
# argument_list node includes parens, strip them
613-
if args_text.startswith("(") and args_text.endswith(")"):
617+
if args_text and args_text[0] == "(" and args_text[-1] == ")":
614618
args_text = args_text[1:-1]
615619

616-
# Byte offsets -> char offsets for correct Python string indexing
617-
start_char = len(content_bytes[:start_byte].decode("utf8"))
618-
end_char = len(content_bytes[:end_byte].decode("utf8"))
620+
# Byte offsets -> char offsets for correct Python string indexing using analyzer mapping
621+
start_char = self.analyzer.byte_to_char_index(start_byte, content_bytes)
622+
end_char = self.analyzer.byte_to_char_index(end_byte, content_bytes)
623+
624+
# Extract receiver and full call text from the wrapper bytes directly (fast for small wrappers)
625+
receiver_text = (
626+
wrapper_bytes[object_node.start_byte : object_node.end_byte].decode("utf8") if object_node else None
627+
)
628+
full_call_text = wrapper_bytes[node.start_byte : node.end_byte].decode("utf8")
619629

620630
return TargetCall(
621-
receiver=get_text(object_node, wrapper_bytes) if object_node else None,
631+
receiver=receiver_text,
622632
method_name=self.func_name,
623633
arguments=args_text,
624-
full_call=get_text(node, wrapper_bytes),
634+
full_call=full_call_text,
625635
start_pos=base_offset + start_char,
626636
end_pos=base_offset + end_char,
627637
)
628638

629-
def _find_top_level_arg_node(self, target_node, wrapper_bytes: bytes):
639+
def _find_top_level_arg_node(self, target_node: Node, wrapper_bytes: bytes) -> Node | None:
630640
"""Find the top-level argument expression containing a nested target call.
631641
632642
Walks up the AST from target_node to the wrapper _d() call's argument_list.

0 commit comments

Comments
 (0)