Skip to content

Commit a06b2f6

Browse files
committed
Merge branch 'main' into feature/datafusion-dsl-query-agg
2 parents ea72060 + bd501c6 commit a06b2f6

1 file changed

Lines changed: 13 additions & 16 deletions

File tree

plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/FlightOutboundHandlerTests.java

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
package org.opensearch.arrow.flight.transport;
1010

11+
import org.apache.arrow.memory.BufferAllocator;
12+
import org.apache.arrow.vector.VectorSchemaRoot;
1113
import org.opensearch.Version;
1214
import org.opensearch.common.util.concurrent.ThreadContext;
1315
import org.opensearch.core.transport.TransportResponse;
@@ -25,6 +27,7 @@
2527
import java.util.concurrent.ExecutorService;
2628
import java.util.concurrent.Executors;
2729
import java.util.concurrent.TimeUnit;
30+
import java.util.concurrent.atomic.AtomicReference;
2831

2932
import static org.mockito.ArgumentMatchers.any;
3033
import static org.mockito.ArgumentMatchers.anyLong;
@@ -54,6 +57,8 @@ public void setUp() throws Exception {
5457

5558
mockFlightChannel = mock(FlightServerChannel.class);
5659
when(mockFlightChannel.getExecutor()).thenReturn(executor);
60+
when(mockFlightChannel.getAllocator()).thenReturn(mock(BufferAllocator.class));
61+
when(mockFlightChannel.getRoot()).thenReturn(mock(VectorSchemaRoot.class));
5762

5863
mockListener = mock(TransportMessageListener.class);
5964
handler.setMessageListener(mockListener);
@@ -72,11 +77,7 @@ public void testSendResponseBatchPreservesCallerThreadContext() throws Exception
7277
ThreadContext threadContext = threadPool.getThreadContext();
7378
threadContext.putHeader(HEADER_KEY, HEADER_VALUE);
7479

75-
CountDownLatch latch = new CountDownLatch(1);
76-
doAnswer(invocation -> {
77-
latch.countDown();
78-
return null;
79-
}).when(mockListener).onResponseSent(anyLong(), anyString(), any(TransportResponse.class));
80+
doAnswer(invocation -> null).when(mockListener).onResponseSent(anyLong(), anyString(), any(TransportResponse.class));
8081

8182
handler.sendResponseBatch(
8283
Version.CURRENT,
@@ -140,20 +141,15 @@ public void testSendResponseBatchPropagatesContextToExecutorThread() throws Exce
140141
threadContext.putHeader(HEADER_KEY, HEADER_VALUE);
141142

142143
CountDownLatch latch = new CountDownLatch(1);
144+
AtomicReference<String> capturedHeader = new AtomicReference<>();
143145

144-
// Use a mock executor that runs the preserveContext-wrapped runnable
145-
ExecutorService mockExecutor = mock(ExecutorService.class);
146+
// Capture the thread context header inside onResponseSent, which runs
147+
// within the preserveContext wrapper on the executor thread
146148
doAnswer(invocation -> {
147-
Runnable command = invocation.getArgument(0);
148-
executor.execute(() -> {
149-
command.run();
150-
// After the preserveContext wrapper runs, capture the header
151-
// The wrapper stashes the executor thread context, restores caller's, runs, then restores executor's
152-
latch.countDown();
153-
});
149+
capturedHeader.set(threadPool.getThreadContext().getHeader(HEADER_KEY));
150+
latch.countDown();
154151
return null;
155-
}).when(mockExecutor).execute(any(Runnable.class));
156-
when(mockFlightChannel.getExecutor()).thenReturn(mockExecutor);
152+
}).when(mockListener).onResponseSent(anyLong(), anyString(), any(TransportResponse.class));
157153

158154
handler.sendResponseBatch(
159155
Version.CURRENT,
@@ -168,6 +164,7 @@ public void testSendResponseBatchPropagatesContextToExecutorThread() throws Exce
168164
);
169165

170166
assertTrue("Executor task should complete", latch.await(5, TimeUnit.SECONDS));
167+
assertEquals("Context should be propagated to executor thread", HEADER_VALUE, capturedHeader.get());
171168
}
172169

173170
public void testMultipleBatchesMaintainCallerContext() throws Exception {

0 commit comments

Comments
 (0)