@@ -166,6 +166,9 @@ class AssertionMatch:
166166 original_text : str = ""
167167 is_exception_assertion : bool = False
168168 lambda_body : str | None = None # For assertThrows lambda content
169+ variable_type : str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
170+ variable_name : str | None = None # Name of assigned variable (e.g., "exception")
171+ exception_class : str | None = None # Exception class from assertThrows args
169172
170173
171174class JavaAssertTransformer :
@@ -326,12 +329,32 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
326329 target_calls = self ._extract_target_calls (args_content , match .end ())
327330 is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
328331
329- # For assertThrows, extract the lambda body
332+ # For assertThrows, extract the lambda body and exception class
330333 lambda_body = None
334+ exception_class = None
331335 if is_exception and assertion_method == "assertThrows" :
332336 lambda_body = self ._extract_lambda_body (args_content )
337+ exception_class = self ._extract_exception_class (args_content )
338+
339+ # Check if assertion is assigned to a variable
340+ var_type , var_name = self ._detect_variable_assignment (source , start_pos )
341+
342+ # If variable assignment detected, adjust start_pos to include the entire line
343+ actual_start = start_pos
344+ actual_leading_ws = leading_ws
345+ if var_type :
346+ # Find the start of the line (beginning of variable declaration)
347+ line_start = source .rfind ("\n " , 0 , start_pos )
348+ if line_start == - 1 :
349+ line_start = 0
350+ else :
351+ line_start += 1
352+ actual_start = line_start
353+ # Extract the actual leading whitespace from the start of the line
354+ line_content = source [line_start :start_pos ]
355+ actual_leading_ws = line_content [:len (line_content ) - len (line_content .lstrip ())]
333356
334- original_text = source [start_pos :end_pos ]
357+ original_text = source [actual_start :end_pos ]
335358
336359 # Determine statement type based on detected framework
337360 detected = self ._detected_framework or "junit5"
@@ -342,15 +365,18 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
342365
343366 assertions .append (
344367 AssertionMatch (
345- start_pos = start_pos ,
368+ start_pos = actual_start ,
346369 end_pos = end_pos ,
347370 statement_type = stmt_type ,
348371 assertion_method = assertion_method ,
349372 target_calls = target_calls ,
350- leading_whitespace = leading_ws ,
373+ leading_whitespace = actual_leading_ws ,
351374 original_text = original_text ,
352375 is_exception_assertion = is_exception ,
353376 lambda_body = lambda_body ,
377+ variable_type = var_type ,
378+ variable_name = var_name ,
379+ exception_class = exception_class ,
354380 )
355381 )
356382
@@ -580,6 +606,85 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa
580606
581607 return target_calls
582608
609+ def _detect_variable_assignment (self , source : str , assertion_start : int ) -> tuple [str | None , str | None ]:
610+ """Check if assertion is assigned to a variable.
611+
612+ Detects patterns like:
613+ IllegalArgumentException exception = assertThrows(...)
614+ Exception ex = assertThrows(...)
615+
616+ Args:
617+ source: The full source code.
618+ assertion_start: Start position of the assertion.
619+
620+ Returns:
621+ Tuple of (variable_type, variable_name) or (None, None).
622+
623+ """
624+ # Look backwards from assertion_start to beginning of line
625+ line_start = source .rfind ("\n " , 0 , assertion_start )
626+ if line_start == - 1 :
627+ line_start = 0
628+ else :
629+ line_start += 1
630+
631+ line_before_assert = source [line_start :assertion_start ]
632+
633+ # Pattern: Type varName = assertXxx(...)
634+ # Handle generic types: Type<Generic> varName = ...
635+ pattern = r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$"
636+ match = re .search (pattern , line_before_assert )
637+
638+ if match :
639+ var_type = match .group (1 ).strip ()
640+ var_name = match .group (2 ).strip ()
641+ return var_type , var_name
642+
643+ return None , None
644+
645+ def _extract_exception_class (self , args_content : str ) -> str | None :
646+ """Extract exception class from assertThrows arguments.
647+
648+ Args:
649+ args_content: Content inside assertThrows parentheses.
650+
651+ Returns:
652+ Exception class name (e.g., "IllegalArgumentException") or None.
653+
654+ Example:
655+ assertThrows(IllegalArgumentException.class, ...) -> "IllegalArgumentException"
656+
657+ """
658+ # First argument is the exception class reference (e.g., "IllegalArgumentException.class")
659+ # Split by comma, but respect nested parentheses and generics
660+ depth = 0
661+ current = []
662+ parts = []
663+
664+ for char in args_content :
665+ if char in "(<" :
666+ depth += 1
667+ current .append (char )
668+ elif char in ")>" :
669+ depth -= 1
670+ current .append (char )
671+ elif char == "," and depth == 0 :
672+ parts .append ("" .join (current ).strip ())
673+ current = []
674+ else :
675+ current .append (char )
676+
677+ if current :
678+ parts .append ("" .join (current ).strip ())
679+
680+ if parts :
681+ exception_arg = parts [0 ].strip ()
682+ # Remove .class suffix
683+ if exception_arg .endswith (".class" ):
684+ return exception_arg [:- 6 ].strip ()
685+
686+ return None
687+
583688 def _extract_lambda_body (self , content : str ) -> str | None :
584689 """Extract the body of a lambda expression from assertThrows arguments.
585690
@@ -745,29 +850,53 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
745850 To:
746851 try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
747852
853+ For variable assignments:
854+ IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());
855+ To:
856+ IllegalArgumentException ex = null;
857+ try { code(); } catch (IllegalArgumentException e) { ex = e; } catch (Exception _cf_ignored1) {}
858+
748859 """
749860 self .invocation_counter += 1
750861
862+ # Extract code to run from lambda body or target calls
863+ code_to_run = None
751864 if assertion .lambda_body :
752- # Extract the actual code from the lambda
753865 code_to_run = assertion .lambda_body
754866 if not code_to_run .endswith (";" ):
755867 code_to_run += ";"
756- return (
757- f"{ assertion .leading_whitespace } try {{ { code_to_run } }} "
758- f"catch (Exception _cf_ignored{ self .invocation_counter } ) {{}}"
759- )
760-
761- # If no lambda body found, try to extract from target calls
762- if assertion .target_calls :
868+ elif assertion .target_calls :
763869 call = assertion .target_calls [0 ]
870+ code_to_run = call .full_call + ";"
871+
872+ if not code_to_run :
873+ # Fallback: comment out the assertion
874+ return f"{ assertion .leading_whitespace } // Removed assertThrows: could not extract callable"
875+
876+ # Check if assertion is assigned to a variable
877+ if assertion .variable_name and assertion .variable_type :
878+ # Generate proper exception capture with variable assignment
879+ exception_type = assertion .exception_class or assertion .variable_type
880+ var_name = assertion .variable_name
881+
882+ # Use a unique catch variable name to avoid conflicts
883+ catch_var = f"_cf_caught{ self .invocation_counter } "
884+
885+ # Get base indentation from leading whitespace (without newlines)
886+ base_indent = assertion .leading_whitespace .lstrip ("\n \r " )
887+
764888 return (
765- f"{ assertion .leading_whitespace } try {{ { call .full_call } ; }} "
889+ f"{ assertion .leading_whitespace } { assertion .variable_type } { var_name } = null;\n "
890+ f"{ base_indent } try {{ { code_to_run } }} "
891+ f"catch ({ exception_type } { catch_var } ) {{ { var_name } = { catch_var } ; }} "
766892 f"catch (Exception _cf_ignored{ self .invocation_counter } ) {{}}"
767893 )
768894
769- # Fallback: comment out the assertion
770- return f"{ assertion .leading_whitespace } // Removed assertThrows: could not extract callable"
895+ # No variable assignment, use simple try-catch
896+ return (
897+ f"{ assertion .leading_whitespace } try {{ { code_to_run } }} "
898+ f"catch (Exception _cf_ignored{ self .invocation_counter } ) {{}}"
899+ )
771900
772901
773902def transform_java_assertions (source : str , function_name : str , qualified_name : str | None = None ) -> str :
0 commit comments