Skip to content

Commit e554391

Browse files
sungwygabeiglio
authored andcommitted
Replace Deprecated (Current) OAuth2 Handling with AuthManager Implementation LegacyOAuth2AuthManager (apache#1981)
<!-- Thanks for opening a pull request! --> <!-- In the case this PR will resolve an issue, please replace ${GITHUB_ISSUE_ID} below with the actual Github issue id. --> <!-- Closes apache#1909 --> # Rationale for this change Replace existing Auth handling with `LegacyOAuth2AuthManager`. Tracking issue: apache#1909 There will be follow up PRs to this PR that will address the following: - introduce a mechanism for using a custom `AuthManager` implementation, along with the ability to use a set of config parameters - introduce a `OAuth2AuthManager` that more closely follows the OAuth2 protocol, and also uses a separate thread to proactively refreshes the token, rather than reactively refreshing the token on `UnAuthorizedError` or the deprecated `AuthorizationExpiredError`. # Are these changes tested? Yes, both through unit and integration tests # Are there any user-facing changes? Yes - previously, if `TOKEN` and `CREDENTIAL` are both defined, `oauth/tokens` endpoint wouldn't be used to refresh the token with client credentials when the `RestCatalog` was initialized. However, `oauth/tokens` endpoint would be used on retries that handled 401 or 419 error. This erratic behavior will now be updated as follows: - if `CREDENTIAL` is defined, `oauth/tokens` endpoint will be used to fetch the access token using the client credentials both when the RestCatalog is initialized, and when the refresh_tokens call is made as a reaction to 401 or 419 error. - if both `CREDENTIAL` and `TOKEN` are defined, we will follow the above behavior. - if only `TOKEN` is defined, the initial token will be used instead <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent 29f4ebe commit e554391

4 files changed

Lines changed: 275 additions & 144 deletions

File tree

pyiceberg/catalog/rest/__init__.py

Lines changed: 49 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,18 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
from enum import Enum
18-
from json import JSONDecodeError
1918
from typing import (
2019
TYPE_CHECKING,
2120
Any,
2221
Dict,
2322
List,
24-
Literal,
2523
Optional,
2624
Set,
2725
Tuple,
28-
Type,
2926
Union,
3027
)
3128

32-
from pydantic import Field, ValidationError, field_validator
29+
from pydantic import Field, field_validator
3330
from requests import HTTPError, Session
3431
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
3532

@@ -41,22 +38,18 @@
4138
Catalog,
4239
PropertiesUpdateSummary,
4340
)
41+
from pyiceberg.catalog.rest.auth import AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
42+
from pyiceberg.catalog.rest.response import _handle_non_200_response
4443
from pyiceberg.exceptions import (
4544
AuthorizationExpiredError,
46-
BadRequestError,
4745
CommitFailedException,
4846
CommitStateUnknownException,
49-
ForbiddenError,
5047
NamespaceAlreadyExistsError,
5148
NamespaceNotEmptyError,
5249
NoSuchIdentifierError,
5350
NoSuchNamespaceError,
5451
NoSuchTableError,
5552
NoSuchViewError,
56-
OAuthError,
57-
RESTError,
58-
ServerError,
59-
ServiceUnavailableError,
6053
TableAlreadyExistsError,
6154
UnauthorizedError,
6255
)
@@ -182,15 +175,6 @@ class RegisterTableRequest(IcebergBaseModel):
182175
metadata_location: str = Field(..., alias="metadata-location")
183176

184177

185-
class TokenResponse(IcebergBaseModel):
186-
access_token: str = Field()
187-
token_type: str = Field()
188-
expires_in: Optional[int] = Field(default=None)
189-
issued_token_type: Optional[str] = Field(default=None)
190-
refresh_token: Optional[str] = Field(default=None)
191-
scope: Optional[str] = Field(default=None)
192-
193-
194178
class ConfigResponse(IcebergBaseModel):
195179
defaults: Properties = Field()
196180
overrides: Properties = Field()
@@ -229,24 +213,6 @@ class ListViewsResponse(IcebergBaseModel):
229213
identifiers: List[ListViewResponseEntry] = Field()
230214

231215

232-
class ErrorResponseMessage(IcebergBaseModel):
233-
message: str = Field()
234-
type: str = Field()
235-
code: int = Field()
236-
237-
238-
class ErrorResponse(IcebergBaseModel):
239-
error: ErrorResponseMessage = Field()
240-
241-
242-
class OAuthErrorResponse(IcebergBaseModel):
243-
error: Literal[
244-
"invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope"
245-
]
246-
error_description: Optional[str] = None
247-
error_uri: Optional[str] = None
248-
249-
250216
class RestCatalog(Catalog):
251217
uri: str
252218
_session: Session
@@ -279,8 +245,7 @@ def _create_session(self) -> Session:
279245
elif ssl_client_cert := ssl_client.get(CERT):
280246
session.cert = ssl_client_cert
281247

282-
self._refresh_token(session, self.properties.get(TOKEN))
283-
248+
session.auth = AuthManagerAdapter(self._create_legacy_oauth2_auth_manager(session))
284249
# Set HTTP headers
285250
self._config_headers(session)
286251

@@ -298,6 +263,26 @@ def _create_session(self) -> Session:
298263

299264
return session
300265

266+
def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager:
267+
"""Create the LegacyOAuth2AuthManager by fetching required properties.
268+
269+
This will be removed in PyIceberg 1.0
270+
"""
271+
client_credentials = self.properties.get(CREDENTIAL)
272+
# We want to call `self.auth_url` only when we are using CREDENTIAL
273+
# with the legacy OAUTH2 flow as it will raise a DeprecationWarning
274+
auth_url = self.auth_url if client_credentials is not None else None
275+
276+
auth_config = {
277+
"session": session,
278+
"auth_url": auth_url,
279+
"credential": client_credentials,
280+
"initial_token": self.properties.get(TOKEN),
281+
"optional_oauth_params": self._extract_optional_oauth_params(),
282+
}
283+
284+
return AuthManagerFactory.create("legacyoauth2", auth_config)
285+
301286
def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier:
302287
"""Check if the identifier has at least one element."""
303288
identifier_tuple = Catalog.identifier_to_tuple(identifier)
@@ -360,27 +345,6 @@ def _extract_optional_oauth_params(self) -> Dict[str, str]:
360345

361346
return optional_oauth_param
362347

363-
def _fetch_access_token(self, session: Session, credential: str) -> str:
364-
if SEMICOLON in credential:
365-
client_id, client_secret = credential.split(SEMICOLON)
366-
else:
367-
client_id, client_secret = None, credential
368-
369-
data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret}
370-
371-
optional_oauth_params = self._extract_optional_oauth_params()
372-
data.update(optional_oauth_params)
373-
374-
response = session.post(
375-
url=self.auth_url, data=data, headers={**session.headers, "Content-type": "application/x-www-form-urlencoded"}
376-
)
377-
try:
378-
response.raise_for_status()
379-
except HTTPError as exc:
380-
self._handle_non_200_response(exc, {400: OAuthError, 401: OAuthError})
381-
382-
return TokenResponse.model_validate_json(response.text).access_token
383-
384348
def _fetch_config(self) -> None:
385349
params = {}
386350
if warehouse_location := self.properties.get(WAREHOUSE_LOCATION):
@@ -391,7 +355,7 @@ def _fetch_config(self) -> None:
391355
try:
392356
response.raise_for_status()
393357
except HTTPError as exc:
394-
self._handle_non_200_response(exc, {})
358+
_handle_non_200_response(exc, {})
395359
config_response = ConfigResponse.model_validate_json(response.text)
396360

397361
config = config_response.defaults
@@ -421,58 +385,6 @@ def _split_identifier_for_json(self, identifier: Union[str, Identifier]) -> Dict
421385
identifier_tuple = self._identifier_to_validated_tuple(identifier)
422386
return {"namespace": identifier_tuple[:-1], "name": identifier_tuple[-1]}
423387

424-
def _handle_non_200_response(self, exc: HTTPError, error_handler: Dict[int, Type[Exception]]) -> None:
425-
exception: Type[Exception]
426-
427-
if exc.response is None:
428-
raise ValueError("Did not receive a response")
429-
430-
code = exc.response.status_code
431-
if code in error_handler:
432-
exception = error_handler[code]
433-
elif code == 400:
434-
exception = BadRequestError
435-
elif code == 401:
436-
exception = UnauthorizedError
437-
elif code == 403:
438-
exception = ForbiddenError
439-
elif code == 422:
440-
exception = RESTError
441-
elif code == 419:
442-
exception = AuthorizationExpiredError
443-
elif code == 501:
444-
exception = NotImplementedError
445-
elif code == 503:
446-
exception = ServiceUnavailableError
447-
elif 500 <= code < 600:
448-
exception = ServerError
449-
else:
450-
exception = RESTError
451-
452-
try:
453-
if exception == OAuthError:
454-
# The OAuthErrorResponse has a different format
455-
error = OAuthErrorResponse.model_validate_json(exc.response.text)
456-
response = str(error.error)
457-
if description := error.error_description:
458-
response += f": {description}"
459-
if uri := error.error_uri:
460-
response += f" ({uri})"
461-
else:
462-
error = ErrorResponse.model_validate_json(exc.response.text).error
463-
response = f"{error.type}: {error.message}"
464-
except JSONDecodeError:
465-
# In the case we don't have a proper response
466-
response = f"RESTError {exc.response.status_code}: Could not decode json payload: {exc.response.text}"
467-
except ValidationError as e:
468-
# In the case we don't have a proper response
469-
errs = ", ".join(err["msg"] for err in e.errors())
470-
response = (
471-
f"RESTError {exc.response.status_code}: Received unexpected JSON Payload: {exc.response.text}, errors: {errs}"
472-
)
473-
474-
raise exception(response) from exc
475-
476388
def _init_sigv4(self, session: Session) -> None:
477389
from urllib import parse
478390

@@ -542,16 +454,13 @@ def _response_to_staged_table(self, identifier_tuple: Tuple[str, ...], table_res
542454
catalog=self,
543455
)
544456

545-
def _refresh_token(self, session: Optional[Session] = None, initial_token: Optional[str] = None) -> None:
546-
session = session or self._session
547-
if initial_token is not None:
548-
self.properties[TOKEN] = initial_token
549-
elif CREDENTIAL in self.properties:
550-
self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL])
551-
552-
# Set Auth token for subsequent calls in the session
553-
if token := self.properties.get(TOKEN):
554-
session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
457+
def _refresh_token(self) -> None:
458+
# Reactive token refresh is atypical - we should proactively refresh tokens in a separate thread
459+
# instead of retrying on Auth Exceptions. Keeping refresh behavior for the LegacyOAuth2AuthManager
460+
# for backward compatibility
461+
auth_manager = self._session.auth.auth_manager # type: ignore[union-attr]
462+
if isinstance(auth_manager, LegacyOAuth2AuthManager):
463+
auth_manager._refresh_token()
555464

556465
def _config_headers(self, session: Session) -> None:
557466
header_properties = get_header_properties(self.properties)
@@ -596,7 +505,7 @@ def _create_table(
596505
try:
597506
response.raise_for_status()
598507
except HTTPError as exc:
599-
self._handle_non_200_response(exc, {409: TableAlreadyExistsError})
508+
_handle_non_200_response(exc, {409: TableAlreadyExistsError})
600509
return TableResponse.model_validate_json(response.text)
601510

602511
@retry(**_RETRY_ARGS)
@@ -669,7 +578,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location:
669578
try:
670579
response.raise_for_status()
671580
except HTTPError as exc:
672-
self._handle_non_200_response(exc, {409: TableAlreadyExistsError})
581+
_handle_non_200_response(exc, {409: TableAlreadyExistsError})
673582

674583
table_response = TableResponse.model_validate_json(response.text)
675584
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)
@@ -682,7 +591,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
682591
try:
683592
response.raise_for_status()
684593
except HTTPError as exc:
685-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
594+
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
686595
return [(*table.namespace, table.name) for table in ListTablesResponse.model_validate_json(response.text).identifiers]
687596

688597
@retry(**_RETRY_ARGS)
@@ -700,7 +609,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
700609
try:
701610
response.raise_for_status()
702611
except HTTPError as exc:
703-
self._handle_non_200_response(exc, {404: NoSuchTableError})
612+
_handle_non_200_response(exc, {404: NoSuchTableError})
704613

705614
table_response = TableResponse.model_validate_json(response.text)
706615
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)
@@ -713,7 +622,7 @@ def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool =
713622
try:
714623
response.raise_for_status()
715624
except HTTPError as exc:
716-
self._handle_non_200_response(exc, {404: NoSuchTableError})
625+
_handle_non_200_response(exc, {404: NoSuchTableError})
717626

718627
@retry(**_RETRY_ARGS)
719628
def purge_table(self, identifier: Union[str, Identifier]) -> None:
@@ -729,7 +638,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
729638
try:
730639
response.raise_for_status()
731640
except HTTPError as exc:
732-
self._handle_non_200_response(exc, {404: NoSuchTableError, 409: TableAlreadyExistsError})
641+
_handle_non_200_response(exc, {404: NoSuchTableError, 409: TableAlreadyExistsError})
733642

734643
return self.load_table(to_identifier)
735644

@@ -752,7 +661,7 @@ def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]:
752661
try:
753662
response.raise_for_status()
754663
except HTTPError as exc:
755-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
664+
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
756665
return [(*view.namespace, view.name) for view in ListViewsResponse.model_validate_json(response.text).identifiers]
757666

758667
@retry(**_RETRY_ARGS)
@@ -790,7 +699,7 @@ def commit_table(
790699
try:
791700
response.raise_for_status()
792701
except HTTPError as exc:
793-
self._handle_non_200_response(
702+
_handle_non_200_response(
794703
exc,
795704
{
796705
409: CommitFailedException,
@@ -809,7 +718,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper
809718
try:
810719
response.raise_for_status()
811720
except HTTPError as exc:
812-
self._handle_non_200_response(exc, {409: NamespaceAlreadyExistsError})
721+
_handle_non_200_response(exc, {409: NamespaceAlreadyExistsError})
813722

814723
@retry(**_RETRY_ARGS)
815724
def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
@@ -819,7 +728,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
819728
try:
820729
response.raise_for_status()
821730
except HTTPError as exc:
822-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceNotEmptyError})
731+
_handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceNotEmptyError})
823732

824733
@retry(**_RETRY_ARGS)
825734
def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]:
@@ -834,7 +743,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi
834743
try:
835744
response.raise_for_status()
836745
except HTTPError as exc:
837-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
746+
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
838747

839748
return ListNamespaceResponse.model_validate_json(response.text).namespaces
840749

@@ -846,7 +755,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper
846755
try:
847756
response.raise_for_status()
848757
except HTTPError as exc:
849-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
758+
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
850759

851760
return NamespaceResponse.model_validate_json(response.text).properties
852761

@@ -861,7 +770,7 @@ def update_namespace_properties(
861770
try:
862771
response.raise_for_status()
863772
except HTTPError as exc:
864-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
773+
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
865774
parsed_response = UpdateNamespacePropertiesResponse.model_validate_json(response.text)
866775
return PropertiesUpdateSummary(
867776
removed=parsed_response.removed,
@@ -883,7 +792,7 @@ def namespace_exists(self, namespace: Union[str, Identifier]) -> bool:
883792
try:
884793
response.raise_for_status()
885794
except HTTPError as exc:
886-
self._handle_non_200_response(exc, {})
795+
_handle_non_200_response(exc, {})
887796

888797
return False
889798

@@ -909,7 +818,7 @@ def table_exists(self, identifier: Union[str, Identifier]) -> bool:
909818
try:
910819
response.raise_for_status()
911820
except HTTPError as exc:
912-
self._handle_non_200_response(exc, {})
821+
_handle_non_200_response(exc, {})
913822

914823
return False
915824

@@ -934,7 +843,7 @@ def view_exists(self, identifier: Union[str, Identifier]) -> bool:
934843
try:
935844
response.raise_for_status()
936845
except HTTPError as exc:
937-
self._handle_non_200_response(exc, {})
846+
_handle_non_200_response(exc, {})
938847

939848
return False
940849

@@ -946,4 +855,4 @@ def drop_view(self, identifier: Union[str]) -> None:
946855
try:
947856
response.raise_for_status()
948857
except HTTPError as exc:
949-
self._handle_non_200_response(exc, {404: NoSuchViewError})
858+
_handle_non_200_response(exc, {404: NoSuchViewError})

0 commit comments

Comments
 (0)