-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathoauth.py
More file actions
299 lines (261 loc) · 14 KB
/
oauth.py
File metadata and controls
299 lines (261 loc) · 14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
from dataclasses import InitVar, dataclass, field
from datetime import datetime, timedelta
from typing import Any, List, Mapping, Optional, Union
from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import (
AbstractOauth2Authenticator,
)
from airbyte_cdk.sources.streams.http.requests_native_auth.oauth import (
SingleUseRefreshTokenOauth2Authenticator,
)
from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse
logger = logging.getLogger("airbyte")
@dataclass
class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAuthenticator):
"""
Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials based on
a declarative connector configuration file. Credentials can be defined explicitly or via interpolation
at runtime. The generated access token is attached to each request via the Authorization header.
Attributes:
token_refresh_endpoint (Union[InterpolatedString, str]): The endpoint to refresh the access token
client_id (Union[InterpolatedString, str]): The client id
client_secret (Union[InterpolatedString, str]): Client secret
refresh_token (Union[InterpolatedString, str]): The token used to refresh the access token
access_token_name (Union[InterpolatedString, str]): THe field to extract access token from in the response
expires_in_name (Union[InterpolatedString, str]): The field to extract expires_in from in the response
config (Mapping[str, Any]): The user-provided configuration as specified by the source's spec
scopes (Optional[List[str]]): The scopes to request
token_expiry_date (Optional[Union[InterpolatedString, str]]): The access token expiration date
token_expiry_date_format str: format of the datetime; provide it if expires_in is returned in datetime instead of seconds
token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration
refresh_request_body (Optional[Mapping[str, Any]]): The request body to send in the refresh request
refresh_request_headers (Optional[Mapping[str, Any]]): The request headers to send in the refresh request
grant_type: The grant_type to request for access_token. If set to refresh_token, the refresh_token parameter has to be provided
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
"""
config: Mapping[str, Any]
parameters: InitVar[Mapping[str, Any]]
client_id: Optional[Union[InterpolatedString, str]] = None
client_secret: Optional[Union[InterpolatedString, str]] = None
token_refresh_endpoint: Optional[Union[InterpolatedString, str]] = None
refresh_token: Optional[Union[InterpolatedString, str]] = None
scopes: Optional[List[str]] = None
token_expiry_date: Optional[Union[InterpolatedString, str]] = None
_token_expiry_date: Optional[AirbyteDateTime] = field(init=False, repr=False, default=None)
token_expiry_date_format: Optional[str] = None
token_expiry_is_time_of_expiration: bool = False
access_token_name: Union[InterpolatedString, str] = "access_token"
access_token_value: Optional[Union[InterpolatedString, str]] = None
client_id_name: Union[InterpolatedString, str] = "client_id"
client_secret_name: Union[InterpolatedString, str] = "client_secret"
expires_in_name: Union[InterpolatedString, str] = "expires_in"
refresh_token_name: Union[InterpolatedString, str] = "refresh_token"
refresh_request_body: Optional[Mapping[str, Any]] = None
refresh_request_headers: Optional[Mapping[str, Any]] = None
grant_type_name: Union[InterpolatedString, str] = "grant_type"
grant_type: Union[InterpolatedString, str] = "refresh_token"
message_repository: MessageRepository = NoopMessageRepository()
profile_assertion: Optional[DeclarativeAuthenticator] = None
use_profile_assertion: Optional[Union[InterpolatedBoolean, str, bool]] = False
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
super().__init__()
if self.token_refresh_endpoint is not None:
self._token_refresh_endpoint: Optional[InterpolatedString] = InterpolatedString.create(
self.token_refresh_endpoint, parameters=parameters
)
else:
self._token_refresh_endpoint = None
self._client_id_name = InterpolatedString.create(self.client_id_name, parameters=parameters)
self._client_id = (
InterpolatedString.create(self.client_id, parameters=parameters)
if self.client_id
else self.client_id
)
self._client_secret_name = InterpolatedString.create(
self.client_secret_name, parameters=parameters
)
self._client_secret = (
InterpolatedString.create(self.client_secret, parameters=parameters)
if self.client_secret
else self.client_secret
)
self._refresh_token_name = InterpolatedString.create(
self.refresh_token_name, parameters=parameters
)
if self.refresh_token is not None:
self._refresh_token: Optional[InterpolatedString] = InterpolatedString.create(
self.refresh_token, parameters=parameters
)
else:
self._refresh_token = None
self.access_token_name = InterpolatedString.create(
self.access_token_name, parameters=parameters
)
self.expires_in_name = InterpolatedString.create(
self.expires_in_name, parameters=parameters
)
self.grant_type_name = InterpolatedString.create(
self.grant_type_name, parameters=parameters
)
self.grant_type = InterpolatedString.create(
"urn:ietf:params:oauth:grant-type:jwt-bearer"
if self.use_profile_assertion
else self.grant_type,
parameters=parameters,
)
self._refresh_request_body = InterpolatedMapping(
self.refresh_request_body or {}, parameters=parameters
)
self._refresh_request_headers = InterpolatedMapping(
self.refresh_request_headers or {}, parameters=parameters
)
try:
if (
isinstance(self.token_expiry_date, (int, str))
and str(self.token_expiry_date).isdigit()
):
self._token_expiry_date = ab_datetime_parse(self.token_expiry_date)
else:
self._token_expiry_date = (
ab_datetime_parse(
InterpolatedString.create(
self.token_expiry_date, parameters=parameters
).eval(self.config)
)
if self.token_expiry_date
else ab_datetime_now() - timedelta(days=1)
)
except ValueError as e:
raise ValueError(f"Invalid token expiry date format: {e}")
self.use_profile_assertion = (
InterpolatedBoolean(self.use_profile_assertion, parameters=parameters)
if isinstance(self.use_profile_assertion, str)
else self.use_profile_assertion
)
self.assertion_name = "assertion"
if self.access_token_value is not None:
self._access_token_value = InterpolatedString.create(
self.access_token_value, parameters=parameters
).eval(self.config)
else:
self._access_token_value = None
self._access_token: Optional[str] = (
self._access_token_value if self.access_token_value else None
)
if not self.use_profile_assertion and any(
client_creds is None for client_creds in [self.client_id, self.client_secret]
):
raise ValueError(
"OAuthAuthenticator configuration error: Both 'client_id' and 'client_secret' are required for the "
"basic OAuth flow."
)
if self.profile_assertion is None and self.use_profile_assertion:
raise ValueError(
"OAuthAuthenticator configuration error: 'profile_assertion' is required when using the profile assertion flow."
)
if self.get_grant_type() == "refresh_token" and self._refresh_token is None:
raise ValueError(
"OAuthAuthenticator configuration error: A 'refresh_token' is required when the 'grant_type' is set to 'refresh_token'."
)
def get_token_refresh_endpoint(self) -> Optional[str]:
if self._token_refresh_endpoint is not None:
refresh_token_endpoint: str = self._token_refresh_endpoint.eval(self.config)
if not refresh_token_endpoint:
raise ValueError(
"OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter"
)
return refresh_token_endpoint
return None
def get_client_id_name(self) -> str:
return self._client_id_name.eval(self.config) # type: ignore # eval returns a string in this context
def get_client_id(self) -> str:
client_id = self._client_id.eval(self.config) if self._client_id else self._client_id
if not client_id:
raise ValueError("OAuthAuthenticator was unable to evaluate client_id parameter")
return client_id # type: ignore # value will be returned as a string, or an error will be raised
def get_client_secret_name(self) -> str:
return self._client_secret_name.eval(self.config) # type: ignore # eval returns a string in this context
def get_client_secret(self) -> str:
client_secret = (
self._client_secret.eval(self.config) if self._client_secret else self._client_secret
)
if not client_secret:
# We've seen some APIs allowing empty client_secret so we will only log here
logger.warning(
"OAuthAuthenticator was unable to evaluate client_secret parameter hence it'll be empty"
)
return client_secret # type: ignore # value will be returned as a string, or an error will be raised
def get_refresh_token_name(self) -> str:
return self._refresh_token_name.eval(self.config) # type: ignore # eval returns a string in this context
def get_refresh_token(self) -> Optional[str]:
return None if self._refresh_token is None else str(self._refresh_token.eval(self.config))
def get_scopes(self) -> List[str]:
return self.scopes or []
def get_access_token_name(self) -> str:
return self.access_token_name.eval(self.config) # type: ignore # eval returns a string in this context
def get_expires_in_name(self) -> str:
return self.expires_in_name.eval(self.config) # type: ignore # eval returns a string in this context
def get_grant_type_name(self) -> str:
return self.grant_type_name.eval(self.config) # type: ignore # eval returns a string in this context
def get_grant_type(self) -> str:
return self.grant_type.eval(self.config) # type: ignore # eval returns a string in this context
def get_refresh_request_body(self) -> Mapping[str, Any]:
return self._refresh_request_body.eval(self.config)
def get_refresh_request_headers(self) -> Mapping[str, Any]:
return self._refresh_request_headers.eval(self.config)
def get_token_expiry_date(self) -> AirbyteDateTime:
if not self._has_access_token_been_initialized():
return AirbyteDateTime.from_datetime(datetime.min)
return self._token_expiry_date # type: ignore # _token_expiry_date is an AirbyteDateTime. It is never None despite what mypy thinks
def _has_access_token_been_initialized(self) -> bool:
return self._access_token is not None
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
self._token_expiry_date = value
def get_assertion_name(self) -> str:
return self.assertion_name
def get_assertion(self) -> str:
if self.profile_assertion is None:
raise ValueError("profile_assertion is not set")
return self.profile_assertion.token
def build_refresh_request_body(self) -> Mapping[str, Any]:
"""
Returns the request body to set on the refresh request
Override to define additional parameters
"""
if self.use_profile_assertion:
return {
self.get_grant_type_name(): self.get_grant_type(),
self.get_assertion_name(): self.get_assertion(),
}
return super().build_refresh_request_body()
@property
def access_token(self) -> str:
if self._access_token is None:
raise ValueError("access_token is not set")
return self._access_token
@access_token.setter
def access_token(self, value: str) -> None:
self._access_token = value
@property
def _message_repository(self) -> MessageRepository:
"""
Overriding AbstractOauth2Authenticator._message_repository to allow for HTTP request logs
"""
return self.message_repository
@dataclass
class DeclarativeSingleUseRefreshTokenOauth2Authenticator(
SingleUseRefreshTokenOauth2Authenticator, DeclarativeAuthenticator
):
"""
Declarative version of SingleUseRefreshTokenOauth2Authenticator which can be used in declarative connectors.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)