-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathclientcredentials.py
More file actions
150 lines (109 loc) · 5.61 KB
/
clientcredentials.py
File metadata and controls
150 lines (109 loc) · 5.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
import hashlib
import requests
import time
from .types import SDKInitHook, BeforeRequestContext, BeforeRequestHook, AfterErrorContext, AfterErrorHook
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
from urllib.parse import urlparse, urljoin
class Credentials:
client_id: str
client_secret: str
token_url: str
def __init__(self, client_id: str, client_secret: str, token_url: str):
self.client_id = client_id
self.client_secret = client_secret
self.token_url = token_url
class Session:
credentials: Credentials
token: str
scopes: List[str]
expires_at: Optional[int] = None
def __init__(self, credentials: Credentials, token: str, scopes: List[str], expires_at: Optional[int] = None):
self.credentials = credentials
self.token = token
self.scopes = scopes
self.expires_at = expires_at
class ClientCredentialsHook(SDKInitHook, BeforeRequestHook, AfterErrorHook):
base_url: str
client: requests.Session
sessions: Dict[str, Session] = {}
def sdk_init(self, base_url: str, client: requests.Session) -> Tuple[str, requests.Session]:
self.base_url = base_url
self.client = client
return base_url, client
def before_request(self, hook_ctx: BeforeRequestContext, request: requests.PreparedRequest) -> requests.PreparedRequest:
if hook_ctx.oauth2_scopes is None:
# OAuth2 not in use
return request
credentials = self.get_credentials(hook_ctx.security_source)
if credentials is None:
return request
session_key = self.get_session_key(
credentials.client_id, credentials.client_secret)
if session_key not in self.sessions or not self.has_required_scopes(self.sessions[session_key].scopes, hook_ctx.oauth2_scopes) or self.has_token_expired(self.sessions[session_key].expires_at):
sess = self.do_token_request(credentials, self.get_scopes(
hook_ctx.oauth2_scopes, self.sessions.get(session_key)))
self.sessions[session_key] = sess
request.headers["Authorization"] = f"Bearer {self.sessions[session_key].token}"
return request
def after_error(self, hook_ctx: AfterErrorContext, response: Optional[requests.Response], error: Optional[Exception]) -> Union[Tuple[Optional[requests.Response], Optional[Exception]], Exception]:
if hook_ctx.oauth2_scopes is None:
# OAuth2 not in use
return (response, error)
# We don't want to refresh the token if the error is not related to the token
if error is not None:
return (response, error)
credentials = self.get_credentials(hook_ctx.security_source)
if credentials is None:
return (response, error)
if response is not None and response.status_code == 401:
session_key = self.get_session_key(
credentials.client_id, credentials.client_secret)
if session_key in self.sessions:
del self.sessions[session_key]
return (response, error)
def get_credentials(self, source: Optional[Union[Any, Callable[[], Any]]]) -> Optional[Credentials]:
if source is None:
return None
security = source() if callable(source) else source
if security is None or security.client_credentials is None:
return None
return Credentials(
client_id=security.client_credentials.client_id,
client_secret=security.client_credentials.client_secret,
token_url=security.client_credentials.token_url
)
def do_token_request(self, credentials: Credentials, scopes: Optional[List[str]]) -> Session:
payload = {
"grant_type": "client_credentials",
"client_id": credentials.client_id,
"client_secret": credentials.client_secret,
}
if scopes is not None and len(scopes) > 0:
payload["scope"] = " ".join(scopes)
token_url = credentials.token_url
if not bool(urlparse(credentials.token_url).netloc):
token_url = urljoin(self.base_url, credentials.token_url)
response = self.client.post(token_url, json=payload)
if response.status_code < 200 or response.status_code >= 300:
raise Exception(
f"Unexpected status code {response.status_code} from token endpoint")
response_data = response.json()
if response_data.get("token_type") != "Bearer":
raise Exception("Unexpected token type from token endpoint")
expires_at = None
if "expires_in" in response_data:
expires_at = int(time.time()) + response_data.get("expires_in")
return Session(credentials=credentials, token=response_data.get("access_token"), scopes=scopes, expires_at=expires_at)
def get_session_key(self, client_id: str, client_secret: str) -> str:
return hashlib.md5(f"{client_id}:{client_secret}".encode()).hexdigest()
def has_required_scopes(self, scopes: List[str], required_scopes: List[str]) -> bool:
return all(scope in scopes for scope in required_scopes)
def get_scopes(self, required_scopes: List[str], sess: Optional[Session]) -> List[str]:
scopes = required_scopes.copy()
if sess is not None and sess.scopes is not None:
scopes.extend(sess.scopes)
scopes = list(set(scopes))
return scopes
def has_token_expired(self, expires_at: Optional[int]) -> bool:
return expires_at is None or time.time()+60 >= expires_at