Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,3 @@ jobs:
uv venv /tmp/test-install
Comment thread
sohankshirsagar marked this conversation as resolved.
uv pip install dist/*.whl --python /tmp/test-install/bin/python
/tmp/test-install/bin/python -c "import drift; print('Package imported successfully')"

234 changes: 227 additions & 7 deletions drift/core/communication/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SdkMessage,
SendAlertRequest,
SendInboundSpanForReplayRequest,
SetTimeTravelResponse,
UnpatchedDependencyAlert,
span_to_proto,
)
Expand Down Expand Up @@ -78,6 +79,12 @@ def __init__(self, config: CommunicatorConfig | None = None) -> None:
self._incoming_buffer = bytearray()
self._pending_requests: dict[str, dict[str, Any]] = {}
self._lock = threading.Lock()
self._background_reader_thread: threading.Thread | None = None
self._stop_background_reader = threading.Event()
# Response routing: background reader stores responses here, callers wait on events
self._response_events: dict[str, threading.Event] = {}
self._response_data: dict[str, CliMessage] = {}
self._response_lock = threading.Lock() # Protects response_events and response_data

@property
def is_connected(self) -> bool:
Expand Down Expand Up @@ -318,6 +325,9 @@ def connect_sync(
if response.success:
logger.debug("CLI acknowledged connection successfully")
self._connected = True

# Start background reader for CLI-initiated messages (like SetTimeTravel)
self._start_background_reader()
Comment thread
sohankshirsagar marked this conversation as resolved.
else:
error_msg = response.error or "Unknown error"
raise ConnectionError(f"CLI rejected connection: {error_msg}")
Expand Down Expand Up @@ -383,11 +393,24 @@ async def request_mock_async(self, mock_request: MockRequestInput) -> MockRespon
f"[ProtobufCommunicator] Creating mock request with requestId: {request_id}, testId: {mock_request.test_id}"
)

# Send and wait for response
await self._send_protobuf_message(sdk_message)
response = await self._receive_response(request_id)
# Pre-register event BEFORE sending message to avoid race condition where
# CLI responds before _wait_for_response registers the event
if self._background_reader_thread and self._background_reader_thread.is_alive():
with self._response_lock:
self._response_events[request_id] = threading.Event()

return response
try:
# Send and wait for response
await self._send_protobuf_message(sdk_message)
response = await self._receive_response(request_id)
return response
except Exception:
# Clean up pre-registered event on failure
if self._background_reader_thread and self._background_reader_thread.is_alive():
with self._response_lock:
self._response_events.pop(request_id, None)
self._response_data.pop(request_id, None)
raise
Comment thread
sohankshirsagar marked this conversation as resolved.

def request_mock_sync(self, mock_request: MockRequestInput) -> MockResponseOutput:
"""Request mocked response data from CLI (synchronous).
Expand Down Expand Up @@ -526,16 +549,27 @@ async def _send_protobuf_message(self, message: SdkMessage) -> None:
# our own socket operations as unpatched dependencies
context_token = calling_library_context.set("ProtobufCommunicator")
try:
# Send synchronously (socket is blocking for sends)
self._socket.sendall(full_message)
# Acquire lock to prevent concurrent sends from background reader thread
# (e.g., _send_message_sync sending SetTimeTravel responses)
with self._lock:
self._socket.sendall(full_message)
finally:
calling_library_context.reset(context_token)

async def _receive_response(self, request_id: str) -> MockResponseOutput:
"""Receive and parse a response for a specific request ID."""
"""Receive and parse a response for a specific request ID.

If the background reader is running, waits on an event for the response.
Otherwise, reads directly from the socket (for async-only connections).
"""
if not self._socket:
raise ConnectionError("Socket not initialized")

# If background reader is running, wait on event instead of reading socket
if self._background_reader_thread and self._background_reader_thread.is_alive():
return await self._wait_for_response_async(request_id)

# No background reader - read directly from socket (async connect path)
self._socket.settimeout(self.config.request_timeout)

try:
Expand Down Expand Up @@ -571,6 +605,51 @@ async def _receive_response(self, request_id: str) -> MockResponseOutput:
except TimeoutError as e:
raise TimeoutError(f"Request timed out: {e}") from e

def _wait_for_response(self, request_id: str) -> MockResponseOutput:
"""Wait for a response from the background reader thread.

Uses a pre-registered event for the request_id (registered before sending
the message to avoid race conditions), waits for the background reader
to signal it, then retrieves the response.
"""
# Use pre-registered event, or create one as fallback
with self._response_lock:
event = self._response_events.get(request_id)
if not event:
# Fallback: register now (shouldn't happen in normal flow)
event = threading.Event()
self._response_events[request_id] = event

try:
# Wait for the background reader to signal us
if not event.wait(timeout=self.config.request_timeout):
raise TimeoutError(f"Request timed out waiting for response: {request_id}")

# Retrieve the response
with self._response_lock:
cli_message = self._response_data.pop(request_id, None)

if cli_message is None:
raise ConnectionError(f"Response was signaled but not found: {request_id}")

return self._handle_cli_message(cli_message)

finally:
# Clean up the event registration
with self._response_lock:
self._response_events.pop(request_id, None)
self._response_data.pop(request_id, None) # In case of timeout

async def _wait_for_response_async(self, request_id: str) -> MockResponseOutput:
"""Async version of _wait_for_response that doesn't block the event loop.

Uses asyncio.to_thread() to run the blocking Event.wait() in a thread pool,
allowing other async tasks to run while waiting for the response.
"""
import asyncio

return await asyncio.to_thread(self._wait_for_response, request_id)

def _recv_exact(self, n: int) -> bytes | None:
"""Receive exactly n bytes from socket."""
if self._socket is None:
Expand Down Expand Up @@ -761,6 +840,13 @@ def _clean_span(self, data: Any) -> Any:

def _cleanup(self) -> None:
"""Clean up resources."""

# Stop background reader thread
self._stop_background_reader.set()
if self._background_reader_thread and self._background_reader_thread.is_alive():
self._background_reader_thread.join(timeout=1.0)
self._background_reader_thread = None

Comment thread
sohankshirsagar marked this conversation as resolved.
self._connected = False
self._session_id = None
self._incoming_buffer.clear()
Expand All @@ -774,3 +860,137 @@ def _cleanup(self) -> None:
self._socket = None

self._pending_requests.clear()

# Clean up response routing data and signal any waiting threads
with self._response_lock:
# Signal all waiting threads so they don't hang
for event in self._response_events.values():
event.set()
self._response_events.clear()
self._response_data.clear()

# ========== Background Reader for CLI-initiated Messages ==========

def _start_background_reader(self) -> None:
"""Start background thread to read CLI-initiated messages."""
if self._background_reader_thread and self._background_reader_thread.is_alive():
return

self._stop_background_reader.clear()
self._background_reader_thread = threading.Thread(
target=self._background_read_loop,
daemon=True,
name="CLI-Message-Reader",
)
self._background_reader_thread.start()
logger.debug("Started background reader thread for CLI-initiated messages")

def _background_read_loop(self) -> None:
"""Background loop to read and handle CLI-initiated messages."""
while not self._stop_background_reader.is_set():
if not self._socket:
break

try:
# Set a short timeout so we can check the stop event periodically
self._socket.settimeout(0.5)

# Try to read length prefix
try:
length_data = self._recv_exact(4)
except TimeoutError:
continue # No data available, check stop event and retry
except Exception:
continue
Comment thread
sohankshirsagar marked this conversation as resolved.

if not length_data:
# None means connection closed (recv returned empty bytes)
break

length = struct.unpack(">I", length_data)[0]

# Read message data
self._socket.settimeout(5.0) # Longer timeout for message body
message_data = self._recv_exact(length)
if not message_data:
# None means connection closed (recv returned empty bytes)
break

# Parse message
cli_message = CliMessage().parse(message_data)
logger.debug(f"Background reader received message type: {cli_message.type}")

# Handle CLI-initiated messages (no request_id, or special types)
if cli_message.type == MessageType.SET_TIME_TRAVEL:
self._handle_set_time_travel_sync(cli_message)
continue

# Route responses to waiting callers by request_id
request_id = cli_message.request_id
if request_id:
with self._response_lock:
if request_id in self._response_events:
# Store response and signal the waiting caller
self._response_data[request_id] = cli_message
self._response_events[request_id].set()
logger.debug(f"Background reader routed response for request_id: {request_id}")
else:
# No one waiting for this response (possibly timed out)
logger.debug(f"Background reader received response with no waiter: {request_id}")

except TimeoutError:
continue # Normal timeout, just retry
except Exception as e:
if not self._stop_background_reader.is_set():
logger.debug(f"Background reader error: {e}")
break

logger.debug("Background reader thread stopped")

def _handle_set_time_travel_sync(self, cli_message: CliMessage) -> None:
"""Handle SetTimeTravel request from CLI and send response."""
request = cli_message.set_time_travel_request
if not request:
return
Comment thread
sohankshirsagar marked this conversation as resolved.

logger.debug(
f"Received SetTimeTravel request: timestamp={request.timestamp_seconds}, "
f"traceId={request.trace_id}, source={request.timestamp_source}"
)

try:
from drift.instrumentation.datetime.instrumentation import start_time_travel

success = start_time_travel(request.timestamp_seconds, request.trace_id)

response = SetTimeTravelResponse(
success=success,
error="" if success else "time-machine library not available or failed to start",
)
except Exception as e:
logger.error(f"Failed to set time travel: {e}")
response = SetTimeTravelResponse(success=False, error=str(e))

# Send response back to CLI
sdk_message = SdkMessage(
type=MessageType.SET_TIME_TRAVEL,
request_id=cli_message.request_id,
set_time_travel_response=response,
)

try:
self._send_message_sync(sdk_message)
logger.debug(f"Sent SetTimeTravel response: success={response.success}")
except Exception as e:
logger.error(f"Failed to send SetTimeTravel response: {e}")

def _send_message_sync(self, message: SdkMessage) -> None:
"""Send a message synchronously on the main socket."""
if not self._socket:
raise ConnectionError("Not connected to CLI")

message_bytes = bytes(message)
length_prefix = struct.pack(">I", len(message_bytes))

with self._lock:
self._socket.sendall(length_prefix + message_bytes)
Comment thread
sohankshirsagar marked this conversation as resolved.
Comment thread
sohankshirsagar marked this conversation as resolved.
10 changes: 10 additions & 0 deletions drift/core/communication/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
"CliMessage",
"InstrumentationVersionMismatchAlert",
"MessageType",
"Runtime",
"SdkMessage",
"SendAlertRequest",
"SendInboundSpanForReplayRequest",
"SetTimeTravelRequest",
"SetTimeTravelResponse",
"UnpatchedDependencyAlert",
# Aliases
"SDKMessageType",
Expand All @@ -43,9 +46,12 @@
CliMessage,
InstrumentationVersionMismatchAlert,
MessageType,
Runtime,
SdkMessage,
SendAlertRequest,
SendInboundSpanForReplayRequest,
SetTimeTravelRequest,
SetTimeTravelResponse,
UnpatchedDependencyAlert,
)
from tusk.drift.core.v1 import (
Expand Down Expand Up @@ -136,6 +142,9 @@ class ConnectRequest:
metadata: dict[str, str] = field(default_factory=dict)
"""Additional metadata."""

runtime: Runtime = Runtime.PYTHON
"""SDK runtime environment (node, python)."""

def to_proto(self) -> ProtoConnectRequest:
"""Convert to protobuf message."""
from betterproto.lib.google.protobuf import Struct
Expand All @@ -150,6 +159,7 @@ def to_proto(self) -> ProtoConnectRequest:
sdk_version=self.sdk_version,
min_cli_version=self.min_cli_version,
metadata=metadata_struct,
runtime=self.runtime,
)


Expand Down
1 change: 1 addition & 0 deletions drift/instrumentation/requests/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ def _try_get_mock(
input_value=input_value,
kind=SpanKind.CLIENT,
input_schema_merges=input_schema_merges,
is_pre_app_start=not sdk.app_ready,
)

if not mock_response_output or not mock_response_output.found:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
"protobuf>=6.0",
"PyYAML>=6.0",
"requests>=2.32.5",
"tusk-drift-schemas>=0.1.9.dev1",
"tusk-drift-schemas>=0.1.24",
"aiohttp>=3.9.0",
"aiofiles>=23.0.0",
"opentelemetry-api>=1.20.0",
Expand Down
Loading