|
22 | 22 | import org.openrewrite.java.JavaIsoVisitor; |
23 | 23 | import org.openrewrite.java.JavaParser; |
24 | 24 | import org.openrewrite.java.JavaTemplate; |
| 25 | +import org.openrewrite.java.JavaVisitor; |
25 | 26 | import org.openrewrite.java.MethodMatcher; |
26 | 27 | import org.openrewrite.java.search.UsesType; |
27 | 28 | import org.openrewrite.java.tree.*; |
|
33 | 34 | import java.util.Iterator; |
34 | 35 | import java.util.List; |
35 | 36 | import java.util.Map; |
| 37 | +import java.util.UUID; |
36 | 38 |
|
37 | 39 | import static java.util.Collections.emptyList; |
38 | 40 | import static org.openrewrite.Tree.randomId; |
@@ -86,14 +88,15 @@ public class PowerMockWhiteboxToJavaReflection extends Recipe { |
86 | 88 | final String description = "Replace `org.powermock.reflect.Whitebox` calls (`setInternalState`, " + |
87 | 89 | "`getInternalState`, `invokeMethod`, static `invokeMethod`, `invokeConstructor`, `getField` and " + |
88 | 90 | "`getMethod`, including the `Class where` overloads) with plain Java reflection using " + |
89 | | - "`java.lang.reflect.Field`, `Method` and `Constructor`. Field/constructor lookups " + |
| 91 | + "`java.lang.reflect.Field`, `Method` and `Constructor`. Calls nested in expression position " + |
| 92 | + "(returns, arguments, conditions) are hoisted into preceding statements. Field/constructor lookups " + |
90 | 93 | "use `getDeclaredField`/`getDeclaredConstructor` on the named class, which differs from PowerMock " + |
91 | 94 | "for members inherited from a superclass."; |
92 | 95 |
|
93 | 96 | /** |
94 | | - * Where a result-producing reflection call stores its result. {@code varName} is the declared |
95 | | - * variable receiving the value, or null when the result is discarded (the call was a bare |
96 | | - * statement); {@code castType} is that variable's declared type. |
| 97 | + * Where a result-producing reflection call stores its result: an existing declared variable |
| 98 | + * (Phase A) or a generated temporary local (Phase B). {@code varName} is null when the result |
| 99 | + * is discarded (the call was a bare statement). |
97 | 100 | */ |
98 | 101 | private static final class ResultSink { |
99 | 102 | private final @Nullable String varName; |
@@ -140,8 +143,8 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex |
140 | 143 | public J.Block visitBlock(J.Block block, ExecutionContext ctx) { |
141 | 144 | J.Block b = super.visitBlock(block, ctx); |
142 | 145 |
|
143 | | - // Replace Whitebox calls that are themselves a statement or a single-variable declaration |
144 | | - // initializer. Process in reverse so coordinates remain valid after each replacement. |
| 146 | + // Phase A: replace Whitebox calls that are themselves a statement or a single-variable |
| 147 | + // declaration initializer. Process in reverse so coordinates remain valid after each replacement. |
145 | 148 | List<Statement> statements = b.getStatements(); |
146 | 149 | for (int i = statements.size() - 1; i >= 0; i--) { |
147 | 150 | Statement stmt = statements.get(i); |
@@ -169,6 +172,28 @@ public J.Block visitBlock(J.Block block, ExecutionContext ctx) { |
169 | 172 | } |
170 | 173 | } |
171 | 174 |
|
| 175 | + // Phase B: hoist Whitebox calls nested in expression position (return/arguments/conditions/...), |
| 176 | + // introducing a temp local before the enclosing statement and referencing it in place. |
| 177 | + List<UUID> originalIds = new ArrayList<>(); |
| 178 | + for (Statement s : b.getStatements()) { |
| 179 | + originalIds.add(s.getId()); |
| 180 | + } |
| 181 | + for (UUID id : originalIds) { |
| 182 | + Statement s = findStatementById(b, id); |
| 183 | + while (s != null) { |
| 184 | + J.MethodInvocation nested = findNestedWhiteboxResultCall(s, ctx); |
| 185 | + if (nested == null) { |
| 186 | + break; |
| 187 | + } |
| 188 | + J.Block hoisted = hoistNestedCall(b, s, nested, ctx); |
| 189 | + if (hoisted == b) { |
| 190 | + break; // could not migrate this nested call; leave it for the comment recipe |
| 191 | + } |
| 192 | + b = hoisted; |
| 193 | + s = findStatementById(b, id); |
| 194 | + } |
| 195 | + } |
| 196 | + |
172 | 197 | return b; |
173 | 198 | } |
174 | 199 |
|
@@ -231,6 +256,129 @@ private ResultSink sinkFromStatement(Statement statement) { |
231 | 256 | return new ResultSink(null, null); |
232 | 257 | } |
233 | 258 |
|
| 259 | + private @Nullable Statement findStatementById(J.Block b, UUID id) { |
| 260 | + for (Statement s : b.getStatements()) { |
| 261 | + if (s.getId().equals(id)) { |
| 262 | + return s; |
| 263 | + } |
| 264 | + } |
| 265 | + return null; |
| 266 | + } |
| 267 | + |
| 268 | + /** |
| 269 | + * Find the first result-producing Whitebox call nested in expression position within |
| 270 | + * {@code enclosing}, excluding the call Phase A targets directly, calls in short-circuit |
| 271 | + * or ternary positions (hoisting would change evaluation), and calls inside nested |
| 272 | + * blocks/lambdas/anonymous classes (handled by their own block). |
| 273 | + */ |
| 274 | + private J.@Nullable MethodInvocation findNestedWhiteboxResultCall(Statement enclosing, ExecutionContext ctx) { |
| 275 | + J.MethodInvocation primary = extractWhiteboxInvocation(enclosing); |
| 276 | + UUID primaryId = primary != null ? primary.getId() : null; |
| 277 | + J.MethodInvocation[] holder = {null}; |
| 278 | + new JavaIsoVisitor<ExecutionContext>() { |
| 279 | + @Override |
| 280 | + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext c) { |
| 281 | + if (holder[0] == null && |
| 282 | + !method.getId().equals(primaryId) && |
| 283 | + isMigratableWhiteboxResult(method) && |
| 284 | + !inCarveOut(getCursor())) { |
| 285 | + holder[0] = method; |
| 286 | + return method; |
| 287 | + } |
| 288 | + return super.visitMethodInvocation(method, c); |
| 289 | + } |
| 290 | + |
| 291 | + @Override |
| 292 | + public J.Block visitBlock(J.Block nestedBlock, ExecutionContext c) { |
| 293 | + return nestedBlock; // nested blocks have their own scope and are handled separately |
| 294 | + } |
| 295 | + |
| 296 | + @Override |
| 297 | + public J.Lambda visitLambda(J.Lambda lambda, ExecutionContext c) { |
| 298 | + return lambda; |
| 299 | + } |
| 300 | + |
| 301 | + @Override |
| 302 | + public J.NewClass visitNewClass(J.NewClass newClass, ExecutionContext c) { |
| 303 | + return newClass; |
| 304 | + } |
| 305 | + }.visit(enclosing, ctx); |
| 306 | + return holder[0]; |
| 307 | + } |
| 308 | + |
| 309 | + // True when the call sits in a short-circuit (&&/||) right operand or a ternary branch, |
| 310 | + // where hoisting would change whether the reflective call executes. |
| 311 | + private boolean inCarveOut(Cursor callCursor) { |
| 312 | + Object child = callCursor.getValue(); |
| 313 | + for (Cursor c = callCursor.getParent(); c != null; c = c.getParent()) { |
| 314 | + Object val = c.getValue(); |
| 315 | + if (val instanceof J.Binary) { |
| 316 | + J.Binary bin = (J.Binary) val; |
| 317 | + if ((bin.getOperator() == J.Binary.Type.And || bin.getOperator() == J.Binary.Type.Or) && |
| 318 | + bin.getRight() == child) { |
| 319 | + return true; |
| 320 | + } |
| 321 | + } else if (val instanceof J.Ternary) { |
| 322 | + J.Ternary ternary = (J.Ternary) val; |
| 323 | + if (ternary.getTruePart() == child || ternary.getFalsePart() == child) { |
| 324 | + return true; |
| 325 | + } |
| 326 | + } |
| 327 | + child = val; |
| 328 | + } |
| 329 | + return false; |
| 330 | + } |
| 331 | + |
| 332 | + private J.Block hoistNestedCall(J.Block b, Statement enclosing, J.MethodInvocation nested, ExecutionContext ctx) { |
| 333 | + Cursor blockCursor = new Cursor(getCursor().getParentOrThrow(), b); |
| 334 | + JavaType.Method resolved = resolveFor(nested); |
| 335 | + String tempName = resultTempName(nested, blockCursor); |
| 336 | + ResultSink sink = new ResultSink(tempName, getCastType(nested.getType())); |
| 337 | + String template = buildReplacementTemplate(nested, sink, blockCursor, resolved); |
| 338 | + if (template == null) { |
| 339 | + return b; |
| 340 | + } |
| 341 | + J.Block rebuilt = JavaTemplate.builder(template) |
| 342 | + .contextSensitive() |
| 343 | + .javaParser(JavaParser.fromJavaVersion()) |
| 344 | + .imports(templateImports(resolved).toArray(new String[0])) |
| 345 | + .build() |
| 346 | + .apply(blockCursor, enclosing.getCoordinates().before(), buildTemplateArgs(nested, resolved)); |
| 347 | + |
| 348 | + J.Identifier ref = new J.Identifier(randomId(), nested.getPrefix(), Markers.EMPTY, emptyList(), |
| 349 | + tempName, nested.getType(), null); |
| 350 | + rebuilt = (J.Block) new JavaVisitor<ExecutionContext>() { |
| 351 | + @Override |
| 352 | + public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext c) { |
| 353 | + if (method.getId().equals(nested.getId())) { |
| 354 | + return ref; |
| 355 | + } |
| 356 | + return super.visitMethodInvocation(method, c); |
| 357 | + } |
| 358 | + }.visitNonNull(rebuilt, ctx); |
| 359 | + |
| 360 | + recordReplacement(nested, resolved); |
| 361 | + return rebuilt; |
| 362 | + } |
| 363 | + |
| 364 | + private String resultTempName(J.MethodInvocation mi, Cursor scope) { |
| 365 | + String base; |
| 366 | + if (GET_FIELD.matches(mi)) { |
| 367 | + base = "reflectField"; |
| 368 | + } else if (GET_METHOD.matches(mi)) { |
| 369 | + base = "reflectMethod"; |
| 370 | + } else if (INVOKE_CONSTRUCTOR_ARGS.matches(mi) || INVOKE_CONSTRUCTOR_EXPLICIT.matches(mi)) { |
| 371 | + base = "reflectInstance"; |
| 372 | + } else if (INVOKE_METHOD.matches(mi) || INVOKE_METHOD_STATIC.matches(mi)) { |
| 373 | + String literal = extractStringLiteral(mi.getArguments().get(1)); |
| 374 | + base = literal != null ? literal + "Result" : "reflectResult"; |
| 375 | + } else { |
| 376 | + String literal = extractStringLiteral(mi.getArguments().get(1)); |
| 377 | + base = literal != null ? literal + "Value" : "reflectValue"; |
| 378 | + } |
| 379 | + return generateVariableName(base, scope, INCREMENT_NUMBER); |
| 380 | + } |
| 381 | + |
234 | 382 | private @Nullable String buildReplacementTemplate(J.MethodInvocation mi, ResultSink sink, |
235 | 383 | Cursor scope, JavaType.@Nullable Method resolvedMethod) { |
236 | 384 | List<Expression> args = mi.getArguments(); |
|
0 commit comments