diff --git a/drift/core/mode_utils.py b/drift/core/mode_utils.py index 45a41de..5ad0012 100644 --- a/drift/core/mode_utils.py +++ b/drift/core/mode_utils.py @@ -14,7 +14,7 @@ from opentelemetry.trace import SpanKind as OTelSpanKind if TYPE_CHECKING: - pass + from ..instrumentation.http import HttpTransformEngine logger = logging.getLogger(__name__) @@ -144,3 +144,52 @@ def is_background_request(is_server_request: bool = False) -> bool: current_span_info = SpanUtils.get_current_span_info() return is_app_ready and not current_span_info and not is_server_request + + +def should_record_inbound_http_request( + method: str, + target: str, + headers: dict[str, str], + transform_engine: HttpTransformEngine | None, + is_pre_app_start: bool, +) -> tuple[bool, str | None]: + """Check if an inbound HTTP request should be recorded. + + This should be called BEFORE reading the request body to avoid + unnecessary I/O for requests that will be dropped or not sampled. + + The check order is: + 1. Drop transforms - check if request matches any drop rules + 2. Sampling - check if request should be sampled (only when app is ready) + + During pre-app-start phase, all requests are sampled to capture + initialization behavior. + + Note: This is HTTP-specific. gRPC or other protocols would need a separate function + with different parameters. + + Args: + method: HTTP method (GET, POST, etc.) + target: Request target (path + query string, e.g., "/api/users?page=1") + headers: Request headers dictionary + transform_engine: Optional HTTP transform engine for drop checks + is_pre_app_start: Whether app is in pre-start phase (always sample if True) + + Returns: + Tuple of (should_record, skip_reason): + - should_record: True if request should be recorded + - skip_reason: If False, explains why ("dropped" or "not_sampled"), None otherwise + """ + if transform_engine and transform_engine.should_drop_inbound_request(method, target, headers): + return False, "dropped" + + if not is_pre_app_start: + from .drift_sdk import TuskDrift + from .sampling import should_sample + + sdk = TuskDrift.get_instance() + sampling_rate = sdk.get_sampling_rate() + if not should_sample(sampling_rate, is_app_ready=True): + return False, "not_sampled" + + return True, None diff --git a/drift/instrumentation/django/middleware.py b/drift/instrumentation/django/middleware.py index cb5f486..4a1d638 100644 --- a/drift/instrumentation/django/middleware.py +++ b/drift/instrumentation/django/middleware.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from django.http import HttpRequest, HttpResponse -from ...core.mode_utils import handle_record_mode +from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils from ...core.types import ( @@ -167,20 +167,25 @@ def _record_request(self, request: HttpRequest, sdk, is_pre_app_start: bool) -> Returns: Django HttpResponse object """ - # Inbound request sampling (only when app is ready) - # Always sample during startup to capture initialization behavior - if not is_pre_app_start: - from ...core.sampling import should_sample + # Pre-flight check: drop transforms and sampling + # NOTE: This is done before body capture to avoid unnecessary I/O + method = request.method or "" + path = request.path + query_string = request.META.get("QUERY_STRING", "") + target = f"{path}?{query_string}" if query_string else path - sampling_rate = sdk.get_sampling_rate() - if not should_sample(sampling_rate, is_app_ready=True): - logger.debug(f"[Django] Request not sampled (rate={sampling_rate}), path={request.path}") - return self.get_response(request) + from ..wsgi import extract_headers - start_time_ns = time.time_ns() + request_headers = extract_headers(request.META) - method = request.method or "" - path = request.path + should_record, skip_reason = should_record_inbound_http_request( + method, target, request_headers, self.transform_engine, is_pre_app_start + ) + if not should_record: + logger.debug(f"[Django] Skipping request ({skip_reason}), path={path}") + return self.get_response(request) + + start_time_ns = time.time_ns() span_name = f"{method} {path}" # Create span using SpanUtils @@ -216,20 +221,6 @@ def _record_request(self, request: HttpRequest, sdk, is_pre_app_start: bool) -> except Exception: pass - # Check if request should be dropped - query_string = request.META.get("QUERY_STRING", "") - target = f"{path}?{query_string}" if query_string else path - - from ..wsgi import extract_headers - - request_headers = extract_headers(request.META) - - if self.transform_engine and self.transform_engine.should_drop_inbound_request(method, target, request_headers): - # Reset context before early return - span_kind_context.reset(span_kind_token) - span_info.span.end() - return self.get_response(request) - # Store metadata on request for later use request._drift_start_time_ns = start_time_ns # type: ignore request._drift_span_info = span_info # type: ignore diff --git a/drift/instrumentation/fastapi/instrumentation.py b/drift/instrumentation/fastapi/instrumentation.py index 7a37073..57a764a 100644 --- a/drift/instrumentation/fastapi/instrumentation.py +++ b/drift/instrumentation/fastapi/instrumentation.py @@ -26,7 +26,7 @@ from ...core.drift_sdk import TuskDrift from ...core.json_schema_helper import JsonSchemaHelper, SchemaMerge -from ...core.mode_utils import handle_record_mode +from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils from ...core.types import ( @@ -106,7 +106,7 @@ async def _handle_replay_request( transform_engine: HttpTransformEngine | None, method: str, raw_path: str, - target: str, + headers: dict[str, str], ) -> None: """Handle FastAPI request in REPLAY mode. @@ -119,7 +119,7 @@ async def _handle_replay_request( from ...core.types import replay_trace_id_context # Extract trace ID from headers (case-insensitive lookup) - request_headers = _extract_headers(scope) + request_headers = headers # Convert headers to lowercase for case-insensitive lookup headers_lower = {k.lower(): v for k, v in request_headers.items()} replay_trace_id = headers_lower.get("x-td-trace-id") @@ -241,6 +241,8 @@ async def _record_request( transform_engine: HttpTransformEngine | None, method: str, raw_path: str, + target: str, + headers: dict[str, str], is_pre_app_start: bool, ) -> None: """Handle request in RECORD mode with span creation using SpanUtils. @@ -254,18 +256,17 @@ async def _record_request( transform_engine: HTTP transform engine for request/response transforms method: HTTP method (GET, POST, etc.) raw_path: Request path + target: Request target (path + query string) + headers: Request headers dictionary is_pre_app_start: Whether this request occurred before app was marked ready """ - # Inbound request sampling (only when app is ready) - # Always sample during startup to capture initialization behavior - if not is_pre_app_start: - from ...core.sampling import should_sample - - sdk = TuskDrift.get_instance() - sampling_rate = sdk.get_sampling_rate() - if not should_sample(sampling_rate, is_app_ready=True): - logger.debug(f"[FastAPI] Request not sampled (rate={sampling_rate}), path={raw_path}") - return await original_call(app, scope, receive, send) + # Pre-flight check: drop transforms and sampling before body capture + should_record, skip_reason = should_record_inbound_http_request( + method, target, headers, transform_engine, is_pre_app_start + ) + if not should_record: + logger.debug(f"[FastAPI] Skipping request ({skip_reason}), path={raw_path}") + return await original_call(app, scope, receive, send) start_time_ns = time.time_ns() @@ -389,16 +390,8 @@ async def _handle_request( query_string = query_bytes.decode("utf-8", errors="replace") else: query_string = str(query_bytes) - target_for_drop = f"{raw_path}?{query_string}" if query_string else raw_path - headers_for_drop = _extract_headers(scope) - - # Check if request should be dropped by transform engine - if transform_engine and transform_engine.should_drop_inbound_request( - method, - target_for_drop, - headers_for_drop, - ): - return await original_call(app, scope, receive, send) + target = f"{raw_path}?{query_string}" if query_string else raw_path + headers = _extract_headers(scope) # DISABLED mode - just pass through if sdk.mode == TuskDriftMode.DISABLED: @@ -407,14 +400,26 @@ async def _handle_request( # REPLAY mode - handle trace ID extraction and context setup if sdk.mode == TuskDriftMode.REPLAY: return await _handle_replay_request( - app, scope, receive, send, original_call, transform_engine, method, raw_path, target_for_drop + app, scope, receive, send, original_call, transform_engine, method, raw_path, headers ) # RECORD mode - use handle_record_mode for consistent is_pre_app_start logic + # NOTE: Pre-flight check (drop + sample) is done inside _record_request + # to access is_pre_app_start from handle_record_mode result = handle_record_mode( original_function_call=lambda: original_call(app, scope, receive, send), record_mode_handler=lambda is_pre_app_start: _record_request( - app, scope, receive, send, original_call, transform_engine, method, raw_path, is_pre_app_start + app, + scope, + receive, + send, + original_call, + transform_engine, + method, + raw_path, + target, + headers, + is_pre_app_start, ), span_kind=OTelSpanKind.SERVER, ) diff --git a/drift/instrumentation/wsgi/handler.py b/drift/instrumentation/wsgi/handler.py index 0a705dd..9b13308 100644 --- a/drift/instrumentation/wsgi/handler.py +++ b/drift/instrumentation/wsgi/handler.py @@ -30,7 +30,7 @@ WsgiAppMethod = Callable[[WSGIApplication, WSGIEnvironment, StartResponse], "Iterable[bytes]"] -from ...core.mode_utils import handle_record_mode +from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanUtils from ...core.types import ( @@ -208,34 +208,26 @@ def _create_and_handle_request( We manually manage context because the span needs to stay open across the WSGI response iterator. """ - # Extract request info for span name and drop check + # Pre-flight check: drop transforms and sampling + # NOTE: This is done before body capture to avoid unnecessary I/O method = environ.get("REQUEST_METHOD", "GET") path = environ.get("PATH_INFO", "") query_string = environ.get("QUERY_STRING", "") target = f"{path}?{query_string}" if query_string else path + request_headers = extract_headers(environ) + + if replay_token is None: + should_record, skip_reason = should_record_inbound_http_request( + method, target, request_headers, transform_engine, is_pre_app_start + ) + if not should_record: + logger.debug(f"[WSGI] Skipping request ({skip_reason}), path={path}") + return original_wsgi_app(app, environ, start_response) # Capture request body request_body = capture_request_body(environ) environ["_drift_request_body"] = request_body - # Check if request should be dropped - request_headers = extract_headers(environ) - if transform_engine and transform_engine.should_drop_inbound_request(method, target, request_headers): - if replay_token: - replay_trace_id_context.reset(replay_token) - return original_wsgi_app(app, environ, start_response) - - # Inbound request sampling (only RECORD mode + app ready) - # - replay_token is None means RECORD mode (REPLAY mode sets replay_token) - # - not is_pre_app_start means app is ready (always sample during startup) - if replay_token is None and not is_pre_app_start: - from ...core.sampling import should_sample - - sampling_rate = sdk.get_sampling_rate() - if not should_sample(sampling_rate, is_app_ready=True): - logger.debug(f"[WSGI] Request not sampled (rate={sampling_rate}), path={path}") - return original_wsgi_app(app, environ, start_response) - span_name = f"{method} {path}" # Build input value before starting span