Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@
import java.util.stream.Stream;
import javax.inject.Inject;
import javax.lang.model.element.ElementKind;
import org.jspecify.annotations.Nullable;

/** A {@link BugChecker}; see the associated {@link BugPattern} annotation for details. */
@BugPattern(summary = "Minimize the amount of logic in assertThrows", severity = WARNING)
public class AssertThrowsMinimizer extends BugChecker implements MethodTreeMatcher {

private static final Matcher<ExpressionTree> MATCHER =
staticMethod().onClass("org.junit.Assert").named("assertThrows");
anyOf(staticMethod().onClass("org.junit.Assert").named("assertThrows"));

private final ConstantExpressions constantExpressions;
private final boolean useVarType;
Expand Down Expand Up @@ -119,11 +120,10 @@ private Optional<AssertThrows> matchMethodInvocation(
if (!(tree.getArguments().getLast() instanceof LambdaExpressionTree lambdaExpressionTree)) {
return Optional.empty();
}
Type firstArgumentType = getType(tree.getArguments().get(0));
if (firstArgumentType.getTypeArguments().isEmpty()) {
Type exceptionType = getExceptionType(tree, state);
if (exceptionType == null) {
return Optional.empty();
}
Type exceptionType = firstArgumentType.getTypeArguments().get(0);
MethodInvocationTree runnable;
switch (lambdaExpressionTree.getBody()) {
case BlockTree blockTree -> {
Expand Down Expand Up @@ -280,7 +280,7 @@ private boolean needsHoisting(ExpressionTree tree, Type exceptionType, VisitorSt
// constant fields and string concatenation.
return false;
}
if (isCheckedException(exceptionType, state) && !throwsSubtypeOf(tree, exceptionType, state)) {
if (isCheckedExceptionType(exceptionType, state) && !maybeThrows(tree, exceptionType, state)) {
return false;
}
boolean needsHoisting =
Expand Down Expand Up @@ -327,18 +327,22 @@ private boolean newClassTreeNeedsHoisting(NewClassTree tree) {
return !tree.getClassBody().getMembers().stream().allMatch(m -> m instanceof MethodTree);
}

private static boolean throwsSubtypeOf(
ExpressionTree tree, Type exceptionType, VisitorState state) {
Types types = state.getTypes();
return types.isSubtype(state.getSymtab().runtimeExceptionType, exceptionType)
|| getThrownExceptions(tree, state).stream()
.anyMatch(t -> isCheckedException(t, state) && types.isSubtype(t, exceptionType));
private static @Nullable Type getExceptionType(MethodInvocationTree tree, VisitorState state) {
Type firstArgumentType = getType(tree.getArguments().get(0));
if (firstArgumentType.getTypeArguments().isEmpty()) {
return null;
}
return firstArgumentType.getTypeArguments().get(0);
}

private static boolean isCheckedException(Type exception, VisitorState state) {
private static boolean maybeThrows(ExpressionTree tree, Type exceptionType, VisitorState state) {
Types types = state.getTypes();
return !types.isSubtype(exception, state.getSymtab().runtimeExceptionType)
&& !types.isSubtype(exception, state.getSymtab().errorType);
if (types.isSubtype(state.getSymtab().runtimeExceptionType, exceptionType)) {
// The exception is Exception or Throwable, assume anything could throw it
return true;
}
return getThrownExceptions(tree, state).stream()
.anyMatch(t -> types.isAssignable(exceptionType, t));
}

private static final Matcher<ExpressionTree> KNOWN_SAFE =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ interface Builder {
Builder setBar(Bar bar);

Builder setBar(Supplier<Bar> bar);

Foo build();
}
}
Expand Down Expand Up @@ -514,7 +515,7 @@ public static Object getThing() throws Exception {
throw new Exception();
}

public static Object getThingUnchecked() {
public static Object getThingUnchecked() {
throw new RuntimeException();
}

Expand Down Expand Up @@ -679,10 +680,11 @@ public void lambda() {
.addInputLines(
"Test.java",
"""
import static org.junit.Assert.assertThrows;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import static org.junit.Assert.assertThrows;

class Test {
void f() {
Expand All @@ -706,10 +708,11 @@ public void methodReference() {
.addInputLines(
"Test.java",
"""
import static org.junit.Assert.assertThrows;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import static org.junit.Assert.assertThrows;

class Test {
void f() {
Expand All @@ -725,10 +728,11 @@ Bar m() {
.addOutputLines(
"Test.java",
"""
import static org.junit.Assert.assertThrows;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import static org.junit.Assert.assertThrows;

class Test {
void f() {
Expand Down Expand Up @@ -797,10 +801,11 @@ public void anonymousClass() {
.addInputLines(
"Test.java",
"""
import static org.junit.Assert.assertThrows;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import static org.junit.Assert.assertThrows;

class Test {
void f() {
Expand Down Expand Up @@ -867,10 +872,11 @@ abstract class InstanceBarSupplier extends BarSupplier {
.addOutputLines(
"Test.java",
"""
import static org.junit.Assert.assertThrows;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import static org.junit.Assert.assertThrows;

class Test {
void f() {
Expand Down Expand Up @@ -984,10 +990,11 @@ public void varArgs() {
.addInputLines(
"Test.java",
"""
import static org.junit.Assert.assertThrows;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import static org.junit.Assert.assertThrows;

abstract class Test {
void f() {
Expand All @@ -1002,10 +1009,11 @@ void f() {
.addOutputLines(
"Test.java",
"""
import static org.junit.Assert.assertThrows;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import static org.junit.Assert.assertThrows;

abstract class Test {
void f() {
Expand All @@ -1028,10 +1036,11 @@ public void cast() {
.addInputLines(
"Test.java",
"""
import static org.junit.Assert.assertThrows;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import static org.junit.Assert.assertThrows;

abstract class Test {
void f(String s, Object o) {
Expand All @@ -1055,10 +1064,11 @@ void f(String s, Object o) {
.addOutputLines(
"Test.java",
"""
import static org.junit.Assert.assertThrows;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import static org.junit.Assert.assertThrows;

abstract class Test {
void f(String s, Object o) {
Expand Down
Loading