Skip to content

Commit 3a7f613

Browse files
authored
feat: add CSRF token normalization for Django record/replay (#57)
1 parent be88193 commit 3a7f613

File tree

5 files changed

+147
-1
lines changed

5 files changed

+147
-1
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Django CSRF token utilities for consistent record/replay testing.
2+
3+
This module provides utilities to normalize CSRF tokens so that recorded
4+
and replayed responses produce identical output for comparison.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import logging
10+
import re
11+
12+
logger = logging.getLogger(__name__)
13+
14+
CSRF_PLACEHOLDER = "__DRIFT_CSRF__"
15+
16+
17+
def normalize_csrf_in_body(body: bytes | None) -> bytes | None:
18+
"""Normalize CSRF tokens in response body for consistent record/replay comparison.
19+
20+
Replaces Django CSRF tokens with a fixed placeholder so that recorded
21+
responses match replayed responses during comparison.
22+
23+
This should be called after the response is sent to the browser,
24+
but before storing in the span. The actual response to the browser
25+
is unchanged.
26+
27+
Args:
28+
body: Response body bytes (typically HTML)
29+
30+
Returns:
31+
Body with CSRF tokens normalized, or original body if not applicable
32+
"""
33+
if not body:
34+
return body
35+
36+
try:
37+
body_str = body.decode("utf-8")
38+
39+
# Pattern 1: Hidden input fields with csrfmiddlewaretoken
40+
# <input type="hidden" name="csrfmiddlewaretoken" value="ABC123...">
41+
# Handles both single and double quotes, various attribute orders
42+
csrf_input_pattern = (
43+
r'(<input[^>]*name=["\']csrfmiddlewaretoken["\'][^>]*value=["\'])'
44+
r'[^"\']+(["\'])'
45+
)
46+
body_str = re.sub(
47+
csrf_input_pattern,
48+
rf"\g<1>{CSRF_PLACEHOLDER}\2",
49+
body_str,
50+
flags=re.IGNORECASE,
51+
)
52+
53+
# Pattern 2: Also handle value before name (different attribute order)
54+
# <input type="hidden" value="ABC123" name="csrfmiddlewaretoken">
55+
csrf_input_pattern_alt = r'(<input[^>]*value=["\'])[^"\']+(["\'][^>]*name=["\']csrfmiddlewaretoken["\'])'
56+
body_str = re.sub(
57+
csrf_input_pattern_alt,
58+
rf"\g<1>{CSRF_PLACEHOLDER}\2",
59+
body_str,
60+
flags=re.IGNORECASE,
61+
)
62+
63+
return body_str.encode("utf-8")
64+
65+
except Exception as e:
66+
logger.debug(f"Error normalizing CSRF tokens: {e}")
67+
return body

drift/instrumentation/django/e2e-tests/src/test_requests.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,7 @@
2222
)
2323
make_request("DELETE", "/api/post/1/delete")
2424

25+
# Test CSRF token normalization
26+
make_request("GET", "/api/csrf-form")
27+
2528
print_request_summary()

drift/instrumentation/django/e2e-tests/src/urls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from django.urls import path
44
from views import (
55
create_post,
6+
csrf_form,
67
delete_post,
78
get_activity,
89
get_post,
@@ -19,4 +20,5 @@
1920
path("api/post/<int:post_id>", get_post, name="get_post"),
2021
path("api/post/<int:post_id>/delete", delete_post, name="delete_post"),
2122
path("api/activity", get_activity, name="get_activity"),
23+
path("api/csrf-form", csrf_form, name="csrf_form"),
2224
]

drift/instrumentation/django/e2e-tests/src/views.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from concurrent.futures import ThreadPoolExecutor
55

66
import requests
7-
from django.http import JsonResponse
7+
from django.http import HttpResponse, JsonResponse
8+
from django.middleware.csrf import get_token
89
from django.views.decorators.csrf import csrf_exempt
910
from django.views.decorators.http import require_GET, require_http_methods, require_POST
1011
from opentelemetry import context as otel_context
@@ -127,3 +128,26 @@ def get_activity(request):
127128
return JsonResponse(response.json())
128129
except Exception as e:
129130
return JsonResponse({"error": f"Failed to fetch activity: {str(e)}"}, status=500)
131+
132+
133+
@require_GET
134+
def csrf_form(request):
135+
"""Return an HTML form with CSRF token for testing CSRF normalization.
136+
137+
This endpoint tests that CSRF tokens are properly normalized during
138+
recording so that replay comparisons succeed.
139+
"""
140+
csrf_token = get_token(request)
141+
html = f"""<!DOCTYPE html>
142+
<html>
143+
<head><title>CSRF Test Form</title></head>
144+
<body>
145+
<h1>CSRF Test Form</h1>
146+
<form method="POST" action="/api/submit">
147+
<input type="hidden" name="csrfmiddlewaretoken" value="{csrf_token}">
148+
<input type="text" name="message" placeholder="Enter message">
149+
<button type="submit">Submit</button>
150+
</form>
151+
</body>
152+
</html>"""
153+
return HttpResponse(html, content_type="text/html")

drift/instrumentation/django/middleware.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def _handle_replay_request(self, request: HttpRequest, sdk) -> HttpResponse:
147147
with SpanUtils.with_span(span_info):
148148
response = self.get_response(request)
149149
# REPLAY mode: don't capture the span (it's already recorded)
150+
# But do normalize CSRF tokens in the response so comparison succeeds
151+
response = self._normalize_csrf_in_response(response)
150152
return response
151153
finally:
152154
# Reset context
@@ -262,6 +264,43 @@ def process_view(
262264
if route:
263265
request._drift_route_template = route # type: ignore
264266

267+
def _normalize_csrf_in_response(self, response: HttpResponse) -> HttpResponse:
268+
"""Normalize CSRF tokens in the actual response body for REPLAY mode.
269+
270+
In REPLAY mode, we need the actual HTTP response to match the recorded
271+
response (which had CSRF tokens normalized during recording). This modifies
272+
the response body to replace real CSRF tokens with the normalized placeholder.
273+
274+
This only affects HTML responses.
275+
276+
Args:
277+
response: Django HttpResponse object
278+
279+
Returns:
280+
Modified response with normalized CSRF tokens
281+
"""
282+
content_type = response.get("Content-Type", "")
283+
if "text/html" not in content_type.lower():
284+
return response
285+
286+
# Skip normalization for compressed responses - decoding gzip/deflate as UTF-8 would corrupt the body
287+
content_encoding = response.get("Content-Encoding", "").lower()
288+
if content_encoding and content_encoding != "identity":
289+
return response
290+
291+
# Get response body and normalize CSRF tokens
292+
if hasattr(response, "content") and response.content:
293+
from .csrf_utils import normalize_csrf_in_body
294+
295+
normalized_body = normalize_csrf_in_body(response.content)
296+
if normalized_body is not None and normalized_body != response.content:
297+
response.content = normalized_body
298+
# Update Content-Length header if present
299+
if "Content-Length" in response:
300+
response["Content-Length"] = len(normalized_body)
301+
302+
return response
303+
265304
def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info: SpanInfo) -> None:
266305
"""Create and collect a span from request/response data.
267306
@@ -301,6 +340,17 @@ def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info:
301340
if isinstance(content, bytes) and len(content) > 0:
302341
response_body = content
303342

343+
# Normalize CSRF tokens in HTML responses for consistent record/replay comparison
344+
# This only affects what is stored in the span, not what the browser receives
345+
if response_body:
346+
content_type = response_headers.get("Content-Type", "")
347+
content_encoding = response_headers.get("Content-Encoding", "").lower()
348+
# Skip normalization for compressed responses - decoding gzip/deflate as UTF-8 would corrupt the body
349+
if "text/html" in content_type.lower() and (not content_encoding or content_encoding == "identity"):
350+
from .csrf_utils import normalize_csrf_in_body
351+
352+
response_body = normalize_csrf_in_body(response_body)
353+
304354
output_value = build_output_value(
305355
status_code,
306356
status_message,

0 commit comments

Comments
 (0)