Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
from dataclasses import InitVar, dataclass, field
from datetime import datetime, timedelta
from typing import Any, List, Mapping, MutableMapping, Optional, Union
from typing import Any, List, Mapping, Optional, Union

from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
Expand All @@ -19,6 +20,8 @@
)
from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse

logger = logging.getLogger("airbyte")


@dataclass
class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAuthenticator):
Expand All @@ -30,7 +33,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
Attributes:
token_refresh_endpoint (Union[InterpolatedString, str]): The endpoint to refresh the access token
client_id (Union[InterpolatedString, str]): The client id
client_secret (Union[InterpolatedString, str]): Client secret
client_secret (Union[InterpolatedString, str]): Client secret (can be empty for APIs that support this)
refresh_token (Union[InterpolatedString, str]): The token used to refresh the access token
access_token_name (Union[InterpolatedString, str]): THe field to extract access token from in the response
expires_in_name (Union[InterpolatedString, str]): The field to extract expires_in from in the response
Expand Down Expand Up @@ -201,8 +204,11 @@ def get_client_secret(self) -> str:
self._client_secret.eval(self.config) if self._client_secret else self._client_secret
)
if not client_secret:
raise ValueError("OAuthAuthenticator was unable to evaluate client_secret parameter")
return client_secret # type: ignore # value will be returned as a string, or an error will be raised
# We've seen some APIs allowing empty client_secret so we will only log here
logger.warning(
"OAuthAuthenticator was unable to evaluate client_secret parameter hence it'll be empty"
)
return client_secret # type: ignore # value will be returned as a string, which might be empty

def get_refresh_token_name(self) -> str:
return self._refresh_token_name.eval(self.config) # type: ignore # eval returns a string in this context
Expand Down
15 changes: 15 additions & 0 deletions unit_tests/sources/declarative/auth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import base64
import json
import logging
from copy import deepcopy
from datetime import timedelta, timezone
from unittest.mock import Mock

Expand Down Expand Up @@ -128,6 +129,20 @@ def test_refresh_with_encode_config_params(self):
}
assert body == expected

def test_client_secret_empty(self):
config_without_client_secret = deepcopy(config)
del config_without_client_secret["client_secret"]
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
config=config_without_client_secret,
parameters={},
grant_type="client_credentials",
)
body = oauth.build_refresh_request_body()
assert body["client_secret"] == ""

def test_refresh_with_decode_config_params(self):
updated_config_fields = {
"client_id": base64.b64encode(config["client_id"].encode("utf-8")).decode(),
Expand Down
Loading