Skip to content

Commit b9b0f6b

Browse files
brucearctortconley1428Sushisource
authored
Fix ContextClassLoader propagation for poll tasks (#2808)
When PollerBehaviorAutoscaling is enabled, poll requests complete asynchronously on the common ForkJoinPool. Subsequent tasks executed by the PollTaskExecutor inherit the ContextClassLoader from these ForkJoinPool threads (AppClassLoader) rather than the thread that started the WorkerFactory (e.g., Spring Boot's LaunchedClassLoader). This fix captures the original ContextClassLoader within ExecutorThreadFactory and PollTaskExecutor, explicitly setting it on all newly created worker threads and virtual threads. Fixes #2795 Co-authored-by: tconley1428 <tconley1428@gmail.com> Co-authored-by: Spencer Judge <spencer@temporal.io>
1 parent da5e8a1 commit b9b0f6b

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

temporal-sdk/src/main/java/io/temporal/internal/worker/ExecutorThreadFactory.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,23 @@ class ExecutorThreadFactory implements ThreadFactory {
77
private final String threadPrefix;
88

99
private final Thread.UncaughtExceptionHandler uncaughtExceptionHandler;
10+
private final ClassLoader contextClassLoader;
1011
private final AtomicInteger threadIndex = new AtomicInteger();
1112

1213
public ExecutorThreadFactory(String threadPrefix, Thread.UncaughtExceptionHandler eh) {
1314
this.threadPrefix = threadPrefix;
1415
this.uncaughtExceptionHandler = eh;
16+
this.contextClassLoader = Thread.currentThread().getContextClassLoader();
1517
}
1618

1719
@Override
1820
public Thread newThread(Runnable r) {
1921
Thread result = new Thread(r);
2022
result.setName(threadPrefix + ": " + threadIndex.incrementAndGet());
2123
result.setUncaughtExceptionHandler(uncaughtExceptionHandler);
24+
if (contextClassLoader != null) {
25+
result.setContextClassLoader(contextClassLoader);
26+
}
2227
return result;
2328
}
2429
}

temporal-sdk/src/main/java/io/temporal/internal/worker/PollTaskExecutor.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,15 @@ public interface TaskHandler<TT> {
4747
} else if (useVirtualThreads) {
4848
// If virtual threads are enabled, we use a virtual thread executor.
4949
AtomicInteger threadIndex = new AtomicInteger();
50+
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
5051
this.taskExecutor =
5152
VirtualThreadDelegate.newVirtualThreadExecutor(
5253
(t) -> {
5354
t.setName(this.pollThreadNamePrefix + ": " + threadIndex.incrementAndGet());
5455
t.setUncaughtExceptionHandler(pollerOptions.getUncaughtExceptionHandler());
56+
if (contextClassLoader != null) {
57+
t.setContextClassLoader(contextClassLoader);
58+
}
5559
});
5660
} else {
5761
ThreadPoolExecutor threadPoolTaskExecutor =
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package io.temporal.internal.worker;
2+
3+
import static org.junit.Assert.assertEquals;
4+
5+
import java.util.concurrent.CompletableFuture;
6+
import java.util.concurrent.TimeUnit;
7+
import java.util.concurrent.atomic.AtomicReference;
8+
import org.junit.Test;
9+
10+
public class PollTaskExecutorTest {
11+
12+
@Test
13+
public void testContextClassLoaderInherited() throws Exception {
14+
PollerOptions pollerOptions =
15+
PollerOptions.newBuilder().setPollThreadNamePrefix("test").build();
16+
17+
ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader();
18+
ClassLoader testClassLoader = new ClassLoader() {};
19+
Thread.currentThread().setContextClassLoader(testClassLoader);
20+
21+
AtomicReference<ClassLoader> executedClassLoader = new AtomicReference<>();
22+
23+
try {
24+
PollTaskExecutor<String> executor =
25+
new PollTaskExecutor<>(
26+
"namespace",
27+
"taskQueue",
28+
"identity",
29+
new PollTaskExecutor.TaskHandler<String>() {
30+
@Override
31+
public void handle(String task) {
32+
executedClassLoader.set(Thread.currentThread().getContextClassLoader());
33+
}
34+
35+
@Override
36+
public Throwable wrapFailure(String task, Throwable failure) {
37+
return failure;
38+
}
39+
},
40+
pollerOptions,
41+
1,
42+
false);
43+
44+
// Execute on a different thread with a different context class loader to simulate
45+
// ForkJoinPool
46+
CompletableFuture<Void> future = new CompletableFuture<>();
47+
Thread triggerThread =
48+
new Thread(
49+
() -> {
50+
Thread.currentThread().setContextClassLoader(null);
51+
executor.process("task");
52+
future.complete(null);
53+
});
54+
triggerThread.start();
55+
future.get(5, TimeUnit.SECONDS);
56+
57+
// Wait for task completion
58+
executor.shutdown(new ShutdownManager(), false).get(5, TimeUnit.SECONDS);
59+
60+
assertEquals(testClassLoader, executedClassLoader.get());
61+
} finally {
62+
Thread.currentThread().setContextClassLoader(originalClassLoader);
63+
}
64+
}
65+
}

0 commit comments

Comments
 (0)