Skip to content

Commit abf2c98

Browse files
chore: merge omni-java into fix/java-exception-assignment-instrumentation
Resolved conflicts by merging the best of both branches: - Kept exception_class field from PR for better exception type detection - Adopted more general variable assignment detection from omni-java - Combined exception replacement logic to use exception_class with fallback - Added double catch (specific exception + generic Exception) for robustness - Merged test cases from both branches with updated expectations Changes: - Updated AssertionMatch to include all fields: assigned_var_type, assigned_var_name, exception_class - Lambda extraction now works for all exception assertions - Exception class extraction specifically for assertThrows - Variable assignment detection handles final modifier and fully qualified types - Exception replacement uses exception_class or falls back to assigned_var_type - All 80 tests passing Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2 parents 5c302bf + 6ba6a28 commit abf2c98

4 files changed

Lines changed: 155 additions & 67 deletions

File tree

codeflash/languages/java/line_profiler.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ def instrument_source(
8383
lines = source.splitlines(keepends=True)
8484

8585
# Process functions in reverse order to preserve line numbers
86-
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
86+
for func in sorted(functions, key=lambda f: f.starting_line, reverse=True):
8787
func_lines = self._instrument_function(func, lines, file_path, analyzer)
88-
start_idx = func.start_line - 1
89-
end_idx = func.end_line
88+
start_idx = func.starting_line - 1
89+
end_idx = func.ending_line
9090
lines = lines[:start_idx] + func_lines + lines[end_idx:]
9191

9292
instrumented_source = "".join(lines)
@@ -261,7 +261,7 @@ def _instrument_function(
261261
Instrumented function lines.
262262
263263
"""
264-
func_lines = lines[func.start_line - 1 : func.end_line]
264+
func_lines = lines[func.starting_line - 1 : func.ending_line]
265265
instrumented_lines = []
266266

267267
# Parse the function to find executable lines
@@ -271,15 +271,15 @@ def _instrument_function(
271271
tree = analyzer.parse(source.encode("utf8"))
272272
executable_lines = self._find_executable_lines(tree.root_node)
273273
except Exception as e:
274-
logger.warning("Failed to parse function %s: %s", func.name, e)
274+
logger.warning("Failed to parse function %s: %s", func.function_name, e)
275275
return func_lines
276276

277277
# Add profiling to each executable line
278278
function_entry_added = False
279279

280280
for local_idx, line in enumerate(func_lines):
281281
local_line_num = local_idx + 1 # 1-indexed within function
282-
global_line_num = func.start_line + local_idx # Global line number
282+
global_line_num = func.starting_line + local_idx # Global line number
283283
stripped = line.strip()
284284

285285
# Add enterFunction() call after the method's opening brace
@@ -409,7 +409,7 @@ def parse_results(profile_file: Path) -> dict:
409409
410410
"""
411411
if not profile_file.exists():
412-
return {"timings": {}, "unit": 1e-9, "raw_data": {}}
412+
return {"timings": {}, "unit": 1e-9, "raw_data": {}, "str_out": ""}
413413

414414
try:
415415
with profile_file.open("r") as f:
@@ -435,15 +435,17 @@ def parse_results(profile_file: Path) -> dict:
435435
"content": content,
436436
}
437437

438-
return {
438+
result = {
439439
"timings": timings,
440440
"unit": 1e-9, # nanoseconds
441441
"raw_data": data,
442442
}
443+
result["str_out"] = format_line_profile_results(result)
444+
return result
443445

444446
except Exception as e:
445447
logger.error("Failed to parse line profile results: %s", e)
446-
return {"timings": {}, "unit": 1e-9, "raw_data": {}}
448+
return {"timings": {}, "unit": 1e-9, "raw_data": {}, "str_out": ""}
447449

448450

449451
def format_line_profile_results(results: dict, file_path: Path | None = None) -> str:

codeflash/languages/java/remove_asserts.py

Lines changed: 76 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ class AssertionMatch:
166166
original_text: str = ""
167167
is_exception_assertion: bool = False
168168
lambda_body: str | None = None # For assertThrows lambda content
169-
variable_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
170-
variable_name: str | None = None # Name of assigned variable (e.g., "exception")
171-
exception_class: str | None = None # Exception class from assertThrows args
169+
assigned_var_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
170+
assigned_var_name: str | None = None # Name of assigned variable (e.g., "exception")
171+
exception_class: str | None = None # Exception class from assertThrows args (e.g., "IllegalArgumentException")
172172

173173

174174
class JavaAssertTransformer:
@@ -306,8 +306,11 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
306306
# - assertEquals (static import)
307307
# - Assert.assertEquals (JUnit 4)
308308
# - Assertions.assertEquals (JUnit 5)
309+
# - org.junit.jupiter.api.Assertions.assertEquals (fully qualified)
309310
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
310-
pattern = re.compile(rf"(\s*)((?:Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE)
311+
pattern = re.compile(
312+
rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE
313+
)
311314

312315
for match in pattern.finditer(source):
313316
leading_ws = match.group(1)
@@ -332,32 +335,41 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
332335
target_calls = self._extract_target_calls(args_content, match.end())
333336
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
334337

335-
# For assertThrows, extract the lambda body and exception class
338+
# For exception assertions, extract the lambda body
336339
lambda_body = None
337340
exception_class = None
338-
if is_exception and assertion_method == "assertThrows":
341+
if is_exception:
339342
lambda_body = self._extract_lambda_body(args_content)
340-
exception_class = self._extract_exception_class(args_content)
343+
# Extract exception class specifically for assertThrows
344+
if assertion_method == "assertThrows":
345+
exception_class = self._extract_exception_class(args_content)
341346

342347
# Check if assertion is assigned to a variable
343-
var_type, var_name = self._detect_variable_assignment(source, start_pos)
344-
345-
# If variable assignment detected, adjust start_pos to include the entire line
346-
actual_start = start_pos
347-
actual_leading_ws = leading_ws
348-
if var_type:
349-
# Find the start of the line (beginning of variable declaration)
350-
line_start = source.rfind("\n", 0, start_pos)
351-
if line_start == -1:
352-
line_start = 0
348+
# Detect variable assignment: Type var = assertXxx(...)
349+
# This applies to all assertions (assertThrows, assertTimeout, etc.)
350+
assigned_var_type = None
351+
assigned_var_name = None
352+
original_text = source[start_pos:end_pos]
353+
354+
before = source[:start_pos]
355+
last_nl_idx = before.rfind("\n")
356+
if last_nl_idx >= 0:
357+
line_prefix = source[last_nl_idx + 1 : start_pos]
358+
else:
359+
line_prefix = source[:start_pos]
360+
361+
var_match = re.match(r"([ \t]*)(?:final\s+)?([\w.<>\[\]]+)\s+(\w+)\s*=\s*$", line_prefix)
362+
if var_match:
363+
if last_nl_idx >= 0:
364+
start_pos = last_nl_idx
365+
leading_ws = "\n" + var_match.group(1)
353366
else:
354-
line_start += 1
355-
actual_start = line_start
356-
# Extract the actual leading whitespace from the start of the line
357-
line_content = source[line_start:start_pos]
358-
actual_leading_ws = line_content[:len(line_content) - len(line_content.lstrip())]
367+
start_pos = 0
368+
leading_ws = var_match.group(1)
359369

360-
original_text = source[actual_start:end_pos]
370+
assigned_var_type = var_match.group(2)
371+
assigned_var_name = var_match.group(3)
372+
original_text = source[start_pos:end_pos] # Update with adjusted range
361373

362374
# Determine statement type based on detected framework
363375
detected = self._detected_framework or "junit5"
@@ -368,17 +380,17 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
368380

369381
assertions.append(
370382
AssertionMatch(
371-
start_pos=actual_start,
383+
start_pos=start_pos,
372384
end_pos=end_pos,
373385
statement_type=stmt_type,
374386
assertion_method=assertion_method,
375387
target_calls=target_calls,
376-
leading_whitespace=actual_leading_ws,
388+
leading_whitespace=leading_ws,
377389
original_text=original_text,
378390
is_exception_assertion=is_exception,
379391
lambda_body=lambda_body,
380-
variable_type=var_type,
381-
variable_name=var_name,
392+
assigned_var_type=assigned_var_type,
393+
assigned_var_name=assigned_var_name,
382394
exception_class=exception_class,
383395
)
384396
)
@@ -709,9 +721,9 @@ def _extract_lambda_body(self, content: str) -> str | None:
709721
return brace_content.strip()
710722
else:
711723
# Expression lambda: () -> expr
712-
# Find the end (before the closing paren of assertThrows)
724+
# Find the end (before the closing paren of assertThrows, or comma at depth 0)
713725
depth = 0
714-
end = body_start
726+
end = len(content)
715727
for i, ch in enumerate(content[body_start:]):
716728
if ch == "(":
717729
depth += 1
@@ -720,6 +732,9 @@ def _extract_lambda_body(self, content: str) -> str | None:
720732
end = body_start + i
721733
break
722734
depth -= 1
735+
elif ch == "," and depth == 0:
736+
end = body_start + i
737+
break
723738
return content[body_start:end].strip()
724739

725740
return None
@@ -851,14 +866,17 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
851866
To:
852867
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
853868
854-
For variable assignments:
869+
When assigned to a variable:
855870
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());
856871
To:
857872
IllegalArgumentException ex = null;
858-
try { code(); } catch (IllegalArgumentException e) { ex = e; } catch (Exception _cf_ignored1) {}
873+
try { code(); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
859874
860875
"""
861876
self.invocation_counter += 1
877+
counter = self.invocation_counter
878+
ws = assertion.leading_whitespace
879+
base_indent = ws.lstrip("\n\r")
862880

863881
# Extract code to run from lambda body or target calls
864882
code_to_run = None
@@ -867,38 +885,39 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
867885
# Use a direct last-character check instead of .endswith for lower overhead
868886
if code_to_run and code_to_run[-1] != ";":
869887
code_to_run += ";"
870-
elif assertion.target_calls:
871-
call = assertion.target_calls[0]
872-
code_to_run = call.full_call + ";"
873-
874-
if not code_to_run:
875-
# Fallback: comment out the assertion
876-
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
877888

878-
# Check if assertion is assigned to a variable
879-
if assertion.variable_name and assertion.variable_type:
880-
# Generate proper exception capture with variable assignment
881-
exception_type = assertion.exception_class or assertion.variable_type
882-
var_name = assertion.variable_name
883-
884-
# Use a unique catch variable name to avoid conflicts
885-
catch_var = f"_cf_caught{self.invocation_counter}"
889+
# Handle variable assignment: Type var = assertThrows(...)
890+
if assertion.assigned_var_name and assertion.assigned_var_type:
891+
var_type = assertion.assigned_var_type
892+
var_name = assertion.assigned_var_name
893+
if assertion.assertion_method == "assertDoesNotThrow":
894+
if ";" not in assertion.lambda_body.strip():
895+
return f"{ws}{var_type} {var_name} = {assertion.lambda_body.strip()};"
896+
return f"{ws}{code_to_run}"
897+
# For assertThrows with variable assignment, use exception_class if available
898+
exception_type = assertion.exception_class or var_type
899+
return (
900+
f"{ws}{var_type} {var_name} = null;\n"
901+
f"{base_indent}try {{ {code_to_run} }} "
902+
f"catch ({exception_type} _cf_caught{counter}) {{ {var_name} = _cf_caught{counter}; }} "
903+
f"catch (Exception _cf_ignored{counter}) {{}}"
904+
)
886905

887-
# Get base indentation from leading whitespace (without newlines)
888-
base_indent = assertion.leading_whitespace.lstrip("\n\r")
906+
return (
907+
f"{ws}try {{ {code_to_run} }} "
908+
f"catch (Exception _cf_ignored{counter}) {{}}"
909+
)
889910

911+
# If no lambda body found, try to extract from target calls
912+
if assertion.target_calls:
913+
call = assertion.target_calls[0]
890914
return (
891-
f"{assertion.leading_whitespace}{assertion.variable_type} {var_name} = null;\n"
892-
f"{base_indent}try {{ {code_to_run} }} "
893-
f"catch ({exception_type} {catch_var}) {{ {var_name} = {catch_var}; }} "
894-
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
915+
f"{ws}try {{ {call.full_call}; }} "
916+
f"catch (Exception _cf_ignored{counter}) {{}}"
895917
)
896918

897-
# No variable assignment, use simple try-catch
898-
return (
899-
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
900-
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
901-
)
919+
# Fallback: comment out the assertion
920+
return f"{ws}// Removed assertThrows: could not extract callable"
902921

903922

904923
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:

codeflash/languages/java/support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def instrument_source_for_line_profiler(
322322

323323
return True
324324
except Exception as e:
325-
logger.error("Failed to instrument %s for line profiling: %s", func_info.name, e)
325+
logger.error("Failed to instrument %s for line profiling: %s", func_info.function_name, e)
326326
return False
327327

328328
def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:

tests/test_java_assertion_removal.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,41 @@ def test_concurrent_assertion_with_assertj(self):
12571257
assert result == expected
12581258

12591259

1260+
class TestFullyQualifiedAssertions:
1261+
"""Tests for fully qualified assertion calls like org.junit.jupiter.api.Assertions.assertXxx."""
1262+
1263+
def test_assert_timeout_fully_qualified_with_variable_assignment(self):
1264+
source = """\
1265+
@Test
1266+
void testLargeInput() {
1267+
Long result = org.junit.jupiter.api.Assertions.assertTimeout(
1268+
Duration.ofSeconds(1),
1269+
() -> Fibonacci.fibonacci(100_000)
1270+
);
1271+
}"""
1272+
expected = """\
1273+
@Test
1274+
void testLargeInput() {
1275+
Object _cf_result1 = Fibonacci.fibonacci(100_000);
1276+
}"""
1277+
result = transform_java_assertions(source, "fibonacci")
1278+
assert result == expected
1279+
1280+
def test_assert_equals_fully_qualified(self):
1281+
source = """\
1282+
@Test
1283+
void testAdd() {
1284+
org.junit.jupiter.api.Assertions.assertEquals(5, calc.add(2, 3));
1285+
}"""
1286+
expected = """\
1287+
@Test
1288+
void testAdd() {
1289+
Object _cf_result1 = calc.add(2, 3);
1290+
}"""
1291+
result = transform_java_assertions(source, "add")
1292+
assert result == expected
1293+
1294+
12601295
class TestAssertThrowsVariableAssignment:
12611296
"""Tests for assertThrows with variable assignment (Issue: exception handling instrumentation bug)."""
12621297

@@ -1358,3 +1393,35 @@ def test_assert_throws_with_variable_and_multi_line_lambda(self):
13581393
}"""
13591394
result = transform_java_assertions(source, "execute")
13601395
assert result == expected
1396+
1397+
def test_assert_throws_assigned_with_final_modifier(self):
1398+
"""Test assertThrows with final modifier on variable."""
1399+
source = """\
1400+
@Test
1401+
void testDivideByZero() {
1402+
final IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
1403+
}"""
1404+
expected = """\
1405+
@Test
1406+
void testDivideByZero() {
1407+
IllegalArgumentException ex = null;
1408+
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
1409+
}"""
1410+
result = transform_java_assertions(source, "divide")
1411+
assert result == expected
1412+
1413+
def test_assert_throws_assigned_with_qualified_assertions(self):
1414+
"""Test assertThrows with qualified assertion (Assertions.assertThrows)."""
1415+
source = """\
1416+
@Test
1417+
void testDivideByZero() {
1418+
IllegalArgumentException ex = Assertions.assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
1419+
}"""
1420+
expected = """\
1421+
@Test
1422+
void testDivideByZero() {
1423+
IllegalArgumentException ex = null;
1424+
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
1425+
}"""
1426+
result = transform_java_assertions(source, "divide")
1427+
assert result == expected

0 commit comments

Comments
 (0)