Skip to content

Commit d8eda9d

Browse files
fix: Build an ssl.SSLContext explicitly and inject it via a custom HTTPAdapter to bypass requests' internal cert handling and cache (#81)
- psf/requests#6767 Signed-off-by: Edgar Ramírez Mondragón <edgarrm358@gmail.com>
1 parent 13fc9a2 commit d8eda9d

8 files changed

Lines changed: 135 additions & 86 deletions

File tree

.pre-commit-config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ repos:
1818
- id: trailing-whitespace
1919

2020
- repo: https://github.com/python-jsonschema/check-jsonschema
21-
rev: 0.36.0
21+
rev: 0.37.1
2222
hooks:
2323
- id: check-dependabot
2424
- id: check-github-workflows
25+
- id: check-meltano
2526

2627
- repo: https://github.com/astral-sh/ruff-pre-commit
27-
rev: v0.14.13
28+
rev: v0.15.8
2829
hooks:
2930
- id: ruff-check
3031
args: [--fix, --exit-non-zero-on-fix, --show-fixes]

generate_schema.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def make_nullable(schema: dict) -> dict:
3838

3939
def generate_schema(
4040
stream_instance: Stream,
41-
context: Context,
41+
context: Context | None,
4242
output_file: str,
4343
) -> None:
4444
"""Generate a schema for a given stream."""
@@ -89,7 +89,9 @@ def main() -> None:
8989

9090
for child_name, child_stream in children.items():
9191
generate_schema(
92-
child_stream, context=None, output_file=f"tap_adp/schemas/{child_name}.json"
92+
child_stream,
93+
context=None,
94+
output_file=f"tap_adp/schemas/{child_name}.json",
9395
)
9496

9597

pyproject.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
dependencies = [
2323
"singer-sdk[faker]~=0.53.2",
2424
"requests~=2.33.0",
25-
"typing-extensions>=4.15.0 ; python_full_version < '3.12'",
25+
"typing-extensions>=4.15.0 ; python_full_version < '3.13'",
2626
]
2727

2828
[project.optional-dependencies]
@@ -47,11 +47,20 @@ build-backend = "hatchling.build"
4747

4848
[tool.pytest]
4949
addopts = [
50+
"-v",
51+
"-ra",
5052
"--durations=10",
5153
]
54+
filterwarnings = [ "error" ]
55+
log_level = "INFO"
5256
minversion = "9"
57+
strict = true
58+
testpaths = [ "tests" ]
5359

5460
[tool.mypy]
61+
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
62+
strict = true
63+
warn_unreachable = true
5564
warn_unused_configs = true
5665
warn_unused_ignores = true
5766

tap_adp/authenticator.py

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from __future__ import annotations
44

55
import os
6+
import ssl
67
import sys
78
import tempfile
89
from typing import Any
910

1011
import requests
12+
from requests.adapters import HTTPAdapter
1113
from singer_sdk.authenticators import OAuthAuthenticator
1214
from singer_sdk.helpers._util import utc_now
1315

@@ -20,6 +22,24 @@
2022
AUTH_ENDPOINT = "https://accounts.adp.com/auth/oauth/v2/token"
2123

2224

25+
class _MTLSAdapter(HTTPAdapter):
26+
"""Requests adapter that injects a pre-built SSL context for mTLS.
27+
28+
Works around SSL context caching in requests >=2.32.5 (psf/requests#6767)
29+
that ignores the ``cert=`` parameter when a cached context already exists
30+
for the target host, causing mTLS authentication to silently drop the
31+
client certificate.
32+
"""
33+
34+
def __init__(self, ssl_context: ssl.SSLContext, **kwargs: Any) -> None:
35+
self._ssl_context = ssl_context
36+
super().__init__(**kwargs)
37+
38+
def init_poolmanager(self, *args: Any, **kwargs: Any) -> None:
39+
kwargs["ssl_context"] = self._ssl_context
40+
super().init_poolmanager(*args, **kwargs) # type: ignore[no-untyped-call]
41+
42+
2343
class ADPAuthenticator(OAuthAuthenticator):
2444
"""Authenticator class for ADP."""
2545

@@ -46,57 +66,64 @@ def __init__(
4666

4767
@override
4868
@property
49-
def oauth_request_body(self) -> dict:
69+
def oauth_request_body(self) -> dict[str, Any]:
5070
"""Define the OAuth request body for ADP."""
5171
return {
5272
"grant_type": "client_credentials",
5373
"client_id": self.client_id,
5474
"client_secret": self.client_secret,
5575
}
5676

77+
def _build_ssl_context(self) -> ssl.SSLContext:
78+
"""Build an SSL context with the client certificate pre-loaded.
79+
80+
Writes the PEM strings to temporary files, loads them into an
81+
``ssl.SSLContext``, then deletes the files before returning. The
82+
context holds the cert in memory so the files are not needed at
83+
request time.
84+
"""
85+
with (
86+
tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".pem") as cert_file,
87+
tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".pem") as key_file,
88+
):
89+
cert_path = cert_file.name
90+
key_path = key_file.name
91+
cert_file.write(self.cert_public.encode("utf-8"))
92+
key_file.write(self.cert_private.encode("utf-8"))
93+
94+
try:
95+
os.chmod(cert_path, 0o600) # noqa: PTH101
96+
os.chmod(key_path, 0o600) # noqa: PTH101
97+
ctx = ssl.create_default_context()
98+
ctx.load_cert_chain(certfile=cert_path, keyfile=key_path)
99+
finally:
100+
os.unlink(cert_path) # noqa: PTH108
101+
os.unlink(key_path) # noqa: PTH108
102+
103+
return ctx
104+
57105
@override
58106
def update_access_token(self) -> None:
59107
"""Update `access_token` along with `last_refreshed` and `expires_in`."""
60108
request_time = utc_now()
61109

62-
# Create temporary files for the cert and key
63-
with (
64-
tempfile.NamedTemporaryFile(mode="wb+", delete=False) as cert_file,
65-
tempfile.NamedTemporaryFile(mode="wb+", delete=False) as key_file,
66-
):
67-
# Write contents to the temporary files
68-
cert_file.write(self.cert_public.encode("utf-8"))
69-
cert_file.flush()
110+
session = requests.Session()
111+
session.mount("https://", _MTLSAdapter(ssl_context=self._build_ssl_context()))
70112

71-
key_file.write(self.cert_private.encode("utf-8"))
72-
key_file.flush()
73-
74-
# Ensure the files are readable only by the owner (optional)
75-
os.chmod(cert_file.name, 0o600) # noqa: PTH101
76-
os.chmod(key_file.name, 0o600) # noqa: PTH101
77-
78-
# Make the OAuth request
79-
try:
80-
response = requests.post(
81-
self.auth_endpoint,
82-
data=self.oauth_request_body,
83-
headers=self._oauth_headers,
84-
timeout=60,
85-
cert=(cert_file.name, key_file.name),
86-
)
87-
response.raise_for_status()
88-
except requests.HTTPError:
89-
self.logger.warning(
90-
"Failed OAuth login, response was '%s'",
91-
response.text,
92-
)
93-
raise
94-
finally:
95-
# Clean up the temporary files
96-
cert_file.close()
97-
key_file.close()
98-
os.unlink(cert_file.name) # noqa: PTH108
99-
os.unlink(key_file.name) # noqa: PTH108
113+
try:
114+
response = session.post(
115+
self.auth_endpoint,
116+
data=self.oauth_request_body,
117+
headers=self._oauth_headers,
118+
timeout=60,
119+
)
120+
response.raise_for_status()
121+
except requests.HTTPError:
122+
self.logger.warning(
123+
"Failed OAuth login, response was '%s'",
124+
response.text,
125+
)
126+
raise
100127

101128
self.logger.info("OAuth authorization attempt was successful.")
102129

tap_adp/client.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
import decimal
66
import sys
7-
import typing as t
87
from functools import cached_property
98
from http import HTTPStatus
9+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
1010

1111
from singer_sdk import SchemaDirectory, StreamSchema
1212
from singer_sdk.helpers._typing import TypeConformanceLevel
@@ -22,12 +22,21 @@
2222
else:
2323
from typing_extensions import override
2424

25-
if t.TYPE_CHECKING:
25+
if sys.version_info >= (3, 13):
26+
from typing import TypeVar
27+
else:
28+
from typing_extensions import TypeVar
29+
30+
if TYPE_CHECKING:
31+
from collections.abc import Iterable
32+
2633
import requests
2734
from singer_sdk.helpers.types import Context
2835

36+
_T = TypeVar("_T", default=Any)
2937

30-
class ADPStream(RESTStream):
38+
39+
class ADPStream(RESTStream[_T], Generic[_T]):
3140
"""ADP stream class."""
3241

3342
records_jsonpath = "$[*]"
@@ -56,7 +65,7 @@ def authenticator(self) -> ADPAuthenticator:
5665
)
5766

5867
@override
59-
def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
68+
def parse_response(self, response: requests.Response) -> Iterable[dict[str, Any]]:
6069
"""Parse the response and return an iterator of result records.
6170
6271
Args:
@@ -90,11 +99,11 @@ def response_error_message(self, response: requests.Response) -> str:
9099
)
91100

92101

93-
class PaginatedADPStream(ADPStream):
102+
class PaginatedADPStream(ADPStream[int]):
94103
"""Paginated ADP stream class."""
95104

96105
@override
97-
def get_new_paginator(self) -> BaseAPIPaginator:
106+
def get_new_paginator(self) -> ADPPaginator:
98107
"""Create a new paginator for ADP API pagination."""
99108
return ADPPaginator(start_value=0, page_size=100)
100109

@@ -103,23 +112,23 @@ def get_url_params(
103112
self,
104113
context: Context | None,
105114
next_page_token: int | None,
106-
) -> dict[str, t.Any]:
115+
) -> dict[str, Any]:
107116
return {
108117
"$top": 100, # Set the desired page size
109118
"$skip": next_page_token or 0,
110119
}
111120

112121

113122
class ADPPaginator(BaseAPIPaginator[int]):
114-
"""Paginator for ADP API that uses 'top' and 'skip' parameters and stops on 204 response.""" # noqa: E501
123+
"""Paginator for ADP API that uses 'top' and 'skip' parameters and stops on 204 response."""
115124

116125
@override
117126
def __init__(
118127
self,
119128
start_value: int,
120129
page_size: int,
121-
*args: t.Any,
122-
**kwargs: t.Any,
130+
*args: Any,
131+
**kwargs: Any,
123132
) -> None:
124133
"""Initialize the paginator with a starting value and page size.
125134
@@ -153,5 +162,5 @@ def has_more(self, response: requests.Response) -> bool:
153162
154163
Returns:
155164
`True` if pagination should continue, `False` if a 204 No Content is received.
156-
""" # noqa: E501
165+
"""
157166
return response.status_code != HTTPStatus.NO_CONTENT

0 commit comments

Comments
 (0)