@@ -198,6 +198,14 @@ def __init__(
198198 # Precompile regex to find next special character (quotes, parens, braces).
199199 self ._special_re = re .compile (r"[\"'{}()]" )
200200
201+ # Precompile literal/cast regexes to avoid recompilation on each literal check.
202+ self ._LONG_LITERAL_RE = re .compile (r"^-?\d+[lL]$" )
203+ self ._INT_LITERAL_RE = re .compile (r"^-?\d+$" )
204+ self ._DOUBLE_LITERAL_RE = re .compile (r"^-?\d+\.\d*[dD]?$|^-?\d+[dD]$" )
205+ self ._FLOAT_LITERAL_RE = re .compile (r"^-?\d+\.?\d*[fF]$" )
206+ self ._CHAR_LITERAL_RE = re .compile (r"^'.'$|^'\\.'$" )
207+ self ._cast_re = re .compile (r"^\((\w+)\)" )
208+
201209 def transform (self , source : str ) -> str :
202210 """Remove assertions from source code, preserving target function calls.
203211
@@ -894,6 +902,143 @@ def _find_balanced_braces(self, code: str, open_brace_pos: int) -> tuple[str | N
894902
895903 return code [open_brace_pos + 1 : pos - 1 ], pos
896904
905+ def _infer_return_type (self , assertion : AssertionMatch ) -> str :
906+ """Infer the Java return type from the assertion context.
907+
908+ For assertEquals(expected, actual) patterns, the expected literal determines the type.
909+ For assertTrue/assertFalse, the result is boolean.
910+ Falls back to Object when the type cannot be determined.
911+ """
912+ method = assertion .assertion_method
913+
914+ # assertTrue/assertFalse always deal with boolean values
915+ if method in {"assertTrue" , "assertFalse" }:
916+ return "boolean"
917+
918+ # assertNull/assertNotNull — keep Object (reference type)
919+ if method in {"assertNull" , "assertNotNull" }:
920+ return "Object"
921+
922+ # For assertEquals/assertNotEquals/assertSame, try to infer from the expected literal
923+ if method in JUNIT5_VALUE_ASSERTIONS :
924+ return self ._infer_type_from_assertion_args (assertion .original_text , method )
925+
926+ # For fluent assertions (assertThat), type inference is harder — keep Object
927+ return "Object"
928+
929+ # Regex patterns for Java literal type inference
930+ _LONG_LITERAL_RE = re .compile (r"^-?\d+[lL]$" )
931+ _INT_LITERAL_RE = re .compile (r"^-?\d+$" )
932+ _DOUBLE_LITERAL_RE = re .compile (r"^-?\d+\.\d*[dD]?$|^-?\d+[dD]$" )
933+ _FLOAT_LITERAL_RE = re .compile (r"^-?\d+\.?\d*[fF]$" )
934+ _CHAR_LITERAL_RE = re .compile (r"^'.'$|^'\\.'$" )
935+
936+ def _infer_type_from_assertion_args (self , original_text : str , method : str ) -> str :
937+ """Infer the return type from assertEquals/assertNotEquals expected value."""
938+ # Extract the args portion from the assertion text
939+ # Pattern: assertXxx( args... )
940+ paren_idx = original_text .find ("(" )
941+ if paren_idx < 0 :
942+ return "Object"
943+
944+ args_str = original_text [paren_idx + 1 :]
945+ # Remove trailing ");", whitespace
946+ args_str = args_str .rstrip ()
947+ if args_str .endswith (");" ):
948+ args_str = args_str [:- 2 ]
949+ elif args_str .endswith (")" ):
950+ args_str = args_str [:- 1 ]
951+
952+ # Fast-path: only extract the first top-level argument instead of splitting all arguments.
953+ first_arg = self ._extract_first_arg (args_str )
954+ if not first_arg :
955+ return "Object"
956+
957+ expected = first_arg .strip ()
958+
959+ # JUnit 4 has assertEquals(String message, expected, actual) where the first arg is a message.
960+ # If the first arg is a string literal, check if there are 3+ args — if so, the real expected
961+ # value is the second argument, not the message string.
962+ if expected .startswith ('"' ) and method in ("assertEquals" , "assertNotEquals" ):
963+ all_args = self ._split_top_level_args (args_str )
964+ if len (all_args ) >= 3 :
965+ expected = all_args [1 ].strip ()
966+
967+ return self ._type_from_literal (expected )
968+
969+ def _type_from_literal (self , value : str ) -> str :
970+ """Determine the Java type of a literal value."""
971+ if value in ("true" , "false" ):
972+ return "boolean"
973+ if value == "null" :
974+ return "Object"
975+ if self ._FLOAT_LITERAL_RE .match (value ):
976+ return "float"
977+ if self ._DOUBLE_LITERAL_RE .match (value ):
978+ return "double"
979+ if self ._LONG_LITERAL_RE .match (value ):
980+ return "long"
981+ if self ._INT_LITERAL_RE .match (value ):
982+ return "int"
983+ if self ._CHAR_LITERAL_RE .match (value ):
984+ return "char"
985+ if value .startswith ('"' ):
986+ return "String"
987+ # Cast expression like (byte)0, (short)1
988+ cast_match = self ._cast_re .match (value )
989+ if cast_match :
990+ return cast_match .group (1 )
991+ return "Object"
992+
993+ def _split_top_level_args (self , args_str : str ) -> list [str ]:
994+ """Split assertion arguments at top-level commas, respecting parens/strings/generics."""
995+ # Fast-path: if there are no special delimiters that require parsing,
996+ # we can use a simple split which is much faster for common simple cases.
997+ if not self ._special_re .search (args_str ):
998+ # Preserve original behavior of returning a list with the single unstripped string
999+ # when there are no commas, otherwise split on commas.
1000+ if "," in args_str :
1001+ return args_str .split ("," )
1002+ return [args_str ]
1003+
1004+ args : list [str ] = []
1005+ depth = 0
1006+ current : list [str ] = []
1007+ i = 0
1008+ in_string = False
1009+ string_char = ""
1010+
1011+ while i < len (args_str ):
1012+ ch = args_str [i ]
1013+
1014+ if in_string :
1015+ current .append (ch )
1016+ if ch == "\\ " and i + 1 < len (args_str ):
1017+ i += 1
1018+ current .append (args_str [i ])
1019+ elif ch == string_char :
1020+ in_string = False
1021+ elif ch in ('"' , "'" ):
1022+ in_string = True
1023+ string_char = ch
1024+ current .append (ch )
1025+ elif ch in ("(" , "<" , "[" , "{" ):
1026+ depth += 1
1027+ current .append (ch )
1028+ elif ch in (")" , ">" , "]" , "}" ):
1029+ depth -= 1
1030+ current .append (ch )
1031+ elif ch == "," and depth == 0 :
1032+ args .append ("" .join (current ))
1033+ current = []
1034+ else :
1035+ current .append (ch )
1036+ i += 1
1037+
1038+ if current :
1039+ args .append ("" .join (current ))
1040+ return args
1041+
8971042 def _generate_replacement (self , assertion : AssertionMatch ) -> str :
8981043 """Generate replacement code for an assertion.
8991044
@@ -912,18 +1057,34 @@ def _generate_replacement(self, assertion: AssertionMatch) -> str:
9121057 if not assertion .target_calls :
9131058 return ""
9141059
1060+ # Infer the return type from assertion context to avoid Object→primitive cast errors
1061+ return_type = self ._infer_return_type (assertion )
1062+
9151063 # Generate capture statements for each target call
916- replacements = []
1064+ replacements : list [ str ] = []
9171065 # For the first replacement, use the full leading whitespace
9181066 # For subsequent ones, strip leading newlines to avoid extra blank lines
919- base_indent = assertion .leading_whitespace .lstrip ("\n \r " )
920- for i , call in enumerate (assertion .target_calls ):
921- self .invocation_counter += 1
922- var_name = f"_cf_result{ self .invocation_counter } "
923- if i == 0 :
924- replacements .append (f"{ assertion .leading_whitespace } Object { var_name } = { call .full_call } ;" )
925- else :
926- replacements .append (f"{ base_indent } Object { var_name } = { call .full_call } ;" )
1067+ leading_ws = assertion .leading_whitespace
1068+ base_indent = leading_ws .lstrip ("\n \r " )
1069+
1070+ # Use a local counter to minimize attribute write overhead in the loop.
1071+ inv = self .invocation_counter
1072+
1073+ calls = assertion .target_calls
1074+ # Handle first call explicitly to avoid a per-iteration branch
1075+ if calls :
1076+ inv += 1
1077+ var_name = "_cf_result" + str (inv )
1078+ replacements .append (f"{ leading_ws } { return_type } { var_name } = { calls [0 ].full_call } ;" )
1079+
1080+ # Handle remaining calls
1081+ for call in calls [1 :]:
1082+ inv += 1
1083+ var_name = "_cf_result" + str (inv )
1084+ replacements .append (f"{ base_indent } { return_type } { var_name } = { call .full_call } ;" )
1085+
1086+ # Write back the counter
1087+ self .invocation_counter = inv
9271088
9281089 return "\n " .join (replacements )
9291090
@@ -942,8 +1103,10 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
9421103 try { code(); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
9431104
9441105 """
945- self .invocation_counter += 1
946- counter = self .invocation_counter
1106+ # Increment invocation counter once for this exception handling
1107+ inv = self .invocation_counter + 1
1108+ self .invocation_counter = inv
1109+ counter = inv
9471110 ws = assertion .leading_whitespace
9481111 base_indent = ws .lstrip ("\n \r " )
9491112
@@ -982,6 +1145,58 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
9821145 # Fallback: comment out the assertion
9831146 return f"{ ws } // Removed assertThrows: could not extract callable"
9841147
1148+ def _extract_first_arg (self , args_str : str ) -> str | None :
1149+ """Extract the first top-level argument from args_str.
1150+
1151+ This is a lightweight alternative to splitting all top-level arguments;
1152+ it stops at the first top-level comma, respects nested delimiters and strings,
1153+ and avoids constructing the full argument list for better performance.
1154+ """
1155+ n = len (args_str )
1156+ i = 0
1157+
1158+ # skip leading whitespace
1159+ while i < n and args_str [i ].isspace ():
1160+ i += 1
1161+ if i >= n :
1162+ return None
1163+
1164+ depth = 0
1165+ in_string = False
1166+ string_char = ""
1167+ cur : list [str ] = []
1168+
1169+ while i < n :
1170+ ch = args_str [i ]
1171+
1172+ if in_string :
1173+ cur .append (ch )
1174+ if ch == "\\ " and i + 1 < n :
1175+ i += 1
1176+ cur .append (args_str [i ])
1177+ elif ch == string_char :
1178+ in_string = False
1179+ elif ch in ('"' , "'" ):
1180+ in_string = True
1181+ string_char = ch
1182+ cur .append (ch )
1183+ elif ch in ("(" , "<" , "[" , "{" ):
1184+ depth += 1
1185+ cur .append (ch )
1186+ elif ch in (")" , ">" , "]" , "}" ):
1187+ depth -= 1
1188+ cur .append (ch )
1189+ elif ch == "," and depth == 0 :
1190+ break
1191+ else :
1192+ cur .append (ch )
1193+ i += 1
1194+
1195+ # Trim trailing whitespace from the extracted argument
1196+ if not cur :
1197+ return None
1198+ return "" .join (cur ).rstrip ()
1199+
9851200
9861201def transform_java_assertions (source : str , function_name : str , qualified_name : str | None = None ) -> str :
9871202 """Transform Java test code by removing assertions and capturing function calls.
0 commit comments