@@ -166,8 +166,9 @@ class AssertionMatch:
166166 original_text : str = ""
167167 is_exception_assertion : bool = False
168168 lambda_body : str | None = None # For assertThrows lambda content
169- assigned_var_type : str | None = None # For Type var = assertThrows(...)
170- assigned_var_name : str | None = None
169+ assigned_var_type : str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
170+ assigned_var_name : str | None = None # Name of assigned variable (e.g., "exception")
171+ exception_class : str | None = None # Exception class from assertThrows args (e.g., "IllegalArgumentException")
171172
172173
173174class JavaAssertTransformer :
@@ -187,6 +188,9 @@ def __init__(
187188 self .invocation_counter = 0
188189 self ._detected_framework : str | None = None
189190
191+ # Precompile the assignment-detection regex to avoid recompiling on each call.
192+ self ._assign_re = re .compile (r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$" )
193+
190194 def transform (self , source : str ) -> str :
191195 """Remove assertions from source code, preserving target function calls.
192196
@@ -333,15 +337,19 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
333337
334338 # For exception assertions, extract the lambda body
335339 lambda_body = None
340+ exception_class = None
336341 if is_exception :
337342 lambda_body = self ._extract_lambda_body (args_content )
343+ # Extract exception class specifically for assertThrows
344+ if assertion_method == "assertThrows" :
345+ exception_class = self ._extract_exception_class (args_content )
338346
339- original_text = source [start_pos :end_pos ]
340-
347+ # Check if assertion is assigned to a variable
341348 # Detect variable assignment: Type var = assertXxx(...)
342349 # This applies to all assertions (assertThrows, assertTimeout, etc.)
343350 assigned_var_type = None
344351 assigned_var_name = None
352+ original_text = source [start_pos :end_pos ]
345353
346354 before = source [:start_pos ]
347355 last_nl_idx = before .rfind ("\n " )
@@ -361,7 +369,7 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
361369
362370 assigned_var_type = var_match .group (2 )
363371 assigned_var_name = var_match .group (3 )
364- original_text = source [start_pos :end_pos ]
372+ original_text = source [start_pos :end_pos ] # Update with adjusted range
365373
366374 # Determine statement type based on detected framework
367375 detected = self ._detected_framework or "junit5"
@@ -383,6 +391,7 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
383391 lambda_body = lambda_body ,
384392 assigned_var_type = assigned_var_type ,
385393 assigned_var_name = assigned_var_name ,
394+ exception_class = exception_class ,
386395 )
387396 )
388397
@@ -612,6 +621,83 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa
612621
613622 return target_calls
614623
624+ def _detect_variable_assignment (self , source : str , assertion_start : int ) -> tuple [str | None , str | None ]:
625+ """Check if assertion is assigned to a variable.
626+
627+ Detects patterns like:
628+ IllegalArgumentException exception = assertThrows(...)
629+ Exception ex = assertThrows(...)
630+
631+ Args:
632+ source: The full source code.
633+ assertion_start: Start position of the assertion.
634+
635+ Returns:
636+ Tuple of (variable_type, variable_name) or (None, None).
637+
638+ """
639+ # Look backwards from assertion_start to beginning of line
640+ line_start = source .rfind ("\n " , 0 , assertion_start )
641+ if line_start == - 1 :
642+ line_start = 0
643+ else :
644+ line_start += 1
645+
646+ # Pattern: Type varName = assertXxx(...)
647+ # Handle generic types: Type<Generic> varName = ...
648+ match = self ._assign_re .search (source , line_start , assertion_start )
649+
650+
651+ if match :
652+ var_type = match .group (1 ).strip ()
653+ var_name = match .group (2 ).strip ()
654+ return var_type , var_name
655+
656+ return None , None
657+
658+ def _extract_exception_class (self , args_content : str ) -> str | None :
659+ """Extract exception class from assertThrows arguments.
660+
661+ Args:
662+ args_content: Content inside assertThrows parentheses.
663+
664+ Returns:
665+ Exception class name (e.g., "IllegalArgumentException") or None.
666+
667+ Example:
668+ assertThrows(IllegalArgumentException.class, ...) -> "IllegalArgumentException"
669+
670+ """
671+ # First argument is the exception class reference (e.g., "IllegalArgumentException.class")
672+ # Split by comma, but respect nested parentheses and generics
673+ depth = 0
674+ current = []
675+ parts = []
676+
677+ for char in args_content :
678+ if char in "(<" :
679+ depth += 1
680+ current .append (char )
681+ elif char in ")>" :
682+ depth -= 1
683+ current .append (char )
684+ elif char == "," and depth == 0 :
685+ parts .append ("" .join (current ).strip ())
686+ current = []
687+ else :
688+ current .append (char )
689+
690+ if current :
691+ parts .append ("" .join (current ).strip ())
692+
693+ if parts :
694+ exception_arg = parts [0 ].strip ()
695+ # Remove .class suffix
696+ if exception_arg .endswith (".class" ):
697+ return exception_arg [:- 6 ].strip ()
698+
699+ return None
700+
615701 def _extract_lambda_body (self , content : str ) -> str | None :
616702 """Extract the body of a lambda expression from assertThrows arguments.
617703
@@ -781,20 +867,23 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
781867 try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
782868
783869 When assigned to a variable:
784- IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0 ));
870+ IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code( ));
785871 To:
786872 IllegalArgumentException ex = null;
787- try { calc.divide(1, 0 ); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
873+ try { code( ); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) { }
788874
789875 """
790876 self .invocation_counter += 1
791877 counter = self .invocation_counter
792878 ws = assertion .leading_whitespace
793879 base_indent = ws .lstrip ("\n \r " )
794880
881+ # Extract code to run from lambda body or target calls
882+ code_to_run = None
795883 if assertion .lambda_body :
796884 code_to_run = assertion .lambda_body
797- if not code_to_run .endswith (";" ):
885+ # Use a direct last-character check instead of .endswith for lower overhead
886+ if code_to_run and code_to_run [- 1 ] != ";" :
798887 code_to_run += ";"
799888
800889 # Handle variable assignment: Type var = assertThrows(...)
@@ -805,10 +894,13 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
805894 if ";" not in assertion .lambda_body .strip ():
806895 return f"{ ws } { var_type } { var_name } = { assertion .lambda_body .strip ()} ;"
807896 return f"{ ws } { code_to_run } "
897+ # For assertThrows with variable assignment, use exception_class if available
898+ exception_type = assertion .exception_class or var_type
808899 return (
809900 f"{ ws } { var_type } { var_name } = null;\n "
810901 f"{ base_indent } try {{ { code_to_run } }} "
811- f"catch ({ var_type } _cf_caught{ counter } ) {{ { var_name } = _cf_caught{ counter } ; }}"
902+ f"catch ({ exception_type } _cf_caught{ counter } ) {{ { var_name } = _cf_caught{ counter } ; }} "
903+ f"catch (Exception _cf_ignored{ counter } ) {{}}"
812904 )
813905
814906 return (
0 commit comments