Skip to content

Commit 9321cbc

Browse files
authored
Convert assertThrows with an Executable variable (#1046)
* Convert `assertThrows` with an `Executable` variable, retyping to `ThrowingCallable` Fixes #511 * Address review: inline safe via reduce; drop unused import; hoist ThrowingCallable type
1 parent 0eb4765 commit 9321cbc

2 files changed

Lines changed: 257 additions & 1 deletion

File tree

src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
package org.openrewrite.java.testing.assertj;
1717

1818
import lombok.Getter;
19+
import lombok.RequiredArgsConstructor;
1920
import org.openrewrite.ExecutionContext;
2021
import org.openrewrite.Preconditions;
2122
import org.openrewrite.Recipe;
23+
import org.openrewrite.Tree;
2224
import org.openrewrite.TreeVisitor;
2325
import org.openrewrite.java.JavaIsoVisitor;
2426
import org.openrewrite.java.JavaParser;
@@ -27,14 +29,25 @@
2729
import org.openrewrite.java.search.UsesMethod;
2830
import org.openrewrite.java.tree.Expression;
2931
import org.openrewrite.java.tree.J;
32+
import org.openrewrite.java.tree.JavaType;
33+
import org.openrewrite.java.tree.Space;
34+
import org.openrewrite.java.tree.TypeTree;
35+
import org.openrewrite.java.tree.TypeUtils;
36+
import org.openrewrite.marker.Markers;
3037

3138
import java.util.List;
39+
import java.util.Objects;
3240
import java.util.Optional;
41+
import java.util.concurrent.atomic.AtomicBoolean;
42+
43+
import static java.util.Collections.emptyList;
3344

3445
public class JUnitAssertThrowsToAssertExceptionType extends Recipe {
3546

3647
private static final String JUNIT_ASSERTIONS = "org.junit.jupiter.api.Assertions";
3748
private static final String ASSERTJ_ASSERTIONS = "org.assertj.core.api.Assertions";
49+
private static final String JUNIT_EXECUTABLE = "org.junit.jupiter.api.function.Executable";
50+
private static final String THROWING_CALLABLE = "org.assertj.core.api.ThrowableAssert$ThrowingCallable";
3851
private static final MethodMatcher ASSERT_THROWS_MATCHER = new MethodMatcher(JUNIT_ASSERTIONS + " assertThrows(..)");
3952

4053
@Getter
@@ -60,13 +73,33 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
6073
return mi;
6174
}
6275

76+
List<Expression> args = mi.getArguments();
77+
78+
// When the executable is a variable typed as JUnit's `Executable`, AssertJ's `isThrownBy`
79+
// expects a `ThrowingCallable` instead, so the variable declaration has to be retyped to keep
80+
// it compiling. That's only safe when every usage of the variable is an `assertThrows`
81+
// executable argument; otherwise leave the call untouched as the types would conflict.
82+
Expression executable = args.get(1);
83+
JavaType.Variable executableVariable = null;
84+
if (executable instanceof J.Identifier) {
85+
JavaType.Variable fieldType = ((J.Identifier) executable).getFieldType();
86+
if (fieldType != null && TypeUtils.isOfClassType(fieldType.getType(), JUNIT_EXECUTABLE)) {
87+
if (!allUsagesAreAssertThrowsExecutable(fieldType)) {
88+
return mi;
89+
}
90+
executableVariable = fieldType;
91+
}
92+
}
93+
6394
boolean returnActual = hasReturnType.get();
6495

6596
maybeRemoveImport(JUNIT_ASSERTIONS);
6697
maybeRemoveImport(JUNIT_ASSERTIONS + ".assertThrows");
6798
maybeAddImport(ASSERTJ_ASSERTIONS, "assertThatExceptionOfType");
6899

69-
List<Expression> args = mi.getArguments();
100+
if (executableVariable != null) {
101+
doAfterVisit(new RetypeExecutableVariable(executableVariable));
102+
}
70103

71104
if (args.size() == 2) {
72105
String code = "assertThatExceptionOfType(#{any(java.lang.Class)}).isThrownBy(#{any(org.assertj.core.api.ThrowableAssert.ThrowingCallable)})";
@@ -120,6 +153,84 @@ private Optional<Boolean> hasReturnType() {
120153
// Unknown parent type so not supported
121154
return Optional.empty();
122155
}
156+
157+
/**
158+
* Determine whether every reference to the given variable can safely be retyped from JUnit's
159+
* {@code Executable} to AssertJ's {@code ThrowingCallable}. That is the case when each reference is
160+
* either the variable's own declaration/assignment target or the executable argument of an
161+
* {@code assertThrows} call (all of which are converted by this recipe). Any other usage would still
162+
* require an {@code Executable} and conflict with the retyped variable.
163+
*/
164+
private boolean allUsagesAreAssertThrowsExecutable(JavaType.Variable variable) {
165+
J.CompilationUnit cu = getCursor().firstEnclosing(J.CompilationUnit.class);
166+
if (cu == null) {
167+
return false;
168+
}
169+
return new JavaIsoVisitor<AtomicBoolean>() {
170+
@Override
171+
public J.Identifier visitIdentifier(J.Identifier identifier, AtomicBoolean safe) {
172+
if (safe.get() && Objects.equals(identifier.getFieldType(), variable)) {
173+
Object parent = getCursor().getParentTreeCursor().getValue();
174+
boolean ok = false;
175+
if (parent instanceof J.VariableDeclarations.NamedVariable) {
176+
ok = ((J.VariableDeclarations.NamedVariable) parent).getName() == identifier;
177+
} else if (parent instanceof J.Assignment) {
178+
ok = ((J.Assignment) parent).getVariable() == identifier;
179+
} else if (parent instanceof J.MethodInvocation) {
180+
J.MethodInvocation enclosing = (J.MethodInvocation) parent;
181+
ok = ASSERT_THROWS_MATCHER.matches(enclosing) &&
182+
enclosing.getArguments().size() >= 2 &&
183+
enclosing.getArguments().get(1) == identifier;
184+
}
185+
if (!ok) {
186+
safe.set(false);
187+
}
188+
}
189+
return super.visitIdentifier(identifier, safe);
190+
}
191+
}.reduce(cu, new AtomicBoolean(true)).get();
192+
}
123193
});
124194
}
195+
196+
/**
197+
* Retypes the declaration of an {@code org.junit.jupiter.api.function.Executable} variable to
198+
* {@code org.assertj.core.api.ThrowableAssert.ThrowingCallable}, so that a variable previously passed to
199+
* {@code assertThrows} keeps compiling when passed to AssertJ's {@code isThrownBy}.
200+
*/
201+
@RequiredArgsConstructor
202+
private static class RetypeExecutableVariable extends JavaIsoVisitor<ExecutionContext> {
203+
204+
private static final JavaType.ShallowClass THROWING_CALLABLE_TYPE = JavaType.ShallowClass.build(THROWING_CALLABLE);
205+
206+
private final JavaType.Variable variable;
207+
208+
@Override
209+
public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations multiVariable, ExecutionContext ctx) {
210+
J.VariableDeclarations mv = super.visitVariableDeclarations(multiVariable, ctx);
211+
if (!TypeUtils.isOfClassType(mv.getTypeAsFullyQualified(), JUNIT_EXECUTABLE) ||
212+
mv.getVariables().stream().noneMatch(v -> Objects.equals(v.getVariableType(), variable))) {
213+
return mv;
214+
}
215+
216+
maybeRemoveImport(JUNIT_EXECUTABLE);
217+
maybeAddImport(THROWING_CALLABLE);
218+
219+
TypeTree typeExpression = mv.getTypeExpression();
220+
J.Identifier newTypeExpression = new J.Identifier(Tree.randomId(),
221+
typeExpression == null ? Space.EMPTY : typeExpression.getPrefix(),
222+
Markers.EMPTY, emptyList(), "ThrowingCallable", THROWING_CALLABLE_TYPE, null);
223+
return mv.withTypeExpression(newTypeExpression);
224+
}
225+
226+
@Override
227+
public J.Identifier visitIdentifier(J.Identifier identifier, ExecutionContext ctx) {
228+
J.Identifier id = super.visitIdentifier(identifier, ctx);
229+
JavaType.Variable fieldType = id.getFieldType();
230+
if (Objects.equals(fieldType, variable)) {
231+
return id.withType(THROWING_CALLABLE_TYPE).withFieldType(fieldType.withType(THROWING_CALLABLE_TYPE));
232+
}
233+
return id;
234+
}
235+
}
125236
}

src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionTypeTest.java

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,151 @@ void foo() {
7272
);
7373
}
7474

75+
@Test
76+
void variableExecutable() {
77+
//language=java
78+
rewriteRun(
79+
java(
80+
"""
81+
import org.junit.jupiter.api.function.Executable;
82+
83+
import static org.junit.jupiter.api.Assertions.assertThrows;
84+
85+
public class SimpleExpectedExceptionTest {
86+
public void throwsExceptionWithSpecificType() {
87+
Executable executable = () -> foo();
88+
assertThrows(NullPointerException.class, executable);
89+
}
90+
void foo() {
91+
throw new NullPointerException();
92+
}
93+
}
94+
""",
95+
"""
96+
import org.assertj.core.api.ThrowableAssert.ThrowingCallable;
97+
98+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
99+
100+
public class SimpleExpectedExceptionTest {
101+
public void throwsExceptionWithSpecificType() {
102+
ThrowingCallable executable = () -> foo();
103+
assertThatExceptionOfType(NullPointerException.class).isThrownBy(executable);
104+
}
105+
void foo() {
106+
throw new NullPointerException();
107+
}
108+
}
109+
"""
110+
)
111+
);
112+
}
113+
114+
@Test
115+
void variableExecutableWithMessage() {
116+
//language=java
117+
rewriteRun(
118+
java(
119+
"""
120+
import org.junit.jupiter.api.function.Executable;
121+
122+
import static org.junit.jupiter.api.Assertions.assertThrows;
123+
124+
public class SimpleExpectedExceptionTest {
125+
public void throwsExceptionWithSpecificType() {
126+
Executable executable = () -> foo();
127+
assertThrows(NullPointerException.class, executable, "message");
128+
}
129+
void foo() {
130+
throw new NullPointerException();
131+
}
132+
}
133+
""",
134+
"""
135+
import org.assertj.core.api.ThrowableAssert.ThrowingCallable;
136+
137+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
138+
139+
public class SimpleExpectedExceptionTest {
140+
public void throwsExceptionWithSpecificType() {
141+
ThrowingCallable executable = () -> foo();
142+
assertThatExceptionOfType(NullPointerException.class).as("message").isThrownBy(executable);
143+
}
144+
void foo() {
145+
throw new NullPointerException();
146+
}
147+
}
148+
"""
149+
)
150+
);
151+
}
152+
153+
@Test
154+
void fieldExecutable() {
155+
//language=java
156+
rewriteRun(
157+
java(
158+
"""
159+
import org.junit.jupiter.api.function.Executable;
160+
161+
import static org.junit.jupiter.api.Assertions.assertThrows;
162+
163+
public class SimpleExpectedExceptionTest {
164+
private final Executable executable = () -> foo();
165+
166+
public void throwsExceptionWithSpecificType() {
167+
assertThrows(NullPointerException.class, executable);
168+
}
169+
void foo() {
170+
throw new NullPointerException();
171+
}
172+
}
173+
""",
174+
"""
175+
import org.assertj.core.api.ThrowableAssert.ThrowingCallable;
176+
177+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
178+
179+
public class SimpleExpectedExceptionTest {
180+
private final ThrowingCallable executable = () -> foo();
181+
182+
public void throwsExceptionWithSpecificType() {
183+
assertThatExceptionOfType(NullPointerException.class).isThrownBy(executable);
184+
}
185+
void foo() {
186+
throw new NullPointerException();
187+
}
188+
}
189+
"""
190+
)
191+
);
192+
}
193+
194+
@Test
195+
void doNotRetypeExecutableSharedWithOtherJUnitAssertion() {
196+
//language=java
197+
rewriteRun(
198+
java(
199+
"""
200+
import org.junit.jupiter.api.function.Executable;
201+
202+
import static org.junit.jupiter.api.Assertions.assertAll;
203+
import static org.junit.jupiter.api.Assertions.assertThrows;
204+
205+
public class SimpleExpectedExceptionTest {
206+
public void throwsExceptionWithSpecificType() {
207+
Executable executable = () -> foo();
208+
assertThrows(NullPointerException.class, executable);
209+
assertAll(executable);
210+
}
211+
void foo() {
212+
throw new NullPointerException();
213+
}
214+
}
215+
"""
216+
)
217+
);
218+
}
219+
75220
@Test
76221
void memberReference() {
77222
//language=java

0 commit comments

Comments
 (0)