From 74ba4216abb08eeabea96ea6ea237c7d384eb1c1 Mon Sep 17 00:00:00 2001 From: Liam Miller-Cushon Date: Thu, 28 May 2026 04:26:20 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 922690277 --- .../bugpatterns/AssertThrowsMinimizer.java | 32 +++++++++++-------- .../AssertThrowsMinimizerTest.java | 30 +++++++++++------ 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizer.java b/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizer.java index 69bb2ee0777..110abd89da6 100644 --- a/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizer.java +++ b/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizer.java @@ -69,13 +69,14 @@ import java.util.stream.Stream; import javax.inject.Inject; import javax.lang.model.element.ElementKind; +import org.jspecify.annotations.Nullable; /** A {@link BugChecker}; see the associated {@link BugPattern} annotation for details. */ @BugPattern(summary = "Minimize the amount of logic in assertThrows", severity = WARNING) public class AssertThrowsMinimizer extends BugChecker implements MethodTreeMatcher { private static final Matcher MATCHER = - staticMethod().onClass("org.junit.Assert").named("assertThrows"); + anyOf(staticMethod().onClass("org.junit.Assert").named("assertThrows")); private final ConstantExpressions constantExpressions; private final boolean useVarType; @@ -119,11 +120,10 @@ private Optional matchMethodInvocation( if (!(tree.getArguments().getLast() instanceof LambdaExpressionTree lambdaExpressionTree)) { return Optional.empty(); } - Type firstArgumentType = getType(tree.getArguments().get(0)); - if (firstArgumentType.getTypeArguments().isEmpty()) { + Type exceptionType = getExceptionType(tree, state); + if (exceptionType == null) { return Optional.empty(); } - Type exceptionType = firstArgumentType.getTypeArguments().get(0); MethodInvocationTree runnable; switch (lambdaExpressionTree.getBody()) { case BlockTree blockTree -> { @@ -280,7 +280,7 @@ private boolean needsHoisting(ExpressionTree tree, Type exceptionType, VisitorSt // constant fields and string concatenation. return false; } - if (isCheckedException(exceptionType, state) && !throwsSubtypeOf(tree, exceptionType, state)) { + if (isCheckedExceptionType(exceptionType, state) && !maybeThrows(tree, exceptionType, state)) { return false; } boolean needsHoisting = @@ -327,18 +327,22 @@ private boolean newClassTreeNeedsHoisting(NewClassTree tree) { return !tree.getClassBody().getMembers().stream().allMatch(m -> m instanceof MethodTree); } - private static boolean throwsSubtypeOf( - ExpressionTree tree, Type exceptionType, VisitorState state) { - Types types = state.getTypes(); - return types.isSubtype(state.getSymtab().runtimeExceptionType, exceptionType) - || getThrownExceptions(tree, state).stream() - .anyMatch(t -> isCheckedException(t, state) && types.isSubtype(t, exceptionType)); + private static @Nullable Type getExceptionType(MethodInvocationTree tree, VisitorState state) { + Type firstArgumentType = getType(tree.getArguments().get(0)); + if (firstArgumentType.getTypeArguments().isEmpty()) { + return null; + } + return firstArgumentType.getTypeArguments().get(0); } - private static boolean isCheckedException(Type exception, VisitorState state) { + private static boolean maybeThrows(ExpressionTree tree, Type exceptionType, VisitorState state) { Types types = state.getTypes(); - return !types.isSubtype(exception, state.getSymtab().runtimeExceptionType) - && !types.isSubtype(exception, state.getSymtab().errorType); + if (types.isSubtype(state.getSymtab().runtimeExceptionType, exceptionType)) { + // The exception is Exception or Throwable, assume anything could throw it + return true; + } + return getThrownExceptions(tree, state).stream() + .anyMatch(t -> types.isAssignable(exceptionType, t)); } private static final Matcher KNOWN_SAFE = diff --git a/core/src/test/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizerTest.java b/core/src/test/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizerTest.java index ac12dcb0657..34533534c06 100644 --- a/core/src/test/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizerTest.java +++ b/core/src/test/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizerTest.java @@ -41,6 +41,7 @@ interface Builder { Builder setBar(Bar bar); Builder setBar(Supplier bar); + Foo build(); } } @@ -514,7 +515,7 @@ public static Object getThing() throws Exception { throw new Exception(); } - public static Object getThingUnchecked() { + public static Object getThingUnchecked() { throw new RuntimeException(); } @@ -679,10 +680,11 @@ public void lambda() { .addInputLines( "Test.java", """ + import static org.junit.Assert.assertThrows; + import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; - import static org.junit.Assert.assertThrows; class Test { void f() { @@ -706,10 +708,11 @@ public void methodReference() { .addInputLines( "Test.java", """ + import static org.junit.Assert.assertThrows; + import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; - import static org.junit.Assert.assertThrows; class Test { void f() { @@ -725,10 +728,11 @@ Bar m() { .addOutputLines( "Test.java", """ + import static org.junit.Assert.assertThrows; + import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; - import static org.junit.Assert.assertThrows; class Test { void f() { @@ -797,10 +801,11 @@ public void anonymousClass() { .addInputLines( "Test.java", """ + import static org.junit.Assert.assertThrows; + import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; - import static org.junit.Assert.assertThrows; class Test { void f() { @@ -867,10 +872,11 @@ abstract class InstanceBarSupplier extends BarSupplier { .addOutputLines( "Test.java", """ + import static org.junit.Assert.assertThrows; + import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; - import static org.junit.Assert.assertThrows; class Test { void f() { @@ -984,10 +990,11 @@ public void varArgs() { .addInputLines( "Test.java", """ + import static org.junit.Assert.assertThrows; + import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; - import static org.junit.Assert.assertThrows; abstract class Test { void f() { @@ -1002,10 +1009,11 @@ void f() { .addOutputLines( "Test.java", """ + import static org.junit.Assert.assertThrows; + import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; - import static org.junit.Assert.assertThrows; abstract class Test { void f() { @@ -1028,10 +1036,11 @@ public void cast() { .addInputLines( "Test.java", """ + import static org.junit.Assert.assertThrows; + import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; - import static org.junit.Assert.assertThrows; abstract class Test { void f(String s, Object o) { @@ -1055,10 +1064,11 @@ void f(String s, Object o) { .addOutputLines( "Test.java", """ + import static org.junit.Assert.assertThrows; + import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; - import static org.junit.Assert.assertThrows; abstract class Test { void f(String s, Object o) {