33from __future__ import annotations
44
55import os
6+ import ssl
67import sys
78import tempfile
89from typing import Any
910
1011import requests
12+ from requests .adapters import HTTPAdapter
1113from singer_sdk .authenticators import OAuthAuthenticator
1214from singer_sdk .helpers ._util import utc_now
1315
2022AUTH_ENDPOINT = "https://accounts.adp.com/auth/oauth/v2/token"
2123
2224
25+ class _MTLSAdapter (HTTPAdapter ):
26+ """Requests adapter that injects a pre-built SSL context for mTLS.
27+
28+ Works around SSL context caching in requests >=2.32.5 (psf/requests#6767)
29+ that ignores the ``cert=`` parameter when a cached context already exists
30+ for the target host, causing mTLS authentication to silently drop the
31+ client certificate.
32+ """
33+
34+ def __init__ (self , ssl_context : ssl .SSLContext , ** kwargs : Any ) -> None :
35+ self ._ssl_context = ssl_context
36+ super ().__init__ (** kwargs )
37+
38+ def init_poolmanager (self , * args : Any , ** kwargs : Any ) -> None :
39+ kwargs ["ssl_context" ] = self ._ssl_context
40+ super ().init_poolmanager (* args , ** kwargs ) # type: ignore[no-untyped-call]
41+
42+
2343class ADPAuthenticator (OAuthAuthenticator ):
2444 """Authenticator class for ADP."""
2545
@@ -46,57 +66,64 @@ def __init__(
4666
4767 @override
4868 @property
49- def oauth_request_body (self ) -> dict :
69+ def oauth_request_body (self ) -> dict [ str , Any ] :
5070 """Define the OAuth request body for ADP."""
5171 return {
5272 "grant_type" : "client_credentials" ,
5373 "client_id" : self .client_id ,
5474 "client_secret" : self .client_secret ,
5575 }
5676
77+ def _build_ssl_context (self ) -> ssl .SSLContext :
78+ """Build an SSL context with the client certificate pre-loaded.
79+
80+ Writes the PEM strings to temporary files, loads them into an
81+ ``ssl.SSLContext``, then deletes the files before returning. The
82+ context holds the cert in memory so the files are not needed at
83+ request time.
84+ """
85+ with (
86+ tempfile .NamedTemporaryFile (mode = "wb" , delete = False , suffix = ".pem" ) as cert_file ,
87+ tempfile .NamedTemporaryFile (mode = "wb" , delete = False , suffix = ".pem" ) as key_file ,
88+ ):
89+ cert_path = cert_file .name
90+ key_path = key_file .name
91+ cert_file .write (self .cert_public .encode ("utf-8" ))
92+ key_file .write (self .cert_private .encode ("utf-8" ))
93+
94+ try :
95+ os .chmod (cert_path , 0o600 ) # noqa: PTH101
96+ os .chmod (key_path , 0o600 ) # noqa: PTH101
97+ ctx = ssl .create_default_context ()
98+ ctx .load_cert_chain (certfile = cert_path , keyfile = key_path )
99+ finally :
100+ os .unlink (cert_path ) # noqa: PTH108
101+ os .unlink (key_path ) # noqa: PTH108
102+
103+ return ctx
104+
57105 @override
58106 def update_access_token (self ) -> None :
59107 """Update `access_token` along with `last_refreshed` and `expires_in`."""
60108 request_time = utc_now ()
61109
62- # Create temporary files for the cert and key
63- with (
64- tempfile .NamedTemporaryFile (mode = "wb+" , delete = False ) as cert_file ,
65- tempfile .NamedTemporaryFile (mode = "wb+" , delete = False ) as key_file ,
66- ):
67- # Write contents to the temporary files
68- cert_file .write (self .cert_public .encode ("utf-8" ))
69- cert_file .flush ()
110+ session = requests .Session ()
111+ session .mount ("https://" , _MTLSAdapter (ssl_context = self ._build_ssl_context ()))
70112
71- key_file .write (self .cert_private .encode ("utf-8" ))
72- key_file .flush ()
73-
74- # Ensure the files are readable only by the owner (optional)
75- os .chmod (cert_file .name , 0o600 ) # noqa: PTH101
76- os .chmod (key_file .name , 0o600 ) # noqa: PTH101
77-
78- # Make the OAuth request
79- try :
80- response = requests .post (
81- self .auth_endpoint ,
82- data = self .oauth_request_body ,
83- headers = self ._oauth_headers ,
84- timeout = 60 ,
85- cert = (cert_file .name , key_file .name ),
86- )
87- response .raise_for_status ()
88- except requests .HTTPError :
89- self .logger .warning (
90- "Failed OAuth login, response was '%s'" ,
91- response .text ,
92- )
93- raise
94- finally :
95- # Clean up the temporary files
96- cert_file .close ()
97- key_file .close ()
98- os .unlink (cert_file .name ) # noqa: PTH108
99- os .unlink (key_file .name ) # noqa: PTH108
113+ try :
114+ response = session .post (
115+ self .auth_endpoint ,
116+ data = self .oauth_request_body ,
117+ headers = self ._oauth_headers ,
118+ timeout = 60 ,
119+ )
120+ response .raise_for_status ()
121+ except requests .HTTPError :
122+ self .logger .warning (
123+ "Failed OAuth login, response was '%s'" ,
124+ response .text ,
125+ )
126+ raise
100127
101128 self .logger .info ("OAuth authorization attempt was successful." )
102129
0 commit comments