2626 SdkMessage ,
2727 SendAlertRequest ,
2828 SendInboundSpanForReplayRequest ,
29+ SetTimeTravelResponse ,
2930 UnpatchedDependencyAlert ,
3031 span_to_proto ,
3132)
@@ -78,6 +79,12 @@ def __init__(self, config: CommunicatorConfig | None = None) -> None:
7879 self ._incoming_buffer = bytearray ()
7980 self ._pending_requests : dict [str , dict [str , Any ]] = {}
8081 self ._lock = threading .Lock ()
82+ self ._background_reader_thread : threading .Thread | None = None
83+ self ._stop_background_reader = threading .Event ()
84+ # Response routing: background reader stores responses here, callers wait on events
85+ self ._response_events : dict [str , threading .Event ] = {}
86+ self ._response_data : dict [str , CliMessage ] = {}
87+ self ._response_lock = threading .Lock () # Protects response_events and response_data
8188
8289 @property
8390 def is_connected (self ) -> bool :
@@ -318,6 +325,9 @@ def connect_sync(
318325 if response .success :
319326 logger .debug ("CLI acknowledged connection successfully" )
320327 self ._connected = True
328+
329+ # Start background reader for CLI-initiated messages (like SetTimeTravel)
330+ self ._start_background_reader ()
321331 else :
322332 error_msg = response .error or "Unknown error"
323333 raise ConnectionError (f"CLI rejected connection: { error_msg } " )
@@ -383,11 +393,24 @@ async def request_mock_async(self, mock_request: MockRequestInput) -> MockRespon
383393 f"[ProtobufCommunicator] Creating mock request with requestId: { request_id } , testId: { mock_request .test_id } "
384394 )
385395
386- # Send and wait for response
387- await self ._send_protobuf_message (sdk_message )
388- response = await self ._receive_response (request_id )
396+ # Pre-register event BEFORE sending message to avoid race condition where
397+ # CLI responds before _wait_for_response registers the event
398+ if self ._background_reader_thread and self ._background_reader_thread .is_alive ():
399+ with self ._response_lock :
400+ self ._response_events [request_id ] = threading .Event ()
389401
390- return response
402+ try :
403+ # Send and wait for response
404+ await self ._send_protobuf_message (sdk_message )
405+ response = await self ._receive_response (request_id )
406+ return response
407+ except Exception :
408+ # Clean up pre-registered event on failure
409+ if self ._background_reader_thread and self ._background_reader_thread .is_alive ():
410+ with self ._response_lock :
411+ self ._response_events .pop (request_id , None )
412+ self ._response_data .pop (request_id , None )
413+ raise
391414
392415 def request_mock_sync (self , mock_request : MockRequestInput ) -> MockResponseOutput :
393416 """Request mocked response data from CLI (synchronous).
@@ -526,16 +549,27 @@ async def _send_protobuf_message(self, message: SdkMessage) -> None:
526549 # our own socket operations as unpatched dependencies
527550 context_token = calling_library_context .set ("ProtobufCommunicator" )
528551 try :
529- # Send synchronously (socket is blocking for sends)
530- self ._socket .sendall (full_message )
552+ # Acquire lock to prevent concurrent sends from background reader thread
553+ # (e.g., _send_message_sync sending SetTimeTravel responses)
554+ with self ._lock :
555+ self ._socket .sendall (full_message )
531556 finally :
532557 calling_library_context .reset (context_token )
533558
534559 async def _receive_response (self , request_id : str ) -> MockResponseOutput :
535- """Receive and parse a response for a specific request ID."""
560+ """Receive and parse a response for a specific request ID.
561+
562+ If the background reader is running, waits on an event for the response.
563+ Otherwise, reads directly from the socket (for async-only connections).
564+ """
536565 if not self ._socket :
537566 raise ConnectionError ("Socket not initialized" )
538567
568+ # If background reader is running, wait on event instead of reading socket
569+ if self ._background_reader_thread and self ._background_reader_thread .is_alive ():
570+ return await self ._wait_for_response_async (request_id )
571+
572+ # No background reader - read directly from socket (async connect path)
539573 self ._socket .settimeout (self .config .request_timeout )
540574
541575 try :
@@ -571,6 +605,51 @@ async def _receive_response(self, request_id: str) -> MockResponseOutput:
571605 except TimeoutError as e :
572606 raise TimeoutError (f"Request timed out: { e } " ) from e
573607
608+ def _wait_for_response (self , request_id : str ) -> MockResponseOutput :
609+ """Wait for a response from the background reader thread.
610+
611+ Uses a pre-registered event for the request_id (registered before sending
612+ the message to avoid race conditions), waits for the background reader
613+ to signal it, then retrieves the response.
614+ """
615+ # Use pre-registered event, or create one as fallback
616+ with self ._response_lock :
617+ event = self ._response_events .get (request_id )
618+ if not event :
619+ # Fallback: register now (shouldn't happen in normal flow)
620+ event = threading .Event ()
621+ self ._response_events [request_id ] = event
622+
623+ try :
624+ # Wait for the background reader to signal us
625+ if not event .wait (timeout = self .config .request_timeout ):
626+ raise TimeoutError (f"Request timed out waiting for response: { request_id } " )
627+
628+ # Retrieve the response
629+ with self ._response_lock :
630+ cli_message = self ._response_data .pop (request_id , None )
631+
632+ if cli_message is None :
633+ raise ConnectionError (f"Response was signaled but not found: { request_id } " )
634+
635+ return self ._handle_cli_message (cli_message )
636+
637+ finally :
638+ # Clean up the event registration
639+ with self ._response_lock :
640+ self ._response_events .pop (request_id , None )
641+ self ._response_data .pop (request_id , None ) # In case of timeout
642+
643+ async def _wait_for_response_async (self , request_id : str ) -> MockResponseOutput :
644+ """Async version of _wait_for_response that doesn't block the event loop.
645+
646+ Uses asyncio.to_thread() to run the blocking Event.wait() in a thread pool,
647+ allowing other async tasks to run while waiting for the response.
648+ """
649+ import asyncio
650+
651+ return await asyncio .to_thread (self ._wait_for_response , request_id )
652+
574653 def _recv_exact (self , n : int ) -> bytes | None :
575654 """Receive exactly n bytes from socket."""
576655 if self ._socket is None :
@@ -761,6 +840,13 @@ def _clean_span(self, data: Any) -> Any:
761840
762841 def _cleanup (self ) -> None :
763842 """Clean up resources."""
843+
844+ # Stop background reader thread
845+ self ._stop_background_reader .set ()
846+ if self ._background_reader_thread and self ._background_reader_thread .is_alive ():
847+ self ._background_reader_thread .join (timeout = 1.0 )
848+ self ._background_reader_thread = None
849+
764850 self ._connected = False
765851 self ._session_id = None
766852 self ._incoming_buffer .clear ()
@@ -774,3 +860,137 @@ def _cleanup(self) -> None:
774860 self ._socket = None
775861
776862 self ._pending_requests .clear ()
863+
864+ # Clean up response routing data and signal any waiting threads
865+ with self ._response_lock :
866+ # Signal all waiting threads so they don't hang
867+ for event in self ._response_events .values ():
868+ event .set ()
869+ self ._response_events .clear ()
870+ self ._response_data .clear ()
871+
872+ # ========== Background Reader for CLI-initiated Messages ==========
873+
874+ def _start_background_reader (self ) -> None :
875+ """Start background thread to read CLI-initiated messages."""
876+ if self ._background_reader_thread and self ._background_reader_thread .is_alive ():
877+ return
878+
879+ self ._stop_background_reader .clear ()
880+ self ._background_reader_thread = threading .Thread (
881+ target = self ._background_read_loop ,
882+ daemon = True ,
883+ name = "CLI-Message-Reader" ,
884+ )
885+ self ._background_reader_thread .start ()
886+ logger .debug ("Started background reader thread for CLI-initiated messages" )
887+
888+ def _background_read_loop (self ) -> None :
889+ """Background loop to read and handle CLI-initiated messages."""
890+ while not self ._stop_background_reader .is_set ():
891+ if not self ._socket :
892+ break
893+
894+ try :
895+ # Set a short timeout so we can check the stop event periodically
896+ self ._socket .settimeout (0.5 )
897+
898+ # Try to read length prefix
899+ try :
900+ length_data = self ._recv_exact (4 )
901+ except TimeoutError :
902+ continue # No data available, check stop event and retry
903+ except Exception :
904+ continue
905+
906+ if not length_data :
907+ # None means connection closed (recv returned empty bytes)
908+ break
909+
910+ length = struct .unpack (">I" , length_data )[0 ]
911+
912+ # Read message data
913+ self ._socket .settimeout (5.0 ) # Longer timeout for message body
914+ message_data = self ._recv_exact (length )
915+ if not message_data :
916+ # None means connection closed (recv returned empty bytes)
917+ break
918+
919+ # Parse message
920+ cli_message = CliMessage ().parse (message_data )
921+ logger .debug (f"Background reader received message type: { cli_message .type } " )
922+
923+ # Handle CLI-initiated messages (no request_id, or special types)
924+ if cli_message .type == MessageType .SET_TIME_TRAVEL :
925+ self ._handle_set_time_travel_sync (cli_message )
926+ continue
927+
928+ # Route responses to waiting callers by request_id
929+ request_id = cli_message .request_id
930+ if request_id :
931+ with self ._response_lock :
932+ if request_id in self ._response_events :
933+ # Store response and signal the waiting caller
934+ self ._response_data [request_id ] = cli_message
935+ self ._response_events [request_id ].set ()
936+ logger .debug (f"Background reader routed response for request_id: { request_id } " )
937+ else :
938+ # No one waiting for this response (possibly timed out)
939+ logger .debug (f"Background reader received response with no waiter: { request_id } " )
940+
941+ except TimeoutError :
942+ continue # Normal timeout, just retry
943+ except Exception as e :
944+ if not self ._stop_background_reader .is_set ():
945+ logger .debug (f"Background reader error: { e } " )
946+ break
947+
948+ logger .debug ("Background reader thread stopped" )
949+
950+ def _handle_set_time_travel_sync (self , cli_message : CliMessage ) -> None :
951+ """Handle SetTimeTravel request from CLI and send response."""
952+ request = cli_message .set_time_travel_request
953+ if not request :
954+ return
955+
956+ logger .debug (
957+ f"Received SetTimeTravel request: timestamp={ request .timestamp_seconds } , "
958+ f"traceId={ request .trace_id } , source={ request .timestamp_source } "
959+ )
960+
961+ try :
962+ from drift .instrumentation .datetime .instrumentation import start_time_travel
963+
964+ success = start_time_travel (request .timestamp_seconds , request .trace_id )
965+
966+ response = SetTimeTravelResponse (
967+ success = success ,
968+ error = "" if success else "time-machine library not available or failed to start" ,
969+ )
970+ except Exception as e :
971+ logger .error (f"Failed to set time travel: { e } " )
972+ response = SetTimeTravelResponse (success = False , error = str (e ))
973+
974+ # Send response back to CLI
975+ sdk_message = SdkMessage (
976+ type = MessageType .SET_TIME_TRAVEL ,
977+ request_id = cli_message .request_id ,
978+ set_time_travel_response = response ,
979+ )
980+
981+ try :
982+ self ._send_message_sync (sdk_message )
983+ logger .debug (f"Sent SetTimeTravel response: success={ response .success } " )
984+ except Exception as e :
985+ logger .error (f"Failed to send SetTimeTravel response: { e } " )
986+
987+ def _send_message_sync (self , message : SdkMessage ) -> None :
988+ """Send a message synchronously on the main socket."""
989+ if not self ._socket :
990+ raise ConnectionError ("Not connected to CLI" )
991+
992+ message_bytes = bytes (message )
993+ length_prefix = struct .pack (">I" , len (message_bytes ))
994+
995+ with self ._lock :
996+ self ._socket .sendall (length_prefix + message_bytes )
0 commit comments