88
99package org .opensearch .arrow .flight .transport ;
1010
11+ import org .apache .arrow .memory .BufferAllocator ;
12+ import org .apache .arrow .vector .VectorSchemaRoot ;
1113import org .opensearch .Version ;
1214import org .opensearch .common .util .concurrent .ThreadContext ;
1315import org .opensearch .core .transport .TransportResponse ;
2527import java .util .concurrent .ExecutorService ;
2628import java .util .concurrent .Executors ;
2729import java .util .concurrent .TimeUnit ;
30+ import java .util .concurrent .atomic .AtomicReference ;
2831
2932import static org .mockito .ArgumentMatchers .any ;
3033import static org .mockito .ArgumentMatchers .anyLong ;
@@ -54,6 +57,8 @@ public void setUp() throws Exception {
5457
5558 mockFlightChannel = mock (FlightServerChannel .class );
5659 when (mockFlightChannel .getExecutor ()).thenReturn (executor );
60+ when (mockFlightChannel .getAllocator ()).thenReturn (mock (BufferAllocator .class ));
61+ when (mockFlightChannel .getRoot ()).thenReturn (mock (VectorSchemaRoot .class ));
5762
5863 mockListener = mock (TransportMessageListener .class );
5964 handler .setMessageListener (mockListener );
@@ -72,11 +77,7 @@ public void testSendResponseBatchPreservesCallerThreadContext() throws Exception
7277 ThreadContext threadContext = threadPool .getThreadContext ();
7378 threadContext .putHeader (HEADER_KEY , HEADER_VALUE );
7479
75- CountDownLatch latch = new CountDownLatch (1 );
76- doAnswer (invocation -> {
77- latch .countDown ();
78- return null ;
79- }).when (mockListener ).onResponseSent (anyLong (), anyString (), any (TransportResponse .class ));
80+ doAnswer (invocation -> null ).when (mockListener ).onResponseSent (anyLong (), anyString (), any (TransportResponse .class ));
8081
8182 handler .sendResponseBatch (
8283 Version .CURRENT ,
@@ -140,20 +141,15 @@ public void testSendResponseBatchPropagatesContextToExecutorThread() throws Exce
140141 threadContext .putHeader (HEADER_KEY , HEADER_VALUE );
141142
142143 CountDownLatch latch = new CountDownLatch (1 );
144+ AtomicReference <String > capturedHeader = new AtomicReference <>();
143145
144- // Use a mock executor that runs the preserveContext-wrapped runnable
145- ExecutorService mockExecutor = mock ( ExecutorService . class );
146+ // Capture the thread context header inside onResponseSent, which runs
147+ // within the preserveContext wrapper on the executor thread
146148 doAnswer (invocation -> {
147- Runnable command = invocation .getArgument (0 );
148- executor .execute (() -> {
149- command .run ();
150- // After the preserveContext wrapper runs, capture the header
151- // The wrapper stashes the executor thread context, restores caller's, runs, then restores executor's
152- latch .countDown ();
153- });
149+ capturedHeader .set (threadPool .getThreadContext ().getHeader (HEADER_KEY ));
150+ latch .countDown ();
154151 return null ;
155- }).when (mockExecutor ).execute (any (Runnable .class ));
156- when (mockFlightChannel .getExecutor ()).thenReturn (mockExecutor );
152+ }).when (mockListener ).onResponseSent (anyLong (), anyString (), any (TransportResponse .class ));
157153
158154 handler .sendResponseBatch (
159155 Version .CURRENT ,
@@ -168,6 +164,7 @@ public void testSendResponseBatchPropagatesContextToExecutorThread() throws Exce
168164 );
169165
170166 assertTrue ("Executor task should complete" , latch .await (5 , TimeUnit .SECONDS ));
167+ assertEquals ("Context should be propagated to executor thread" , HEADER_VALUE , capturedHeader .get ());
171168 }
172169
173170 public void testMultipleBatchesMaintainCallerContext () throws Exception {
0 commit comments