Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions drift/core/communication/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
MessageType,
MockRequestInput,
MockResponseOutput,
Runtime,
SdkMessage,
SendAlertRequest,
SendInboundSpanForReplayRequest,
SetTimeTravelRequest,
SetTimeTravelResponse,
UnpatchedDependencyAlert,
span_to_proto,
)
Expand Down Expand Up @@ -78,6 +81,8 @@ def __init__(self, config: CommunicatorConfig | None = None) -> None:
self._incoming_buffer = bytearray()
self._pending_requests: dict[str, dict[str, Any]] = {}
self._lock = threading.Lock()
self._background_reader_thread: threading.Thread | None = None
self._stop_background_reader = threading.Event()

@property
def is_connected(self) -> bool:
Expand Down Expand Up @@ -321,6 +326,9 @@ def connect_sync(
self._connected = True
logger.info(f"[CONNECT_SYNC] Connection successful! Socket is: {self._socket}")
logger.info(f"[CONNECT_SYNC] _connected={self._connected}, is_connected={self.is_connected}")

# Start background reader for CLI-initiated messages (like SetTimeTravel)
self._start_background_reader()
Comment thread
sohankshirsagar marked this conversation as resolved.
else:
error_msg = response.error or "Unknown error"
raise ConnectionError(f"CLI rejected connection: {error_msg}")
Expand Down Expand Up @@ -770,6 +778,12 @@ def _cleanup(self) -> None:
logger.warning("[CLEANUP] _cleanup() called! Stack trace:")
logger.warning("".join(traceback.format_stack()))

# Stop background reader thread
self._stop_background_reader.set()
if self._background_reader_thread and self._background_reader_thread.is_alive():
self._background_reader_thread.join(timeout=1.0)
self._background_reader_thread = None

Comment thread
sohankshirsagar marked this conversation as resolved.
self._connected = False
self._session_id = None
self._incoming_buffer.clear()
Expand All @@ -783,3 +797,116 @@ def _cleanup(self) -> None:
self._socket = None

self._pending_requests.clear()

# ========== Background Reader for CLI-initiated Messages ==========

def _start_background_reader(self) -> None:
"""Start background thread to read CLI-initiated messages."""
if self._background_reader_thread and self._background_reader_thread.is_alive():
return

self._stop_background_reader.clear()
self._background_reader_thread = threading.Thread(
target=self._background_read_loop,
daemon=True,
name="CLI-Message-Reader",
)
self._background_reader_thread.start()
logger.debug("Started background reader thread for CLI-initiated messages")

def _background_read_loop(self) -> None:
"""Background loop to read and handle CLI-initiated messages."""
while not self._stop_background_reader.is_set():
if not self._socket:
break

try:
# Set a short timeout so we can check the stop event periodically
self._socket.settimeout(0.5)

# Try to read length prefix
try:
length_data = self._recv_exact(4)
except socket.timeout:
continue # No data available, check stop event and retry
except Exception:
continue
Comment thread
sohankshirsagar marked this conversation as resolved.

if not length_data:
continue
Comment thread
sohankshirsagar marked this conversation as resolved.
Outdated
Comment thread
sohankshirsagar marked this conversation as resolved.
Outdated

length = struct.unpack(">I", length_data)[0]

# Read message data
self._socket.settimeout(5.0) # Longer timeout for message body
message_data = self._recv_exact(length)
if not message_data:
continue

# Parse message
cli_message = CliMessage().parse(message_data)
logger.debug(f"Background reader received message type: {cli_message.type}")

# Handle CLI-initiated messages based on message type
if cli_message.type == MessageType.SET_TIME_TRAVEL:
self._handle_set_time_travel_sync(cli_message)
else:
# Other message types (responses to SDK requests) are handled elsewhere
logger.debug(f"Background reader ignoring message type: {cli_message.type}")
Comment thread
sohankshirsagar marked this conversation as resolved.
Outdated

except socket.timeout:
continue # Normal timeout, just retry
except Exception as e:
if not self._stop_background_reader.is_set():
logger.debug(f"Background reader error: {e}")
break

logger.debug("Background reader thread stopped")

def _handle_set_time_travel_sync(self, cli_message: CliMessage) -> None:
"""Handle SetTimeTravel request from CLI and send response."""
request = cli_message.set_time_travel_request
if not request:
return
Comment thread
sohankshirsagar marked this conversation as resolved.

logger.debug(
f"Received SetTimeTravel request: timestamp={request.timestamp_seconds}, "
f"traceId={request.trace_id}, source={request.timestamp_source}"
)

try:
from drift.instrumentation.datetime.instrumentation import start_time_travel

success = start_time_travel(request.timestamp_seconds, request.trace_id)

response = SetTimeTravelResponse(
success=success,
error="" if success else "time-machine library not available or failed to start",
)
except Exception as e:
logger.error(f"Failed to set time travel: {e}")
response = SetTimeTravelResponse(success=False, error=str(e))

# Send response back to CLI
sdk_message = SdkMessage(
type=MessageType.SET_TIME_TRAVEL,
request_id=cli_message.request_id,
set_time_travel_response=response,
)

try:
self._send_message_sync(sdk_message)
logger.debug(f"Sent SetTimeTravel response: success={response.success}")
except Exception as e:
logger.error(f"Failed to send SetTimeTravel response: {e}")

def _send_message_sync(self, message: SdkMessage) -> None:
"""Send a message synchronously on the main socket."""
if not self._socket:
raise ConnectionError("Not connected to CLI")

message_bytes = bytes(message)
length_prefix = struct.pack(">I", len(message_bytes))

with self._lock:
self._socket.sendall(length_prefix + message_bytes)
Comment thread
sohankshirsagar marked this conversation as resolved.
Comment thread
sohankshirsagar marked this conversation as resolved.
10 changes: 10 additions & 0 deletions drift/core/communication/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
"CliMessage",
"InstrumentationVersionMismatchAlert",
"MessageType",
"Runtime",
"SdkMessage",
"SendAlertRequest",
"SendInboundSpanForReplayRequest",
"SetTimeTravelRequest",
"SetTimeTravelResponse",
"UnpatchedDependencyAlert",
# Aliases
"SDKMessageType",
Expand All @@ -43,9 +46,12 @@
CliMessage,
InstrumentationVersionMismatchAlert,
MessageType,
Runtime,
SdkMessage,
SendAlertRequest,
SendInboundSpanForReplayRequest,
SetTimeTravelRequest,
SetTimeTravelResponse,
UnpatchedDependencyAlert,
)
from tusk.drift.core.v1 import (
Expand Down Expand Up @@ -136,6 +142,9 @@ class ConnectRequest:
metadata: dict[str, str] = field(default_factory=dict)
"""Additional metadata."""

runtime: Runtime = Runtime.PYTHON
"""SDK runtime environment (node, python)."""

def to_proto(self) -> ProtoConnectRequest:
"""Convert to protobuf message."""
from betterproto.lib.google.protobuf import Struct
Expand All @@ -150,6 +159,7 @@ def to_proto(self) -> ProtoConnectRequest:
sdk_version=self.sdk_version,
min_cli_version=self.min_cli_version,
metadata=metadata_struct,
runtime=self.runtime,
)


Expand Down
7 changes: 7 additions & 0 deletions drift/core/drift_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,13 @@ def _init_auto_instrumentations(self) -> None:
logger.debug("Socket instrumentation initialized (REPLAY mode - unpatched dependency detection)")
except Exception as e:
logger.debug(f"Socket instrumentation initialization failed: {e}")

# try:
# from ..instrumentation.kinde import KindeInstrumentation
# _ = KindeInstrumentation(enabled=True)
# logger.debug("Kinde instrumentation initialized (REPLAY mode - auth token validation)")
# except Exception as e:
# logger.debug(f"Kinde instrumentation initialization failed: {e}")

def create_env_vars_snapshot(self) -> None:
"""Create a span capturing all environment variables.
Expand Down
2 changes: 2 additions & 0 deletions drift/instrumentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from .base import InstrumentationBase
from .django import DjangoInstrumentation
from .kinde import KindeInstrumentation
from .registry import install_hooks, patch_instances_via_gc, register_patch
from .wsgi import WsgiInstrumentation

__all__ = [
"InstrumentationBase",
"DjangoInstrumentation",
"KindeInstrumentation",
"WsgiInstrumentation",
"register_patch",
"install_hooks",
Expand Down
5 changes: 5 additions & 0 deletions drift/instrumentation/kinde/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Kinde SDK instrumentation for bypassing authentication in replay mode."""

from .instrumentation import KindeInstrumentation

__all__ = ["KindeInstrumentation"]
71 changes: 71 additions & 0 deletions drift/instrumentation/kinde/instrumentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Instrumentation for Kinde SDK authentication.

Patches is_authenticated() to always return True in REPLAY mode,
allowing tests to bypass authentication checks during replay.
"""

from __future__ import annotations

import logging
from typing import Any

from ..base import InstrumentationBase

logger = logging.getLogger(__name__)


class KindeInstrumentation(InstrumentationBase):
"""Instrumentation for the Kinde SDK authentication library.

Patches OAuth.is_authenticated() to:
- Return True in REPLAY mode (bypass authentication)
- Call the original method in RECORD and DISABLED modes

Since SmartOAuth and AsyncOAuth delegate to OAuth.is_authenticated(),
patching OAuth covers all authentication entry points.
"""

def __init__(self, enabled: bool = True) -> None:
super().__init__(
name="KindeInstrumentation",
module_name="kinde_sdk.auth.oauth",
supported_versions="*",
enabled=enabled,
)

def patch(self, module: Any) -> None:
"""Patch the kinde_sdk.auth.oauth module.

Patches OAuth.is_authenticated() to return True in REPLAY mode.
"""
if not hasattr(module, "OAuth"):
logger.warning("kinde_sdk.auth.oauth.OAuth not found, skipping instrumentation")
return

original_is_authenticated = module.OAuth.is_authenticated

def patched_is_authenticated(oauth_self) -> bool:
"""Patched is_authenticated method.

Args:
oauth_self: OAuth instance

Returns:
True in REPLAY mode, otherwise delegates to original method
"""
# Lazy imports to avoid circular dependency
from ...core.drift_sdk import TuskDrift
from ...core.types import TuskDriftMode

sdk = TuskDrift.get_instance()

# In REPLAY mode, always return True to bypass authentication
if sdk.mode == TuskDriftMode.REPLAY:
logger.debug("[KindeInstrumentation] REPLAY mode: returning True for is_authenticated")
return True

# In RECORD or DISABLED mode, call the original method
return original_is_authenticated(oauth_self)

module.OAuth.is_authenticated = patched_is_authenticated
logger.info("kinde_sdk.auth.oauth.OAuth.is_authenticated instrumented")
1 change: 1 addition & 0 deletions drift/instrumentation/requests/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ def _try_get_mock(
input_value=input_value,
kind=SpanKind.CLIENT,
input_schema_merges=input_schema_merges,
is_pre_app_start=not sdk.app_ready,
)

if not mock_response_output or not mock_response_output.found:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
"protobuf>=6.0",
"PyYAML>=6.0",
"requests>=2.32.5",
"tusk-drift-schemas>=0.1.9.dev1",
"tusk-drift-schemas @ file:///Users/sohankshirsagar/Desktop/Playground/tusk-drift-container/tusk-drift-schemas",
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
Outdated
"aiohttp>=3.9.0",
"aiofiles>=23.0.0",
"opentelemetry-api>=1.20.0",
Expand Down
Loading
Loading