@@ -166,9 +166,9 @@ class AssertionMatch:
166166 original_text : str = ""
167167 is_exception_assertion : bool = False
168168 lambda_body : str | None = None # For assertThrows lambda content
169- variable_type : str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
170- variable_name : str | None = None # Name of assigned variable (e.g., "exception")
171- exception_class : str | None = None # Exception class from assertThrows args
169+ assigned_var_type : str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
170+ assigned_var_name : str | None = None # Name of assigned variable (e.g., "exception")
171+ exception_class : str | None = None # Exception class from assertThrows args (e.g., "IllegalArgumentException")
172172
173173
174174class JavaAssertTransformer :
@@ -306,8 +306,11 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
306306 # - assertEquals (static import)
307307 # - Assert.assertEquals (JUnit 4)
308308 # - Assertions.assertEquals (JUnit 5)
309+ # - org.junit.jupiter.api.Assertions.assertEquals (fully qualified)
309310 all_assertions = "|" .join (JUNIT5_ALL_ASSERTIONS )
310- pattern = re .compile (rf"(\s*)((?:Assert(?:ions)?\.)?({ all_assertions } ))\s*\(" , re .MULTILINE )
311+ pattern = re .compile (
312+ rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({ all_assertions } ))\s*\(" , re .MULTILINE
313+ )
311314
312315 for match in pattern .finditer (source ):
313316 leading_ws = match .group (1 )
@@ -332,32 +335,41 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
332335 target_calls = self ._extract_target_calls (args_content , match .end ())
333336 is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
334337
335- # For assertThrows , extract the lambda body and exception class
338+ # For exception assertions , extract the lambda body
336339 lambda_body = None
337340 exception_class = None
338- if is_exception and assertion_method == "assertThrows" :
341+ if is_exception :
339342 lambda_body = self ._extract_lambda_body (args_content )
340- exception_class = self ._extract_exception_class (args_content )
343+ # Extract exception class specifically for assertThrows
344+ if assertion_method == "assertThrows" :
345+ exception_class = self ._extract_exception_class (args_content )
341346
342347 # Check if assertion is assigned to a variable
343- var_type , var_name = self ._detect_variable_assignment (source , start_pos )
344-
345- # If variable assignment detected, adjust start_pos to include the entire line
346- actual_start = start_pos
347- actual_leading_ws = leading_ws
348- if var_type :
349- # Find the start of the line (beginning of variable declaration)
350- line_start = source .rfind ("\n " , 0 , start_pos )
351- if line_start == - 1 :
352- line_start = 0
348+ # Detect variable assignment: Type var = assertXxx(...)
349+ # This applies to all assertions (assertThrows, assertTimeout, etc.)
350+ assigned_var_type = None
351+ assigned_var_name = None
352+ original_text = source [start_pos :end_pos ]
353+
354+ before = source [:start_pos ]
355+ last_nl_idx = before .rfind ("\n " )
356+ if last_nl_idx >= 0 :
357+ line_prefix = source [last_nl_idx + 1 : start_pos ]
358+ else :
359+ line_prefix = source [:start_pos ]
360+
361+ var_match = re .match (r"([ \t]*)(?:final\s+)?([\w.<>\[\]]+)\s+(\w+)\s*=\s*$" , line_prefix )
362+ if var_match :
363+ if last_nl_idx >= 0 :
364+ start_pos = last_nl_idx
365+ leading_ws = "\n " + var_match .group (1 )
353366 else :
354- line_start += 1
355- actual_start = line_start
356- # Extract the actual leading whitespace from the start of the line
357- line_content = source [line_start :start_pos ]
358- actual_leading_ws = line_content [:len (line_content ) - len (line_content .lstrip ())]
367+ start_pos = 0
368+ leading_ws = var_match .group (1 )
359369
360- original_text = source [actual_start :end_pos ]
370+ assigned_var_type = var_match .group (2 )
371+ assigned_var_name = var_match .group (3 )
372+ original_text = source [start_pos :end_pos ] # Update with adjusted range
361373
362374 # Determine statement type based on detected framework
363375 detected = self ._detected_framework or "junit5"
@@ -368,17 +380,17 @@ def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
368380
369381 assertions .append (
370382 AssertionMatch (
371- start_pos = actual_start ,
383+ start_pos = start_pos ,
372384 end_pos = end_pos ,
373385 statement_type = stmt_type ,
374386 assertion_method = assertion_method ,
375387 target_calls = target_calls ,
376- leading_whitespace = actual_leading_ws ,
388+ leading_whitespace = leading_ws ,
377389 original_text = original_text ,
378390 is_exception_assertion = is_exception ,
379391 lambda_body = lambda_body ,
380- variable_type = var_type ,
381- variable_name = var_name ,
392+ assigned_var_type = assigned_var_type ,
393+ assigned_var_name = assigned_var_name ,
382394 exception_class = exception_class ,
383395 )
384396 )
@@ -709,9 +721,9 @@ def _extract_lambda_body(self, content: str) -> str | None:
709721 return brace_content .strip ()
710722 else :
711723 # Expression lambda: () -> expr
712- # Find the end (before the closing paren of assertThrows)
724+ # Find the end (before the closing paren of assertThrows, or comma at depth 0 )
713725 depth = 0
714- end = body_start
726+ end = len ( content )
715727 for i , ch in enumerate (content [body_start :]):
716728 if ch == "(" :
717729 depth += 1
@@ -720,6 +732,9 @@ def _extract_lambda_body(self, content: str) -> str | None:
720732 end = body_start + i
721733 break
722734 depth -= 1
735+ elif ch == "," and depth == 0 :
736+ end = body_start + i
737+ break
723738 return content [body_start :end ].strip ()
724739
725740 return None
@@ -851,14 +866,17 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
851866 To:
852867 try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
853868
854- For variable assignments :
869+ When assigned to a variable :
855870 IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());
856871 To:
857872 IllegalArgumentException ex = null;
858- try { code(); } catch (IllegalArgumentException e ) { ex = e ; } catch (Exception _cf_ignored1) {}
873+ try { code(); } catch (IllegalArgumentException _cf_caught1 ) { ex = _cf_caught1 ; } catch (Exception _cf_ignored1) {}
859874
860875 """
861876 self .invocation_counter += 1
877+ counter = self .invocation_counter
878+ ws = assertion .leading_whitespace
879+ base_indent = ws .lstrip ("\n \r " )
862880
863881 # Extract code to run from lambda body or target calls
864882 code_to_run = None
@@ -867,38 +885,39 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
867885 # Use a direct last-character check instead of .endswith for lower overhead
868886 if code_to_run and code_to_run [- 1 ] != ";" :
869887 code_to_run += ";"
870- elif assertion .target_calls :
871- call = assertion .target_calls [0 ]
872- code_to_run = call .full_call + ";"
873-
874- if not code_to_run :
875- # Fallback: comment out the assertion
876- return f"{ assertion .leading_whitespace } // Removed assertThrows: could not extract callable"
877888
878- # Check if assertion is assigned to a variable
879- if assertion .variable_name and assertion .variable_type :
880- # Generate proper exception capture with variable assignment
881- exception_type = assertion .exception_class or assertion .variable_type
882- var_name = assertion .variable_name
883-
884- # Use a unique catch variable name to avoid conflicts
885- catch_var = f"_cf_caught{ self .invocation_counter } "
889+ # Handle variable assignment: Type var = assertThrows(...)
890+ if assertion .assigned_var_name and assertion .assigned_var_type :
891+ var_type = assertion .assigned_var_type
892+ var_name = assertion .assigned_var_name
893+ if assertion .assertion_method == "assertDoesNotThrow" :
894+ if ";" not in assertion .lambda_body .strip ():
895+ return f"{ ws } { var_type } { var_name } = { assertion .lambda_body .strip ()} ;"
896+ return f"{ ws } { code_to_run } "
897+ # For assertThrows with variable assignment, use exception_class if available
898+ exception_type = assertion .exception_class or var_type
899+ return (
900+ f"{ ws } { var_type } { var_name } = null;\n "
901+ f"{ base_indent } try {{ { code_to_run } }} "
902+ f"catch ({ exception_type } _cf_caught{ counter } ) {{ { var_name } = _cf_caught{ counter } ; }} "
903+ f"catch (Exception _cf_ignored{ counter } ) {{}}"
904+ )
886905
887- # Get base indentation from leading whitespace (without newlines)
888- base_indent = assertion .leading_whitespace .lstrip ("\n \r " )
906+ return (
907+ f"{ ws } try {{ { code_to_run } }} "
908+ f"catch (Exception _cf_ignored{ counter } ) {{}}"
909+ )
889910
911+ # If no lambda body found, try to extract from target calls
912+ if assertion .target_calls :
913+ call = assertion .target_calls [0 ]
890914 return (
891- f"{ assertion .leading_whitespace } { assertion .variable_type } { var_name } = null;\n "
892- f"{ base_indent } try {{ { code_to_run } }} "
893- f"catch ({ exception_type } { catch_var } ) {{ { var_name } = { catch_var } ; }} "
894- f"catch (Exception _cf_ignored{ self .invocation_counter } ) {{}}"
915+ f"{ ws } try {{ { call .full_call } ; }} "
916+ f"catch (Exception _cf_ignored{ counter } ) {{}}"
895917 )
896918
897- # No variable assignment, use simple try-catch
898- return (
899- f"{ assertion .leading_whitespace } try {{ { code_to_run } }} "
900- f"catch (Exception _cf_ignored{ self .invocation_counter } ) {{}}"
901- )
919+ # Fallback: comment out the assertion
920+ return f"{ ws } // Removed assertThrows: could not extract callable"
902921
903922
904923def transform_java_assertions (source : str , function_name : str , qualified_name : str | None = None ) -> str :
0 commit comments