Skip to content

Commit 171da45

Browse files
Optimize JavaAssertTransformer._infer_return_type
Runtime improvement (primary): the optimized version cuts the measured wall-clock time from ~11.9 ms to ~5.23 ms (≈127% speedup). Most of the previous time was spent parsing the entire argument list for JUnit value assertions; the profiler shows _split_top_level_args accounted for the dominant portion of runtime. What changed (specific optimizations): - Introduced _extract_first_arg that scans args_str once and stops as soon as the first top-level comma is encountered instead of calling _split_top_level_args to produce the full list. - The new routine keeps parsing state inline (depth, in_string, escape handling) and builds only the first-argument string (one small list buffer) rather than accumulating all arguments into a list of substrings. - Early-trimming and early-return avoid unnecessary work when the first argument is empty or when there are no commas. Why this is faster (mechanics): - Less work: in common cases we only need the first top-level argument to infer the expected type. Splitting all top-level arguments does O(n) work and allocates O(m) substrings for the entire argument list; extracting only the first arg is usually much cheaper (O(k) where k is length up to first top-level comma). - Fewer allocations: avoids creating many intermediate strings and list entries, which reduces Python object overhead and GC pressure. - Better branch locality: the loop exits earlier in the typical case (simple literals), so average time per call drops significantly — this shows up strongly in the large-loop and many-arg tests. Behavioral impact and trade-offs: - Semantics are preserved for the intended use: the function only needs the first argument to infer the return type, so replacing a full-split with a single-arg extractor keeps correctness for all existing tests. - Microbenchmarks for very trivial cases (e.g., assertTrue/assertFalse) show tiny per-call regressions (a few tens of ns) in some test samples; this is a reasonable trade-off for the substantial end-to-end runtime improvement, especially since the optimized code targets the hot path (value-assertion type inference) where gains are largest. When this helps most: - Calls with long argument lists or many nested/comma-containing constructs (nested generics, long sequences of arguments) — see the huge improvements in tests like large number of args and nested generics. - Hot loops and repeated inference (many_inferences_loop_stress, repeated_inference) — fewer allocations and earlier exits compound into large throughput gains. In short: the optimization reduces unnecessary parsing and allocations by only extracting what is required (the first top-level argument), which directly reduced CPU time and memory churn and produced the measured ~2x runtime improvement while keeping behavior for the intended use-cases.
1 parent 342a9c5 commit 171da45

1 file changed

Lines changed: 56 additions & 4 deletions

File tree

codeflash/languages/java/remove_asserts.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -941,15 +941,15 @@ def _infer_type_from_assertion_args(self, original_text: str, method: str) -> st
941941
elif args_str.endswith(")"):
942942
args_str = args_str[:-1]
943943

944-
# Split top-level args (respecting parens, strings, generics)
945-
args = self._split_top_level_args(args_str)
946-
if not args:
944+
# Fast-path: only extract the first top-level argument instead of splitting all arguments.
945+
first_arg = self._extract_first_arg(args_str)
946+
if not first_arg:
947947
return "Object"
948948

949949
# assertEquals has (expected, actual) or (expected, actual, message/delta)
950950
# Some overloads have (message, expected, actual) in JUnit 4 but JUnit 5 uses (expected, actual[, message])
951951
# Try the first argument as the expected value
952-
expected = args[0].strip()
952+
expected = first_arg.strip()
953953

954954
return self._type_from_literal(expected)
955955

@@ -1108,6 +1108,58 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
11081108
# Fallback: comment out the assertion
11091109
return f"{ws}// Removed assertThrows: could not extract callable"
11101110

1111+
def _extract_first_arg(self, args_str: str) -> str | None:
1112+
"""Extract the first top-level argument from args_str.
1113+
1114+
This is a lightweight alternative to splitting all top-level arguments;
1115+
it stops at the first top-level comma, respects nested delimiters and strings,
1116+
and avoids constructing the full argument list for better performance.
1117+
"""
1118+
n = len(args_str)
1119+
i = 0
1120+
1121+
# skip leading whitespace
1122+
while i < n and args_str[i].isspace():
1123+
i += 1
1124+
if i >= n:
1125+
return None
1126+
1127+
depth = 0
1128+
in_string = False
1129+
string_char = ""
1130+
cur: list[str] = []
1131+
1132+
while i < n:
1133+
ch = args_str[i]
1134+
1135+
if in_string:
1136+
cur.append(ch)
1137+
if ch == "\\" and i + 1 < n:
1138+
i += 1
1139+
cur.append(args_str[i])
1140+
elif ch == string_char:
1141+
in_string = False
1142+
elif ch in ('"', "'"):
1143+
in_string = True
1144+
string_char = ch
1145+
cur.append(ch)
1146+
elif ch in ("(", "<", "[", "{"):
1147+
depth += 1
1148+
cur.append(ch)
1149+
elif ch in (")", ">", "]", "}"):
1150+
depth -= 1
1151+
cur.append(ch)
1152+
elif ch == "," and depth == 0:
1153+
break
1154+
else:
1155+
cur.append(ch)
1156+
i += 1
1157+
1158+
# Trim trailing whitespace from the extracted argument
1159+
if not cur:
1160+
return None
1161+
return "".join(cur).rstrip()
1162+
11111163

11121164
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:
11131165
"""Transform Java test code by removing assertions and capturing function calls.

0 commit comments

Comments
 (0)