Skip to content

Commit 9368a57

Browse files
Refactor urllib instrumentation to reduce code duplication
Extract shared logic into helper methods: - _build_input_value(): constructs HTTP request input value dict - _build_input_schema_merges(): creates schema merge hints for headers/body Both RECORD (_finalize_span) and REPLAY (_try_get_mock) modes now use these helpers, eliminating ~60 lines of duplicate code and ensuring consistent behavior between modes.
1 parent 570059a commit 9368a57

1 file changed

Lines changed: 88 additions & 79 deletions

File tree

drift/instrumentation/urllib/instrumentation.py

Lines changed: 88 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,83 @@ def _get_content_type_header(self, headers: dict) -> str | None:
641641
return value
642642
return None
643643

644+
def _build_input_value(
645+
self,
646+
method: str,
647+
url: str,
648+
headers: dict[str, str],
649+
body: bytes | None,
650+
) -> tuple[dict[str, Any], str | None, int]:
651+
"""Build the input value dictionary for HTTP requests.
652+
653+
Args:
654+
method: HTTP method (GET, POST, etc.)
655+
url: Full request URL
656+
headers: Request headers dictionary
657+
body: Request body bytes (or None)
658+
659+
Returns:
660+
Tuple of (input_value dict, body_base64 string or None, body_size int)
661+
"""
662+
parsed_url = urlparse(url)
663+
664+
# Parse query params from URL
665+
params = {}
666+
if parsed_url.query:
667+
params = {k: v[0] if len(v) == 1 else v for k, v in parse_qs(parsed_url.query).items()}
668+
669+
# Encode body to base64
670+
body_base64 = None
671+
body_size = 0
672+
673+
if body is not None:
674+
body_base64, body_size = self._encode_body_to_base64(body)
675+
676+
input_value = {
677+
"method": method.upper(),
678+
"url": url,
679+
"protocol": parsed_url.scheme,
680+
"hostname": parsed_url.hostname,
681+
"port": parsed_url.port,
682+
"path": parsed_url.path or "/",
683+
"headers": dict(headers),
684+
"query": params,
685+
}
686+
687+
# Add body fields only if body exists
688+
if body_base64 is not None:
689+
input_value["body"] = body_base64
690+
input_value["bodySize"] = body_size
691+
692+
return input_value, body_base64, body_size
693+
694+
def _build_input_schema_merges(
695+
self,
696+
headers: dict[str, str],
697+
body_base64: str | None,
698+
) -> dict[str, SchemaMerge]:
699+
"""Build schema merge hints for input value.
700+
701+
Args:
702+
headers: Request headers dictionary
703+
body_base64: Base64-encoded body string (or None if no body)
704+
705+
Returns:
706+
Dictionary of schema merge hints
707+
"""
708+
input_schema_merges: dict[str, SchemaMerge] = {
709+
"headers": SchemaMerge(match_importance=0.0),
710+
}
711+
712+
if body_base64 is not None:
713+
request_content_type = self._get_content_type_header(headers)
714+
input_schema_merges["body"] = SchemaMerge(
715+
encoding=EncodingType.BASE64,
716+
decoded_type=self._get_decoded_type_from_content_type(request_content_type),
717+
)
718+
719+
return input_schema_merges
720+
644721
def _try_get_mock(
645722
self,
646723
sdk: TuskDrift,
@@ -665,46 +742,12 @@ def _try_get_mock(
665742
headers = request_info.get("headers", {})
666743
body = request_info.get("body")
667744

668-
# Parse query params from URL
669-
params = {}
670-
if parsed_url.query:
671-
params = {k: v[0] if len(v) == 1 else v for k, v in parse_qs(parsed_url.query).items()}
672-
673-
# Handle request body - encode to base64
674-
body_base64 = None
675-
body_size = 0
676-
677-
if body is not None:
678-
body_base64, body_size = self._encode_body_to_base64(body)
679-
680-
raw_input_value = {
681-
"method": method.upper(),
682-
"url": url,
683-
"protocol": parsed_url.scheme,
684-
"hostname": parsed_url.hostname,
685-
"port": parsed_url.port,
686-
"path": parsed_url.path or "/",
687-
"headers": dict(headers),
688-
"query": params,
689-
}
690-
691-
# Add body fields only if body exists
692-
if body_base64 is not None:
693-
raw_input_value["body"] = body_base64
694-
raw_input_value["bodySize"] = body_size
695-
745+
# Build input value using shared helper
746+
raw_input_value, body_base64, _ = self._build_input_value(method, url, headers, body)
696747
input_value = create_mock_input_value(raw_input_value)
697748

698-
# Create schema merge hints for input
699-
input_schema_merges = {
700-
"headers": SchemaMerge(match_importance=0.0),
701-
}
702-
if body_base64 is not None:
703-
request_content_type = self._get_content_type_header(headers)
704-
input_schema_merges["body"] = SchemaMerge(
705-
encoding=EncodingType.BASE64,
706-
decoded_type=self._get_decoded_type_from_content_type(request_content_type),
707-
)
749+
# Build schema merge hints using shared helper
750+
input_schema_merges = self._build_input_schema_merges(headers, body_base64)
708751

709752
# Use centralized mock finding utility
710753
from ...core.mock_utils import find_mock_response_sync
@@ -868,39 +911,12 @@ def _finalize_span(
868911
request_info: Original request info dict
869912
"""
870913
try:
871-
parsed_url = urlparse(url)
872-
873914
# ===== BUILD INPUT VALUE =====
874915
headers = request_info.get("headers", {})
875916
body = request_info.get("body")
876917

877-
# Parse query params from URL
878-
params = {}
879-
if parsed_url.query:
880-
params = {k: v[0] if len(v) == 1 else v for k, v in parse_qs(parsed_url.query).items()}
881-
882-
# Get request body and encode to base64
883-
body_base64 = None
884-
body_size = 0
885-
886-
if body is not None:
887-
body_base64, body_size = self._encode_body_to_base64(body)
888-
889-
input_value = {
890-
"method": method.upper(),
891-
"url": url,
892-
"protocol": parsed_url.scheme,
893-
"hostname": parsed_url.hostname,
894-
"port": parsed_url.port,
895-
"path": parsed_url.path or "/",
896-
"headers": dict(headers),
897-
"query": params,
898-
}
899-
900-
# Add body fields only if body exists
901-
if body_base64 is not None:
902-
input_value["body"] = body_base64
903-
input_value["bodySize"] = body_size
918+
# Build input value using shared helper
919+
input_value, body_base64, _ = self._build_input_value(method, url, headers, body)
904920

905921
# ===== BUILD OUTPUT VALUE =====
906922
output_value = {}
@@ -1016,7 +1032,10 @@ def _finalize_span(
10161032
transform_metadata = span_data.transform_metadata
10171033

10181034
# ===== CREATE SCHEMA MERGE HINTS =====
1019-
request_content_type = self._get_content_type_header(headers)
1035+
# Build input schema merges using shared helper
1036+
input_schema_merges = self._build_input_schema_merges(headers, body_base64)
1037+
1038+
# Get response content type for output schema merges
10201039
response_content_type = None
10211040
if response:
10221041
try:
@@ -1026,18 +1045,8 @@ def _finalize_span(
10261045
except Exception:
10271046
pass
10281047

1029-
# Create schema merge hints for input
1030-
input_schema_merges = {
1031-
"headers": SchemaMerge(match_importance=0.0),
1032-
}
1033-
if body_base64 is not None:
1034-
input_schema_merges["body"] = SchemaMerge(
1035-
encoding=EncodingType.BASE64,
1036-
decoded_type=self._get_decoded_type_from_content_type(request_content_type),
1037-
)
1038-
10391048
# Create schema merge hints for output
1040-
output_schema_merges = {
1049+
output_schema_merges: dict[str, SchemaMerge] = {
10411050
"headers": SchemaMerge(match_importance=0.0),
10421051
}
10431052
if response_body_base64 is not None:

0 commit comments

Comments
 (0)