Skip to content

Commit f8ef94c

Browse files
Support Workflow.getInfo from query method body
1 parent 5d64818 commit f8ef94c

7 files changed

Lines changed: 72 additions & 12 deletions

File tree

temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunner.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import io.temporal.internal.worker.WorkflowExecutorCache;
44
import io.temporal.workflow.CancellationScope;
5+
import java.util.Optional;
56
import javax.annotation.Nonnull;
67
import javax.annotation.Nullable;
78

@@ -90,4 +91,17 @@ static DeterministicRunner newRunner(
9091
/** Creates a new instance of a workflow callback thread. */
9192
@Nonnull
9293
WorkflowThread newCallbackThread(Runnable runnable, @Nullable String name);
94+
95+
/**
96+
* Retrieve data from runner locals. Returns 1. not found (an empty Optional) 2. found but null
97+
* (an Optional of an empty Optional) 3. found and non-null (an Optional of an Optional of a
98+
* value). The type nesting is because Java Optionals cannot understand "Some null" vs "None",
99+
* which is exactly what we need here.
100+
*
101+
* @param key
102+
* @return one of three cases
103+
* @param <T>
104+
*/
105+
@SuppressWarnings("unchecked")
106+
<T> Optional<Optional<T>> getRunnerLocal(RunnerLocalInternal<T> key);
93107
}

temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunnerImpl.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ private boolean areThreadsToBeExecuted() {
586586
* @param <T>
587587
*/
588588
@SuppressWarnings("unchecked")
589-
<T> Optional<Optional<T>> getRunnerLocal(RunnerLocalInternal<T> key) {
589+
public <T> Optional<Optional<T>> getRunnerLocal(RunnerLocalInternal<T> key) {
590590
if (!runnerLocalMap.containsKey(key)) {
591591
return Optional.empty();
592592
}

temporal-sdk/src/main/java/io/temporal/internal/sync/QueryDispatcher.java

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,33 @@ class QueryDispatcher {
2323

2424
private DynamicQueryHandler dynamicQueryHandler;
2525
private WorkflowInboundCallsInterceptor inboundCallsInterceptor;
26+
private static final ThreadLocal<SyncWorkflowContext> queryHandlerWorkflowContext =
27+
new ThreadLocal<>();
2628

2729
public QueryDispatcher(DataConverter dataConverterWithWorkflowContext) {
2830
this.dataConverterWithWorkflowContext = dataConverterWithWorkflowContext;
2931
}
3032

33+
/**
34+
* @return True if the current thread is executing a query handler.
35+
*/
36+
public static boolean isQueryHandler() {
37+
SyncWorkflowContext value = queryHandlerWorkflowContext.get();
38+
return value != null;
39+
}
40+
41+
/**
42+
* @return The current workflow context if the current thread is executing a query handler.
43+
* @throws IllegalStateException if not in a query handler.
44+
*/
45+
public static SyncWorkflowContext getWorkflowContext() {
46+
SyncWorkflowContext value = queryHandlerWorkflowContext.get();
47+
if (value == null) {
48+
throw new IllegalStateException("Not in a query handler");
49+
}
50+
return value;
51+
}
52+
3153
public void setInboundCallsInterceptor(WorkflowInboundCallsInterceptor inboundCallsInterceptor) {
3254
this.inboundCallsInterceptor = inboundCallsInterceptor;
3355
}
@@ -51,7 +73,11 @@ public WorkflowInboundCallsInterceptor.QueryOutput handleInterceptedQuery(
5173
return new WorkflowInboundCallsInterceptor.QueryOutput(result);
5274
}
5375

54-
public Optional<Payloads> handleQuery(String queryName, Header header, Optional<Payloads> input) {
76+
public Optional<Payloads> handleQuery(
77+
SyncWorkflowContext replayContext,
78+
String queryName,
79+
Header header,
80+
Optional<Payloads> input) {
5581
WorkflowOutboundCallsInterceptor.RegisterQueryInput handler = queryCallbacks.get(queryName);
5682
Object[] args;
5783
if (queryName.startsWith(TEMPORAL_RESERVED_PREFIX)) {
@@ -69,11 +95,18 @@ public Optional<Payloads> handleQuery(String queryName, Header header, Optional<
6995
dataConverterWithWorkflowContext.fromPayloads(
7096
input, handler.getArgTypes(), handler.getGenericArgTypes());
7197
}
72-
Object result =
73-
inboundCallsInterceptor
74-
.handleQuery(new WorkflowInboundCallsInterceptor.QueryInput(queryName, header, args))
75-
.getResult();
76-
return dataConverterWithWorkflowContext.toPayloads(result);
98+
try {
99+
replayContext.setReadOnly(true);
100+
queryHandlerWorkflowContext.set(replayContext);
101+
Object result =
102+
inboundCallsInterceptor
103+
.handleQuery(new WorkflowInboundCallsInterceptor.QueryInput(queryName, header, args))
104+
.getResult();
105+
return dataConverterWithWorkflowContext.toPayloads(result);
106+
} finally {
107+
replayContext.setReadOnly(false);
108+
queryHandlerWorkflowContext.set(null);
109+
}
77110
}
78111

79112
public void registerQueryHandlers(WorkflowOutboundCallsInterceptor.RegisterQueryInput request) {

temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,15 @@ public RunnerLocalInternal(boolean useCaching) {
1616
}
1717

1818
public T get(Supplier<? extends T> supplier) {
19-
Optional<Optional<T>> result =
20-
DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this);
19+
Optional<Optional<T>> result;
20+
// Query handlers are special in that they are executing in a different context
21+
// than the main workflow execution threads. We need to fetch the runner local from the
22+
// correct context based on whether we are in a query handler or not.
23+
if (QueryDispatcher.isQueryHandler()) {
24+
result = QueryDispatcher.getWorkflowContext().getRunner().getRunnerLocal(this);
25+
} else {
26+
result = DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this);
27+
}
2128
T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null);
2229
if (!result.isPresent() && useCaching) {
2330
// This is the first time we've tried fetching this, and caching is enabled. Store it.

temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ public WorkflowInboundCallsInterceptor.QueryOutput handleInterceptedQuery(
349349
}
350350

351351
public Optional<Payloads> handleQuery(String queryName, Header header, Optional<Payloads> input) {
352-
return queryDispatcher.handleQuery(queryName, header, input);
352+
return queryDispatcher.handleQuery(this, queryName, header, input);
353353
}
354354

355355
public boolean isEveryHandlerFinished() {

temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowInternal.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,12 @@ static WorkflowOutboundCallsInterceptor getWorkflowOutboundInterceptor() {
843843
}
844844

845845
static SyncWorkflowContext getRootWorkflowContext() {
846+
// If we are in a query handler, we need to get the workflow context from the
847+
// QueryDispatcher, otherwise we get it from the current thread's internal context.
848+
// This is necessary because query handlers run in a different context than the main workflow threads.
849+
if (QueryDispatcher.isQueryHandler()) {
850+
return QueryDispatcher.getWorkflowContext();
851+
}
846852
return DeterministicRunnerImpl.currentThreadInternal().getWorkflowContext();
847853
}
848854

temporal-sdk/src/test/java/io/temporal/internal/sync/QueryDispatcherTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void testQuerySuccess() {
4949

5050
// Invoke functionality under test, expect no exceptions for an existing query.
5151
Optional<Payloads> queryResult =
52-
dispatcher.handleQuery("QueryB", Header.empty(), Optional.empty());
52+
dispatcher.handleQuery(null, "QueryB", Header.empty(), Optional.empty());
5353
assertTrue(queryResult.isPresent());
5454
}
5555

@@ -61,7 +61,7 @@ public void testQueryDispatcherException() {
6161
assertThrows(
6262
IllegalArgumentException.class,
6363
() -> {
64-
dispatcher.handleQuery("QueryC", Header.empty(), null);
64+
dispatcher.handleQuery(null, "QueryC", Header.empty(), null);
6565
});
6666
assertEquals("Unknown query type: QueryC, knownTypes=[QueryA, QueryB]", exception.getMessage());
6767
}

0 commit comments

Comments
 (0)