Skip to content

Commit b836321

Browse files
Merge pull request #1443 from codeflash-ai/fix/java-exception-assignment-instrumentation
fix: handle assertThrows variable assignment in Java instrumentation
2 parents 6ba6a28 + abf2c98 commit b836321

2 files changed

Lines changed: 180 additions & 37 deletions

File tree

codeflash/languages/java/remove_asserts.py

Lines changed: 101 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,9 @@ class AssertionMatch:
166166
original_text: str = ""
167167
is_exception_assertion: bool = False
168168
lambda_body: str | None = None # For assertThrows lambda content
169-
assigned_var_type: str | None = None # For Type var = assertThrows(...)
170-
assigned_var_name: str | None = None
169+
assigned_var_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
170+
assigned_var_name: str | None = None # Name of assigned variable (e.g., "exception")
171+
exception_class: str | None = None # Exception class from assertThrows args (e.g., "IllegalArgumentException")
171172

172173

173174
class JavaAssertTransformer:
@@ -187,6 +188,9 @@ def __init__(
187188
self.invocation_counter = 0
188189
self._detected_framework: str | None = None
189190

191+
# Precompile the assignment-detection regex to avoid recompiling on each call.
192+
self._assign_re = re.compile(r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$")
193+
190194
def transform(self, source: str) -> str:
191195
"""Remove assertions from source code, preserving target function calls.
192196
@@ -333,15 +337,19 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
333337

334338
# For exception assertions, extract the lambda body
335339
lambda_body = None
340+
exception_class = None
336341
if is_exception:
337342
lambda_body = self._extract_lambda_body(args_content)
343+
# Extract exception class specifically for assertThrows
344+
if assertion_method == "assertThrows":
345+
exception_class = self._extract_exception_class(args_content)
338346

339-
original_text = source[start_pos:end_pos]
340-
347+
# Check if assertion is assigned to a variable
341348
# Detect variable assignment: Type var = assertXxx(...)
342349
# This applies to all assertions (assertThrows, assertTimeout, etc.)
343350
assigned_var_type = None
344351
assigned_var_name = None
352+
original_text = source[start_pos:end_pos]
345353

346354
before = source[:start_pos]
347355
last_nl_idx = before.rfind("\n")
@@ -361,7 +369,7 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
361369

362370
assigned_var_type = var_match.group(2)
363371
assigned_var_name = var_match.group(3)
364-
original_text = source[start_pos:end_pos]
372+
original_text = source[start_pos:end_pos] # Update with adjusted range
365373

366374
# Determine statement type based on detected framework
367375
detected = self._detected_framework or "junit5"
@@ -383,6 +391,7 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
383391
lambda_body=lambda_body,
384392
assigned_var_type=assigned_var_type,
385393
assigned_var_name=assigned_var_name,
394+
exception_class=exception_class,
386395
)
387396
)
388397

@@ -612,6 +621,83 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa
612621

613622
return target_calls
614623

624+
def _detect_variable_assignment(self, source: str, assertion_start: int) -> tuple[str | None, str | None]:
625+
"""Check if assertion is assigned to a variable.
626+
627+
Detects patterns like:
628+
IllegalArgumentException exception = assertThrows(...)
629+
Exception ex = assertThrows(...)
630+
631+
Args:
632+
source: The full source code.
633+
assertion_start: Start position of the assertion.
634+
635+
Returns:
636+
Tuple of (variable_type, variable_name) or (None, None).
637+
638+
"""
639+
# Look backwards from assertion_start to beginning of line
640+
line_start = source.rfind("\n", 0, assertion_start)
641+
if line_start == -1:
642+
line_start = 0
643+
else:
644+
line_start += 1
645+
646+
# Pattern: Type varName = assertXxx(...)
647+
# Handle generic types: Type<Generic> varName = ...
648+
match = self._assign_re.search(source, line_start, assertion_start)
649+
650+
651+
if match:
652+
var_type = match.group(1).strip()
653+
var_name = match.group(2).strip()
654+
return var_type, var_name
655+
656+
return None, None
657+
658+
def _extract_exception_class(self, args_content: str) -> str | None:
659+
"""Extract exception class from assertThrows arguments.
660+
661+
Args:
662+
args_content: Content inside assertThrows parentheses.
663+
664+
Returns:
665+
Exception class name (e.g., "IllegalArgumentException") or None.
666+
667+
Example:
668+
assertThrows(IllegalArgumentException.class, ...) -> "IllegalArgumentException"
669+
670+
"""
671+
# First argument is the exception class reference (e.g., "IllegalArgumentException.class")
672+
# Split by comma, but respect nested parentheses and generics
673+
depth = 0
674+
current = []
675+
parts = []
676+
677+
for char in args_content:
678+
if char in "(<":
679+
depth += 1
680+
current.append(char)
681+
elif char in ")>":
682+
depth -= 1
683+
current.append(char)
684+
elif char == "," and depth == 0:
685+
parts.append("".join(current).strip())
686+
current = []
687+
else:
688+
current.append(char)
689+
690+
if current:
691+
parts.append("".join(current).strip())
692+
693+
if parts:
694+
exception_arg = parts[0].strip()
695+
# Remove .class suffix
696+
if exception_arg.endswith(".class"):
697+
return exception_arg[:-6].strip()
698+
699+
return None
700+
615701
def _extract_lambda_body(self, content: str) -> str | None:
616702
"""Extract the body of a lambda expression from assertThrows arguments.
617703
@@ -781,20 +867,23 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
781867
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
782868
783869
When assigned to a variable:
784-
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
870+
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());
785871
To:
786872
IllegalArgumentException ex = null;
787-
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
873+
try { code(); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
788874
789875
"""
790876
self.invocation_counter += 1
791877
counter = self.invocation_counter
792878
ws = assertion.leading_whitespace
793879
base_indent = ws.lstrip("\n\r")
794880

881+
# Extract code to run from lambda body or target calls
882+
code_to_run = None
795883
if assertion.lambda_body:
796884
code_to_run = assertion.lambda_body
797-
if not code_to_run.endswith(";"):
885+
# Use a direct last-character check instead of .endswith for lower overhead
886+
if code_to_run and code_to_run[-1] != ";":
798887
code_to_run += ";"
799888

800889
# Handle variable assignment: Type var = assertThrows(...)
@@ -805,10 +894,13 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
805894
if ";" not in assertion.lambda_body.strip():
806895
return f"{ws}{var_type} {var_name} = {assertion.lambda_body.strip()};"
807896
return f"{ws}{code_to_run}"
897+
# For assertThrows with variable assignment, use exception_class if available
898+
exception_type = assertion.exception_class or var_type
808899
return (
809900
f"{ws}{var_type} {var_name} = null;\n"
810901
f"{base_indent}try {{ {code_to_run} }} "
811-
f"catch ({var_type} _cf_caught{counter}) {{ {var_name} = _cf_caught{counter}; }}"
902+
f"catch ({exception_type} _cf_caught{counter}) {{ {var_name} = _cf_caught{counter}; }} "
903+
f"catch (Exception _cf_ignored{counter}) {{}}"
812904
)
813905

814906
return (

tests/test_java_assertion_removal.py

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,74 +1293,125 @@ def test_assert_equals_fully_qualified(self):
12931293

12941294

12951295
class TestAssertThrowsVariableAssignment:
1296-
"""Tests for assertThrows assigned to a variable: Type var = assertThrows(...)."""
1296+
"""Tests for assertThrows with variable assignment (Issue: exception handling instrumentation bug)."""
12971297

1298-
def test_assert_throws_assigned_to_variable(self):
1298+
def test_assert_throws_with_variable_assignment_expression_lambda(self):
1299+
"""Test assertThrows assigned to variable with expression lambda."""
12991300
source = """\
13001301
@Test
1301-
void testDivideByZero() {
1302-
Calculator calc = new Calculator();
1303-
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
1304-
assertEquals("Cannot divide by zero", ex.getMessage());
1302+
void testNegativeInput() {
1303+
IllegalArgumentException exception = assertThrows(
1304+
IllegalArgumentException.class,
1305+
() -> calculator.fibonacci(-1)
1306+
);
1307+
assertEquals("Negative input not allowed", exception.getMessage());
13051308
}"""
13061309
expected = """\
13071310
@Test
1308-
void testDivideByZero() {
1309-
Calculator calc = new Calculator();
1310-
IllegalArgumentException ex = null;
1311-
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
1312-
assertEquals("Cannot divide by zero", ex.getMessage());
1311+
void testNegativeInput() {
1312+
IllegalArgumentException exception = null;
1313+
try { calculator.fibonacci(-1); } catch (IllegalArgumentException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
1314+
assertEquals("Negative input not allowed", exception.getMessage());
13131315
}"""
1314-
result = transform_java_assertions(source, "divide")
1316+
result = transform_java_assertions(source, "fibonacci")
13151317
assert result == expected
13161318

1317-
def test_assert_throws_assigned_to_variable_block_lambda(self):
1319+
def test_assert_throws_with_variable_assignment_block_lambda(self):
1320+
"""Test assertThrows assigned to variable with block lambda."""
13181321
source = """\
13191322
@Test
1320-
void testDivideByZero() {
1323+
void testInvalidOperation() {
13211324
ArithmeticException ex = assertThrows(ArithmeticException.class, () -> {
1322-
calculator.divide(1, 0);
1325+
calculator.divide(10, 0);
13231326
});
1327+
assertEquals("Division by zero", ex.getMessage());
13241328
}"""
13251329
expected = """\
13261330
@Test
1327-
void testDivideByZero() {
1331+
void testInvalidOperation() {
13281332
ArithmeticException ex = null;
1329-
try { calculator.divide(1, 0); } catch (ArithmeticException _cf_caught1) { ex = _cf_caught1; }
1333+
try { calculator.divide(10, 0); } catch (ArithmeticException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
1334+
assertEquals("Division by zero", ex.getMessage());
13301335
}"""
13311336
result = transform_java_assertions(source, "divide")
13321337
assert result == expected
13331338

1334-
def test_assert_throws_assigned_with_final_modifier(self):
1339+
def test_assert_throws_with_variable_assignment_generic_exception(self):
1340+
"""Test assertThrows with generic Exception type."""
13351341
source = """\
13361342
@Test
1337-
void testDivideByZero() {
1338-
final IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
1343+
void testGenericException() {
1344+
Exception e = assertThrows(Exception.class, () -> processor.process(null));
1345+
assertNotNull(e.getMessage());
13391346
}"""
13401347
expected = """\
13411348
@Test
1342-
void testDivideByZero() {
1343-
IllegalArgumentException ex = null;
1344-
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
1349+
void testGenericException() {
1350+
Exception e = null;
1351+
try { processor.process(null); } catch (Exception _cf_caught1) { e = _cf_caught1; } catch (Exception _cf_ignored1) {}
1352+
assertNotNull(e.getMessage());
13451353
}"""
1346-
result = transform_java_assertions(source, "divide")
1354+
result = transform_java_assertions(source, "process")
13471355
assert result == expected
13481356

1349-
def test_assert_throws_not_assigned_unchanged(self):
1357+
def test_assert_throws_without_variable_assignment(self):
1358+
"""Test assertThrows without variable assignment still works (no regression)."""
1359+
source = """\
1360+
@Test
1361+
void testThrowsException() {
1362+
assertThrows(IllegalArgumentException.class, () -> calculator.fibonacci(-1));
1363+
}"""
1364+
expected = """\
1365+
@Test
1366+
void testThrowsException() {
1367+
try { calculator.fibonacci(-1); } catch (Exception _cf_ignored1) {}
1368+
}"""
1369+
result = transform_java_assertions(source, "fibonacci")
1370+
assert result == expected
1371+
1372+
def test_assert_throws_with_variable_and_multi_line_lambda(self):
1373+
"""Test assertThrows with variable assignment and multi-line lambda."""
1374+
source = """\
1375+
@Test
1376+
void testComplexException() {
1377+
IllegalStateException exception = assertThrows(
1378+
IllegalStateException.class,
1379+
() -> {
1380+
processor.initialize();
1381+
processor.execute();
1382+
}
1383+
);
1384+
assertTrue(exception.getMessage().contains("not initialized"));
1385+
}"""
1386+
expected = """\
1387+
@Test
1388+
void testComplexException() {
1389+
IllegalStateException exception = null;
1390+
try { processor.initialize();
1391+
processor.execute(); } catch (IllegalStateException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
1392+
assertTrue(exception.getMessage().contains("not initialized"));
1393+
}"""
1394+
result = transform_java_assertions(source, "execute")
1395+
assert result == expected
1396+
1397+
def test_assert_throws_assigned_with_final_modifier(self):
1398+
"""Test assertThrows with final modifier on variable."""
13501399
source = """\
13511400
@Test
13521401
void testDivideByZero() {
1353-
assertThrows(IllegalArgumentException.class, () -> calculator.divide(1, 0));
1402+
final IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
13541403
}"""
13551404
expected = """\
13561405
@Test
13571406
void testDivideByZero() {
1358-
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
1407+
IllegalArgumentException ex = null;
1408+
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
13591409
}"""
13601410
result = transform_java_assertions(source, "divide")
13611411
assert result == expected
13621412

13631413
def test_assert_throws_assigned_with_qualified_assertions(self):
1414+
"""Test assertThrows with qualified assertion (Assertions.assertThrows)."""
13641415
source = """\
13651416
@Test
13661417
void testDivideByZero() {
@@ -1370,7 +1421,7 @@ def test_assert_throws_assigned_with_qualified_assertions(self):
13701421
@Test
13711422
void testDivideByZero() {
13721423
IllegalArgumentException ex = null;
1373-
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
1424+
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
13741425
}"""
13751426
result = transform_java_assertions(source, "divide")
13761427
assert result == expected

0 commit comments

Comments
 (0)