Skip to content

Commit b337ac0

Browse files
TristanAndersonTswizzle
authored andcommitted
adds virtual thread support for activity execution context
> updates CurrentActivityExecutionContext.java to support that mission > adds test for CurrentActivityExecutionContextTest > removed unused code in temporal-sdk/src/test/java/io/temporal/internal/nexus/WorkflowRunTokenTest.java
1 parent 22a1c1a commit b337ac0

3 files changed

Lines changed: 195 additions & 14 deletions

File tree

temporal-sdk/src/main/java/io/temporal/internal/activity/CurrentActivityExecutionContext.java

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,56 @@
11
package io.temporal.internal.activity;
22

33
import io.temporal.activity.ActivityExecutionContext;
4+
import java.util.ArrayDeque;
5+
import java.util.Collections;
6+
import java.util.Deque;
7+
import java.util.Map;
8+
import java.util.WeakHashMap;
49

510
/**
6-
* Thread local store of the context object passed to an activity implementation. Avoid using this
7-
* class directly.
11+
* Thread-local / virtual-thread-aware store of the context object passed to an activity
12+
* implementation. Avoid using this class directly.
813
*
9-
* @author fateev
14+
* <p>Uses a per-thread stack so nested sets/unsets are handled correctly. Platform threads use
15+
* ThreadLocal; virtual threads use a WeakHashMap keyed by Thread to avoid leaking memory when
16+
* virtual threads die.
17+
*
18+
* @author fateev (adapted)
1019
*/
11-
final class CurrentActivityExecutionContext {
20+
public final class CurrentActivityExecutionContext {
21+
22+
private static final ThreadLocal<Deque<ActivityExecutionContext>> PLATFORM_STACK =
23+
ThreadLocal.withInitial(ArrayDeque::new);
1224

13-
private static final ThreadLocal<ActivityExecutionContext> CURRENT = new ThreadLocal<>();
25+
private static final Map<Thread, Deque<ActivityExecutionContext>> VIRTUAL_STACKS =
26+
Collections.synchronizedMap(new WeakHashMap<>());
27+
28+
private static Deque<ActivityExecutionContext> getStackForCurrentThread() {
29+
Thread t = Thread.currentThread();
30+
if (isVirtualThread(t)) {
31+
Deque<ActivityExecutionContext> d =
32+
VIRTUAL_STACKS.computeIfAbsent(t, k -> new ArrayDeque<>());
33+
return d;
34+
} else {
35+
return PLATFORM_STACK.get();
36+
}
37+
}
38+
39+
private static boolean isVirtualThread(Thread t) {
40+
try {
41+
t.getClass().getMethod("isVirtual", boolean.class);
42+
return true;
43+
} catch (NoSuchMethodException e) {
44+
return false;
45+
}
46+
}
1447

1548
/**
1649
* This is used by activity implementation to get access to the current ActivityExecutionContext
1750
*/
1851
public static ActivityExecutionContext get() {
19-
ActivityExecutionContext result = CURRENT.get();
52+
Deque<ActivityExecutionContext> stack = getStackForCurrentThread();
53+
ActivityExecutionContext result = stack.peek();
2054
if (result == null) {
2155
throw new IllegalStateException(
2256
"ActivityExecutionContext can be used only inside of activity "
@@ -26,21 +60,49 @@ public static ActivityExecutionContext get() {
2660
}
2761

2862
public static boolean isSet() {
29-
return CURRENT.get() != null;
63+
Deque<ActivityExecutionContext> stack = getStackForCurrentThread();
64+
return stack.peek() != null;
3065
}
3166

67+
/**
68+
* Pushes the provided context for the current thread. Null context is rejected. We allow nested
69+
* sets (push semantics) to support nested interceptors / wrappers.
70+
*/
3271
public static void set(ActivityExecutionContext context) {
3372
if (context == null) {
3473
throw new IllegalArgumentException("null context");
3574
}
36-
if (CURRENT.get() != null) {
37-
throw new IllegalStateException("current already set");
38-
}
39-
CURRENT.set(context);
75+
Deque<ActivityExecutionContext> stack = getStackForCurrentThread();
76+
stack.push(context);
4077
}
4178

79+
/**
80+
* Pops the current context for the thread. If the stack becomes empty, clear the storage for the
81+
* thread to allow GC (remove ThreadLocal or remove map entry for virtual threads).
82+
*/
4283
public static void unset() {
43-
CURRENT.set(null);
84+
Thread t = Thread.currentThread();
85+
if (isVirtualThread(t)) {
86+
synchronized (VIRTUAL_STACKS) {
87+
Deque<ActivityExecutionContext> stack = VIRTUAL_STACKS.get(t);
88+
if (stack == null || stack.isEmpty()) {
89+
return;
90+
}
91+
stack.pop();
92+
if (stack.isEmpty()) {
93+
VIRTUAL_STACKS.remove(t);
94+
}
95+
}
96+
} else {
97+
Deque<ActivityExecutionContext> stack = PLATFORM_STACK.get();
98+
if (stack == null || stack.isEmpty()) {
99+
return;
100+
}
101+
stack.pop();
102+
if (stack.isEmpty()) {
103+
PLATFORM_STACK.remove();
104+
}
105+
}
44106
}
45107

46108
private CurrentActivityExecutionContext() {}

temporal-sdk/src/test/java/io/temporal/internal/nexus/WorkflowRunTokenTest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
public class WorkflowRunTokenTest {
1212
private static final ObjectWriter ow =
1313
new ObjectMapper().registerModule(new Jdk8Module()).writer();
14-
private static final ObjectReader or =
15-
new ObjectMapper().registerModule(new Jdk8Module()).reader();
1614
private static final Base64.Encoder encoder = Base64.getUrlEncoder().withoutPadding();
1715

1816
@Test
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package io.temporal.internal.activity;
2+
3+
import static org.junit.Assert.*;
4+
5+
import io.temporal.activity.ActivityExecutionContext;
6+
import java.lang.reflect.InvocationHandler;
7+
import java.lang.reflect.Proxy;
8+
import java.util.concurrent.atomic.AtomicReference;
9+
import org.junit.Assume;
10+
import org.junit.Test;
11+
12+
public class CurrentActivityExecutionContextTest {
13+
14+
private static ActivityExecutionContext proxyContext() {
15+
InvocationHandler handler = (proxy, method, args) -> null;
16+
return (ActivityExecutionContext)
17+
Proxy.newProxyInstance(
18+
ActivityExecutionContext.class.getClassLoader(),
19+
new Class[] {ActivityExecutionContext.class},
20+
handler);
21+
}
22+
23+
@Test
24+
public void platformThreadNestedSetUnsetBehavior() {
25+
ActivityExecutionContext ctx1 = proxyContext();
26+
ActivityExecutionContext ctx2 = proxyContext();
27+
28+
assertFalse(CurrentActivityExecutionContext.isSet());
29+
assertThrows(IllegalStateException.class, CurrentActivityExecutionContext::get);
30+
31+
CurrentActivityExecutionContext.set(ctx1);
32+
assertTrue(CurrentActivityExecutionContext.isSet());
33+
assertSame("should return ctx1", ctx1, CurrentActivityExecutionContext.get());
34+
35+
CurrentActivityExecutionContext.set(ctx2);
36+
assertTrue(CurrentActivityExecutionContext.isSet());
37+
assertSame("should return ctx2 (top of stack)", ctx2, CurrentActivityExecutionContext.get());
38+
39+
CurrentActivityExecutionContext.unset();
40+
assertTrue(CurrentActivityExecutionContext.isSet());
41+
assertSame("after popping, should return ctx1", ctx1, CurrentActivityExecutionContext.get());
42+
43+
CurrentActivityExecutionContext.unset();
44+
assertFalse(CurrentActivityExecutionContext.isSet());
45+
assertThrows(
46+
"get() should throw after final unset",
47+
IllegalStateException.class,
48+
CurrentActivityExecutionContext::get);
49+
}
50+
51+
@Test
52+
public void virtualThreadNestedSetUnsetBehavior_ifSupported() throws Exception {
53+
boolean supportsVirtual;
54+
try {
55+
Thread.class.getMethod("startVirtualThread", Runnable.class);
56+
supportsVirtual = true;
57+
} catch (NoSuchMethodException e) {
58+
supportsVirtual = false;
59+
}
60+
61+
Assume.assumeTrue("Virtual threads not supported in this JVM; skipping", supportsVirtual);
62+
63+
AtomicReference<Throwable> failure = new AtomicReference<>(null);
64+
AtomicReference<ActivityExecutionContext> seenAfterFirstSet = new AtomicReference<>(null);
65+
AtomicReference<ActivityExecutionContext> seenAfterSecondSet = new AtomicReference<>(null);
66+
AtomicReference<Boolean> seenIsSetAfterFinalUnset = new AtomicReference<>(null);
67+
68+
Thread vt =
69+
Thread.startVirtualThread(
70+
() -> {
71+
try {
72+
ActivityExecutionContext vctx1 = proxyContext();
73+
ActivityExecutionContext vctx2 = proxyContext();
74+
75+
assertFalse(CurrentActivityExecutionContext.isSet());
76+
try {
77+
CurrentActivityExecutionContext.get();
78+
fail("get() should have thrown when no context is set");
79+
} catch (IllegalStateException expected) {
80+
}
81+
82+
CurrentActivityExecutionContext.set(vctx1);
83+
seenAfterFirstSet.set(CurrentActivityExecutionContext.get());
84+
85+
CurrentActivityExecutionContext.set(vctx2);
86+
seenAfterSecondSet.set(CurrentActivityExecutionContext.get());
87+
88+
CurrentActivityExecutionContext.unset();
89+
ActivityExecutionContext afterPop = CurrentActivityExecutionContext.get();
90+
if (afterPop != vctx1) {
91+
throw new AssertionError("after pop expected vctx1 but got " + afterPop);
92+
}
93+
94+
CurrentActivityExecutionContext.unset();
95+
seenIsSetAfterFinalUnset.set(CurrentActivityExecutionContext.isSet());
96+
try {
97+
CurrentActivityExecutionContext.get();
98+
throw new AssertionError("get() should have thrown after final unset");
99+
} catch (IllegalStateException expected) {
100+
}
101+
} catch (Throwable t) {
102+
failure.set(t);
103+
}
104+
});
105+
106+
vt.join();
107+
108+
if (failure.get() != null) {
109+
Throwable t = failure.get();
110+
if (t instanceof AssertionError) {
111+
throw (AssertionError) t;
112+
} else {
113+
throw new RuntimeException(t);
114+
}
115+
}
116+
117+
assertNotNull("virtual thread did not record first set", seenAfterFirstSet.get());
118+
assertNotNull("virtual thread did not record second (nested) set", seenAfterSecondSet.get());
119+
assertFalse("expected context to be unset at the end", seenIsSetAfterFinalUnset.get());
120+
}
121+
}

0 commit comments

Comments
 (0)