Skip to content

Commit e207b83

Browse files
fix: handle assertThrows variable assignment in Java instrumentation
When assertThrows was assigned to a variable to validate exception properties, the transformation generated invalid Java syntax by replacing the assertThrows call with try-catch while leaving the variable assignment intact. Example of invalid output: IllegalArgumentException e = try { code(); } catch (Exception) {} This fix detects variable assignments, extracts the exception type from assertThrows arguments, and generates proper exception capture: IllegalArgumentException e = null; try { code(); } catch (IllegalArgumentException _cf_caught1) { e = _cf_caught1; } catch (Exception _cf_ignored1) {} Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent df5b6a2 commit e207b83

2 files changed

Lines changed: 247 additions & 15 deletions

File tree

codeflash/languages/java/remove_asserts.py

Lines changed: 144 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ class AssertionMatch:
166166
original_text: str = ""
167167
is_exception_assertion: bool = False
168168
lambda_body: str | None = None # For assertThrows lambda content
169+
variable_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
170+
variable_name: str | None = None # Name of assigned variable (e.g., "exception")
171+
exception_class: str | None = None # Exception class from assertThrows args
169172

170173

171174
class JavaAssertTransformer:
@@ -326,12 +329,32 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
326329
target_calls = self._extract_target_calls(args_content, match.end())
327330
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
328331

329-
# For assertThrows, extract the lambda body
332+
# For assertThrows, extract the lambda body and exception class
330333
lambda_body = None
334+
exception_class = None
331335
if is_exception and assertion_method == "assertThrows":
332336
lambda_body = self._extract_lambda_body(args_content)
337+
exception_class = self._extract_exception_class(args_content)
338+
339+
# Check if assertion is assigned to a variable
340+
var_type, var_name = self._detect_variable_assignment(source, start_pos)
341+
342+
# If variable assignment detected, adjust start_pos to include the entire line
343+
actual_start = start_pos
344+
actual_leading_ws = leading_ws
345+
if var_type:
346+
# Find the start of the line (beginning of variable declaration)
347+
line_start = source.rfind("\n", 0, start_pos)
348+
if line_start == -1:
349+
line_start = 0
350+
else:
351+
line_start += 1
352+
actual_start = line_start
353+
# Extract the actual leading whitespace from the start of the line
354+
line_content = source[line_start:start_pos]
355+
actual_leading_ws = line_content[:len(line_content) - len(line_content.lstrip())]
333356

334-
original_text = source[start_pos:end_pos]
357+
original_text = source[actual_start:end_pos]
335358

336359
# Determine statement type based on detected framework
337360
detected = self._detected_framework or "junit5"
@@ -342,15 +365,18 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
342365

343366
assertions.append(
344367
AssertionMatch(
345-
start_pos=start_pos,
368+
start_pos=actual_start,
346369
end_pos=end_pos,
347370
statement_type=stmt_type,
348371
assertion_method=assertion_method,
349372
target_calls=target_calls,
350-
leading_whitespace=leading_ws,
373+
leading_whitespace=actual_leading_ws,
351374
original_text=original_text,
352375
is_exception_assertion=is_exception,
353376
lambda_body=lambda_body,
377+
variable_type=var_type,
378+
variable_name=var_name,
379+
exception_class=exception_class,
354380
)
355381
)
356382

@@ -580,6 +606,85 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa
580606

581607
return target_calls
582608

609+
def _detect_variable_assignment(self, source: str, assertion_start: int) -> tuple[str | None, str | None]:
610+
"""Check if assertion is assigned to a variable.
611+
612+
Detects patterns like:
613+
IllegalArgumentException exception = assertThrows(...)
614+
Exception ex = assertThrows(...)
615+
616+
Args:
617+
source: The full source code.
618+
assertion_start: Start position of the assertion.
619+
620+
Returns:
621+
Tuple of (variable_type, variable_name) or (None, None).
622+
623+
"""
624+
# Look backwards from assertion_start to beginning of line
625+
line_start = source.rfind("\n", 0, assertion_start)
626+
if line_start == -1:
627+
line_start = 0
628+
else:
629+
line_start += 1
630+
631+
line_before_assert = source[line_start:assertion_start]
632+
633+
# Pattern: Type varName = assertXxx(...)
634+
# Handle generic types: Type<Generic> varName = ...
635+
pattern = r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$"
636+
match = re.search(pattern, line_before_assert)
637+
638+
if match:
639+
var_type = match.group(1).strip()
640+
var_name = match.group(2).strip()
641+
return var_type, var_name
642+
643+
return None, None
644+
645+
def _extract_exception_class(self, args_content: str) -> str | None:
646+
"""Extract exception class from assertThrows arguments.
647+
648+
Args:
649+
args_content: Content inside assertThrows parentheses.
650+
651+
Returns:
652+
Exception class name (e.g., "IllegalArgumentException") or None.
653+
654+
Example:
655+
assertThrows(IllegalArgumentException.class, ...) -> "IllegalArgumentException"
656+
657+
"""
658+
# First argument is the exception class reference (e.g., "IllegalArgumentException.class")
659+
# Split by comma, but respect nested parentheses and generics
660+
depth = 0
661+
current = []
662+
parts = []
663+
664+
for char in args_content:
665+
if char in "(<":
666+
depth += 1
667+
current.append(char)
668+
elif char in ")>":
669+
depth -= 1
670+
current.append(char)
671+
elif char == "," and depth == 0:
672+
parts.append("".join(current).strip())
673+
current = []
674+
else:
675+
current.append(char)
676+
677+
if current:
678+
parts.append("".join(current).strip())
679+
680+
if parts:
681+
exception_arg = parts[0].strip()
682+
# Remove .class suffix
683+
if exception_arg.endswith(".class"):
684+
return exception_arg[:-6].strip()
685+
686+
return None
687+
583688
def _extract_lambda_body(self, content: str) -> str | None:
584689
"""Extract the body of a lambda expression from assertThrows arguments.
585690
@@ -745,29 +850,53 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
745850
To:
746851
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
747852
853+
For variable assignments:
854+
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());
855+
To:
856+
IllegalArgumentException ex = null;
857+
try { code(); } catch (IllegalArgumentException e) { ex = e; } catch (Exception _cf_ignored1) {}
858+
748859
"""
749860
self.invocation_counter += 1
750861

862+
# Extract code to run from lambda body or target calls
863+
code_to_run = None
751864
if assertion.lambda_body:
752-
# Extract the actual code from the lambda
753865
code_to_run = assertion.lambda_body
754866
if not code_to_run.endswith(";"):
755867
code_to_run += ";"
756-
return (
757-
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
758-
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
759-
)
760-
761-
# If no lambda body found, try to extract from target calls
762-
if assertion.target_calls:
868+
elif assertion.target_calls:
763869
call = assertion.target_calls[0]
870+
code_to_run = call.full_call + ";"
871+
872+
if not code_to_run:
873+
# Fallback: comment out the assertion
874+
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
875+
876+
# Check if assertion is assigned to a variable
877+
if assertion.variable_name and assertion.variable_type:
878+
# Generate proper exception capture with variable assignment
879+
exception_type = assertion.exception_class or assertion.variable_type
880+
var_name = assertion.variable_name
881+
882+
# Use a unique catch variable name to avoid conflicts
883+
catch_var = f"_cf_caught{self.invocation_counter}"
884+
885+
# Get base indentation from leading whitespace (without newlines)
886+
base_indent = assertion.leading_whitespace.lstrip("\n\r")
887+
764888
return (
765-
f"{assertion.leading_whitespace}try {{ {call.full_call}; }} "
889+
f"{assertion.leading_whitespace}{assertion.variable_type} {var_name} = null;\n"
890+
f"{base_indent}try {{ {code_to_run} }} "
891+
f"catch ({exception_type} {catch_var}) {{ {var_name} = {catch_var}; }} "
766892
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
767893
)
768894

769-
# Fallback: comment out the assertion
770-
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
895+
# No variable assignment, use simple try-catch
896+
return (
897+
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
898+
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
899+
)
771900

772901

773902
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:

tests/test_java_assertion_removal.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,3 +1255,106 @@ def test_concurrent_assertion_with_assertj(self):
12551255
}"""
12561256
result = transform_java_assertions(source, "incrementAndGet")
12571257
assert result == expected
1258+
1259+
1260+
class TestAssertThrowsVariableAssignment:
1261+
"""Tests for assertThrows with variable assignment (Issue: exception handling instrumentation bug)."""
1262+
1263+
def test_assert_throws_with_variable_assignment_expression_lambda(self):
1264+
"""Test assertThrows assigned to variable with expression lambda."""
1265+
source = """\
1266+
@Test
1267+
void testNegativeInput() {
1268+
IllegalArgumentException exception = assertThrows(
1269+
IllegalArgumentException.class,
1270+
() -> calculator.fibonacci(-1)
1271+
);
1272+
assertEquals("Negative input not allowed", exception.getMessage());
1273+
}"""
1274+
expected = """\
1275+
@Test
1276+
void testNegativeInput() {
1277+
IllegalArgumentException exception = null;
1278+
try { calculator.fibonacci(-1); } catch (IllegalArgumentException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
1279+
assertEquals("Negative input not allowed", exception.getMessage());
1280+
}"""
1281+
result = transform_java_assertions(source, "fibonacci")
1282+
assert result == expected
1283+
1284+
def test_assert_throws_with_variable_assignment_block_lambda(self):
1285+
"""Test assertThrows assigned to variable with block lambda."""
1286+
source = """\
1287+
@Test
1288+
void testInvalidOperation() {
1289+
ArithmeticException ex = assertThrows(ArithmeticException.class, () -> {
1290+
calculator.divide(10, 0);
1291+
});
1292+
assertEquals("Division by zero", ex.getMessage());
1293+
}"""
1294+
expected = """\
1295+
@Test
1296+
void testInvalidOperation() {
1297+
ArithmeticException ex = null;
1298+
try { calculator.divide(10, 0); } catch (ArithmeticException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
1299+
assertEquals("Division by zero", ex.getMessage());
1300+
}"""
1301+
result = transform_java_assertions(source, "divide")
1302+
assert result == expected
1303+
1304+
def test_assert_throws_with_variable_assignment_generic_exception(self):
1305+
"""Test assertThrows with generic Exception type."""
1306+
source = """\
1307+
@Test
1308+
void testGenericException() {
1309+
Exception e = assertThrows(Exception.class, () -> processor.process(null));
1310+
assertNotNull(e.getMessage());
1311+
}"""
1312+
expected = """\
1313+
@Test
1314+
void testGenericException() {
1315+
Exception e = null;
1316+
try { processor.process(null); } catch (Exception _cf_caught1) { e = _cf_caught1; } catch (Exception _cf_ignored1) {}
1317+
assertNotNull(e.getMessage());
1318+
}"""
1319+
result = transform_java_assertions(source, "process")
1320+
assert result == expected
1321+
1322+
def test_assert_throws_without_variable_assignment(self):
1323+
"""Test assertThrows without variable assignment still works (no regression)."""
1324+
source = """\
1325+
@Test
1326+
void testThrowsException() {
1327+
assertThrows(IllegalArgumentException.class, () -> calculator.fibonacci(-1));
1328+
}"""
1329+
expected = """\
1330+
@Test
1331+
void testThrowsException() {
1332+
try { calculator.fibonacci(-1); } catch (Exception _cf_ignored1) {}
1333+
}"""
1334+
result = transform_java_assertions(source, "fibonacci")
1335+
assert result == expected
1336+
1337+
def test_assert_throws_with_variable_and_multi_line_lambda(self):
1338+
"""Test assertThrows with variable assignment and multi-line lambda."""
1339+
source = """\
1340+
@Test
1341+
void testComplexException() {
1342+
IllegalStateException exception = assertThrows(
1343+
IllegalStateException.class,
1344+
() -> {
1345+
processor.initialize();
1346+
processor.execute();
1347+
}
1348+
);
1349+
assertTrue(exception.getMessage().contains("not initialized"));
1350+
}"""
1351+
expected = """\
1352+
@Test
1353+
void testComplexException() {
1354+
IllegalStateException exception = null;
1355+
try { processor.initialize();
1356+
processor.execute(); } catch (IllegalStateException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
1357+
assertTrue(exception.getMessage().contains("not initialized"));
1358+
}"""
1359+
result = transform_java_assertions(source, "execute")
1360+
assert result == expected

0 commit comments

Comments
 (0)