Skip to content

Commit 26514c1

Browse files
committed
Merge branch 'main' into support-py39
2 parents 861fc6f + b4b4565 commit 26514c1

28 files changed

Lines changed: 3742 additions & 952 deletions

File tree

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,3 @@ jobs:
102102
uv venv /tmp/test-install
103103
uv pip install dist/*.whl --python /tmp/test-install/bin/python
104104
/tmp/test-install/bin/python -c "import drift; print('Package imported successfully')"
105-

drift/core/communication/communicator.py

Lines changed: 227 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
SdkMessage,
2727
SendAlertRequest,
2828
SendInboundSpanForReplayRequest,
29+
SetTimeTravelResponse,
2930
UnpatchedDependencyAlert,
3031
span_to_proto,
3132
)
@@ -78,6 +79,12 @@ def __init__(self, config: CommunicatorConfig | None = None) -> None:
7879
self._incoming_buffer = bytearray()
7980
self._pending_requests: dict[str, dict[str, Any]] = {}
8081
self._lock = threading.Lock()
82+
self._background_reader_thread: threading.Thread | None = None
83+
self._stop_background_reader = threading.Event()
84+
# Response routing: background reader stores responses here, callers wait on events
85+
self._response_events: dict[str, threading.Event] = {}
86+
self._response_data: dict[str, CliMessage] = {}
87+
self._response_lock = threading.Lock() # Protects response_events and response_data
8188

8289
@property
8390
def is_connected(self) -> bool:
@@ -318,6 +325,9 @@ def connect_sync(
318325
if response.success:
319326
logger.debug("CLI acknowledged connection successfully")
320327
self._connected = True
328+
329+
# Start background reader for CLI-initiated messages (like SetTimeTravel)
330+
self._start_background_reader()
321331
else:
322332
error_msg = response.error or "Unknown error"
323333
raise ConnectionError(f"CLI rejected connection: {error_msg}")
@@ -383,11 +393,24 @@ async def request_mock_async(self, mock_request: MockRequestInput) -> MockRespon
383393
f"[ProtobufCommunicator] Creating mock request with requestId: {request_id}, testId: {mock_request.test_id}"
384394
)
385395

386-
# Send and wait for response
387-
await self._send_protobuf_message(sdk_message)
388-
response = await self._receive_response(request_id)
396+
# Pre-register event BEFORE sending message to avoid race condition where
397+
# CLI responds before _wait_for_response registers the event
398+
if self._background_reader_thread and self._background_reader_thread.is_alive():
399+
with self._response_lock:
400+
self._response_events[request_id] = threading.Event()
389401

390-
return response
402+
try:
403+
# Send and wait for response
404+
await self._send_protobuf_message(sdk_message)
405+
response = await self._receive_response(request_id)
406+
return response
407+
except Exception:
408+
# Clean up pre-registered event on failure
409+
if self._background_reader_thread and self._background_reader_thread.is_alive():
410+
with self._response_lock:
411+
self._response_events.pop(request_id, None)
412+
self._response_data.pop(request_id, None)
413+
raise
391414

392415
def request_mock_sync(self, mock_request: MockRequestInput) -> MockResponseOutput:
393416
"""Request mocked response data from CLI (synchronous).
@@ -526,16 +549,27 @@ async def _send_protobuf_message(self, message: SdkMessage) -> None:
526549
# our own socket operations as unpatched dependencies
527550
context_token = calling_library_context.set("ProtobufCommunicator")
528551
try:
529-
# Send synchronously (socket is blocking for sends)
530-
self._socket.sendall(full_message)
552+
# Acquire lock to prevent concurrent sends from background reader thread
553+
# (e.g., _send_message_sync sending SetTimeTravel responses)
554+
with self._lock:
555+
self._socket.sendall(full_message)
531556
finally:
532557
calling_library_context.reset(context_token)
533558

534559
async def _receive_response(self, request_id: str) -> MockResponseOutput:
535-
"""Receive and parse a response for a specific request ID."""
560+
"""Receive and parse a response for a specific request ID.
561+
562+
If the background reader is running, waits on an event for the response.
563+
Otherwise, reads directly from the socket (for async-only connections).
564+
"""
536565
if not self._socket:
537566
raise ConnectionError("Socket not initialized")
538567

568+
# If background reader is running, wait on event instead of reading socket
569+
if self._background_reader_thread and self._background_reader_thread.is_alive():
570+
return await self._wait_for_response_async(request_id)
571+
572+
# No background reader - read directly from socket (async connect path)
539573
self._socket.settimeout(self.config.request_timeout)
540574

541575
try:
@@ -571,6 +605,51 @@ async def _receive_response(self, request_id: str) -> MockResponseOutput:
571605
except TimeoutError as e:
572606
raise TimeoutError(f"Request timed out: {e}") from e
573607

608+
def _wait_for_response(self, request_id: str) -> MockResponseOutput:
609+
"""Wait for a response from the background reader thread.
610+
611+
Uses a pre-registered event for the request_id (registered before sending
612+
the message to avoid race conditions), waits for the background reader
613+
to signal it, then retrieves the response.
614+
"""
615+
# Use pre-registered event, or create one as fallback
616+
with self._response_lock:
617+
event = self._response_events.get(request_id)
618+
if not event:
619+
# Fallback: register now (shouldn't happen in normal flow)
620+
event = threading.Event()
621+
self._response_events[request_id] = event
622+
623+
try:
624+
# Wait for the background reader to signal us
625+
if not event.wait(timeout=self.config.request_timeout):
626+
raise TimeoutError(f"Request timed out waiting for response: {request_id}")
627+
628+
# Retrieve the response
629+
with self._response_lock:
630+
cli_message = self._response_data.pop(request_id, None)
631+
632+
if cli_message is None:
633+
raise ConnectionError(f"Response was signaled but not found: {request_id}")
634+
635+
return self._handle_cli_message(cli_message)
636+
637+
finally:
638+
# Clean up the event registration
639+
with self._response_lock:
640+
self._response_events.pop(request_id, None)
641+
self._response_data.pop(request_id, None) # In case of timeout
642+
643+
async def _wait_for_response_async(self, request_id: str) -> MockResponseOutput:
644+
"""Async version of _wait_for_response that doesn't block the event loop.
645+
646+
Uses asyncio.to_thread() to run the blocking Event.wait() in a thread pool,
647+
allowing other async tasks to run while waiting for the response.
648+
"""
649+
import asyncio
650+
651+
return await asyncio.to_thread(self._wait_for_response, request_id)
652+
574653
def _recv_exact(self, n: int) -> bytes | None:
575654
"""Receive exactly n bytes from socket."""
576655
if self._socket is None:
@@ -761,6 +840,13 @@ def _clean_span(self, data: Any) -> Any:
761840

762841
def _cleanup(self) -> None:
763842
"""Clean up resources."""
843+
844+
# Stop background reader thread
845+
self._stop_background_reader.set()
846+
if self._background_reader_thread and self._background_reader_thread.is_alive():
847+
self._background_reader_thread.join(timeout=1.0)
848+
self._background_reader_thread = None
849+
764850
self._connected = False
765851
self._session_id = None
766852
self._incoming_buffer.clear()
@@ -774,3 +860,137 @@ def _cleanup(self) -> None:
774860
self._socket = None
775861

776862
self._pending_requests.clear()
863+
864+
# Clean up response routing data and signal any waiting threads
865+
with self._response_lock:
866+
# Signal all waiting threads so they don't hang
867+
for event in self._response_events.values():
868+
event.set()
869+
self._response_events.clear()
870+
self._response_data.clear()
871+
872+
# ========== Background Reader for CLI-initiated Messages ==========
873+
874+
def _start_background_reader(self) -> None:
875+
"""Start background thread to read CLI-initiated messages."""
876+
if self._background_reader_thread and self._background_reader_thread.is_alive():
877+
return
878+
879+
self._stop_background_reader.clear()
880+
self._background_reader_thread = threading.Thread(
881+
target=self._background_read_loop,
882+
daemon=True,
883+
name="CLI-Message-Reader",
884+
)
885+
self._background_reader_thread.start()
886+
logger.debug("Started background reader thread for CLI-initiated messages")
887+
888+
def _background_read_loop(self) -> None:
889+
"""Background loop to read and handle CLI-initiated messages."""
890+
while not self._stop_background_reader.is_set():
891+
if not self._socket:
892+
break
893+
894+
try:
895+
# Set a short timeout so we can check the stop event periodically
896+
self._socket.settimeout(0.5)
897+
898+
# Try to read length prefix
899+
try:
900+
length_data = self._recv_exact(4)
901+
except TimeoutError:
902+
continue # No data available, check stop event and retry
903+
except Exception:
904+
continue
905+
906+
if not length_data:
907+
# None means connection closed (recv returned empty bytes)
908+
break
909+
910+
length = struct.unpack(">I", length_data)[0]
911+
912+
# Read message data
913+
self._socket.settimeout(5.0) # Longer timeout for message body
914+
message_data = self._recv_exact(length)
915+
if not message_data:
916+
# None means connection closed (recv returned empty bytes)
917+
break
918+
919+
# Parse message
920+
cli_message = CliMessage().parse(message_data)
921+
logger.debug(f"Background reader received message type: {cli_message.type}")
922+
923+
# Handle CLI-initiated messages (no request_id, or special types)
924+
if cli_message.type == MessageType.SET_TIME_TRAVEL:
925+
self._handle_set_time_travel_sync(cli_message)
926+
continue
927+
928+
# Route responses to waiting callers by request_id
929+
request_id = cli_message.request_id
930+
if request_id:
931+
with self._response_lock:
932+
if request_id in self._response_events:
933+
# Store response and signal the waiting caller
934+
self._response_data[request_id] = cli_message
935+
self._response_events[request_id].set()
936+
logger.debug(f"Background reader routed response for request_id: {request_id}")
937+
else:
938+
# No one waiting for this response (possibly timed out)
939+
logger.debug(f"Background reader received response with no waiter: {request_id}")
940+
941+
except TimeoutError:
942+
continue # Normal timeout, just retry
943+
except Exception as e:
944+
if not self._stop_background_reader.is_set():
945+
logger.debug(f"Background reader error: {e}")
946+
break
947+
948+
logger.debug("Background reader thread stopped")
949+
950+
def _handle_set_time_travel_sync(self, cli_message: CliMessage) -> None:
951+
"""Handle SetTimeTravel request from CLI and send response."""
952+
request = cli_message.set_time_travel_request
953+
if not request:
954+
return
955+
956+
logger.debug(
957+
f"Received SetTimeTravel request: timestamp={request.timestamp_seconds}, "
958+
f"traceId={request.trace_id}, source={request.timestamp_source}"
959+
)
960+
961+
try:
962+
from drift.instrumentation.datetime.instrumentation import start_time_travel
963+
964+
success = start_time_travel(request.timestamp_seconds, request.trace_id)
965+
966+
response = SetTimeTravelResponse(
967+
success=success,
968+
error="" if success else "time-machine library not available or failed to start",
969+
)
970+
except Exception as e:
971+
logger.error(f"Failed to set time travel: {e}")
972+
response = SetTimeTravelResponse(success=False, error=str(e))
973+
974+
# Send response back to CLI
975+
sdk_message = SdkMessage(
976+
type=MessageType.SET_TIME_TRAVEL,
977+
request_id=cli_message.request_id,
978+
set_time_travel_response=response,
979+
)
980+
981+
try:
982+
self._send_message_sync(sdk_message)
983+
logger.debug(f"Sent SetTimeTravel response: success={response.success}")
984+
except Exception as e:
985+
logger.error(f"Failed to send SetTimeTravel response: {e}")
986+
987+
def _send_message_sync(self, message: SdkMessage) -> None:
988+
"""Send a message synchronously on the main socket."""
989+
if not self._socket:
990+
raise ConnectionError("Not connected to CLI")
991+
992+
message_bytes = bytes(message)
993+
length_prefix = struct.pack(">I", len(message_bytes))
994+
995+
with self._lock:
996+
self._socket.sendall(length_prefix + message_bytes)

drift/core/communication/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
"CliMessage",
1717
"InstrumentationVersionMismatchAlert",
1818
"MessageType",
19+
"Runtime",
1920
"SdkMessage",
2021
"SendAlertRequest",
2122
"SendInboundSpanForReplayRequest",
23+
"SetTimeTravelRequest",
24+
"SetTimeTravelResponse",
2225
"UnpatchedDependencyAlert",
2326
# Aliases
2427
"SDKMessageType",
@@ -43,9 +46,12 @@
4346
CliMessage,
4447
InstrumentationVersionMismatchAlert,
4548
MessageType,
49+
Runtime,
4650
SdkMessage,
4751
SendAlertRequest,
4852
SendInboundSpanForReplayRequest,
53+
SetTimeTravelRequest,
54+
SetTimeTravelResponse,
4955
UnpatchedDependencyAlert,
5056
)
5157
from tusk.drift.core.v1 import (
@@ -136,6 +142,9 @@ class ConnectRequest:
136142
metadata: dict[str, str] = field(default_factory=dict)
137143
"""Additional metadata."""
138144

145+
runtime: Runtime = Runtime.PYTHON
146+
"""SDK runtime environment (node, python)."""
147+
139148
def to_proto(self) -> ProtoConnectRequest:
140149
"""Convert to protobuf message."""
141150
from betterproto.lib.google.protobuf import Struct
@@ -150,6 +159,7 @@ def to_proto(self) -> ProtoConnectRequest:
150159
sdk_version=self.sdk_version,
151160
min_cli_version=self.min_cli_version,
152161
metadata=metadata_struct,
162+
runtime=self.runtime,
153163
)
154164

155165

drift/instrumentation/aiohttp/e2e-tests/src/test_requests.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,6 @@
11
"""Execute test requests against the Flask app to exercise the aiohttp instrumentation."""
22

3-
import time
4-
5-
import requests
6-
7-
BASE_URL = "http://localhost:8000"
8-
9-
10-
def make_request(method, endpoint, **kwargs):
11-
"""Make HTTP request and log result."""
12-
url = f"{BASE_URL}{endpoint}"
13-
print(f"-> {method} {endpoint}")
14-
15-
# Set default timeout if not provided
16-
kwargs.setdefault("timeout", 30)
17-
response = requests.request(method, url, **kwargs)
18-
print(f" Status: {response.status_code}")
19-
time.sleep(0.5) # Small delay between requests
20-
return response
21-
3+
from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary
224

235
if __name__ == "__main__":
246
print("Starting test request sequence for aiohttp instrumentation...\n")
@@ -106,4 +88,4 @@ def make_request(method, endpoint, **kwargs):
10688
# POST with bytes body
10789
make_request("POST", "/test/post-bytes")
10890

109-
print("\nAll requests completed successfully")
91+
print_request_summary()

0 commit comments

Comments
 (0)