77from collections .abc import Callable
88from typing import TYPE_CHECKING
99
10- from opentelemetry import context as otel_context
1110from opentelemetry .trace import SpanKind as OTelSpanKind
12- from opentelemetry .trace import set_span_in_context
1311
1412logger = logging .getLogger (__name__ )
1513
3533 build_output_schema_merges ,
3634 build_output_value ,
3735)
36+ from ...core .mode_utils import handle_record_mode
37+ from ...core .tracing .span_utils import CreateSpanOptions , SpanInfo , SpanUtils
3838
3939
4040class DriftMiddleware :
@@ -64,67 +64,158 @@ def __call__(self, request: HttpRequest) -> HttpResponse:
6464
6565 sdk = TuskDrift .get_instance ()
6666
67- # Check if we're in REPLAY mode and handle trace ID extraction
68- replay_trace_id = None
69- replay_token = None
67+ # DISABLED mode - just pass through
68+ if sdk .mode == TuskDriftMode .DISABLED :
69+ return self .get_response (request )
70+
71+ # REPLAY mode - handle trace ID extraction and context setup
7072 if sdk .mode == TuskDriftMode .REPLAY :
71- # Extract trace ID from headers (case-insensitive lookup)
72- # Django stores headers in request.META
73- headers_lower = {k .lower (): v for k , v in request .META .items () if k .startswith ("HTTP_" )}
74- logger .info (f"[DJANGO_MIDDLEWARE] REPLAY mode, headers: { list (headers_lower .keys ())} " )
75- # Convert HTTP_X_TD_TRACE_ID -> x-td-trace-id
76- replay_trace_id = headers_lower .get ("http_x_td_trace_id" )
77- logger .info (f"[DJANGO_MIDDLEWARE] replay_trace_id from header: { replay_trace_id } " )
78-
79- if not replay_trace_id :
80- # No trace context in REPLAY mode; proceed without span
81- logger .warning ("[DJANGO_MIDDLEWARE] No replay_trace_id found in headers, proceeding without span" )
82- return self .get_response (request )
83-
84- # Set replay trace context
85- replay_token = replay_trace_id_context .set (replay_trace_id )
73+ return self ._handle_replay_request (request , sdk )
74+
75+ # RECORD mode - use handle_record_mode for consistent is_pre_app_start logic
76+ return handle_record_mode (
77+ original_function_call = lambda : self .get_response (request ),
78+ record_mode_handler = lambda is_pre_app_start : self ._record_request (
79+ request , sdk , is_pre_app_start
80+ ),
81+ span_kind = OTelSpanKind .SERVER ,
82+ )
8683
87- start_time_ns = time .time_ns ()
84+ def _handle_replay_request (
85+ self , request : HttpRequest , sdk
86+ ) -> HttpResponse :
87+ """Handle request in REPLAY mode.
8888
89- # Create OpenTelemetry span
90- from ... core . drift_sdk import TuskDrift
89+ Extracts trace ID from headers and sets up context for child spans.
90+ Does not record the root span in REPLAY mode.
9191
92- sdk = TuskDrift .get_instance ()
93- tracer = sdk .get_tracer ()
92+ Args:
93+ request: Django HttpRequest object
94+ sdk: TuskDrift SDK instance
95+
96+ Returns:
97+ Django HttpResponse object
98+ """
99+ # Extract trace ID from headers (case-insensitive lookup)
100+ # Django stores headers in request.META
101+ headers_lower = {
102+ k .lower (): v for k , v in request .META .items () if k .startswith ("HTTP_" )
103+ }
104+ logger .info (
105+ f"[DJANGO_MIDDLEWARE] REPLAY mode, headers: { list (headers_lower .keys ())} "
106+ )
107+ # Convert HTTP_X_TD_TRACE_ID -> x-td-trace-id
108+ replay_trace_id = headers_lower .get ("http_x_td_trace_id" )
109+ logger .info (
110+ f"[DJANGO_MIDDLEWARE] replay_trace_id from header: { replay_trace_id } "
111+ )
112+
113+ if not replay_trace_id :
114+ # No trace context in REPLAY mode; proceed without span
115+ logger .warning (
116+ "[DJANGO_MIDDLEWARE] No replay_trace_id found in headers, proceeding without span"
117+ )
118+ return self .get_response (request )
119+
120+ # Set replay trace context
121+ replay_token = replay_trace_id_context .set (replay_trace_id )
94122
95123 method = request .method
96124 path = request .path
97125 span_name = f"{ method } { path } "
98126
99- span = tracer .start_span (
100- name = span_name ,
101- kind = OTelSpanKind .SERVER ,
102- attributes = {
103- TdSpanAttributes .NAME : span_name ,
104- TdSpanAttributes .PACKAGE_NAME : "django" ,
105- TdSpanAttributes .INSTRUMENTATION_NAME : "DjangoInstrumentation" ,
106- TdSpanAttributes .SUBMODULE_NAME : method ,
107- TdSpanAttributes .PACKAGE_TYPE : PackageType .HTTP .name ,
108- TdSpanAttributes .IS_PRE_APP_START : not sdk .app_ready ,
109- TdSpanAttributes .IS_ROOT_SPAN : True ,
110- },
127+ # Create span using SpanUtils
128+ span_info = SpanUtils .create_span (
129+ CreateSpanOptions (
130+ name = span_name ,
131+ kind = OTelSpanKind .SERVER ,
132+ attributes = {
133+ TdSpanAttributes .NAME : span_name ,
134+ TdSpanAttributes .PACKAGE_NAME : "django" ,
135+ TdSpanAttributes .INSTRUMENTATION_NAME : "DjangoInstrumentation" ,
136+ TdSpanAttributes .SUBMODULE_NAME : method ,
137+ TdSpanAttributes .PACKAGE_TYPE : PackageType .HTTP .name ,
138+ TdSpanAttributes .IS_PRE_APP_START : False ,
139+ TdSpanAttributes .IS_ROOT_SPAN : True ,
140+ },
141+ is_pre_app_start = False ,
142+ )
143+ )
144+
145+ if not span_info :
146+ # Failed to create span, just process the request
147+ replay_trace_id_context .reset (replay_token )
148+ return self .get_response (request )
149+
150+ # Set span_kind_context for child spans
151+ span_kind_token = span_kind_context .set (SpanKind .SERVER )
152+
153+ # Store metadata on request for later use
154+ request ._drift_start_time_ns = time .time_ns () # type: ignore
155+ request ._drift_span = span_info .span # type: ignore
156+ request ._drift_route_template = None # type: ignore
157+
158+ try :
159+ with SpanUtils .with_span (span_info ):
160+ response = self .get_response (request )
161+ # REPLAY mode: don't capture the span (it's already recorded)
162+ return response
163+ finally :
164+ # Reset context
165+ span_kind_context .reset (span_kind_token )
166+ replay_trace_id_context .reset (replay_token )
167+ span_info .span .end ()
168+
169+ def _record_request (
170+ self , request : HttpRequest , sdk , is_pre_app_start : bool
171+ ) -> HttpResponse :
172+ """Handle request in RECORD mode.
173+
174+ Creates a span, processes the request, and captures the span.
175+
176+ Args:
177+ request: Django HttpRequest object
178+ sdk: TuskDrift SDK instance
179+ is_pre_app_start: Whether this request occurred before app was marked ready
180+
181+ Returns:
182+ Django HttpResponse object
183+ """
184+ start_time_ns = time .time_ns ()
185+
186+ method = request .method
187+ path = request .path
188+ span_name = f"{ method } { path } "
189+
190+ # Create span using SpanUtils
191+ span_info = SpanUtils .create_span (
192+ CreateSpanOptions (
193+ name = span_name ,
194+ kind = OTelSpanKind .SERVER ,
195+ attributes = {
196+ TdSpanAttributes .NAME : span_name ,
197+ TdSpanAttributes .PACKAGE_NAME : "django" ,
198+ TdSpanAttributes .INSTRUMENTATION_NAME : "DjangoInstrumentation" ,
199+ TdSpanAttributes .SUBMODULE_NAME : method ,
200+ TdSpanAttributes .PACKAGE_TYPE : PackageType .HTTP .name ,
201+ TdSpanAttributes .IS_PRE_APP_START : is_pre_app_start ,
202+ TdSpanAttributes .IS_ROOT_SPAN : True ,
203+ },
204+ is_pre_app_start = is_pre_app_start ,
205+ )
111206 )
112207
113- # Make span active
114- ctx = otel_context .get_current ()
115- ctx_with_span = set_span_in_context (span , ctx )
116- token = otel_context .attach (ctx_with_span )
208+ if not span_info :
209+ # Failed to create span, just process the request
210+ return self .get_response (request )
117211
118- # Set span_kind_context for child spans and socket instrumentation to detect SERVER context
212+ # Set span_kind_context for child spans
119213 span_kind_token = span_kind_context .set (SpanKind .SERVER )
120214
121215 # Capture request body
122- # Django provides request.body which handles reading and caching
123- # No truncation at capture time - span-level 1MB blocking at export handles oversized spans
124216 request_body = None
125217 if request .method in ("POST" , "PUT" , "PATCH" ):
126218 try :
127- # Django's request.body automatically reads and caches the body
128219 request_body = request .body
129220 except Exception :
130221 pass
@@ -133,48 +224,35 @@ def __call__(self, request: HttpRequest) -> HttpResponse:
133224 query_string = request .META .get ("QUERY_STRING" , "" )
134225 target = f"{ path } ?{ query_string } " if query_string else path
135226
136- # Extract headers from META
137227 from ..wsgi import extract_headers
138228
139229 request_headers = extract_headers (request .META )
140230
141- if self .transform_engine and self .transform_engine .should_drop_inbound_request (method , target , request_headers ):
231+ if self .transform_engine and self .transform_engine .should_drop_inbound_request (
232+ method , target , request_headers
233+ ):
142234 # Reset context before early return
143235 span_kind_context .reset (span_kind_token )
144- if replay_token :
145- replay_trace_id_context .reset (replay_token )
146- otel_context .detach (token )
147- span .end ()
236+ span_info .span .end ()
148237 return self .get_response (request )
149238
150239 # Store metadata on request for later use
151240 request ._drift_start_time_ns = start_time_ns # type: ignore
152- request ._drift_span = span # type: ignore
153- request ._drift_token = token # type: ignore
154- request ._drift_replay_token = replay_token # type: ignore
155- request ._drift_span_kind_token = span_kind_token # type: ignore
241+ request ._drift_span_info = span_info # type: ignore
156242 request ._drift_request_body = request_body # type: ignore
157- request ._drift_route_template = None # Will be set in process_view # type: ignore
243+ request ._drift_route_template = None # type: ignore
158244
159245 try :
160- # Call next middleware or view
161- response = self .get_response (request )
162-
163- # Capture span after response is complete
164- self ._capture_span (request , response )
165-
166- return response
246+ with SpanUtils .with_span (span_info ):
247+ response = self .get_response (request )
248+ self ._capture_span (request , response , span_info )
249+ return response
167250 except Exception as e :
168- # Capture error span
169- self ._capture_error_span (request , e )
251+ self ._capture_error_span (request , e , span_info )
170252 raise
171253 finally :
172- # Reset context
173254 span_kind_context .reset (span_kind_token )
174- if replay_token :
175- replay_trace_id_context .reset (replay_token )
176- otel_context .detach (token )
177- span .end ()
255+ span_info .span .end ()
178256
179257 def process_view (
180258 self ,
@@ -199,23 +277,24 @@ def process_view(
199277 if route :
200278 request ._drift_route_template = route # type: ignore
201279
202- def _capture_span (self , request : HttpRequest , response : HttpResponse ) -> None :
280+ def _capture_span (
281+ self , request : HttpRequest , response : HttpResponse , span_info : SpanInfo
282+ ) -> None :
203283 """Create and collect a span from request/response data.
204284
205285 Args:
206286 request: Django HttpRequest object
207287 response: Django HttpResponse object
288+ span_info: SpanInfo containing trace/span IDs and span reference
208289 """
209290 start_time_ns = getattr (request , "_drift_start_time_ns" , None )
210- span = getattr (request , "_drift_span" , None )
211291
212- if not start_time_ns or not span or not span .is_recording ():
292+ if not start_time_ns or not span_info . span .is_recording ():
213293 return
214294
215- # Extract trace_id and span_id from the span context
216- span_context = span .get_span_context ()
217- trace_id = format (span_context .trace_id , "032x" )
218- span_id = format (span_context .span_id , "016x" )
295+ # Use trace_id and span_id from span_info
296+ trace_id = span_info .trace_id
297+ span_id = span_info .span_id
219298
220299 end_time_ns = time .time_ns ()
221300 duration_ns = end_time_ns - start_time_ns
@@ -338,7 +417,7 @@ def dict_to_schema_merges(merges_dict):
338417 # Only create and collect span in RECORD mode
339418 # In REPLAY mode, we only set up context for child spans but don't record the root span
340419 if sdk .mode == TuskDriftMode .RECORD :
341- span = CleanSpanData (
420+ clean_span = CleanSpanData (
342421 trace_id = trace_id ,
343422 span_id = span_id ,
344423 parent_span_id = "" ,
@@ -357,33 +436,34 @@ def dict_to_schema_merges(merges_dict):
357436 input_schema_hash = input_schema_info .decoded_schema_hash ,
358437 output_schema_hash = output_schema_info .decoded_schema_hash ,
359438 status = status ,
360- is_pre_app_start = not sdk . app_ready ,
439+ is_pre_app_start = span_info . is_pre_app_start ,
361440 is_root_span = True ,
362441 timestamp = Timestamp (seconds = timestamp_seconds , nanos = timestamp_nanos ),
363442 duration = Duration (seconds = duration_seconds , nanos = duration_nanos ),
364443 transform_metadata = transform_metadata ,
365444 metadata = None ,
366445 )
367446
368- sdk .collect_span (span )
447+ sdk .collect_span (clean_span )
369448
370- def _capture_error_span (self , request : HttpRequest , exception : Exception ) -> None :
449+ def _capture_error_span (
450+ self , request : HttpRequest , exception : Exception , span_info : SpanInfo
451+ ) -> None :
371452 """Create and collect an error span.
372453
373454 Args:
374455 request: Django HttpRequest object
375456 exception: The exception that was raised
457+ span_info: SpanInfo containing trace/span IDs and span reference
376458 """
377459 start_time_ns = getattr (request , "_drift_start_time_ns" , None )
378- span = getattr (request , "_drift_span" , None )
379460
380- if not start_time_ns or not span or not span .is_recording ():
461+ if not start_time_ns or not span_info . span .is_recording ():
381462 return
382463
383- # Extract trace_id and span_id from the span context
384- span_context = span .get_span_context ()
385- trace_id = format (span_context .trace_id , "032x" )
386- span_id = format (span_context .span_id , "016x" )
464+ # Use trace_id and span_id from span_info
465+ trace_id = span_info .trace_id
466+ span_id = span_info .span_id
387467
388468 end_time_ns = time .time_ns ()
389469 duration_ns = end_time_ns - start_time_ns
@@ -436,7 +516,7 @@ def dict_to_schema_merges(merges_dict):
436516 route_template = getattr (request , "_drift_route_template" , None )
437517 span_name = f"{ method } { route_template } " if route_template else f"{ method } { request .path } "
438518
439- span = CleanSpanData (
519+ clean_span = CleanSpanData (
440520 trace_id = trace_id ,
441521 span_id = span_id ,
442522 parent_span_id = "" ,
@@ -455,12 +535,12 @@ def dict_to_schema_merges(merges_dict):
455535 input_schema_hash = input_schema_info .decoded_schema_hash ,
456536 output_schema_hash = output_schema_info .decoded_schema_hash ,
457537 status = SpanStatus (code = StatusCode .ERROR , message = f"Exception: { type (exception ).__name__ } " ),
458- is_pre_app_start = not sdk . app_ready ,
538+ is_pre_app_start = span_info . is_pre_app_start ,
459539 is_root_span = True ,
460540 timestamp = Timestamp (seconds = timestamp_seconds , nanos = timestamp_nanos ),
461541 duration = Duration (seconds = duration_seconds , nanos = duration_nanos ),
462542 transform_metadata = None ,
463543 metadata = None ,
464544 )
465545
466- sdk .collect_span (span )
546+ sdk .collect_span (clean_span )
0 commit comments