Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 5 additions & 156 deletions drift/core/communication/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,121 +122,6 @@ def _get_stack_trace(self) -> str:

# ========== Connection Methods ==========

async def connect(
self,
connection_info: dict[str, Any] | None = None,
service_id: str = "",
) -> None:
"""Connect to the CLI and perform handshake.

Args:
connection_info: Dict with 'socketPath' or 'host'/'port'
service_id: Service identifier for the connection

Raises:
ConnectionError: If connection fails
TimeoutError: If connection times out
"""
# Determine address
if connection_info:
if "socketPath" in connection_info:
address: tuple[str, int] | str = connection_info["socketPath"]
else:
address = (connection_info["host"], connection_info["port"])
else:
address = self._get_socket_address()

# Set calling_library_context to prevent socket instrumentation from flagging
# our own socket operations as unpatched dependencies
context_token = calling_library_context.set("ProtobufCommunicator")
try:
# Create appropriate socket type
if isinstance(address, str):
# Unix socket
self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
logger.debug(f"Connecting to Unix socket: {address}")
else:
# TCP socket
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
logger.debug(f"Connecting to TCP: {address}")

self._socket.settimeout(self.config.connect_timeout)
self._socket.connect(address)

conn_type = "Unix socket" if isinstance(address, str) else "TCP"
logger.debug(f"Connected to CLI via protobuf ({conn_type})")

# Send connect message
await self._send_connect_message(service_id)

self._connected = True

except TimeoutError as e:
self._cleanup()
raise TimeoutError(f"Connection timed out: {e}") from e
except OSError as e:
self._cleanup()
raise ConnectionError(f"Socket error: {e}") from e
finally:
calling_library_context.reset(context_token)

async def _send_connect_message(self, service_id: str) -> None:
"""Send the initial connection message to CLI and wait for acknowledgement."""
connect_request = ConnectRequest(
service_id=service_id,
sdk_version=SDK_VERSION,
min_cli_version=MIN_CLI_VERSION,
)

request_id = self._generate_request_id()
sdk_message = SdkMessage(
type=MessageType.SDK_CONNECT,
request_id=request_id,
connect_request=connect_request.to_proto(),
)

await self._send_protobuf_message(sdk_message)

# Wait for connect response from CLI
await self._receive_connect_response(request_id)

async def _receive_connect_response(self, request_id: str) -> None:
"""Wait for and handle the connect response from CLI."""
if not self._socket:
raise ConnectionError("Socket not initialized")

self._socket.settimeout(self.config.connect_timeout)

try:
# Read length prefix
length_data = self._recv_exact(4)
if not length_data:
raise ConnectionError("Connection closed by CLI")

length = struct.unpack(">I", length_data)[0]

# Read message data
message_data = self._recv_exact(length)
if not message_data:
raise ConnectionError("Connection closed by CLI")

cli_message = CliMessage().parse(message_data)

logger.debug(f"Received connect response: type={cli_message.type}, requestId={cli_message.request_id}")

if cli_message.connect_response:
response = cli_message.connect_response
if response.success:
logger.debug("CLI acknowledged connection successfully")
else:
error_msg = response.error or "Unknown error"
raise ConnectionError(f"CLI rejected connection: {error_msg}")
else:
raise ConnectionError(f"Expected connect response but got message type: {cli_message.type}")

except TimeoutError as e:
raise TimeoutError(f"Timeout waiting for connect response: {e}") from e

def connect_sync(
self,
connection_info: dict[str, Any] | None = None,
Expand Down Expand Up @@ -343,7 +228,7 @@ def connect_sync(
finally:
calling_library_context.reset(context_token)

async def disconnect(self) -> None:
def disconnect(self) -> None:
"""Disconnect from CLI."""
self._cleanup()
logger.debug("Disconnected from CLI")
Expand Down Expand Up @@ -559,51 +444,15 @@ async def _send_protobuf_message(self, message: SdkMessage) -> None:
async def _receive_response(self, request_id: str) -> MockResponseOutput:
"""Receive and parse a response for a specific request ID.

If the background reader is running, waits on an event for the response.
Otherwise, reads directly from the socket (for async-only connections).
Waits on an event for the background reader to deliver the response.
"""
if not self._socket:
raise ConnectionError("Socket not initialized")

# If background reader is running, wait on event instead of reading socket
if self._background_reader_thread and self._background_reader_thread.is_alive():
return await self._wait_for_response_async(request_id)

# No background reader - read directly from socket (async connect path)
self._socket.settimeout(self.config.request_timeout)

try:
while True:
# Read length prefix
length_data = self._recv_exact(4)
if not length_data:
raise ConnectionError("Connection closed by CLI")
if not self._background_reader_thread or not self._background_reader_thread.is_alive():
raise ConnectionError("Background reader is not running - connection may have been closed")

length = struct.unpack(">I", length_data)[0]

# Read message data
message_data = self._recv_exact(length)
if not message_data:
raise ConnectionError("Connection closed by CLI")

cli_message = CliMessage().parse(message_data)

logger.debug(f"Received CLI message type: {cli_message.type}, requestId: {cli_message.request_id}")

if cli_message.request_id == request_id:
return self._handle_cli_message(cli_message)

if cli_message.connect_response:
response = cli_message.connect_response
if response.success:
logger.debug("CLI acknowledged connection")
# Note: session_id is not in the protobuf schema
else:
logger.error(f"CLI rejected connection: {response.error}")
continue

except TimeoutError as e:
raise TimeoutError(f"Request timed out: {e}") from e
return await self._wait_for_response_async(request_id)

def _wait_for_response(self, request_id: str) -> MockResponseOutput:
"""Wait for a response from the background reader thread.
Expand Down
2 changes: 1 addition & 1 deletion drift/core/drift_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def shutdown(self) -> None:

if self.communicator:
try:
asyncio.run(self.communicator.disconnect())
self.communicator.disconnect()
except Exception as e:
logger.error(f"Error disconnecting from CLI: {e}")

Expand Down