|
22 | 22 | import org.openrewrite.java.*; |
23 | 23 | import org.openrewrite.java.search.UsesMethod; |
24 | 24 | import org.openrewrite.java.tree.*; |
| 25 | +import org.openrewrite.kotlin.KotlinIsoVisitor; |
| 26 | +import org.openrewrite.kotlin.KotlinParser; |
| 27 | +import org.openrewrite.kotlin.KotlinTemplate; |
| 28 | +import org.openrewrite.kotlin.tree.K; |
25 | 29 |
|
26 | 30 | import java.util.*; |
27 | 31 | import java.util.concurrent.atomic.AtomicBoolean; |
@@ -60,7 +64,23 @@ public class MockitoWhenOnStaticToMockStatic extends Recipe { |
60 | 64 |
|
61 | 65 | @Override |
62 | 66 | public TreeVisitor<?, ExecutionContext> getVisitor() { |
63 | | - return Preconditions.check(new UsesMethod<>(MOCKITO_WHEN), new JavaIsoVisitor<ExecutionContext>() { |
| 67 | + return Preconditions.check(new UsesMethod<>(MOCKITO_WHEN), new TreeVisitor<Tree, ExecutionContext>() { |
| 68 | + @Override |
| 69 | + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { |
| 70 | + stopAfterPreVisit(); |
| 71 | + if (tree instanceof J.CompilationUnit) { |
| 72 | + return javaVisitor().visit(tree, ctx); |
| 73 | + } |
| 74 | + if (tree instanceof K.CompilationUnit) { |
| 75 | + return kotlinVisitor().visit(tree, ctx); |
| 76 | + } |
| 77 | + return tree; |
| 78 | + } |
| 79 | + }); |
| 80 | + } |
| 81 | + |
| 82 | + private JavaIsoVisitor<ExecutionContext> javaVisitor() { |
| 83 | + return new JavaIsoVisitor<ExecutionContext>() { |
64 | 84 | @Override |
65 | 85 | public J.Block visitBlock(J.Block block, ExecutionContext ctx) { |
66 | 86 | J.MethodDeclaration containingMethod = getCursor().firstEnclosing(J.MethodDeclaration.class); |
@@ -116,30 +136,6 @@ private List<Statement> maybeWrapStatementsInTryWithResourcesMockedStatic(J.Bloc |
116 | 136 | }); |
117 | 137 | } |
118 | 138 |
|
119 | | - private J.@Nullable MethodInvocation getWhenArg(Statement statement) { |
120 | | - if (statement instanceof J.MethodInvocation && MOCKITO_WHEN.matches(((J.MethodInvocation) statement).getSelect())) { |
121 | | - J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) statement).getSelect(); |
122 | | - if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) { |
123 | | - J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0); |
124 | | - if (whenArg.getMethodType() != null && whenArg.getMethodType().hasFlags(Static)) { |
125 | | - return whenArg; |
126 | | - } |
127 | | - } |
128 | | - } |
129 | | - return null; |
130 | | - } |
131 | | - |
132 | | - private JavaType.@Nullable Class getTypeFromInvocation(J.MethodInvocation whenArg) { |
133 | | - J.Identifier clazz = null; |
134 | | - // Having a fieldType implies that something is a field rather than a class itself |
135 | | - if (whenArg.getSelect() instanceof J.Identifier && ((J.Identifier) whenArg.getSelect()).getFieldType() == null) { |
136 | | - clazz = (J.Identifier) whenArg.getSelect(); |
137 | | - } else if (whenArg.getSelect() instanceof J.FieldAccess && ((J.FieldAccess) whenArg.getSelect()).getTarget() instanceof J.Identifier) { |
138 | | - clazz = (J.Identifier) ((J.FieldAccess) whenArg.getSelect()).getTarget(); |
139 | | - } |
140 | | - return clazz != null && clazz.getType() != null ? (JavaType.Class) clazz.getType() : null; |
141 | | - } |
142 | | - |
143 | 139 | private J.Try tryWithMockedStatic(J.Block block, List<Statement> statements, Integer index, |
144 | 140 | J.MethodInvocation statement, String className, J.MethodInvocation whenArg, ExecutionContext ctx) { |
145 | 141 | String variableName = generateVariableName("mock" + className + ++varCounter, updateCursor(block), INCREMENT_NUMBER); |
@@ -258,7 +254,125 @@ private JavaTemplate javaTemplateMockStatic(String code, ExecutionContext ctx) { |
258 | 254 | .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "mockito-core-5")) |
259 | 255 | .build(); |
260 | 256 | } |
261 | | - }); |
| 257 | + }; |
| 258 | + } |
| 259 | + |
| 260 | + private KotlinIsoVisitor<ExecutionContext> kotlinVisitor() { |
| 261 | + return new KotlinIsoVisitor<ExecutionContext>() { |
| 262 | + @Override |
| 263 | + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { |
| 264 | + J.MethodInvocation m = super.visitMethodInvocation(method, ctx); |
| 265 | + J.MethodInvocation whenArg = getWhenArg(m); |
| 266 | + if (whenArg == null) { |
| 267 | + return m; |
| 268 | + } |
| 269 | + JavaType.Class invokedType = getTypeFromInvocation(whenArg); |
| 270 | + if (invokedType == null) { |
| 271 | + return m; |
| 272 | + } |
| 273 | + |
| 274 | + String paramName = findEnclosingMockStaticUseParam(getCursor(), invokedType); |
| 275 | + if (paramName == null) { |
| 276 | + J.Identifier classProperty = findMockedStaticVariable(getCursor(), invokedType); |
| 277 | + if (classProperty == null) { |
| 278 | + return m; |
| 279 | + } |
| 280 | + paramName = classProperty.getSimpleName(); |
| 281 | + } |
| 282 | + |
| 283 | + String returnTypeName = getReturnTypeName(whenArg); |
| 284 | + Expression thenReturnArg = m.getArguments().get(0); |
| 285 | + J.MethodInvocation rewritten = KotlinTemplate.builder(String.format( |
| 286 | + "%s.`when`<%s> { #{any()} }.thenReturn(#{any()})", |
| 287 | + paramName, returnTypeName)) |
| 288 | + .imports("org.mockito.MockedStatic") |
| 289 | + .parser(KotlinParser.builder().classpathFromResources(ctx, "mockito-core-5")) |
| 290 | + .build() |
| 291 | + .apply(getCursor(), m.getCoordinates().replace(), whenArg, thenReturnArg); |
| 292 | + maybeRemoveImport("org.mockito.Mockito.when"); |
| 293 | + return rewritten; |
| 294 | + } |
| 295 | + }; |
| 296 | + } |
| 297 | + |
| 298 | + private static J.@Nullable MethodInvocation getWhenArg(Statement statement) { |
| 299 | + if (statement instanceof J.MethodInvocation && MOCKITO_WHEN.matches(((J.MethodInvocation) statement).getSelect())) { |
| 300 | + J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) statement).getSelect(); |
| 301 | + if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) { |
| 302 | + J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0); |
| 303 | + if (whenArg.getMethodType() != null && whenArg.getMethodType().hasFlags(Static)) { |
| 304 | + return whenArg; |
| 305 | + } |
| 306 | + } |
| 307 | + } |
| 308 | + return null; |
| 309 | + } |
| 310 | + |
| 311 | + private static JavaType.@Nullable Class getTypeFromInvocation(J.MethodInvocation whenArg) { |
| 312 | + J.Identifier clazz = null; |
| 313 | + // Having a fieldType implies that something is a field rather than a class itself |
| 314 | + if (whenArg.getSelect() instanceof J.Identifier && ((J.Identifier) whenArg.getSelect()).getFieldType() == null) { |
| 315 | + clazz = (J.Identifier) whenArg.getSelect(); |
| 316 | + } else if (whenArg.getSelect() instanceof J.FieldAccess && ((J.FieldAccess) whenArg.getSelect()).getTarget() instanceof J.Identifier) { |
| 317 | + clazz = (J.Identifier) ((J.FieldAccess) whenArg.getSelect()).getTarget(); |
| 318 | + } |
| 319 | + return clazz != null && clazz.getType() != null ? (JavaType.Class) clazz.getType() : null; |
| 320 | + } |
| 321 | + |
| 322 | + private static @Nullable String findEnclosingMockStaticUseParam(Cursor cursor, JavaType invokedType) { |
| 323 | + Cursor c = cursor; |
| 324 | + while ((c = c.getParent()) != null) { |
| 325 | + if (!(c.getValue() instanceof J.Lambda)) { |
| 326 | + continue; |
| 327 | + } |
| 328 | + J.Lambda lambda = c.getValue(); |
| 329 | + Cursor p = c; |
| 330 | + while ((p = p.getParent()) != null) { |
| 331 | + Object pv = p.getValue(); |
| 332 | + if (pv instanceof J.MethodInvocation) { |
| 333 | + J.MethodInvocation parentInv = (J.MethodInvocation) pv; |
| 334 | + if ("use".equals(parentInv.getSimpleName()) && |
| 335 | + parentInv.getSelect() instanceof J.MethodInvocation) { |
| 336 | + J.MethodInvocation receiver = (J.MethodInvocation) parentInv.getSelect(); |
| 337 | + if (receiver.getMethodType() != null && |
| 338 | + isMockedStaticOfType(invokedType, receiver.getMethodType().getReturnType())) { |
| 339 | + return lambdaParamName(lambda); |
| 340 | + } |
| 341 | + } |
| 342 | + break; |
| 343 | + } |
| 344 | + if (pv instanceof J) { |
| 345 | + // Lambda's nearest J ancestor isn't a method invocation — not a `.use { }` block |
| 346 | + break; |
| 347 | + } |
| 348 | + } |
| 349 | + } |
| 350 | + return null; |
| 351 | + } |
| 352 | + |
| 353 | + private static String lambdaParamName(J.Lambda lambda) { |
| 354 | + List<J> params = lambda.getParameters().getParameters(); |
| 355 | + if (params.isEmpty() || params.get(0) instanceof J.Empty) { |
| 356 | + return "it"; |
| 357 | + } |
| 358 | + J first = params.get(0); |
| 359 | + if (first instanceof J.VariableDeclarations) { |
| 360 | + return ((J.VariableDeclarations) first).getVariables().get(0).getSimpleName(); |
| 361 | + } |
| 362 | + return "it"; |
| 363 | + } |
| 364 | + |
| 365 | + private static String getReturnTypeName(J.MethodInvocation whenArg) { |
| 366 | + if (whenArg.getMethodType() != null) { |
| 367 | + JavaType returnType = whenArg.getMethodType().getReturnType(); |
| 368 | + if (returnType instanceof JavaType.FullyQualified) { |
| 369 | + return ((JavaType.FullyQualified) returnType).getClassName(); |
| 370 | + } |
| 371 | + if (returnType instanceof JavaType.Primitive) { |
| 372 | + return ((JavaType.Primitive) returnType).getKeyword(); |
| 373 | + } |
| 374 | + } |
| 375 | + return "Any"; |
262 | 376 | } |
263 | 377 |
|
264 | 378 | private static List<J.Try.Resource> getMatchingFilteredResources(@Nullable List<J.Try.Resource> resources, JavaType className) { |
|
0 commit comments