Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ jobs:
- name: Install dependencies
run: uv sync --all-extras

- name: Debug lockfile
run: |
head -100 uv.lock
grep -A10 "asgiref" uv.lock
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
Outdated

- name: Check formatting
run: uv run ruff format --check drift/ tests/

Expand Down Expand Up @@ -99,4 +104,3 @@ jobs:
uv venv /tmp/test-install
Comment thread
sohankshirsagar marked this conversation as resolved.
uv pip install dist/*.whl --python /tmp/test-install/bin/python
/tmp/test-install/bin/python -c "import drift; print('Package imported successfully')"

127 changes: 127 additions & 0 deletions drift/core/communication/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
MessageType,
MockRequestInput,
MockResponseOutput,
Runtime,
SdkMessage,
SendAlertRequest,
SendInboundSpanForReplayRequest,
SetTimeTravelRequest,
SetTimeTravelResponse,
UnpatchedDependencyAlert,
span_to_proto,
)
Expand Down Expand Up @@ -78,6 +81,8 @@ def __init__(self, config: CommunicatorConfig | None = None) -> None:
self._incoming_buffer = bytearray()
self._pending_requests: dict[str, dict[str, Any]] = {}
self._lock = threading.Lock()
self._background_reader_thread: threading.Thread | None = None
self._stop_background_reader = threading.Event()

@property
def is_connected(self) -> bool:
Expand Down Expand Up @@ -321,6 +326,9 @@ def connect_sync(
self._connected = True
logger.info(f"[CONNECT_SYNC] Connection successful! Socket is: {self._socket}")
logger.info(f"[CONNECT_SYNC] _connected={self._connected}, is_connected={self.is_connected}")

# Start background reader for CLI-initiated messages (like SetTimeTravel)
self._start_background_reader()
Comment thread
sohankshirsagar marked this conversation as resolved.
else:
error_msg = response.error or "Unknown error"
raise ConnectionError(f"CLI rejected connection: {error_msg}")
Expand Down Expand Up @@ -770,6 +778,12 @@ def _cleanup(self) -> None:
logger.warning("[CLEANUP] _cleanup() called! Stack trace:")
logger.warning("".join(traceback.format_stack()))

# Stop background reader thread
self._stop_background_reader.set()
if self._background_reader_thread and self._background_reader_thread.is_alive():
self._background_reader_thread.join(timeout=1.0)
self._background_reader_thread = None

Comment thread
sohankshirsagar marked this conversation as resolved.
self._connected = False
self._session_id = None
self._incoming_buffer.clear()
Expand All @@ -783,3 +797,116 @@ def _cleanup(self) -> None:
self._socket = None

self._pending_requests.clear()

# ========== Background Reader for CLI-initiated Messages ==========

def _start_background_reader(self) -> None:
"""Start background thread to read CLI-initiated messages."""
if self._background_reader_thread and self._background_reader_thread.is_alive():
return

self._stop_background_reader.clear()
self._background_reader_thread = threading.Thread(
target=self._background_read_loop,
daemon=True,
name="CLI-Message-Reader",
)
self._background_reader_thread.start()
logger.debug("Started background reader thread for CLI-initiated messages")

def _background_read_loop(self) -> None:
"""Background loop to read and handle CLI-initiated messages."""
while not self._stop_background_reader.is_set():
if not self._socket:
break

try:
# Set a short timeout so we can check the stop event periodically
self._socket.settimeout(0.5)

# Try to read length prefix
try:
length_data = self._recv_exact(4)
except socket.timeout:
continue # No data available, check stop event and retry
except Exception:
continue
Comment thread
sohankshirsagar marked this conversation as resolved.

if not length_data:
continue
Comment thread
sohankshirsagar marked this conversation as resolved.
Outdated
Comment thread
sohankshirsagar marked this conversation as resolved.
Outdated

length = struct.unpack(">I", length_data)[0]

# Read message data
self._socket.settimeout(5.0) # Longer timeout for message body
message_data = self._recv_exact(length)
if not message_data:
continue

# Parse message
cli_message = CliMessage().parse(message_data)
logger.debug(f"Background reader received message type: {cli_message.type}")

# Handle CLI-initiated messages based on message type
if cli_message.type == MessageType.SET_TIME_TRAVEL:
self._handle_set_time_travel_sync(cli_message)
else:
# Other message types (responses to SDK requests) are handled elsewhere
logger.debug(f"Background reader ignoring message type: {cli_message.type}")
Comment thread
sohankshirsagar marked this conversation as resolved.
Outdated

except socket.timeout:
continue # Normal timeout, just retry
except Exception as e:
if not self._stop_background_reader.is_set():
logger.debug(f"Background reader error: {e}")
break

logger.debug("Background reader thread stopped")

def _handle_set_time_travel_sync(self, cli_message: CliMessage) -> None:
"""Handle SetTimeTravel request from CLI and send response."""
request = cli_message.set_time_travel_request
if not request:
return
Comment thread
sohankshirsagar marked this conversation as resolved.

logger.debug(
f"Received SetTimeTravel request: timestamp={request.timestamp_seconds}, "
f"traceId={request.trace_id}, source={request.timestamp_source}"
)

try:
from drift.instrumentation.datetime.instrumentation import start_time_travel

success = start_time_travel(request.timestamp_seconds, request.trace_id)

response = SetTimeTravelResponse(
success=success,
error="" if success else "time-machine library not available or failed to start",
)
except Exception as e:
logger.error(f"Failed to set time travel: {e}")
response = SetTimeTravelResponse(success=False, error=str(e))

# Send response back to CLI
sdk_message = SdkMessage(
type=MessageType.SET_TIME_TRAVEL,
request_id=cli_message.request_id,
set_time_travel_response=response,
)

try:
self._send_message_sync(sdk_message)
logger.debug(f"Sent SetTimeTravel response: success={response.success}")
except Exception as e:
logger.error(f"Failed to send SetTimeTravel response: {e}")

def _send_message_sync(self, message: SdkMessage) -> None:
"""Send a message synchronously on the main socket."""
if not self._socket:
raise ConnectionError("Not connected to CLI")

message_bytes = bytes(message)
length_prefix = struct.pack(">I", len(message_bytes))

with self._lock:
self._socket.sendall(length_prefix + message_bytes)
Comment thread
sohankshirsagar marked this conversation as resolved.
Comment thread
sohankshirsagar marked this conversation as resolved.
10 changes: 10 additions & 0 deletions drift/core/communication/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
"CliMessage",
"InstrumentationVersionMismatchAlert",
"MessageType",
"Runtime",
"SdkMessage",
"SendAlertRequest",
"SendInboundSpanForReplayRequest",
"SetTimeTravelRequest",
"SetTimeTravelResponse",
"UnpatchedDependencyAlert",
# Aliases
"SDKMessageType",
Expand All @@ -43,9 +46,12 @@
CliMessage,
InstrumentationVersionMismatchAlert,
MessageType,
Runtime,
SdkMessage,
SendAlertRequest,
SendInboundSpanForReplayRequest,
SetTimeTravelRequest,
SetTimeTravelResponse,
UnpatchedDependencyAlert,
)
from tusk.drift.core.v1 import (
Expand Down Expand Up @@ -136,6 +142,9 @@ class ConnectRequest:
metadata: dict[str, str] = field(default_factory=dict)
"""Additional metadata."""

runtime: Runtime = Runtime.PYTHON
"""SDK runtime environment (node, python)."""

def to_proto(self) -> ProtoConnectRequest:
"""Convert to protobuf message."""
from betterproto.lib.google.protobuf import Struct
Expand All @@ -150,6 +159,7 @@ def to_proto(self) -> ProtoConnectRequest:
sdk_version=self.sdk_version,
min_cli_version=self.min_cli_version,
metadata=metadata_struct,
runtime=self.runtime,
)


Expand Down
7 changes: 7 additions & 0 deletions drift/core/drift_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,13 @@ def _init_auto_instrumentations(self) -> None:
logger.debug("Socket instrumentation initialized (REPLAY mode - unpatched dependency detection)")
except Exception as e:
logger.debug(f"Socket instrumentation initialization failed: {e}")

# try:
# from ..instrumentation.kinde import KindeInstrumentation
# _ = KindeInstrumentation(enabled=True)
# logger.debug("Kinde instrumentation initialized (REPLAY mode - auth token validation)")
# except Exception as e:
# logger.debug(f"Kinde instrumentation initialization failed: {e}")

def create_env_vars_snapshot(self) -> None:
"""Create a span capturing all environment variables.
Expand Down
2 changes: 2 additions & 0 deletions drift/instrumentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from .base import InstrumentationBase
from .django import DjangoInstrumentation
from .kinde import KindeInstrumentation
from .registry import install_hooks, patch_instances_via_gc, register_patch
from .wsgi import WsgiInstrumentation

__all__ = [
"InstrumentationBase",
"DjangoInstrumentation",
"KindeInstrumentation",
"WsgiInstrumentation",
"register_patch",
"install_hooks",
Expand Down
5 changes: 5 additions & 0 deletions drift/instrumentation/kinde/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Kinde SDK instrumentation for bypassing authentication in replay mode."""

from .instrumentation import KindeInstrumentation

__all__ = ["KindeInstrumentation"]
71 changes: 71 additions & 0 deletions drift/instrumentation/kinde/instrumentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Instrumentation for Kinde SDK authentication.

Patches is_authenticated() to always return True in REPLAY mode,
allowing tests to bypass authentication checks during replay.
"""

from __future__ import annotations

import logging
from typing import Any

from ..base import InstrumentationBase

logger = logging.getLogger(__name__)


class KindeInstrumentation(InstrumentationBase):
"""Instrumentation for the Kinde SDK authentication library.

Patches OAuth.is_authenticated() to:
- Return True in REPLAY mode (bypass authentication)
- Call the original method in RECORD and DISABLED modes

Since SmartOAuth and AsyncOAuth delegate to OAuth.is_authenticated(),
patching OAuth covers all authentication entry points.
"""

def __init__(self, enabled: bool = True) -> None:
super().__init__(
name="KindeInstrumentation",
module_name="kinde_sdk.auth.oauth",
supported_versions="*",
enabled=enabled,
)

def patch(self, module: Any) -> None:
"""Patch the kinde_sdk.auth.oauth module.

Patches OAuth.is_authenticated() to return True in REPLAY mode.
"""
if not hasattr(module, "OAuth"):
logger.warning("kinde_sdk.auth.oauth.OAuth not found, skipping instrumentation")
return

original_is_authenticated = module.OAuth.is_authenticated

def patched_is_authenticated(oauth_self) -> bool:
"""Patched is_authenticated method.

Args:
oauth_self: OAuth instance

Returns:
True in REPLAY mode, otherwise delegates to original method
"""
# Lazy imports to avoid circular dependency
from ...core.drift_sdk import TuskDrift
from ...core.types import TuskDriftMode

sdk = TuskDrift.get_instance()

# In REPLAY mode, always return True to bypass authentication
if sdk.mode == TuskDriftMode.REPLAY:
logger.debug("[KindeInstrumentation] REPLAY mode: returning True for is_authenticated")
return True

# In RECORD or DISABLED mode, call the original method
return original_is_authenticated(oauth_self)

module.OAuth.is_authenticated = patched_is_authenticated
logger.info("kinde_sdk.auth.oauth.OAuth.is_authenticated instrumented")
1 change: 1 addition & 0 deletions drift/instrumentation/requests/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ def _try_get_mock(
input_value=input_value,
kind=SpanKind.CLIENT,
input_schema_merges=input_schema_merges,
is_pre_app_start=not sdk.app_ready,
)

if not mock_response_output or not mock_response_output.found:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
"protobuf>=6.0",
"PyYAML>=6.0",
"requests>=2.32.5",
"tusk-drift-schemas>=0.1.9.dev1",
"tusk-drift-schemas>=0.1.24",
"aiohttp>=3.9.0",
"aiofiles>=23.0.0",
"opentelemetry-api>=1.20.0",
Expand Down
Loading