Skip to content

Commit 6ba0947

Browse files
authored
Fix async poller workflow slot used (#2803)
* mark task slot used in async workflow poll task * slot supplier test with async poll task
1 parent dbd648b commit 6ba0947

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
public class AsyncWorkflowPollTask
3131
implements AsyncPoller.PollTaskAsync<WorkflowTask>, DisableNormalPolling {
3232
private static final Logger log = LoggerFactory.getLogger(AsyncWorkflowPollTask.class);
33-
private final TrackingSlotSupplier<?> slotSupplier;
33+
private final TrackingSlotSupplier<WorkflowSlotInfo> slotSupplier;
3434
private final WorkflowServiceStubs service;
3535
private final Scope metricsScope;
3636
private final Scope pollerMetricScope;
@@ -150,6 +150,7 @@ public CompletableFuture<WorkflowTask> poll(SlotPermit permit)
150150
.inc(1);
151151
return null;
152152
}
153+
slotSupplier.markSlotUsed(new WorkflowSlotInfo(r, pollRequest), permit);
153154
pollerMetricScope
154155
.counter(MetricsType.WORKFLOW_TASK_QUEUE_POLL_SUCCEED_COUNTER)
155156
.inc(1);

temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import static org.mockito.ArgumentMatchers.any;
66
import static org.mockito.Mockito.*;
77

8+
import com.google.common.util.concurrent.Futures;
89
import com.google.protobuf.ByteString;
910
import com.uber.m3.tally.RootScopeBuilder;
1011
import com.uber.m3.tally.Scope;
@@ -17,6 +18,8 @@
1718
import io.temporal.serviceclient.WorkflowServiceStubs;
1819
import io.temporal.worker.tuning.*;
1920
import java.util.Objects;
21+
import java.util.concurrent.CompletableFuture;
22+
import java.util.concurrent.ExecutionException;
2023
import java.util.concurrent.atomic.AtomicInteger;
2124
import org.junit.Test;
2225
import org.junit.runner.RunWith;
@@ -124,4 +127,64 @@ public void supplierIsCalledAppropriately() {
124127
assertEquals(1, trackingSS.getUsedSlots().size());
125128
}
126129
}
130+
131+
@Test
132+
public void asyncPollerSupplierIsCalledAppropriately() throws Exception {
133+
WorkflowServiceStubs client = mock(WorkflowServiceStubs.class);
134+
when(client.getServerCapabilities())
135+
.thenReturn(() -> GetSystemInfoResponse.Capabilities.newBuilder().build());
136+
137+
WorkflowServiceGrpc.WorkflowServiceFutureStub futureStub =
138+
mock(WorkflowServiceGrpc.WorkflowServiceFutureStub.class);
139+
when(client.futureStub()).thenReturn(futureStub);
140+
when(futureStub.withOption(any(), any())).thenReturn(futureStub);
141+
142+
SlotSupplier<WorkflowSlotInfo> mockSupplier = mock(SlotSupplier.class);
143+
Scope metricsScope =
144+
new RootScopeBuilder()
145+
.reporter(reporter)
146+
.reportEvery(com.uber.m3.util.Duration.ofMillis(1));
147+
TrackingSlotSupplier<WorkflowSlotInfo> trackingSS =
148+
new TrackingSlotSupplier<>(mockSupplier, metricsScope);
149+
150+
PollWorkflowTaskQueueResponse pollResponse =
151+
PollWorkflowTaskQueueResponse.newBuilder()
152+
.setTaskToken(ByteString.copyFrom("token", UTF_8))
153+
.setWorkflowExecution(
154+
WorkflowExecution.newBuilder().setWorkflowId(WORKFLOW_ID).setRunId(RUN_ID).build())
155+
.setWorkflowType(WorkflowType.newBuilder().setName(WORKFLOW_TYPE).build())
156+
.build();
157+
158+
if (throwOnPoll) {
159+
when(futureStub.pollWorkflowTaskQueue(any()))
160+
.thenReturn(Futures.immediateFailedFuture(new RuntimeException("Poll failed")));
161+
} else {
162+
when(futureStub.pollWorkflowTaskQueue(any()))
163+
.thenReturn(Futures.immediateFuture(pollResponse));
164+
}
165+
166+
AsyncWorkflowPollTask pollTask =
167+
new AsyncWorkflowPollTask(
168+
client,
169+
"default",
170+
TASK_QUEUE,
171+
null,
172+
"",
173+
new WorkerVersioningOptions("", false, null),
174+
trackingSS,
175+
metricsScope,
176+
() -> GetSystemInfoResponse.Capabilities.newBuilder().build());
177+
178+
SlotPermit permit = new SlotPermit();
179+
180+
CompletableFuture<WorkflowTask> future = pollTask.poll(permit);
181+
if (throwOnPoll) {
182+
assertThrows(ExecutionException.class, future::get);
183+
assertEquals(0, trackingSS.getUsedSlots().size());
184+
} else {
185+
WorkflowTask task = future.get();
186+
assertNotNull(task);
187+
assertEquals(1, trackingSS.getUsedSlots().size());
188+
}
189+
}
127190
}

0 commit comments

Comments
 (0)