Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion drift/core/mode_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from opentelemetry.trace import SpanKind as OTelSpanKind

if TYPE_CHECKING:
pass
from ..instrumentation.http import HttpTransformEngine

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -144,3 +144,52 @@ def is_background_request(is_server_request: bool = False) -> bool:
current_span_info = SpanUtils.get_current_span_info()

return is_app_ready and not current_span_info and not is_server_request


def should_record_inbound_http_request(
method: str,
target: str,
headers: dict[str, str],
transform_engine: HttpTransformEngine | None,
is_pre_app_start: bool,
) -> tuple[bool, str | None]:
"""Check if an inbound HTTP request should be recorded.

This should be called BEFORE reading the request body to avoid
unnecessary I/O for requests that will be dropped or not sampled.

The check order is:
1. Drop transforms - check if request matches any drop rules
2. Sampling - check if request should be sampled (only when app is ready)

During pre-app-start phase, all requests are sampled to capture
initialization behavior.

Note: This is HTTP-specific. gRPC or other protocols would need a separate function
with different parameters.

Args:
method: HTTP method (GET, POST, etc.)
target: Request target (path + query string, e.g., "/api/users?page=1")
headers: Request headers dictionary
transform_engine: Optional HTTP transform engine for drop checks
is_pre_app_start: Whether app is in pre-start phase (always sample if True)

Returns:
Tuple of (should_record, skip_reason):
- should_record: True if request should be recorded
- skip_reason: If False, explains why ("dropped" or "not_sampled"), None otherwise
"""
if transform_engine and transform_engine.should_drop_inbound_request(method, target, headers):
return False, "dropped"

if not is_pre_app_start:
from .drift_sdk import TuskDrift
from .sampling import should_sample

sdk = TuskDrift.get_instance()
sampling_rate = sdk.get_sampling_rate()
if not should_sample(sampling_rate, is_app_ready=True):
return False, "not_sampled"

return True, None
43 changes: 17 additions & 26 deletions drift/instrumentation/django/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

if TYPE_CHECKING:
from django.http import HttpRequest, HttpResponse
from ...core.mode_utils import handle_record_mode
from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request
from ...core.tracing import TdSpanAttributes
from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils
from ...core.types import (
Expand Down Expand Up @@ -167,20 +167,25 @@ def _record_request(self, request: HttpRequest, sdk, is_pre_app_start: bool) ->
Returns:
Django HttpResponse object
"""
# Inbound request sampling (only when app is ready)
# Always sample during startup to capture initialization behavior
if not is_pre_app_start:
from ...core.sampling import should_sample
# Pre-flight check: drop transforms and sampling
# NOTE: This is done before body capture to avoid unnecessary I/O
method = request.method or ""
path = request.path
query_string = request.META.get("QUERY_STRING", "")
target = f"{path}?{query_string}" if query_string else path

sampling_rate = sdk.get_sampling_rate()
if not should_sample(sampling_rate, is_app_ready=True):
logger.debug(f"[Django] Request not sampled (rate={sampling_rate}), path={request.path}")
return self.get_response(request)
from ..wsgi import extract_headers

start_time_ns = time.time_ns()
request_headers = extract_headers(request.META)

method = request.method or ""
path = request.path
should_record, skip_reason = should_record_inbound_http_request(
method, target, request_headers, self.transform_engine, is_pre_app_start
)
if not should_record:
logger.debug(f"[Django] Skipping request ({skip_reason}), path={path}")
return self.get_response(request)

start_time_ns = time.time_ns()
span_name = f"{method} {path}"

# Create span using SpanUtils
Expand Down Expand Up @@ -216,20 +221,6 @@ def _record_request(self, request: HttpRequest, sdk, is_pre_app_start: bool) ->
except Exception:
pass

# Check if request should be dropped
query_string = request.META.get("QUERY_STRING", "")
target = f"{path}?{query_string}" if query_string else path

from ..wsgi import extract_headers

request_headers = extract_headers(request.META)

if self.transform_engine and self.transform_engine.should_drop_inbound_request(method, target, request_headers):
# Reset context before early return
span_kind_context.reset(span_kind_token)
span_info.span.end()
return self.get_response(request)

# Store metadata on request for later use
request._drift_start_time_ns = start_time_ns # type: ignore
request._drift_span_info = span_info # type: ignore
Expand Down
55 changes: 30 additions & 25 deletions drift/instrumentation/fastapi/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ...core.drift_sdk import TuskDrift
from ...core.json_schema_helper import JsonSchemaHelper, SchemaMerge
from ...core.mode_utils import handle_record_mode
from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request
from ...core.tracing import TdSpanAttributes
from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils
from ...core.types import (
Expand Down Expand Up @@ -106,7 +106,7 @@ async def _handle_replay_request(
transform_engine: HttpTransformEngine | None,
method: str,
raw_path: str,
target: str,
headers: dict[str, str],
) -> None:
"""Handle FastAPI request in REPLAY mode.

Expand All @@ -119,7 +119,7 @@ async def _handle_replay_request(
from ...core.types import replay_trace_id_context

# Extract trace ID from headers (case-insensitive lookup)
request_headers = _extract_headers(scope)
request_headers = headers
# Convert headers to lowercase for case-insensitive lookup
headers_lower = {k.lower(): v for k, v in request_headers.items()}
replay_trace_id = headers_lower.get("x-td-trace-id")
Expand Down Expand Up @@ -241,6 +241,8 @@ async def _record_request(
transform_engine: HttpTransformEngine | None,
method: str,
raw_path: str,
target: str,
headers: dict[str, str],
is_pre_app_start: bool,
) -> None:
"""Handle request in RECORD mode with span creation using SpanUtils.
Expand All @@ -254,18 +256,17 @@ async def _record_request(
transform_engine: HTTP transform engine for request/response transforms
method: HTTP method (GET, POST, etc.)
raw_path: Request path
target: Request target (path + query string)
headers: Request headers dictionary
is_pre_app_start: Whether this request occurred before app was marked ready
"""
# Inbound request sampling (only when app is ready)
# Always sample during startup to capture initialization behavior
if not is_pre_app_start:
from ...core.sampling import should_sample

sdk = TuskDrift.get_instance()
sampling_rate = sdk.get_sampling_rate()
if not should_sample(sampling_rate, is_app_ready=True):
logger.debug(f"[FastAPI] Request not sampled (rate={sampling_rate}), path={raw_path}")
return await original_call(app, scope, receive, send)
# Pre-flight check: drop transforms and sampling before body capture
should_record, skip_reason = should_record_inbound_http_request(
method, target, headers, transform_engine, is_pre_app_start
)
if not should_record:
logger.debug(f"[FastAPI] Skipping request ({skip_reason}), path={raw_path}")
return await original_call(app, scope, receive, send)

start_time_ns = time.time_ns()

Expand Down Expand Up @@ -389,16 +390,8 @@ async def _handle_request(
query_string = query_bytes.decode("utf-8", errors="replace")
else:
query_string = str(query_bytes)
target_for_drop = f"{raw_path}?{query_string}" if query_string else raw_path
headers_for_drop = _extract_headers(scope)

# Check if request should be dropped by transform engine
if transform_engine and transform_engine.should_drop_inbound_request(
method,
target_for_drop,
headers_for_drop,
):
return await original_call(app, scope, receive, send)
target = f"{raw_path}?{query_string}" if query_string else raw_path
headers = _extract_headers(scope)

# DISABLED mode - just pass through
if sdk.mode == TuskDriftMode.DISABLED:
Expand All @@ -407,14 +400,26 @@ async def _handle_request(
# REPLAY mode - handle trace ID extraction and context setup
if sdk.mode == TuskDriftMode.REPLAY:
return await _handle_replay_request(
app, scope, receive, send, original_call, transform_engine, method, raw_path, target_for_drop
app, scope, receive, send, original_call, transform_engine, method, raw_path, headers
)

# RECORD mode - use handle_record_mode for consistent is_pre_app_start logic
# NOTE: Pre-flight check (drop + sample) is done inside _record_request
# to access is_pre_app_start from handle_record_mode
result = handle_record_mode(
original_function_call=lambda: original_call(app, scope, receive, send),
record_mode_handler=lambda is_pre_app_start: _record_request(
app, scope, receive, send, original_call, transform_engine, method, raw_path, is_pre_app_start
app,
scope,
receive,
send,
original_call,
transform_engine,
method,
raw_path,
target,
headers,
is_pre_app_start,
),
span_kind=OTelSpanKind.SERVER,
)
Expand Down
32 changes: 12 additions & 20 deletions drift/instrumentation/wsgi/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
WsgiAppMethod = Callable[[WSGIApplication, WSGIEnvironment, StartResponse], "Iterable[bytes]"]


from ...core.mode_utils import handle_record_mode
from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request
from ...core.tracing import TdSpanAttributes
from ...core.tracing.span_utils import CreateSpanOptions, SpanUtils
from ...core.types import (
Expand Down Expand Up @@ -208,34 +208,26 @@ def _create_and_handle_request(
We manually manage context because the span needs to stay open
across the WSGI response iterator.
"""
# Extract request info for span name and drop check
# Pre-flight check: drop transforms and sampling
# NOTE: This is done before body capture to avoid unnecessary I/O
method = environ.get("REQUEST_METHOD", "GET")
path = environ.get("PATH_INFO", "")
query_string = environ.get("QUERY_STRING", "")
target = f"{path}?{query_string}" if query_string else path
request_headers = extract_headers(environ)

if replay_token is None:
should_record, skip_reason = should_record_inbound_http_request(
method, target, request_headers, transform_engine, is_pre_app_start
)
if not should_record:
logger.debug(f"[WSGI] Skipping request ({skip_reason}), path={path}")
return original_wsgi_app(app, environ, start_response)

# Capture request body
request_body = capture_request_body(environ)
environ["_drift_request_body"] = request_body

# Check if request should be dropped
request_headers = extract_headers(environ)
if transform_engine and transform_engine.should_drop_inbound_request(method, target, request_headers):
if replay_token:
replay_trace_id_context.reset(replay_token)
return original_wsgi_app(app, environ, start_response)

# Inbound request sampling (only RECORD mode + app ready)
# - replay_token is None means RECORD mode (REPLAY mode sets replay_token)
# - not is_pre_app_start means app is ready (always sample during startup)
if replay_token is None and not is_pre_app_start:
from ...core.sampling import should_sample

sampling_rate = sdk.get_sampling_rate()
if not should_sample(sampling_rate, is_app_ready=True):
logger.debug(f"[WSGI] Request not sampled (rate={sampling_rate}), path={path}")
return original_wsgi_app(app, environ, start_response)

span_name = f"{method} {path}"

# Build input value before starting span
Expand Down