|
5 | 5 | import static org.mockito.ArgumentMatchers.any; |
6 | 6 | import static org.mockito.Mockito.*; |
7 | 7 |
|
| 8 | +import com.google.common.util.concurrent.Futures; |
8 | 9 | import com.google.protobuf.ByteString; |
9 | 10 | import com.uber.m3.tally.RootScopeBuilder; |
10 | 11 | import com.uber.m3.tally.Scope; |
|
17 | 18 | import io.temporal.serviceclient.WorkflowServiceStubs; |
18 | 19 | import io.temporal.worker.tuning.*; |
19 | 20 | import java.util.Objects; |
| 21 | +import java.util.concurrent.CompletableFuture; |
| 22 | +import java.util.concurrent.ExecutionException; |
20 | 23 | import java.util.concurrent.atomic.AtomicInteger; |
21 | 24 | import org.junit.Test; |
22 | 25 | import org.junit.runner.RunWith; |
@@ -124,4 +127,64 @@ public void supplierIsCalledAppropriately() { |
124 | 127 | assertEquals(1, trackingSS.getUsedSlots().size()); |
125 | 128 | } |
126 | 129 | } |
| 130 | + |
| 131 | + @Test |
| 132 | + public void asyncPollerSupplierIsCalledAppropriately() throws Exception { |
| 133 | + WorkflowServiceStubs client = mock(WorkflowServiceStubs.class); |
| 134 | + when(client.getServerCapabilities()) |
| 135 | + .thenReturn(() -> GetSystemInfoResponse.Capabilities.newBuilder().build()); |
| 136 | + |
| 137 | + WorkflowServiceGrpc.WorkflowServiceFutureStub futureStub = |
| 138 | + mock(WorkflowServiceGrpc.WorkflowServiceFutureStub.class); |
| 139 | + when(client.futureStub()).thenReturn(futureStub); |
| 140 | + when(futureStub.withOption(any(), any())).thenReturn(futureStub); |
| 141 | + |
| 142 | + SlotSupplier<WorkflowSlotInfo> mockSupplier = mock(SlotSupplier.class); |
| 143 | + Scope metricsScope = |
| 144 | + new RootScopeBuilder() |
| 145 | + .reporter(reporter) |
| 146 | + .reportEvery(com.uber.m3.util.Duration.ofMillis(1)); |
| 147 | + TrackingSlotSupplier<WorkflowSlotInfo> trackingSS = |
| 148 | + new TrackingSlotSupplier<>(mockSupplier, metricsScope); |
| 149 | + |
| 150 | + PollWorkflowTaskQueueResponse pollResponse = |
| 151 | + PollWorkflowTaskQueueResponse.newBuilder() |
| 152 | + .setTaskToken(ByteString.copyFrom("token", UTF_8)) |
| 153 | + .setWorkflowExecution( |
| 154 | + WorkflowExecution.newBuilder().setWorkflowId(WORKFLOW_ID).setRunId(RUN_ID).build()) |
| 155 | + .setWorkflowType(WorkflowType.newBuilder().setName(WORKFLOW_TYPE).build()) |
| 156 | + .build(); |
| 157 | + |
| 158 | + if (throwOnPoll) { |
| 159 | + when(futureStub.pollWorkflowTaskQueue(any())) |
| 160 | + .thenReturn(Futures.immediateFailedFuture(new RuntimeException("Poll failed"))); |
| 161 | + } else { |
| 162 | + when(futureStub.pollWorkflowTaskQueue(any())) |
| 163 | + .thenReturn(Futures.immediateFuture(pollResponse)); |
| 164 | + } |
| 165 | + |
| 166 | + AsyncWorkflowPollTask pollTask = |
| 167 | + new AsyncWorkflowPollTask( |
| 168 | + client, |
| 169 | + "default", |
| 170 | + TASK_QUEUE, |
| 171 | + null, |
| 172 | + "", |
| 173 | + new WorkerVersioningOptions("", false, null), |
| 174 | + trackingSS, |
| 175 | + metricsScope, |
| 176 | + () -> GetSystemInfoResponse.Capabilities.newBuilder().build()); |
| 177 | + |
| 178 | + SlotPermit permit = new SlotPermit(); |
| 179 | + |
| 180 | + CompletableFuture<WorkflowTask> future = pollTask.poll(permit); |
| 181 | + if (throwOnPoll) { |
| 182 | + assertThrows(ExecutionException.class, future::get); |
| 183 | + assertEquals(0, trackingSS.getUsedSlots().size()); |
| 184 | + } else { |
| 185 | + WorkflowTask task = future.get(); |
| 186 | + assertNotNull(task); |
| 187 | + assertEquals(1, trackingSS.getUsedSlots().size()); |
| 188 | + } |
| 189 | + } |
127 | 190 | } |
0 commit comments