Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.TypeUtils;

public class AssertTrueComparisonToAssertEquals extends Recipe {
private static final MethodMatcher ASSERT_TRUE = new MethodMatcher(
Expand Down Expand Up @@ -99,14 +100,36 @@ private boolean isEqualBinary(J.MethodInvocation method) {
return false;
}

// Prevent breaking identity comparison.
// Objects that are compared with == should not be compared with `.equals()` instead.
// Out of the primitives == is not allowed when both are of type String
return binary.getLeft().getType() instanceof JavaType.Primitive &&
binary.getRight().getType() instanceof JavaType.Primitive &&
!(binary.getLeft().getType() == JavaType.Primitive.String &&
binary.getRight().getType() == JavaType.Primitive.String);
// Prevent breaking identity comparison: wrapper-wrapper and reference-reference == are
// reference equality and belong to UseAssertSame, not assertEquals. We rewrite when:
// - both operands are primitive (excluding String == String, which is reference identity), OR
// - one operand is primitive and the other a numeric wrapper (Java unboxes; == is value).
JavaType leftType = binary.getLeft().getType();
JavaType rightType = binary.getRight().getType();
if (leftType instanceof JavaType.Primitive && rightType instanceof JavaType.Primitive) {
return !(leftType == JavaType.Primitive.String && rightType == JavaType.Primitive.String);
}
return (isPrimitiveValueType(leftType) && isNumericWrapper(rightType)) ||
(isNumericWrapper(leftType) && isPrimitiveValueType(rightType));
}
});
}

private static boolean isPrimitiveValueType(JavaType type) {
return type instanceof JavaType.Primitive &&
type != JavaType.Primitive.Null &&
type != JavaType.Primitive.String &&
type != JavaType.Primitive.None;
}

private static boolean isNumericWrapper(JavaType type) {
return TypeUtils.isOfClassType(type, "java.lang.Boolean") ||
TypeUtils.isOfClassType(type, "java.lang.Byte") ||
TypeUtils.isOfClassType(type, "java.lang.Character") ||
TypeUtils.isOfClassType(type, "java.lang.Short") ||
TypeUtils.isOfClassType(type, "java.lang.Integer") ||
TypeUtils.isOfClassType(type, "java.lang.Long") ||
TypeUtils.isOfClassType(type, "java.lang.Float") ||
TypeUtils.isOfClassType(type, "java.lang.Double");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ public class UseAssertSame extends Recipe {

@Getter
final String description = "Prefers the usage of `assertSame` or `assertNotSame` methods instead of using of vanilla `assertTrue` " +
"or `assertFalse` with a boolean comparison.";
"or `assertFalse` with a boolean comparison. Only applies when both operands are reference types — " +
"primitive operands are handled by `AssertTrueComparisonToAssertEquals`.";

private static final MethodMatcher ASSERT_TRUE_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions assertTrue(..)");
private static final MethodMatcher ASSERT_FALSE_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions assertFalse(..)");
Expand Down Expand Up @@ -67,9 +68,10 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocat
binary.getRight().getType() == JavaType.Primitive.Null) {
return mi;
}
// Skip primitive comparisons — `==` is value equality, not reference equality
if (binary.getLeft().getType() instanceof JavaType.Primitive ||
binary.getRight().getType() instanceof JavaType.Primitive) {
// Skip when either operand has value-equality semantics under == (primitives, plus
// the wrapper-primitive case where Java unboxes). Defer to AssertTrueComparisonToAssertEquals.
if (isPrimitiveValueType(binary.getLeft().getType()) ||
isPrimitiveValueType(binary.getRight().getType())) {
return mi;
}
List<Expression> newArguments = new ArrayList<>();
Expand Down Expand Up @@ -111,4 +113,12 @@ private JavaType.Method assertSameMethodType(J.MethodInvocation mi, String newMe
visitor);
}

// String is a JavaType.Primitive enum value but its == is reference equality (assertSame is correct).
private static boolean isPrimitiveValueType(JavaType type) {
return type instanceof JavaType.Primitive &&
type != JavaType.Primitive.Null &&
type != JavaType.Primitive.String &&
type != JavaType.Primitive.None;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,57 @@ void test() {
);
}

@SuppressWarnings({"NumberEquality", "SimplifiableAssertion"})
@Test
void onlyRewritesWhenAtLeastOneOperandIsPrimitive() {
// Mixed primitive-wrapper rewrites because Java unboxes (== is value equality).
// Wrapper-wrapper is reference equality and stays for UseAssertSame to emit assertSame.
//language=java
rewriteRun(
java(
"""
import org.junit.jupiter.api.Assertions;

import java.util.HashMap;
import java.util.Map;

public class Test {
void test() {
Map<String, Double> map = new HashMap<>();
map.put("k", 0.5);
double d = 0.5;
int i = 5;
Integer iBoxed = 5;
Double a = 0.5;
Double b = a;
Assertions.assertTrue(map.get("k") == d);
Assertions.assertTrue(i == iBoxed);
Assertions.assertTrue(a == b);
}
}
""",
"""
import org.junit.jupiter.api.Assertions;

import java.util.HashMap;
import java.util.Map;

public class Test {
void test() {
Map<String, Double> map = new HashMap<>();
map.put("k", 0.5);
double d = 0.5;
int i = 5;
Integer iBoxed = 5;
Double a = 0.5;
Double b = a;
Assertions.assertEquals(map.get("k"), d);
Assertions.assertEquals(i, iBoxed);
Assertions.assertTrue(a == b);
}
}
"""
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,50 @@ public void test(Object obj) {
)
);
}

@Test
void onlyConvertsReferenceComparisons() {
// Wrappers are reference types so == is reference equality (assertSame).
// With a primitive operand Java unboxes, so defer to AssertTrueComparisonToAssertEquals.
//language=java
rewriteRun(
java(
"""
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertTrue;

class MyTest {

@Test
public void test() {
int primitive = 5;
Integer a = 5;
Integer b = a;
assertTrue(primitive == a);
assertTrue(a == b);
}
}
""",
"""
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;

class MyTest {

@Test
public void test() {
int primitive = 5;
Integer a = 5;
Integer b = a;
assertTrue(primitive == a);
assertSame(a, b);
}
}
"""
)
);
}
}
Loading