@@ -23,7 +23,7 @@ public class DetectThreadLeaksExtension
2323 private static final Logger LOGGER = Logger .getLogger (DetectThreadLeaksExtension .class .getName ());
2424 private static final ExtensionContext .Namespace EXTENSION_NAMESPACE =
2525 ExtensionContext .Namespace .create (DetectThreadLeaksExtension .class );
26- private static final String SNAPSHOT_KEY = "snapshot" ;
26+ private static final String THREAD_SNAPSHOT_KEY = "snapshot" ;
2727 private static final String CONCURRENT_KEY = "concurrent" ;
2828 private static final String UNCAUGHT_EXCEPTION_HANDLER_KEY = "uncaught-exception-handler" ;
2929
@@ -32,6 +32,10 @@ public class DetectThreadLeaksExtension
3232
3333 @ Override
3434 public void beforeAll (ExtensionContext context ) {
35+ if (scope (context ) == DetectThreadLeaks .Scope .NONE ) {
36+ return ;
37+ }
38+
3539 if (context .getExecutionMode () != ExecutionMode .SAME_THREAD ) {
3640 LOGGER .warning (
3741 "Thread leak detection is disabled: tests in ["
@@ -40,23 +44,36 @@ public void beforeAll(ExtensionContext context) {
4044 context .getStore (EXTENSION_NAMESPACE ).put (CONCURRENT_KEY , Boolean .TRUE );
4145 return ;
4246 }
43- if (scope (context ) == DetectThreadLeaks .Scope .SUITE ) {
44- var store = context .getStore (EXTENSION_NAMESPACE );
45- var filter = buildFilter (context );
46- store .put (UNCAUGHT_EXCEPTION_HANDLER_KEY , installUncaughtExceptionHandler ());
47- store .put (SNAPSHOT_KEY , liveThreads (filter ));
47+
48+ var store = context .getStore (EXTENSION_NAMESPACE );
49+ var filter = buildFilter (context );
50+ store .put (UNCAUGHT_EXCEPTION_HANDLER_KEY , installUncaughtExceptionHandler ());
51+ store .put (THREAD_SNAPSHOT_KEY , liveThreads (filter ));
52+ }
53+
54+ @ Override
55+ public void beforeEach (ExtensionContext context ) {
56+ if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .TEST ) {
57+ return ;
4858 }
59+
60+ var store = context .getStore (EXTENSION_NAMESPACE );
61+ var filter = buildFilter (context );
62+ store .put (THREAD_SNAPSHOT_KEY , liveThreads (filter ));
4963 }
5064
5165 @ Override
52- public void afterAll (ExtensionContext context ) {
53- if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .SUITE ) return ;
66+ public void afterEach (ExtensionContext context ) {
67+ if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .TEST ) {
68+ return ;
69+ }
70+
5471 var store = context .getStore (EXTENSION_NAMESPACE );
5572 var handler = store .get (UNCAUGHT_EXCEPTION_HANDLER_KEY , UncaughtExceptionsHandler .class );
5673 try {
5774 checkLeaks (
5875 store ,
59- "suite [" + context .getDisplayName () + "]" ,
76+ "test [" + context .getDisplayName () + "]" ,
6077 linger (context ),
6178 buildFilter (context ),
6279 handler );
@@ -66,23 +83,17 @@ public void afterAll(ExtensionContext context) {
6683 }
6784
6885 @ Override
69- public void beforeEach (ExtensionContext context ) {
70- if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .TEST ) return ;
71- var store = context .getStore (EXTENSION_NAMESPACE );
72- var filter = buildFilter (context );
73- store .put (UNCAUGHT_EXCEPTION_HANDLER_KEY , installUncaughtExceptionHandler ());
74- store .put (SNAPSHOT_KEY , liveThreads (filter ));
75- }
86+ public void afterAll (ExtensionContext context ) {
87+ if (isConcurrentMode (context ) || scope (context ) == DetectThreadLeaks .Scope .NONE ) {
88+ return ;
89+ }
7690
77- @ Override
78- public void afterEach (ExtensionContext context ) {
79- if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .TEST ) return ;
8091 var store = context .getStore (EXTENSION_NAMESPACE );
8192 var handler = store .get (UNCAUGHT_EXCEPTION_HANDLER_KEY , UncaughtExceptionsHandler .class );
8293 try {
8394 checkLeaks (
8495 store ,
85- "test [" + context .getDisplayName () + "]" ,
96+ "suite [" + context .getDisplayName () + "]" ,
8697 linger (context ),
8798 buildFilter (context ),
8899 handler );
@@ -119,41 +130,26 @@ private static int linger(ExtensionContext context) {
119130 * any filter matches it.
120131 */
121132 private static Predicate <Thread > buildFilter (ExtensionContext context ) {
122- var filterClasses = new LinkedHashSet <Class <? extends Predicate <Thread >>>();
123-
124- context
125- .getTestMethod ()
126- .ifPresent (
127- m -> {
128- var ann = m .getAnnotation (DetectThreadLeaks .ExcludeThreads .class );
129- if (ann != null ) {
130- for (var c : ann .value ()) filterClasses .add (c );
131- }
132- });
133+ var filterClasses = new LinkedHashSet <Predicate <Thread >>();
133134
134135 for (Class <?> cls = context .getRequiredTestClass (); cls != null ; cls = cls .getSuperclass ()) {
135136 var ann = cls .getAnnotation (DetectThreadLeaks .ExcludeThreads .class );
136137 if (ann != null ) {
137- for (var c : ann .value ()) filterClasses .add (c );
138+ for (var c : ann .value ()) {
139+ try {
140+ filterClasses .add (c .getDeclaredConstructor ().newInstance ());
141+ } catch (ReflectiveOperationException e ) {
142+ throw new RuntimeException ("Cannot instantiate thread filter: " + cls .getName (), e );
143+ }
144+ }
138145 }
139146 }
140147
141- if (filterClasses .isEmpty ()) return t -> false ;
142-
143- var predicates =
144- filterClasses .stream ()
145- .map (
146- cls -> {
147- try {
148- return (Predicate <Thread >) cls .getDeclaredConstructor ().newInstance ();
149- } catch (ReflectiveOperationException e ) {
150- throw new RuntimeException (
151- "Cannot instantiate thread filter: " + cls .getName (), e );
152- }
153- })
154- .toList ();
155-
156- return t -> predicates .stream ().anyMatch (p -> p .test (t ));
148+ if (filterClasses .isEmpty ()) {
149+ return t -> false ;
150+ }
151+
152+ return t -> filterClasses .stream ().anyMatch (p -> p .test (t ));
157153 }
158154
159155 private static boolean isConcurrentMode (ExtensionContext context ) {
@@ -172,7 +168,7 @@ private static void checkLeaks(
172168 int lingerMs ,
173169 Predicate <Thread > filter ,
174170 UncaughtExceptionsHandler handler ) {
175- var snapshot = store .get (SNAPSHOT_KEY , HashSet .class );
171+ var snapshot = store .get (THREAD_SNAPSHOT_KEY , HashSet .class );
176172 AssertionError leakError = null ;
177173
178174 if (snapshot != null ) {
@@ -196,7 +192,10 @@ private static void checkLeaks(
196192 if (!leaked .isEmpty ()) {
197193 // Suppress uncaught exception reporting during the interrupt/join phase to avoid
198194 // capturing expected InterruptedException-related exceptions from cleaned-up threads.
199- if (handler != null ) handler .stopReporting ();
195+ if (handler != null ) {
196+ handler .stopReporting ();
197+ }
198+
200199 try {
201200 leaked .keySet ().forEach (Thread ::interrupt );
202201 long joinDeadline = System .nanoTime () + TimeUnit .MILLISECONDS .toNanos (INTERRUPT_JOIN_MS );
@@ -211,7 +210,9 @@ private static void checkLeaks(
211210 }
212211 }
213212 } finally {
214- if (handler != null ) handler .resumeReporting ();
213+ if (handler != null ) {
214+ handler .resumeReporting ();
215+ }
215216 }
216217
217218 var sb = new StringBuilder (leaked .size () + " thread(s) leaked from " + description + ":" );
0 commit comments