diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index a881aa208..ee7700f5e 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -1325,32 +1325,3 @@ def instrument_generated_java_test( logger.debug("Instrumented generated Java test for %s (mode=%s)", function_name, mode) return modified_code - - -def _add_import(source: str, import_statement: str) -> str: - """Add an import statement to the source. - - Args: - source: The source code. - import_statement: The import to add. - - Returns: - Source with import added. - - """ - lines = source.splitlines(keepends=True) - insert_idx = 0 - - # Find the last import or package statement - for i, line in enumerate(lines): - stripped = line.strip() - if stripped.startswith(("import ", "package ")): - insert_idx = i + 1 - elif stripped and not stripped.startswith("//") and not stripped.startswith("/*"): - # First non-import, non-comment line - if insert_idx == 0: - insert_idx = i - break - - lines.insert(insert_idx, import_statement + "\n") - return "".join(lines) diff --git a/codeflash/languages/java/remove_asserts.py b/codeflash/languages/java/remove_asserts.py index 8544a771d..470e0d62e 100644 --- a/codeflash/languages/java/remove_asserts.py +++ b/codeflash/languages/java/remove_asserts.py @@ -198,6 +198,14 @@ def __init__( # Precompile regex to find next special character (quotes, parens, braces). self._special_re = re.compile(r"[\"'{}()]") + # Precompile literal/cast regexes to avoid recompilation on each literal check. + self._LONG_LITERAL_RE = re.compile(r"^-?\d+[lL]$") + self._INT_LITERAL_RE = re.compile(r"^-?\d+$") + self._DOUBLE_LITERAL_RE = re.compile(r"^-?\d+\.\d*[dD]?$|^-?\d+[dD]$") + self._FLOAT_LITERAL_RE = re.compile(r"^-?\d+\.?\d*[fF]$") + self._CHAR_LITERAL_RE = re.compile(r"^'.'$|^'\\.'$") + self._cast_re = re.compile(r"^\((\w+)\)") + def transform(self, source: str) -> str: """Remove assertions from source code, preserving target function calls. @@ -894,6 +902,143 @@ def _find_balanced_braces(self, code: str, open_brace_pos: int) -> tuple[str | N return code[open_brace_pos + 1 : pos - 1], pos + def _infer_return_type(self, assertion: AssertionMatch) -> str: + """Infer the Java return type from the assertion context. + + For assertEquals(expected, actual) patterns, the expected literal determines the type. + For assertTrue/assertFalse, the result is boolean. + Falls back to Object when the type cannot be determined. + """ + method = assertion.assertion_method + + # assertTrue/assertFalse always deal with boolean values + if method in {"assertTrue", "assertFalse"}: + return "boolean" + + # assertNull/assertNotNull — keep Object (reference type) + if method in {"assertNull", "assertNotNull"}: + return "Object" + + # For assertEquals/assertNotEquals/assertSame, try to infer from the expected literal + if method in JUNIT5_VALUE_ASSERTIONS: + return self._infer_type_from_assertion_args(assertion.original_text, method) + + # For fluent assertions (assertThat), type inference is harder — keep Object + return "Object" + + # Regex patterns for Java literal type inference + _LONG_LITERAL_RE = re.compile(r"^-?\d+[lL]$") + _INT_LITERAL_RE = re.compile(r"^-?\d+$") + _DOUBLE_LITERAL_RE = re.compile(r"^-?\d+\.\d*[dD]?$|^-?\d+[dD]$") + _FLOAT_LITERAL_RE = re.compile(r"^-?\d+\.?\d*[fF]$") + _CHAR_LITERAL_RE = re.compile(r"^'.'$|^'\\.'$") + + def _infer_type_from_assertion_args(self, original_text: str, method: str) -> str: + """Infer the return type from assertEquals/assertNotEquals expected value.""" + # Extract the args portion from the assertion text + # Pattern: assertXxx( args... ) + paren_idx = original_text.find("(") + if paren_idx < 0: + return "Object" + + args_str = original_text[paren_idx + 1 :] + # Remove trailing ");", whitespace + args_str = args_str.rstrip() + if args_str.endswith(");"): + args_str = args_str[:-2] + elif args_str.endswith(")"): + args_str = args_str[:-1] + + # Fast-path: only extract the first top-level argument instead of splitting all arguments. + first_arg = self._extract_first_arg(args_str) + if not first_arg: + return "Object" + + expected = first_arg.strip() + + # JUnit 4 has assertEquals(String message, expected, actual) where the first arg is a message. + # If the first arg is a string literal, check if there are 3+ args — if so, the real expected + # value is the second argument, not the message string. + if expected.startswith('"') and method in ("assertEquals", "assertNotEquals"): + all_args = self._split_top_level_args(args_str) + if len(all_args) >= 3: + expected = all_args[1].strip() + + return self._type_from_literal(expected) + + def _type_from_literal(self, value: str) -> str: + """Determine the Java type of a literal value.""" + if value in ("true", "false"): + return "boolean" + if value == "null": + return "Object" + if self._FLOAT_LITERAL_RE.match(value): + return "float" + if self._DOUBLE_LITERAL_RE.match(value): + return "double" + if self._LONG_LITERAL_RE.match(value): + return "long" + if self._INT_LITERAL_RE.match(value): + return "int" + if self._CHAR_LITERAL_RE.match(value): + return "char" + if value.startswith('"'): + return "String" + # Cast expression like (byte)0, (short)1 + cast_match = self._cast_re.match(value) + if cast_match: + return cast_match.group(1) + return "Object" + + def _split_top_level_args(self, args_str: str) -> list[str]: + """Split assertion arguments at top-level commas, respecting parens/strings/generics.""" + # Fast-path: if there are no special delimiters that require parsing, + # we can use a simple split which is much faster for common simple cases. + if not self._special_re.search(args_str): + # Preserve original behavior of returning a list with the single unstripped string + # when there are no commas, otherwise split on commas. + if "," in args_str: + return args_str.split(",") + return [args_str] + + args: list[str] = [] + depth = 0 + current: list[str] = [] + i = 0 + in_string = False + string_char = "" + + while i < len(args_str): + ch = args_str[i] + + if in_string: + current.append(ch) + if ch == "\\" and i + 1 < len(args_str): + i += 1 + current.append(args_str[i]) + elif ch == string_char: + in_string = False + elif ch in ('"', "'"): + in_string = True + string_char = ch + current.append(ch) + elif ch in ("(", "<", "[", "{"): + depth += 1 + current.append(ch) + elif ch in (")", ">", "]", "}"): + depth -= 1 + current.append(ch) + elif ch == "," and depth == 0: + args.append("".join(current)) + current = [] + else: + current.append(ch) + i += 1 + + if current: + args.append("".join(current)) + return args + def _generate_replacement(self, assertion: AssertionMatch) -> str: """Generate replacement code for an assertion. @@ -912,18 +1057,34 @@ def _generate_replacement(self, assertion: AssertionMatch) -> str: if not assertion.target_calls: return "" + # Infer the return type from assertion context to avoid Object→primitive cast errors + return_type = self._infer_return_type(assertion) + # Generate capture statements for each target call - replacements = [] + replacements: list[str] = [] # For the first replacement, use the full leading whitespace # For subsequent ones, strip leading newlines to avoid extra blank lines - base_indent = assertion.leading_whitespace.lstrip("\n\r") - for i, call in enumerate(assertion.target_calls): - self.invocation_counter += 1 - var_name = f"_cf_result{self.invocation_counter}" - if i == 0: - replacements.append(f"{assertion.leading_whitespace}Object {var_name} = {call.full_call};") - else: - replacements.append(f"{base_indent}Object {var_name} = {call.full_call};") + leading_ws = assertion.leading_whitespace + base_indent = leading_ws.lstrip("\n\r") + + # Use a local counter to minimize attribute write overhead in the loop. + inv = self.invocation_counter + + calls = assertion.target_calls + # Handle first call explicitly to avoid a per-iteration branch + if calls: + inv += 1 + var_name = "_cf_result" + str(inv) + replacements.append(f"{leading_ws}{return_type} {var_name} = {calls[0].full_call};") + + # Handle remaining calls + for call in calls[1:]: + inv += 1 + var_name = "_cf_result" + str(inv) + replacements.append(f"{base_indent}{return_type} {var_name} = {call.full_call};") + + # Write back the counter + self.invocation_counter = inv return "\n".join(replacements) @@ -942,8 +1103,10 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str: try { code(); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {} """ - self.invocation_counter += 1 - counter = self.invocation_counter + # Increment invocation counter once for this exception handling + inv = self.invocation_counter + 1 + self.invocation_counter = inv + counter = inv ws = assertion.leading_whitespace base_indent = ws.lstrip("\n\r") @@ -982,6 +1145,58 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str: # Fallback: comment out the assertion return f"{ws}// Removed assertThrows: could not extract callable" + def _extract_first_arg(self, args_str: str) -> str | None: + """Extract the first top-level argument from args_str. + + This is a lightweight alternative to splitting all top-level arguments; + it stops at the first top-level comma, respects nested delimiters and strings, + and avoids constructing the full argument list for better performance. + """ + n = len(args_str) + i = 0 + + # skip leading whitespace + while i < n and args_str[i].isspace(): + i += 1 + if i >= n: + return None + + depth = 0 + in_string = False + string_char = "" + cur: list[str] = [] + + while i < n: + ch = args_str[i] + + if in_string: + cur.append(ch) + if ch == "\\" and i + 1 < n: + i += 1 + cur.append(args_str[i]) + elif ch == string_char: + in_string = False + elif ch in ('"', "'"): + in_string = True + string_char = ch + cur.append(ch) + elif ch in ("(", "<", "[", "{"): + depth += 1 + cur.append(ch) + elif ch in (")", ">", "]", "}"): + depth -= 1 + cur.append(ch) + elif ch == "," and depth == 0: + break + else: + cur.append(ch) + i += 1 + + # Trim trailing whitespace from the extracted argument + if not cur: + return None + return "".join(cur).rstrip() + def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str: """Transform Java test code by removing assertions and capturing function calls. diff --git a/codeflash/languages/java/test_runner.py b/codeflash/languages/java/test_runner.py index a40150432..fd01d2623 100644 --- a/codeflash/languages/java/test_runner.py +++ b/codeflash/languages/java/test_runner.py @@ -47,6 +47,17 @@ # Allows: letters, digits, underscores, dots, and dollar signs (inner classes) _VALID_JAVA_CLASS_NAME = re.compile(r"^[a-zA-Z_$][a-zA-Z0-9_$.]*$") +# Skip validation/analysis plugins that reject generated instrumented files +# (e.g. Apache Rat rejects missing license headers, Checkstyle rejects naming, etc.) +_MAVEN_VALIDATION_SKIP_FLAGS = [ + "-Drat.skip=true", + "-Dcheckstyle.skip=true", + "-Dspotbugs.skip=true", + "-Dpmd.skip=true", + "-Denforcer.skip=true", + "-Djapicmp.skip=true", +] + def _run_cmd_kill_pg_on_timeout( cmd: list[str], @@ -85,9 +96,7 @@ def _run_cmd_kill_pg_on_timeout( # Windows does not have POSIX process groups / killpg. Fall back to # the standard subprocess.run() behaviour (kills parent only). try: - return subprocess.run( - cmd, cwd=cwd, env=env, capture_output=True, text=text, timeout=timeout, check=False - ) + return subprocess.run(cmd, cwd=cwd, env=env, capture_output=True, text=text, timeout=timeout, check=False) except subprocess.TimeoutExpired: return subprocess.CompletedProcess( args=cmd, returncode=-2, stdout="", stderr=f"Process timed out after {timeout}s" @@ -509,8 +518,9 @@ def run_behavioral_tests( add_jacoco_plugin_to_pom(pom_path) coverage_xml_path = get_jacoco_xml_path(project_root) - # Use a minimum timeout of 60s for Java builds (120s when coverage is enabled due to verify phase) - min_timeout = 120 if enable_coverage else 60 + # Use a minimum timeout of 60s for Java builds (300s when coverage is enabled due to verify phase + # which runs full compilation + instrumentation + test execution in multi-module projects) + min_timeout = 300 if enable_coverage else 60 effective_timeout = max(timeout or 300, min_timeout) if enable_coverage: @@ -591,6 +601,7 @@ def _compile_tests( return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found") cmd = [mvn, "test-compile", "-e", "-B"] # Show errors but not verbose output; -B for batch mode (no ANSI colors) + cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS) if test_module: cmd.extend(["-pl", test_module, "-am"]) @@ -1526,6 +1537,7 @@ def _run_maven_tests( # JaCoCo's report goal is bound to the verify phase to get post-test execution data maven_goal = "verify" if enable_coverage else "test" cmd = [mvn, maven_goal, "-fae", "-B"] # Fail at end to run all tests; -B for batch mode (no ANSI colors) + cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS) # Add --add-opens flags for Java 16+ module system compatibility. # The codeflash-runtime Serializer uses Kryo which needs reflective access to @@ -1562,7 +1574,16 @@ def _run_maven_tests( # -am = also make dependencies # -DfailIfNoTests=false allows dependency modules without tests to pass # -DskipTests=false overrides any skipTests=true in pom.xml - cmd.extend(["-pl", test_module, "-am", "-DfailIfNoTests=false", "-DskipTests=false"]) + cmd.extend( + [ + "-pl", + test_module, + "-am", + "-DfailIfNoTests=false", + "-Dsurefire.failIfNoSpecifiedTests=false", + "-DskipTests=false", + ] + ) if test_filter: # Validate test filter to prevent command injection diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index d44a347fc..dedc814dd 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -786,7 +786,13 @@ def parse_test_xml( if class_name is not None and class_name.startswith(test_module_path): test_class = class_name[len(test_module_path) + 1 :] # +1 for the dot, gets Unittest class name - loop_index = int(testcase.name.split("[ ")[-1][:-2]) if testcase.name and "[" in testcase.name else 1 + loop_index = 1 + if testcase.name and "[" in testcase.name: + bracket_content = testcase.name.rsplit("[", 1)[-1].rstrip("]").strip() + try: + loop_index = int(bracket_content) + except ValueError: + loop_index = 1 timed_out = False if len(testcase.result) > 1: diff --git a/tests/test_java_assertion_removal.py b/tests/test_java_assertion_removal.py index 7b991db99..ec37e7c27 100644 --- a/tests/test_java_assertion_removal.py +++ b/tests/test_java_assertion_removal.py @@ -24,7 +24,7 @@ def test_assert_equals_basic(self): expected = """\ @Test void testFibonacci() { - Object _cf_result1 = calculator.fibonacci(10); + int _cf_result1 = calculator.fibonacci(10); }""" result = transform_java_assertions(source, "fibonacci") assert result == expected @@ -38,7 +38,7 @@ def test_assert_equals_with_message(self): expected = """\ @Test void testFibonacci() { - Object _cf_result1 = calculator.fibonacci(10); + int _cf_result1 = calculator.fibonacci(10); }""" result = transform_java_assertions(source, "fibonacci") assert result == expected @@ -52,7 +52,7 @@ def test_assert_true(self): expected = """\ @Test void testIsValid() { - Object _cf_result1 = validator.isValid("test"); + boolean _cf_result1 = validator.isValid("test"); }""" result = transform_java_assertions(source, "isValid") assert result == expected @@ -66,7 +66,7 @@ def test_assert_false(self): expected = """\ @Test void testIsInvalid() { - Object _cf_result1 = validator.isValid(""); + boolean _cf_result1 = validator.isValid(""); }""" result = transform_java_assertions(source, "isValid") assert result == expected @@ -108,7 +108,7 @@ def test_assert_not_equals(self): expected = """\ @Test void testDifferent() { - Object _cf_result1 = calculator.add(1, 2); + int _cf_result1 = calculator.add(1, 2); }""" result = transform_java_assertions(source, "add") assert result == expected @@ -154,7 +154,7 @@ def test_assertions_prefix(self): expected = """\ @Test void testFibonacci() { - Object _cf_result1 = calculator.fibonacci(10); + int _cf_result1 = calculator.fibonacci(10); }""" result = transform_java_assertions(source, "fibonacci") assert result == expected @@ -168,7 +168,7 @@ def test_assert_prefix(self): expected = """\ @Test void testAdd() { - Object _cf_result1 = calculator.add(2, 3); + int _cf_result1 = calculator.add(2, 3); }""" result = transform_java_assertions(source, "add") assert result == expected @@ -234,7 +234,7 @@ def test_static_method_call(self): expected = """\ @Test void testQuickAdd() { - Object _cf_result1 = Calculator.quickAdd(10.0, 5.0); + double _cf_result1 = Calculator.quickAdd(10.0, 5.0); }""" result = transform_java_assertions(source, "quickAdd") assert result == expected @@ -248,7 +248,7 @@ def test_static_method_fully_qualified(self): expected = """\ @Test void testReverse() { - Object _cf_result1 = com.example.StringUtils.reverse("hello"); + String _cf_result1 = com.example.StringUtils.reverse("hello"); }""" result = transform_java_assertions(source, "reverse") assert result == expected @@ -268,9 +268,9 @@ def test_multiple_assertions_same_function(self): expected = """\ @Test void testFibonacciSequence() { - Object _cf_result1 = calculator.fibonacci(0); - Object _cf_result2 = calculator.fibonacci(1); - Object _cf_result3 = calculator.fibonacci(10); + int _cf_result1 = calculator.fibonacci(0); + int _cf_result2 = calculator.fibonacci(1); + int _cf_result3 = calculator.fibonacci(10); }""" result = transform_java_assertions(source, "fibonacci") assert result == expected @@ -285,7 +285,7 @@ def test_multiple_assertions_different_functions(self): expected = """\ @Test void testCalculator() { - Object _cf_result1 = calculator.add(2, 3); + int _cf_result1 = calculator.add(2, 3); }""" result = transform_java_assertions(source, "add") assert result == expected @@ -421,7 +421,7 @@ def test_target_nested_in_non_target_call(self): expected = """\ @Test void testSubtract() { - Object _cf_result1 = add(2, subtract(2, 2)); + int _cf_result1 = add(2, subtract(2, 2)); }""" result = transform_java_assertions(source, "subtract") assert result == expected @@ -435,7 +435,7 @@ def test_non_target_nested_in_target_call(self): expected = """\ @Test void testAdd() { - Object _cf_result1 = subtract(2, add(2, 3)); + int _cf_result1 = subtract(2, add(2, 3)); }""" result = transform_java_assertions(source, "add") assert result == expected @@ -449,7 +449,7 @@ def test_multiple_targets_nested_in_same_outer_call(self): expected = """\ @Test void testOuter() { - Object _cf_result1 = outer(subtract(1, 1), subtract(2, 2)); + int _cf_result1 = outer(subtract(1, 1), subtract(2, 2)); }""" result = transform_java_assertions(source, "subtract") assert result == expected @@ -467,7 +467,7 @@ def test_preserves_indentation(self): expected = """\ @Test void testFibonacci() { - Object _cf_result1 = calculator.fibonacci(10); + int _cf_result1 = calculator.fibonacci(10); }""" result = transform_java_assertions(source, "fibonacci") assert result == expected @@ -510,7 +510,7 @@ def test_string_with_parentheses(self): expected = """\ @Test void testFormat() { - Object _cf_result1 = formatter.format("hello", "world"); + String _cf_result1 = formatter.format("hello", "world"); }""" result = transform_java_assertions(source, "format") assert result == expected @@ -524,7 +524,7 @@ def test_string_with_quotes(self): expected = """\ @Test void testEscape() { - Object _cf_result1 = formatter.escape("hello \\"world\\""); + String _cf_result1 = formatter.escape("hello \\"world\\""); }""" result = transform_java_assertions(source, "escape") assert result == expected @@ -538,7 +538,7 @@ def test_string_with_newlines(self): expected = """\ @Test void testMultiline() { - Object _cf_result1 = processor.join("line1", "line2"); + String _cf_result1 = processor.join("line1", "line2"); }""" result = transform_java_assertions(source, "join") assert result == expected @@ -560,7 +560,7 @@ def test_setup_code_preserved(self): void testWithSetup() { Calculator calc = new Calculator(2); int input = 10; - Object _cf_result1 = calc.fibonacci(input); + int _cf_result1 = calc.fibonacci(input); }""" result = transform_java_assertions(source, "fibonacci") assert result == expected @@ -577,7 +577,7 @@ def test_other_method_calls_preserved(self): @Test void testWithHelper() { helper.setup(); - Object _cf_result1 = calculator.fibonacci(10); + int _cf_result1 = calculator.fibonacci(10); helper.cleanup(); }""" result = transform_java_assertions(source, "fibonacci") @@ -649,7 +649,7 @@ class FibonacciTests { class FibonacciTests { @Test void testBasic() { - Object _cf_result1 = calculator.fibonacci(10); + int _cf_result1 = calculator.fibonacci(10); } }""" result = transform_java_assertions(source, "fibonacci") @@ -670,7 +670,7 @@ def test_mockito_when_preserved(self): @Test void testWithMock() { when(mockService.getData()).thenReturn("test"); - Object _cf_result1 = processor.process(mockService); + String _cf_result1 = processor.process(mockService); }""" result = transform_java_assertions(source, "process") assert result == expected @@ -733,7 +733,7 @@ def test_function_name_in_string(self): expected = """\ @Test void testWithStringContainingFunctionName() { - Object _cf_result1 = formatter.format("fibonacci", 10, 55); + String _cf_result1 = formatter.format("fibonacci", 10, 55); }""" result = transform_java_assertions(source, "format") assert result == expected @@ -755,7 +755,7 @@ def test_junit4_assert_equals(self): @Test public void testFibonacci() { - Object _cf_result1 = calculator.fibonacci(10); + int _cf_result1 = calculator.fibonacci(10); }""" result = transform_java_assertions(source, "fibonacci") assert result == expected @@ -769,11 +769,58 @@ def test_junit4_with_message_first(self): expected = """\ @Test public void testFibonacci() { - Object _cf_result1 = calculator.fibonacci(10); + int _cf_result1 = calculator.fibonacci(10); }""" result = transform_java_assertions(source, "fibonacci") assert result == expected + def test_junit4_message_first_with_string_expected(self): + """When assertEquals has 3 args and the first is a message but the second is also a string, + the type should be inferred from the second arg (the real expected value), not the message.""" + source = """\ +@Test +public void testGetName() { + assertEquals("Name should match", "Alice", user.getName()); +}""" + expected = """\ +@Test +public void testGetName() { + String _cf_result1 = user.getName(); +}""" + result = transform_java_assertions(source, "getName") + assert result == expected + + def test_junit4_message_first_with_boolean_expected(self): + """JUnit 4 assertEquals with message, boolean expected, and actual.""" + source = """\ +@Test +public void testIsValid() { + assertEquals("Should be true", true, validator.isValid(input)); +}""" + expected = """\ +@Test +public void testIsValid() { + boolean _cf_result1 = validator.isValid(input); +}""" + result = transform_java_assertions(source, "isValid") + assert result == expected + + def test_two_arg_string_expected_not_treated_as_message(self): + """When assertEquals has only 2 args and the first is a string, it IS the expected value, + not a message. This tests that we don't incorrectly skip the first arg.""" + source = """\ +@Test +public void testGetGreeting() { + assertEquals("hello", greeter.getGreeting()); +}""" + expected = """\ +@Test +public void testGetGreeting() { + String _cf_result1 = greeter.getGreeting(); +}""" + result = transform_java_assertions(source, "getGreeting") + assert result == expected + class TestAssertAll: """Tests for assertAll grouped assertions.""" @@ -813,8 +860,8 @@ def test_invocation_counter_increments(self): expected = """\ @Test void test() { - Object _cf_result1 = calc.fibonacci(0); - Object _cf_result2 = calc.fibonacci(1); + int _cf_result1 = calc.fibonacci(0); + int _cf_result2 = calc.fibonacci(1); }""" result = transformer.transform(source) assert result == expected @@ -983,8 +1030,8 @@ def test_string_utils_pattern(self): @Test @DisplayName("should reverse a simple string") void testReverseSimple() { - Object _cf_result1 = StringUtils.reverse("hello"); - Object _cf_result2 = StringUtils.reverse("world"); + String _cf_result1 = StringUtils.reverse("hello"); + String _cf_result2 = StringUtils.reverse("world"); }""" result = transform_java_assertions(source, "reverse") assert result == expected @@ -1012,7 +1059,7 @@ def test_with_before_each_setup(self): @Test void testFibonacci() { - Object _cf_result1 = calculator.fibonacci(10); + int _cf_result1 = calculator.fibonacci(10); }""" result = transform_java_assertions(source, "fibonacci") assert result == expected @@ -1039,7 +1086,7 @@ def test_synchronized_method_assertion_removal(self): @Test void testSynchronizedAccess() { synchronized (lock) { - Object _cf_result1 = counter.incrementAndGet(); + int _cf_result1 = counter.incrementAndGet(); } }""" result = transform_java_assertions(source, "incrementAndGet") @@ -1055,7 +1102,7 @@ def test_volatile_field_read_preserved(self): expected = """\ @Test void testVolatileRead() { - Object _cf_result1 = buffer.isReady(); + boolean _cf_result1 = buffer.isReady(); }""" result = transform_java_assertions(source, "isReady") assert result == expected @@ -1076,7 +1123,7 @@ def test_synchronized_block_with_multiple_assertions(self): @Test void testSynchronizedBlock() { synchronized (cache) { - Object _cf_result1 = cache.size(); + int _cf_result1 = cache.size(); } }""" result = transform_java_assertions(source, "size") @@ -1114,8 +1161,8 @@ def test_atomic_operations_preserved(self): expected = """\ @Test void testAtomicCounter() { - Object _cf_result1 = counter.incrementAndGet(); - Object _cf_result2 = counter.incrementAndGet(); + int _cf_result1 = counter.incrementAndGet(); + int _cf_result2 = counter.incrementAndGet(); }""" result = transform_java_assertions(source, "incrementAndGet") assert result == expected @@ -1130,7 +1177,7 @@ def test_concurrent_collection_assertion(self): expected = """\ @Test void testConcurrentMap() { - Object _cf_result1 = concurrentMap.putIfAbsent("key", "value"); + String _cf_result1 = concurrentMap.putIfAbsent("key", "value"); }""" result = transform_java_assertions(source, "putIfAbsent") assert result == expected @@ -1147,7 +1194,7 @@ def test_thread_sleep_with_assertion(self): @Test void testWithThreadSleep() throws InterruptedException { Thread.sleep(100); - Object _cf_result1 = processor.getResult(); + int _cf_result1 = processor.getResult(); }""" result = transform_java_assertions(source, "getResult") assert result == expected @@ -1162,7 +1209,7 @@ def test_synchronized_method_signature_preserved(self): expected = """\ @Test synchronized void testSyncMethod() { - Object _cf_result1 = calculator.compute(5); + int _cf_result1 = calculator.compute(5); }""" result = transform_java_assertions(source, "compute") assert result == expected @@ -1183,7 +1230,7 @@ def test_wait_notify_pattern_preserved(self): synchronized (monitor) { monitor.notify(); } - Object _cf_result1 = listener.wasNotified(); + boolean _cf_result1 = listener.wasNotified(); }""" result = transform_java_assertions(source, "wasNotified") assert result == expected @@ -1205,7 +1252,7 @@ def test_reentrant_lock_pattern_preserved(self): void testReentrantLock() { lock.lock(); try { - Object _cf_result1 = sharedResource.getValue(); + int _cf_result1 = sharedResource.getValue(); } finally { lock.unlock(); } @@ -1227,7 +1274,7 @@ def test_count_down_latch_pattern_preserved(self): void testCountDownLatch() throws InterruptedException { latch.countDown(); latch.await(); - Object _cf_result1 = collector.getTotal(); + int _cf_result1 = collector.getTotal(); }""" result = transform_java_assertions(source, "getTotal") assert result == expected @@ -1245,8 +1292,8 @@ def test_token_bucket_synchronized_method(self): @Test void testTokenBucketAllowRequest() { TokenBucket bucket = new TokenBucket(10, 1); - Object _cf_result1 = bucket.allowRequest(); - Object _cf_result2 = bucket.allowRequest(); + boolean _cf_result1 = bucket.allowRequest(); + boolean _cf_result2 = bucket.allowRequest(); }""" result = transform_java_assertions(source, "allowRequest") assert result == expected @@ -1268,9 +1315,9 @@ def test_circular_buffer_atomic_integer_pattern(self): @Test void testCircularBufferOperations() { CircularBuffer buffer = new CircularBuffer<>(3); - Object _cf_result1 = buffer.isEmpty(); + boolean _cf_result1 = buffer.isEmpty(); buffer.put(1); - Object _cf_result2 = buffer.isEmpty(); + boolean _cf_result2 = buffer.isEmpty(); }""" result = transform_java_assertions(source, "isEmpty") assert result == expected @@ -1328,7 +1375,7 @@ def test_assert_equals_fully_qualified(self): expected = """\ @Test void testAdd() { - Object _cf_result1 = calc.add(2, 3); + int _cf_result1 = calc.add(2, 3); }""" result = transform_java_assertions(source, "add") assert result == expected diff --git a/tests/test_languages/test_java/test_instrumentation.py b/tests/test_languages/test_java/test_instrumentation.py index e0d6de086..a7e1e769f 100644 --- a/tests/test_languages/test_java/test_instrumentation.py +++ b/tests/test_languages/test_java/test_instrumentation.py @@ -1324,7 +1324,7 @@ def test_instrument_generated_test_behavior_mode(self): } } } - Object _cf_result1 = _cf_result1_1; + int _cf_result1 = (int)_cf_result1_1; } } """ diff --git a/tests/test_languages/test_java/test_remove_asserts.py b/tests/test_languages/test_java/test_remove_asserts.py index e0a252ad8..edc7138ce 100644 --- a/tests/test_languages/test_java/test_remove_asserts.py +++ b/tests/test_languages/test_java/test_remove_asserts.py @@ -41,7 +41,7 @@ def test_assertfalse_with_message(self): public class BitSetTest { @Test public void testGet_IndexZero_ReturnsFalse() { - Object _cf_result1 = instance.get(0); + boolean _cf_result1 = instance.get(0); } } """ @@ -67,7 +67,7 @@ def test_asserttrue_with_message(self): public class BitSetTest { @Test public void testGet_SetBit_DetectedTrue() { - Object _cf_result1 = bs.get(67); + boolean _cf_result1 = bs.get(67); } } """ @@ -93,7 +93,7 @@ def test_assertequals_with_static_call(self): public class FibonacciTest { @Test public void testFibonacci() { - Object _cf_result1 = Fibonacci.fibonacci(10); + int _cf_result1 = Fibonacci.fibonacci(10); } } """ @@ -121,7 +121,7 @@ def test_assertequals_with_instance_call(self): @Test public void testAdd() { Calculator calc = new Calculator(); - Object _cf_result1 = calc.add(2, 2); + int _cf_result1 = calc.add(2, 2); } } """ @@ -199,7 +199,7 @@ def test_assertnotequals(self): public class CalculatorTest { @Test public void testSubtract() { - Object _cf_result1 = calc.subtract(5, 3); + int _cf_result1 = calc.subtract(5, 3); } } """ @@ -251,7 +251,7 @@ def test_qualified_assert_call(self): public class CalculatorTest { @Test public void testAdd() { - Object _cf_result1 = calc.add(2, 2); + int _cf_result1 = calc.add(2, 2); } } """ @@ -298,9 +298,9 @@ def test_assertequals_static_import(self): public class FibonacciTest { @Test void testFibonacci() { - Object _cf_result1 = Fibonacci.fibonacci(0); - Object _cf_result2 = Fibonacci.fibonacci(1); - Object _cf_result3 = Fibonacci.fibonacci(10); + int _cf_result1 = Fibonacci.fibonacci(0); + int _cf_result2 = Fibonacci.fibonacci(1); + int _cf_result3 = Fibonacci.fibonacci(10); } } """ @@ -326,7 +326,7 @@ def test_assertequals_qualified(self): public class FibonacciTest { @Test void testFibonacci() { - Object _cf_result1 = Fibonacci.fibonacci(10); + int _cf_result1 = Fibonacci.fibonacci(10); } } """ @@ -485,7 +485,7 @@ def test_asserttrue_boolean_call(self): public class FibonacciTest { @Test void testIsFibonacci() { - Object _cf_result1 = Fibonacci.isFibonacci(5); + boolean _cf_result1 = Fibonacci.isFibonacci(5); } } """ @@ -511,7 +511,7 @@ def test_assertfalse_boolean_call(self): public class FibonacciTest { @Test void testIsNotFibonacci() { - Object _cf_result1 = Fibonacci.isFibonacci(4); + boolean _cf_result1 = Fibonacci.isFibonacci(4); } } """ @@ -709,7 +709,7 @@ def test_multiple_calls_in_one_assertion(self): public class FibonacciTest { @Test void testConsecutive() { - Object _cf_result1 = Fibonacci.areConsecutiveFibonacci(Fibonacci.fibonacci(5), Fibonacci.fibonacci(6)); + boolean _cf_result1 = Fibonacci.areConsecutiveFibonacci(Fibonacci.fibonacci(5), Fibonacci.fibonacci(6)); } } """ @@ -739,11 +739,11 @@ def test_multiple_assertions_in_one_method(self): public class FibonacciTest { @Test void testMultiple() { - Object _cf_result1 = Fibonacci.fibonacci(0); - Object _cf_result2 = Fibonacci.fibonacci(1); - Object _cf_result3 = Fibonacci.fibonacci(2); - Object _cf_result4 = Fibonacci.fibonacci(3); - Object _cf_result5 = Fibonacci.fibonacci(5); + int _cf_result1 = Fibonacci.fibonacci(0); + int _cf_result2 = Fibonacci.fibonacci(1); + int _cf_result3 = Fibonacci.fibonacci(2); + int _cf_result4 = Fibonacci.fibonacci(3); + int _cf_result5 = Fibonacci.fibonacci(5); } } """ @@ -774,7 +774,7 @@ def test_assertion_without_target_removed(self): public class SetupTest { @Test void testSetup() { - Object _cf_result1 = Fibonacci.fibonacci(10); + int _cf_result1 = Fibonacci.fibonacci(10); } } """ @@ -829,7 +829,7 @@ def test_multiline_assertion(self): public class FibonacciTest { @Test void testFibonacci() { - Object _cf_result1 = Fibonacci.fibonacci(10); + int _cf_result1 = Fibonacci.fibonacci(10); } } """ @@ -855,7 +855,7 @@ def test_assertion_with_string_containing_parens(self): public class ParserTest { @Test void testParse() { - Object _cf_result1 = parser.parse("input(1)"); + String _cf_result1 = parser.parse("input(1)"); } } """ @@ -911,7 +911,7 @@ def test_nested_method_calls(self): public class FibonacciTest { @Test void testIndex() { - Object _cf_result1 = Fibonacci.fibonacciIndex(Fibonacci.fibonacci(10)); + int _cf_result1 = Fibonacci.fibonacciIndex(Fibonacci.fibonacci(10)); } } """ @@ -937,7 +937,7 @@ def test_chained_method_on_result(self): public class FibonacciTest { @Test void testUpTo() { - Object _cf_result1 = Fibonacci.fibonacciUpTo(20); + int _cf_result1 = Fibonacci.fibonacciUpTo(20); } } """ @@ -1053,24 +1053,24 @@ class TestBitSetLikeQuestDB: @Test public void testGet_IndexZero_ReturnsFalse() { - Object _cf_result1 = instance.get(0); + boolean _cf_result1 = instance.get(0); } @Test public void testGet_SpecificIndexWithinRange_ReturnsFalse() { - Object _cf_result2 = instance.get(100); + boolean _cf_result2 = instance.get(100); } @Test public void testGet_LastIndexOfInitialRange_ReturnsFalse() { int lastIndex = 16 * BitSet.BITS_PER_WORD - 1; - Object _cf_result3 = instance.get(lastIndex); + boolean _cf_result3 = instance.get(lastIndex); } @Test public void testGet_IndexBeyondAllocated_ReturnsFalse() { int beyond = 16 * BitSet.BITS_PER_WORD; - Object _cf_result4 = instance.get(beyond); + boolean _cf_result4 = instance.get(beyond); } @Test(expected = ArrayIndexOutOfBoundsException.class) @@ -1086,22 +1086,22 @@ class TestBitSetLikeQuestDB: long[] words = new long[2]; words[1] = 1L << 3; wordsField.set(bs, words); - Object _cf_result5 = bs.get(64 + 3); + boolean _cf_result5 = bs.get(64 + 3); } @Test public void testGet_LargeIndexDoesNotThrow_ReturnsFalse() { - Object _cf_result6 = instance.get(Integer.MAX_VALUE); + boolean _cf_result6 = instance.get(Integer.MAX_VALUE); } @Test public void testGet_BitBoundaryWordEdge63_ReturnsFalse() { - Object _cf_result7 = instance.get(63); + boolean _cf_result7 = instance.get(63); } @Test public void testGet_BitBoundaryWordEdge64_ReturnsFalse() { - Object _cf_result8 = instance.get(64); + boolean _cf_result8 = instance.get(64); } @Test @@ -1109,7 +1109,7 @@ class TestBitSetLikeQuestDB: int nBits = 1_000_000; BitSet big = new BitSet(nBits); int last = nBits - 1; - Object _cf_result9 = big.get(last); + boolean _cf_result9 = big.get(last); } } """ @@ -1240,15 +1240,15 @@ def test_counters_assigned_in_source_order(self): public class FibTest { @Test void testA() { - Object _cf_result1 = Fibonacci.fibonacci(0); + int _cf_result1 = Fibonacci.fibonacci(0); } @Test void testB() { - Object _cf_result2 = Fibonacci.fibonacci(10); + int _cf_result2 = Fibonacci.fibonacci(10); } @Test void testC() { - Object _cf_result3 = Fibonacci.fibonacci(1); + int _cf_result3 = Fibonacci.fibonacci(1); } } """ @@ -1329,7 +1329,7 @@ def test_non_nested_assertions_all_replaced(self): public class FibTest { @Test void test() { - Object _cf_result1 = Fibonacci.fibonacci(0); + int _cf_result1 = Fibonacci.fibonacci(0); } } """ @@ -1362,11 +1362,11 @@ def test_reverse_replacement_preserves_all_positions(self): public class CalcTest { @Test void test() { - Object _cf_result1 = engine.compute(1); - Object _cf_result2 = engine.compute(2); - Object _cf_result3 = engine.compute(3); - Object _cf_result4 = engine.compute(4); - Object _cf_result5 = engine.compute(5); + int _cf_result1 = engine.compute(1); + int _cf_result2 = engine.compute(2); + int _cf_result3 = engine.compute(3); + int _cf_result4 = engine.compute(4); + int _cf_result5 = engine.compute(5); } } """ @@ -1400,8 +1400,8 @@ def test_mixed_assertions_all_removed(self): public class FibTest { @Test void test() { - Object _cf_result1 = Fibonacci.fibonacci(0); - Object _cf_result2 = Fibonacci.fibonacci(1); + int _cf_result1 = Fibonacci.fibonacci(0); + int _cf_result2 = Fibonacci.fibonacci(1); } } """ @@ -1461,7 +1461,7 @@ def test_single_assertion_exact_output(self): public class FibTest { @Test void test() { - Object _cf_result1 = Fibonacci.fibonacci(10); + int _cf_result1 = Fibonacci.fibonacci(10); } } """ @@ -1489,8 +1489,8 @@ def test_multiple_assertions_exact_output(self): public class CalcTest { @Test void test() { - Object _cf_result1 = calc.add(1, 2); - Object _cf_result2 = calc.add(3, 4); + int _cf_result1 = calc.add(1, 2); + int _cf_result2 = calc.add(3, 4); } } """ @@ -1546,12 +1546,12 @@ def test_invocation_counter_increments(self): public class FibTest { @Test void test1() { - Object _cf_result1 = Fibonacci.fibonacci(0); + int _cf_result1 = Fibonacci.fibonacci(0); } @Test void test2() { - Object _cf_result2 = Fibonacci.fibonacci(10); + int _cf_result2 = Fibonacci.fibonacci(10); } } """ @@ -1784,9 +1784,9 @@ class TestAllAssertionsRemoved: @Test void testFibonacci() { - Object _cf_result1 = Fibonacci.fibonacci(0); - Object _cf_result2 = Fibonacci.fibonacci(1); - Object _cf_result3 = Fibonacci.fibonacci(5); + int _cf_result1 = Fibonacci.fibonacci(0); + int _cf_result2 = Fibonacci.fibonacci(1); + int _cf_result3 = Fibonacci.fibonacci(5); } @Test @@ -1846,7 +1846,7 @@ def test_preserves_non_assertion_code(self): void testAdd() { Calculator calc = new Calculator(); int result = calc.setup(); - Object _cf_result1 = calc.add(2, 3); + int _cf_result1 = calc.add(2, 3); } } """ @@ -1902,7 +1902,7 @@ def test_mixed_frameworks_all_removed(self): public class MixedTest { @Test void test() { - Object _cf_result1 = obj.target(1); + int _cf_result1 = obj.target(1); } } """