Skip to content

Commit 32c7752

Browse files
authored
Handle Kotlin sources in MockitoWhenOnStaticToMockStatic (#1003)
1 parent 6e33f5b commit 32c7752

2 files changed

Lines changed: 356 additions & 26 deletions

File tree

src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java

Lines changed: 140 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
import org.openrewrite.java.*;
2323
import org.openrewrite.java.search.UsesMethod;
2424
import org.openrewrite.java.tree.*;
25+
import org.openrewrite.kotlin.KotlinIsoVisitor;
26+
import org.openrewrite.kotlin.KotlinParser;
27+
import org.openrewrite.kotlin.KotlinTemplate;
28+
import org.openrewrite.kotlin.tree.K;
2529

2630
import java.util.*;
2731
import java.util.concurrent.atomic.AtomicBoolean;
@@ -60,7 +64,23 @@ public class MockitoWhenOnStaticToMockStatic extends Recipe {
6064

6165
@Override
6266
public TreeVisitor<?, ExecutionContext> getVisitor() {
63-
return Preconditions.check(new UsesMethod<>(MOCKITO_WHEN), new JavaIsoVisitor<ExecutionContext>() {
67+
return Preconditions.check(new UsesMethod<>(MOCKITO_WHEN), new TreeVisitor<Tree, ExecutionContext>() {
68+
@Override
69+
public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) {
70+
stopAfterPreVisit();
71+
if (tree instanceof J.CompilationUnit) {
72+
return javaVisitor().visit(tree, ctx);
73+
}
74+
if (tree instanceof K.CompilationUnit) {
75+
return kotlinVisitor().visit(tree, ctx);
76+
}
77+
return tree;
78+
}
79+
});
80+
}
81+
82+
private JavaIsoVisitor<ExecutionContext> javaVisitor() {
83+
return new JavaIsoVisitor<ExecutionContext>() {
6484
@Override
6585
public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
6686
J.MethodDeclaration containingMethod = getCursor().firstEnclosing(J.MethodDeclaration.class);
@@ -116,30 +136,6 @@ private List<Statement> maybeWrapStatementsInTryWithResourcesMockedStatic(J.Bloc
116136
});
117137
}
118138

119-
private J.@Nullable MethodInvocation getWhenArg(Statement statement) {
120-
if (statement instanceof J.MethodInvocation && MOCKITO_WHEN.matches(((J.MethodInvocation) statement).getSelect())) {
121-
J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) statement).getSelect();
122-
if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) {
123-
J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0);
124-
if (whenArg.getMethodType() != null && whenArg.getMethodType().hasFlags(Static)) {
125-
return whenArg;
126-
}
127-
}
128-
}
129-
return null;
130-
}
131-
132-
private JavaType.@Nullable Class getTypeFromInvocation(J.MethodInvocation whenArg) {
133-
J.Identifier clazz = null;
134-
// Having a fieldType implies that something is a field rather than a class itself
135-
if (whenArg.getSelect() instanceof J.Identifier && ((J.Identifier) whenArg.getSelect()).getFieldType() == null) {
136-
clazz = (J.Identifier) whenArg.getSelect();
137-
} else if (whenArg.getSelect() instanceof J.FieldAccess && ((J.FieldAccess) whenArg.getSelect()).getTarget() instanceof J.Identifier) {
138-
clazz = (J.Identifier) ((J.FieldAccess) whenArg.getSelect()).getTarget();
139-
}
140-
return clazz != null && clazz.getType() != null ? (JavaType.Class) clazz.getType() : null;
141-
}
142-
143139
private J.Try tryWithMockedStatic(J.Block block, List<Statement> statements, Integer index,
144140
J.MethodInvocation statement, String className, J.MethodInvocation whenArg, ExecutionContext ctx) {
145141
String variableName = generateVariableName("mock" + className + ++varCounter, updateCursor(block), INCREMENT_NUMBER);
@@ -258,7 +254,125 @@ private JavaTemplate javaTemplateMockStatic(String code, ExecutionContext ctx) {
258254
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "mockito-core-5"))
259255
.build();
260256
}
261-
});
257+
};
258+
}
259+
260+
private KotlinIsoVisitor<ExecutionContext> kotlinVisitor() {
261+
return new KotlinIsoVisitor<ExecutionContext>() {
262+
@Override
263+
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
264+
J.MethodInvocation m = super.visitMethodInvocation(method, ctx);
265+
J.MethodInvocation whenArg = getWhenArg(m);
266+
if (whenArg == null) {
267+
return m;
268+
}
269+
JavaType.Class invokedType = getTypeFromInvocation(whenArg);
270+
if (invokedType == null) {
271+
return m;
272+
}
273+
274+
String paramName = findEnclosingMockStaticUseParam(getCursor(), invokedType);
275+
if (paramName == null) {
276+
J.Identifier classProperty = findMockedStaticVariable(getCursor(), invokedType);
277+
if (classProperty == null) {
278+
return m;
279+
}
280+
paramName = classProperty.getSimpleName();
281+
}
282+
283+
String returnTypeName = getReturnTypeName(whenArg);
284+
Expression thenReturnArg = m.getArguments().get(0);
285+
J.MethodInvocation rewritten = KotlinTemplate.builder(String.format(
286+
"%s.`when`<%s> { #{any()} }.thenReturn(#{any()})",
287+
paramName, returnTypeName))
288+
.imports("org.mockito.MockedStatic")
289+
.parser(KotlinParser.builder().classpathFromResources(ctx, "mockito-core-5"))
290+
.build()
291+
.apply(getCursor(), m.getCoordinates().replace(), whenArg, thenReturnArg);
292+
maybeRemoveImport("org.mockito.Mockito.when");
293+
return rewritten;
294+
}
295+
};
296+
}
297+
298+
private static J.@Nullable MethodInvocation getWhenArg(Statement statement) {
299+
if (statement instanceof J.MethodInvocation && MOCKITO_WHEN.matches(((J.MethodInvocation) statement).getSelect())) {
300+
J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) statement).getSelect();
301+
if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) {
302+
J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0);
303+
if (whenArg.getMethodType() != null && whenArg.getMethodType().hasFlags(Static)) {
304+
return whenArg;
305+
}
306+
}
307+
}
308+
return null;
309+
}
310+
311+
private static JavaType.@Nullable Class getTypeFromInvocation(J.MethodInvocation whenArg) {
312+
J.Identifier clazz = null;
313+
// Having a fieldType implies that something is a field rather than a class itself
314+
if (whenArg.getSelect() instanceof J.Identifier && ((J.Identifier) whenArg.getSelect()).getFieldType() == null) {
315+
clazz = (J.Identifier) whenArg.getSelect();
316+
} else if (whenArg.getSelect() instanceof J.FieldAccess && ((J.FieldAccess) whenArg.getSelect()).getTarget() instanceof J.Identifier) {
317+
clazz = (J.Identifier) ((J.FieldAccess) whenArg.getSelect()).getTarget();
318+
}
319+
return clazz != null && clazz.getType() != null ? (JavaType.Class) clazz.getType() : null;
320+
}
321+
322+
private static @Nullable String findEnclosingMockStaticUseParam(Cursor cursor, JavaType invokedType) {
323+
Cursor c = cursor;
324+
while ((c = c.getParent()) != null) {
325+
if (!(c.getValue() instanceof J.Lambda)) {
326+
continue;
327+
}
328+
J.Lambda lambda = c.getValue();
329+
Cursor p = c;
330+
while ((p = p.getParent()) != null) {
331+
Object pv = p.getValue();
332+
if (pv instanceof J.MethodInvocation) {
333+
J.MethodInvocation parentInv = (J.MethodInvocation) pv;
334+
if ("use".equals(parentInv.getSimpleName()) &&
335+
parentInv.getSelect() instanceof J.MethodInvocation) {
336+
J.MethodInvocation receiver = (J.MethodInvocation) parentInv.getSelect();
337+
if (receiver.getMethodType() != null &&
338+
isMockedStaticOfType(invokedType, receiver.getMethodType().getReturnType())) {
339+
return lambdaParamName(lambda);
340+
}
341+
}
342+
break;
343+
}
344+
if (pv instanceof J) {
345+
// Lambda's nearest J ancestor isn't a method invocation — not a `.use { }` block
346+
break;
347+
}
348+
}
349+
}
350+
return null;
351+
}
352+
353+
private static String lambdaParamName(J.Lambda lambda) {
354+
List<J> params = lambda.getParameters().getParameters();
355+
if (params.isEmpty() || params.get(0) instanceof J.Empty) {
356+
return "it";
357+
}
358+
J first = params.get(0);
359+
if (first instanceof J.VariableDeclarations) {
360+
return ((J.VariableDeclarations) first).getVariables().get(0).getSimpleName();
361+
}
362+
return "it";
363+
}
364+
365+
private static String getReturnTypeName(J.MethodInvocation whenArg) {
366+
if (whenArg.getMethodType() != null) {
367+
JavaType returnType = whenArg.getMethodType().getReturnType();
368+
if (returnType instanceof JavaType.FullyQualified) {
369+
return ((JavaType.FullyQualified) returnType).getClassName();
370+
}
371+
if (returnType instanceof JavaType.Primitive) {
372+
return ((JavaType.Primitive) returnType).getKeyword();
373+
}
374+
}
375+
return "Any";
262376
}
263377

264378
private static List<J.Try.Resource> getMatchingFilteredResources(@Nullable List<J.Try.Resource> resources, JavaType className) {

0 commit comments

Comments
 (0)