Skip to content

Commit 7d13f11

Browse files
refactor django
1 parent 085c1de commit 7d13f11

1 file changed

Lines changed: 170 additions & 90 deletions

File tree

drift/instrumentation/django/middleware.py

Lines changed: 170 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
from collections.abc import Callable
88
from typing import TYPE_CHECKING
99

10-
from opentelemetry import context as otel_context
1110
from opentelemetry.trace import SpanKind as OTelSpanKind
12-
from opentelemetry.trace import set_span_in_context
1311

1412
logger = logging.getLogger(__name__)
1513

@@ -35,6 +33,8 @@
3533
build_output_schema_merges,
3634
build_output_value,
3735
)
36+
from ...core.mode_utils import handle_record_mode
37+
from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils
3838

3939

4040
class DriftMiddleware:
@@ -64,67 +64,158 @@ def __call__(self, request: HttpRequest) -> HttpResponse:
6464

6565
sdk = TuskDrift.get_instance()
6666

67-
# Check if we're in REPLAY mode and handle trace ID extraction
68-
replay_trace_id = None
69-
replay_token = None
67+
# DISABLED mode - just pass through
68+
if sdk.mode == TuskDriftMode.DISABLED:
69+
return self.get_response(request)
70+
71+
# REPLAY mode - handle trace ID extraction and context setup
7072
if sdk.mode == TuskDriftMode.REPLAY:
71-
# Extract trace ID from headers (case-insensitive lookup)
72-
# Django stores headers in request.META
73-
headers_lower = {k.lower(): v for k, v in request.META.items() if k.startswith("HTTP_")}
74-
logger.info(f"[DJANGO_MIDDLEWARE] REPLAY mode, headers: {list(headers_lower.keys())}")
75-
# Convert HTTP_X_TD_TRACE_ID -> x-td-trace-id
76-
replay_trace_id = headers_lower.get("http_x_td_trace_id")
77-
logger.info(f"[DJANGO_MIDDLEWARE] replay_trace_id from header: {replay_trace_id}")
78-
79-
if not replay_trace_id:
80-
# No trace context in REPLAY mode; proceed without span
81-
logger.warning("[DJANGO_MIDDLEWARE] No replay_trace_id found in headers, proceeding without span")
82-
return self.get_response(request)
83-
84-
# Set replay trace context
85-
replay_token = replay_trace_id_context.set(replay_trace_id)
73+
return self._handle_replay_request(request, sdk)
74+
75+
# RECORD mode - use handle_record_mode for consistent is_pre_app_start logic
76+
return handle_record_mode(
77+
original_function_call=lambda: self.get_response(request),
78+
record_mode_handler=lambda is_pre_app_start: self._record_request(
79+
request, sdk, is_pre_app_start
80+
),
81+
span_kind=OTelSpanKind.SERVER,
82+
)
8683

87-
start_time_ns = time.time_ns()
84+
def _handle_replay_request(
85+
self, request: HttpRequest, sdk
86+
) -> HttpResponse:
87+
"""Handle request in REPLAY mode.
8888
89-
# Create OpenTelemetry span
90-
from ...core.drift_sdk import TuskDrift
89+
Extracts trace ID from headers and sets up context for child spans.
90+
Does not record the root span in REPLAY mode.
9191
92-
sdk = TuskDrift.get_instance()
93-
tracer = sdk.get_tracer()
92+
Args:
93+
request: Django HttpRequest object
94+
sdk: TuskDrift SDK instance
95+
96+
Returns:
97+
Django HttpResponse object
98+
"""
99+
# Extract trace ID from headers (case-insensitive lookup)
100+
# Django stores headers in request.META
101+
headers_lower = {
102+
k.lower(): v for k, v in request.META.items() if k.startswith("HTTP_")
103+
}
104+
logger.info(
105+
f"[DJANGO_MIDDLEWARE] REPLAY mode, headers: {list(headers_lower.keys())}"
106+
)
107+
# Convert HTTP_X_TD_TRACE_ID -> x-td-trace-id
108+
replay_trace_id = headers_lower.get("http_x_td_trace_id")
109+
logger.info(
110+
f"[DJANGO_MIDDLEWARE] replay_trace_id from header: {replay_trace_id}"
111+
)
112+
113+
if not replay_trace_id:
114+
# No trace context in REPLAY mode; proceed without span
115+
logger.warning(
116+
"[DJANGO_MIDDLEWARE] No replay_trace_id found in headers, proceeding without span"
117+
)
118+
return self.get_response(request)
119+
120+
# Set replay trace context
121+
replay_token = replay_trace_id_context.set(replay_trace_id)
94122

95123
method = request.method
96124
path = request.path
97125
span_name = f"{method} {path}"
98126

99-
span = tracer.start_span(
100-
name=span_name,
101-
kind=OTelSpanKind.SERVER,
102-
attributes={
103-
TdSpanAttributes.NAME: span_name,
104-
TdSpanAttributes.PACKAGE_NAME: "django",
105-
TdSpanAttributes.INSTRUMENTATION_NAME: "DjangoInstrumentation",
106-
TdSpanAttributes.SUBMODULE_NAME: method,
107-
TdSpanAttributes.PACKAGE_TYPE: PackageType.HTTP.name,
108-
TdSpanAttributes.IS_PRE_APP_START: not sdk.app_ready,
109-
TdSpanAttributes.IS_ROOT_SPAN: True,
110-
},
127+
# Create span using SpanUtils
128+
span_info = SpanUtils.create_span(
129+
CreateSpanOptions(
130+
name=span_name,
131+
kind=OTelSpanKind.SERVER,
132+
attributes={
133+
TdSpanAttributes.NAME: span_name,
134+
TdSpanAttributes.PACKAGE_NAME: "django",
135+
TdSpanAttributes.INSTRUMENTATION_NAME: "DjangoInstrumentation",
136+
TdSpanAttributes.SUBMODULE_NAME: method,
137+
TdSpanAttributes.PACKAGE_TYPE: PackageType.HTTP.name,
138+
TdSpanAttributes.IS_PRE_APP_START: False,
139+
TdSpanAttributes.IS_ROOT_SPAN: True,
140+
},
141+
is_pre_app_start=False,
142+
)
143+
)
144+
145+
if not span_info:
146+
# Failed to create span, just process the request
147+
replay_trace_id_context.reset(replay_token)
148+
return self.get_response(request)
149+
150+
# Set span_kind_context for child spans
151+
span_kind_token = span_kind_context.set(SpanKind.SERVER)
152+
153+
# Store metadata on request for later use
154+
request._drift_start_time_ns = time.time_ns() # type: ignore
155+
request._drift_span = span_info.span # type: ignore
156+
request._drift_route_template = None # type: ignore
157+
158+
try:
159+
with SpanUtils.with_span(span_info):
160+
response = self.get_response(request)
161+
# REPLAY mode: don't capture the span (it's already recorded)
162+
return response
163+
finally:
164+
# Reset context
165+
span_kind_context.reset(span_kind_token)
166+
replay_trace_id_context.reset(replay_token)
167+
span_info.span.end()
168+
169+
def _record_request(
170+
self, request: HttpRequest, sdk, is_pre_app_start: bool
171+
) -> HttpResponse:
172+
"""Handle request in RECORD mode.
173+
174+
Creates a span, processes the request, and captures the span.
175+
176+
Args:
177+
request: Django HttpRequest object
178+
sdk: TuskDrift SDK instance
179+
is_pre_app_start: Whether this request occurred before app was marked ready
180+
181+
Returns:
182+
Django HttpResponse object
183+
"""
184+
start_time_ns = time.time_ns()
185+
186+
method = request.method
187+
path = request.path
188+
span_name = f"{method} {path}"
189+
190+
# Create span using SpanUtils
191+
span_info = SpanUtils.create_span(
192+
CreateSpanOptions(
193+
name=span_name,
194+
kind=OTelSpanKind.SERVER,
195+
attributes={
196+
TdSpanAttributes.NAME: span_name,
197+
TdSpanAttributes.PACKAGE_NAME: "django",
198+
TdSpanAttributes.INSTRUMENTATION_NAME: "DjangoInstrumentation",
199+
TdSpanAttributes.SUBMODULE_NAME: method,
200+
TdSpanAttributes.PACKAGE_TYPE: PackageType.HTTP.name,
201+
TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start,
202+
TdSpanAttributes.IS_ROOT_SPAN: True,
203+
},
204+
is_pre_app_start=is_pre_app_start,
205+
)
111206
)
112207

113-
# Make span active
114-
ctx = otel_context.get_current()
115-
ctx_with_span = set_span_in_context(span, ctx)
116-
token = otel_context.attach(ctx_with_span)
208+
if not span_info:
209+
# Failed to create span, just process the request
210+
return self.get_response(request)
117211

118-
# Set span_kind_context for child spans and socket instrumentation to detect SERVER context
212+
# Set span_kind_context for child spans
119213
span_kind_token = span_kind_context.set(SpanKind.SERVER)
120214

121215
# Capture request body
122-
# Django provides request.body which handles reading and caching
123-
# No truncation at capture time - span-level 1MB blocking at export handles oversized spans
124216
request_body = None
125217
if request.method in ("POST", "PUT", "PATCH"):
126218
try:
127-
# Django's request.body automatically reads and caches the body
128219
request_body = request.body
129220
except Exception:
130221
pass
@@ -133,48 +224,35 @@ def __call__(self, request: HttpRequest) -> HttpResponse:
133224
query_string = request.META.get("QUERY_STRING", "")
134225
target = f"{path}?{query_string}" if query_string else path
135226

136-
# Extract headers from META
137227
from ..wsgi import extract_headers
138228

139229
request_headers = extract_headers(request.META)
140230

141-
if self.transform_engine and self.transform_engine.should_drop_inbound_request(method, target, request_headers):
231+
if self.transform_engine and self.transform_engine.should_drop_inbound_request(
232+
method, target, request_headers
233+
):
142234
# Reset context before early return
143235
span_kind_context.reset(span_kind_token)
144-
if replay_token:
145-
replay_trace_id_context.reset(replay_token)
146-
otel_context.detach(token)
147-
span.end()
236+
span_info.span.end()
148237
return self.get_response(request)
149238

150239
# Store metadata on request for later use
151240
request._drift_start_time_ns = start_time_ns # type: ignore
152-
request._drift_span = span # type: ignore
153-
request._drift_token = token # type: ignore
154-
request._drift_replay_token = replay_token # type: ignore
155-
request._drift_span_kind_token = span_kind_token # type: ignore
241+
request._drift_span_info = span_info # type: ignore
156242
request._drift_request_body = request_body # type: ignore
157-
request._drift_route_template = None # Will be set in process_view # type: ignore
243+
request._drift_route_template = None # type: ignore
158244

159245
try:
160-
# Call next middleware or view
161-
response = self.get_response(request)
162-
163-
# Capture span after response is complete
164-
self._capture_span(request, response)
165-
166-
return response
246+
with SpanUtils.with_span(span_info):
247+
response = self.get_response(request)
248+
self._capture_span(request, response, span_info)
249+
return response
167250
except Exception as e:
168-
# Capture error span
169-
self._capture_error_span(request, e)
251+
self._capture_error_span(request, e, span_info)
170252
raise
171253
finally:
172-
# Reset context
173254
span_kind_context.reset(span_kind_token)
174-
if replay_token:
175-
replay_trace_id_context.reset(replay_token)
176-
otel_context.detach(token)
177-
span.end()
255+
span_info.span.end()
178256

179257
def process_view(
180258
self,
@@ -199,23 +277,24 @@ def process_view(
199277
if route:
200278
request._drift_route_template = route # type: ignore
201279

202-
def _capture_span(self, request: HttpRequest, response: HttpResponse) -> None:
280+
def _capture_span(
281+
self, request: HttpRequest, response: HttpResponse, span_info: SpanInfo
282+
) -> None:
203283
"""Create and collect a span from request/response data.
204284
205285
Args:
206286
request: Django HttpRequest object
207287
response: Django HttpResponse object
288+
span_info: SpanInfo containing trace/span IDs and span reference
208289
"""
209290
start_time_ns = getattr(request, "_drift_start_time_ns", None)
210-
span = getattr(request, "_drift_span", None)
211291

212-
if not start_time_ns or not span or not span.is_recording():
292+
if not start_time_ns or not span_info.span.is_recording():
213293
return
214294

215-
# Extract trace_id and span_id from the span context
216-
span_context = span.get_span_context()
217-
trace_id = format(span_context.trace_id, "032x")
218-
span_id = format(span_context.span_id, "016x")
295+
# Use trace_id and span_id from span_info
296+
trace_id = span_info.trace_id
297+
span_id = span_info.span_id
219298

220299
end_time_ns = time.time_ns()
221300
duration_ns = end_time_ns - start_time_ns
@@ -338,7 +417,7 @@ def dict_to_schema_merges(merges_dict):
338417
# Only create and collect span in RECORD mode
339418
# In REPLAY mode, we only set up context for child spans but don't record the root span
340419
if sdk.mode == TuskDriftMode.RECORD:
341-
span = CleanSpanData(
420+
clean_span = CleanSpanData(
342421
trace_id=trace_id,
343422
span_id=span_id,
344423
parent_span_id="",
@@ -357,33 +436,34 @@ def dict_to_schema_merges(merges_dict):
357436
input_schema_hash=input_schema_info.decoded_schema_hash,
358437
output_schema_hash=output_schema_info.decoded_schema_hash,
359438
status=status,
360-
is_pre_app_start=not sdk.app_ready,
439+
is_pre_app_start=span_info.is_pre_app_start,
361440
is_root_span=True,
362441
timestamp=Timestamp(seconds=timestamp_seconds, nanos=timestamp_nanos),
363442
duration=Duration(seconds=duration_seconds, nanos=duration_nanos),
364443
transform_metadata=transform_metadata,
365444
metadata=None,
366445
)
367446

368-
sdk.collect_span(span)
447+
sdk.collect_span(clean_span)
369448

370-
def _capture_error_span(self, request: HttpRequest, exception: Exception) -> None:
449+
def _capture_error_span(
450+
self, request: HttpRequest, exception: Exception, span_info: SpanInfo
451+
) -> None:
371452
"""Create and collect an error span.
372453
373454
Args:
374455
request: Django HttpRequest object
375456
exception: The exception that was raised
457+
span_info: SpanInfo containing trace/span IDs and span reference
376458
"""
377459
start_time_ns = getattr(request, "_drift_start_time_ns", None)
378-
span = getattr(request, "_drift_span", None)
379460

380-
if not start_time_ns or not span or not span.is_recording():
461+
if not start_time_ns or not span_info.span.is_recording():
381462
return
382463

383-
# Extract trace_id and span_id from the span context
384-
span_context = span.get_span_context()
385-
trace_id = format(span_context.trace_id, "032x")
386-
span_id = format(span_context.span_id, "016x")
464+
# Use trace_id and span_id from span_info
465+
trace_id = span_info.trace_id
466+
span_id = span_info.span_id
387467

388468
end_time_ns = time.time_ns()
389469
duration_ns = end_time_ns - start_time_ns
@@ -436,7 +516,7 @@ def dict_to_schema_merges(merges_dict):
436516
route_template = getattr(request, "_drift_route_template", None)
437517
span_name = f"{method} {route_template}" if route_template else f"{method} {request.path}"
438518

439-
span = CleanSpanData(
519+
clean_span = CleanSpanData(
440520
trace_id=trace_id,
441521
span_id=span_id,
442522
parent_span_id="",
@@ -455,12 +535,12 @@ def dict_to_schema_merges(merges_dict):
455535
input_schema_hash=input_schema_info.decoded_schema_hash,
456536
output_schema_hash=output_schema_info.decoded_schema_hash,
457537
status=SpanStatus(code=StatusCode.ERROR, message=f"Exception: {type(exception).__name__}"),
458-
is_pre_app_start=not sdk.app_ready,
538+
is_pre_app_start=span_info.is_pre_app_start,
459539
is_root_span=True,
460540
timestamp=Timestamp(seconds=timestamp_seconds, nanos=timestamp_nanos),
461541
duration=Duration(seconds=duration_seconds, nanos=duration_nanos),
462542
transform_metadata=None,
463543
metadata=None,
464544
)
465545

466-
sdk.collect_span(span)
546+
sdk.collect_span(clean_span)

0 commit comments

Comments
 (0)