1616package org .openrewrite .java .testing .assertj ;
1717
1818import lombok .Getter ;
19+ import lombok .RequiredArgsConstructor ;
1920import org .openrewrite .ExecutionContext ;
2021import org .openrewrite .Preconditions ;
2122import org .openrewrite .Recipe ;
23+ import org .openrewrite .Tree ;
2224import org .openrewrite .TreeVisitor ;
2325import org .openrewrite .java .JavaIsoVisitor ;
2426import org .openrewrite .java .JavaParser ;
2729import org .openrewrite .java .search .UsesMethod ;
2830import org .openrewrite .java .tree .Expression ;
2931import org .openrewrite .java .tree .J ;
32+ import org .openrewrite .java .tree .JavaType ;
33+ import org .openrewrite .java .tree .Space ;
34+ import org .openrewrite .java .tree .TypeTree ;
35+ import org .openrewrite .java .tree .TypeUtils ;
36+ import org .openrewrite .marker .Markers ;
3037
3138import java .util .List ;
39+ import java .util .Objects ;
3240import java .util .Optional ;
41+ import java .util .concurrent .atomic .AtomicBoolean ;
42+
43+ import static java .util .Collections .emptyList ;
3344
3445public class JUnitAssertThrowsToAssertExceptionType extends Recipe {
3546
3647 private static final String JUNIT_ASSERTIONS = "org.junit.jupiter.api.Assertions" ;
3748 private static final String ASSERTJ_ASSERTIONS = "org.assertj.core.api.Assertions" ;
49+ private static final String JUNIT_EXECUTABLE = "org.junit.jupiter.api.function.Executable" ;
50+ private static final String THROWING_CALLABLE = "org.assertj.core.api.ThrowableAssert$ThrowingCallable" ;
3851 private static final MethodMatcher ASSERT_THROWS_MATCHER = new MethodMatcher (JUNIT_ASSERTIONS + " assertThrows(..)" );
3952
4053 @ Getter
@@ -60,13 +73,33 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
6073 return mi ;
6174 }
6275
76+ List <Expression > args = mi .getArguments ();
77+
78+ // When the executable is a variable typed as JUnit's `Executable`, AssertJ's `isThrownBy`
79+ // expects a `ThrowingCallable` instead, so the variable declaration has to be retyped to keep
80+ // it compiling. That's only safe when every usage of the variable is an `assertThrows`
81+ // executable argument; otherwise leave the call untouched as the types would conflict.
82+ Expression executable = args .get (1 );
83+ JavaType .Variable executableVariable = null ;
84+ if (executable instanceof J .Identifier ) {
85+ JavaType .Variable fieldType = ((J .Identifier ) executable ).getFieldType ();
86+ if (fieldType != null && TypeUtils .isOfClassType (fieldType .getType (), JUNIT_EXECUTABLE )) {
87+ if (!allUsagesAreAssertThrowsExecutable (fieldType )) {
88+ return mi ;
89+ }
90+ executableVariable = fieldType ;
91+ }
92+ }
93+
6394 boolean returnActual = hasReturnType .get ();
6495
6596 maybeRemoveImport (JUNIT_ASSERTIONS );
6697 maybeRemoveImport (JUNIT_ASSERTIONS + ".assertThrows" );
6798 maybeAddImport (ASSERTJ_ASSERTIONS , "assertThatExceptionOfType" );
6899
69- List <Expression > args = mi .getArguments ();
100+ if (executableVariable != null ) {
101+ doAfterVisit (new RetypeExecutableVariable (executableVariable ));
102+ }
70103
71104 if (args .size () == 2 ) {
72105 String code = "assertThatExceptionOfType(#{any(java.lang.Class)}).isThrownBy(#{any(org.assertj.core.api.ThrowableAssert.ThrowingCallable)})" ;
@@ -120,6 +153,84 @@ private Optional<Boolean> hasReturnType() {
120153 // Unknown parent type so not supported
121154 return Optional .empty ();
122155 }
156+
157+ /**
158+ * Determine whether every reference to the given variable can safely be retyped from JUnit's
159+ * {@code Executable} to AssertJ's {@code ThrowingCallable}. That is the case when each reference is
160+ * either the variable's own declaration/assignment target or the executable argument of an
161+ * {@code assertThrows} call (all of which are converted by this recipe). Any other usage would still
162+ * require an {@code Executable} and conflict with the retyped variable.
163+ */
164+ private boolean allUsagesAreAssertThrowsExecutable (JavaType .Variable variable ) {
165+ J .CompilationUnit cu = getCursor ().firstEnclosing (J .CompilationUnit .class );
166+ if (cu == null ) {
167+ return false ;
168+ }
169+ return new JavaIsoVisitor <AtomicBoolean >() {
170+ @ Override
171+ public J .Identifier visitIdentifier (J .Identifier identifier , AtomicBoolean safe ) {
172+ if (safe .get () && Objects .equals (identifier .getFieldType (), variable )) {
173+ Object parent = getCursor ().getParentTreeCursor ().getValue ();
174+ boolean ok = false ;
175+ if (parent instanceof J .VariableDeclarations .NamedVariable ) {
176+ ok = ((J .VariableDeclarations .NamedVariable ) parent ).getName () == identifier ;
177+ } else if (parent instanceof J .Assignment ) {
178+ ok = ((J .Assignment ) parent ).getVariable () == identifier ;
179+ } else if (parent instanceof J .MethodInvocation ) {
180+ J .MethodInvocation enclosing = (J .MethodInvocation ) parent ;
181+ ok = ASSERT_THROWS_MATCHER .matches (enclosing ) &&
182+ enclosing .getArguments ().size () >= 2 &&
183+ enclosing .getArguments ().get (1 ) == identifier ;
184+ }
185+ if (!ok ) {
186+ safe .set (false );
187+ }
188+ }
189+ return super .visitIdentifier (identifier , safe );
190+ }
191+ }.reduce (cu , new AtomicBoolean (true )).get ();
192+ }
123193 });
124194 }
195+
196+ /**
197+ * Retypes the declaration of an {@code org.junit.jupiter.api.function.Executable} variable to
198+ * {@code org.assertj.core.api.ThrowableAssert.ThrowingCallable}, so that a variable previously passed to
199+ * {@code assertThrows} keeps compiling when passed to AssertJ's {@code isThrownBy}.
200+ */
201+ @ RequiredArgsConstructor
202+ private static class RetypeExecutableVariable extends JavaIsoVisitor <ExecutionContext > {
203+
204+ private static final JavaType .ShallowClass THROWING_CALLABLE_TYPE = JavaType .ShallowClass .build (THROWING_CALLABLE );
205+
206+ private final JavaType .Variable variable ;
207+
208+ @ Override
209+ public J .VariableDeclarations visitVariableDeclarations (J .VariableDeclarations multiVariable , ExecutionContext ctx ) {
210+ J .VariableDeclarations mv = super .visitVariableDeclarations (multiVariable , ctx );
211+ if (!TypeUtils .isOfClassType (mv .getTypeAsFullyQualified (), JUNIT_EXECUTABLE ) ||
212+ mv .getVariables ().stream ().noneMatch (v -> Objects .equals (v .getVariableType (), variable ))) {
213+ return mv ;
214+ }
215+
216+ maybeRemoveImport (JUNIT_EXECUTABLE );
217+ maybeAddImport (THROWING_CALLABLE );
218+
219+ TypeTree typeExpression = mv .getTypeExpression ();
220+ J .Identifier newTypeExpression = new J .Identifier (Tree .randomId (),
221+ typeExpression == null ? Space .EMPTY : typeExpression .getPrefix (),
222+ Markers .EMPTY , emptyList (), "ThrowingCallable" , THROWING_CALLABLE_TYPE , null );
223+ return mv .withTypeExpression (newTypeExpression );
224+ }
225+
226+ @ Override
227+ public J .Identifier visitIdentifier (J .Identifier identifier , ExecutionContext ctx ) {
228+ J .Identifier id = super .visitIdentifier (identifier , ctx );
229+ JavaType .Variable fieldType = id .getFieldType ();
230+ if (Objects .equals (fieldType , variable )) {
231+ return id .withType (THROWING_CALLABLE_TYPE ).withFieldType (fieldType .withType (THROWING_CALLABLE_TYPE ));
232+ }
233+ return id ;
234+ }
235+ }
125236}
0 commit comments