11package com .carrotsearch .randomizedtesting .jupiter ;
22
3+ import java .util .Arrays ;
34import java .util .HashSet ;
5+ import java .util .LinkedHashSet ;
46import java .util .Map ;
57import java .util .concurrent .TimeUnit ;
8+ import java .util .function .Predicate ;
69import java .util .logging .Logger ;
710import java .util .stream .Collectors ;
811import org .junit .jupiter .api .extension .AfterAllCallback ;
@@ -36,7 +39,7 @@ public void beforeAll(ExtensionContext context) {
3639 return ;
3740 }
3841 if (scope (context ) == DetectThreadLeaks .Scope .SUITE ) {
39- context .getStore (EXTENSION_NAMESPACE ).put (SNAPSHOT_KEY , liveThreads ());
42+ context .getStore (EXTENSION_NAMESPACE ).put (SNAPSHOT_KEY , liveThreads (buildFilter ( context ) ));
4043 }
4144 }
4245
@@ -48,15 +51,16 @@ public void afterAll(ExtensionContext context) {
4851 checkLeaks (
4952 context .getStore (EXTENSION_NAMESPACE ),
5053 "suite [" + context .getDisplayName () + "]" ,
51- linger (context ));
54+ linger (context ),
55+ buildFilter (context ));
5256 }
5357
5458 @ Override
5559 public void beforeEach (ExtensionContext context ) {
5660 if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .TEST ) {
5761 return ;
5862 }
59- context .getStore (EXTENSION_NAMESPACE ).put (SNAPSHOT_KEY , liveThreads ());
63+ context .getStore (EXTENSION_NAMESPACE ).put (SNAPSHOT_KEY , liveThreads (buildFilter ( context ) ));
6064 }
6165
6266 @ Override
@@ -67,7 +71,8 @@ public void afterEach(ExtensionContext context) {
6771 checkLeaks (
6872 context .getStore (EXTENSION_NAMESPACE ),
6973 "test [" + context .getDisplayName () + "]" ,
70- linger (context ));
74+ linger (context ),
75+ buildFilter (context ));
7176 }
7277
7378 private static DetectThreadLeaks .Scope scope (ExtensionContext context ) {
@@ -87,6 +92,51 @@ private static int linger(ExtensionContext context) {
8792 return classAnn == null ? 0 : classAnn .millis ();
8893 }
8994
95+ /**
96+ * Collects {@link DetectThreadLeaks.ExcludeThreads} filter classes from the entire hierarchy
97+ * (method → class → superclasses) and returns a combined predicate that excludes a thread when
98+ * any filter matches it.
99+ */
100+ private static Predicate <Thread > buildFilter (ExtensionContext context ) {
101+ var filterClasses = new LinkedHashSet <Class <? extends Predicate <Thread >>>();
102+
103+ context
104+ .getTestMethod ()
105+ .ifPresent (
106+ m -> {
107+ var ann = m .getAnnotation (DetectThreadLeaks .ExcludeThreads .class );
108+ if (ann != null ) {
109+ for (var c : ann .value ()) filterClasses .add (c );
110+ }
111+ });
112+
113+ for (Class <?> cls = context .getRequiredTestClass (); cls != null ; cls = cls .getSuperclass ()) {
114+ var ann = cls .getAnnotation (DetectThreadLeaks .ExcludeThreads .class );
115+ if (ann != null ) {
116+ filterClasses .addAll (Arrays .asList (ann .value ()));
117+ }
118+ }
119+
120+ if (filterClasses .isEmpty ()) {
121+ return t -> false ;
122+ }
123+
124+ var predicates =
125+ filterClasses .stream ()
126+ .map (
127+ cls -> {
128+ try {
129+ return (Predicate <Thread >) cls .getDeclaredConstructor ().newInstance ();
130+ } catch (Exception e ) {
131+ throw new RuntimeException (
132+ "Cannot instantiate thread filter: " + cls .getName (), e );
133+ }
134+ })
135+ .toList ();
136+
137+ return t -> predicates .stream ().anyMatch (p -> p .test (t ));
138+ }
139+
90140 private static boolean isConcurrentMode (ExtensionContext context ) {
91141 // Check the concurrent flag stored in beforeAll (class-level context = parent of method ctx).
92142 return context
@@ -98,11 +148,12 @@ private static boolean isConcurrentMode(ExtensionContext context) {
98148 .orElse (false );
99149 }
100150
101- private static void checkLeaks (ExtensionContext .Store store , String description , int lingerMs ) {
151+ private static void checkLeaks (
152+ ExtensionContext .Store store , String description , int lingerMs , Predicate <Thread > filter ) {
102153 var snapshot = store .get (SNAPSHOT_KEY , HashSet .class );
103154 if (snapshot == null ) return ;
104155
105- var leaked = leakedSince (snapshot );
156+ var leaked = leakedSince (snapshot , filter );
106157 if (leaked .isEmpty ()) return ;
107158
108159 // Linger: poll until threads self-terminate or the window expires.
@@ -116,7 +167,7 @@ private static void checkLeaks(ExtensionContext.Store store, String description,
116167 Thread .currentThread ().interrupt ();
117168 break ;
118169 }
119- leaked = leakedSince (snapshot );
170+ leaked = leakedSince (snapshot , filter );
120171 }
121172 if (leaked .isEmpty ()) return ;
122173 }
@@ -147,20 +198,22 @@ private static void checkLeaks(ExtensionContext.Store store, String description,
147198 throw new AssertionError (sb .toString ());
148199 }
149200
150- private static Map <Thread , StackTraceElement []> leakedSince (HashSet <?> snapshot ) {
151- var current = liveThreadsWithStacks ();
201+ private static Map <Thread , StackTraceElement []> leakedSince (
202+ HashSet <?> snapshot , Predicate <Thread > filter ) {
203+ var current = liveThreadsWithStacks (filter );
152204 current .keySet ().removeAll (snapshot );
153205 return current ;
154206 }
155207
156- private static HashSet <Thread > liveThreads () {
157- return new HashSet <>(liveThreadsWithStacks ().keySet ());
208+ private static HashSet <Thread > liveThreads (Predicate < Thread > filter ) {
209+ return new HashSet <>(liveThreadsWithStacks (filter ).keySet ());
158210 }
159211
160- private static Map <Thread , StackTraceElement []> liveThreadsWithStacks () {
212+ private static Map <Thread , StackTraceElement []> liveThreadsWithStacks (Predicate < Thread > filter ) {
161213 return Thread .getAllStackTraces ().entrySet ().stream ()
162214 .filter (e -> e .getKey ().isAlive ())
163215 .filter (e -> !isKnownSystemThread (e .getKey ()))
216+ .filter (e -> !filter .test (e .getKey ()))
164217 .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
165218 }
166219
0 commit comments