Skip to content

Commit 5699bdd

Browse files
refactor: move token refresh logic to refresh_and_set_access_token method in authenticator
Co-Authored-By: Daryna Ishchenko <darina.ishchenko17@gmail.com>
1 parent e6c03a3 commit 5699bdd

4 files changed

Lines changed: 36 additions & 47 deletions

File tree

airbyte_cdk/sources/streams/http/http_client.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@
5050
rate_limit_default_backoff_handler,
5151
user_defined_backoff_handler,
5252
)
53-
from airbyte_cdk.sources.streams.http.requests_native_auth import (
54-
SingleUseRefreshTokenOauth2Authenticator,
55-
)
5653
from airbyte_cdk.sources.utils.types import JsonType
5754
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
5855
from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH
@@ -461,31 +458,16 @@ def _handle_error_resolution(
461458

462459
# Handle REFRESH_TOKEN_THEN_RETRY: Force refresh the OAuth token before retry
463460
# This is useful when the API returns 401 but the stored token expiry hasn't been reached yet
464-
# Only OAuth authenticators have refresh_access_token method
461+
# Only OAuth authenticators have refresh_and_set_access_token method
465462
# Non-OAuth auth types (e.g., BearerAuthenticator) will fall through to normal retry
466463
if error_resolution.response_action == ResponseAction.REFRESH_TOKEN_THEN_RETRY:
467464
if (
468465
hasattr(self._session, "auth")
469466
and self._session.auth is not None
470-
and hasattr(self._session.auth, "refresh_access_token")
467+
and hasattr(self._session.auth, "refresh_and_set_access_token")
471468
):
472469
try:
473-
if isinstance(self._session.auth, SingleUseRefreshTokenOauth2Authenticator):
474-
# For single-use refresh tokens, we must persist the new refresh token
475-
# and emit a control message to update the connector config
476-
token, expires_in, new_refresh_token = (
477-
self._session.auth.refresh_access_token()
478-
)
479-
self._session.auth.access_token = token
480-
self._session.auth.set_refresh_token(new_refresh_token)
481-
self._session.auth.set_token_expiry_date(expires_in)
482-
self._session.auth._emit_control_message()
483-
else:
484-
# Use extended unpacking to handle both 2-tuple (AbstractOauth2Authenticator)
485-
# and 3-tuple (Oauth2Authenticator which also returns refresh_token) returns
486-
token, expires_in, *_ = self._session.auth.refresh_access_token() # type: ignore[union-attr]
487-
self._session.auth.access_token = token # type: ignore[union-attr]
488-
self._session.auth.set_token_expiry_date(expires_in) # type: ignore[union-attr]
470+
self._session.auth.refresh_and_set_access_token() # type: ignore[union-attr]
489471
self._logger.info(
490472
"Refreshed OAuth token due to REFRESH_TOKEN_THEN_RETRY response action"
491473
)

airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,21 @@ def get_auth_header(self) -> Mapping[str, Any]:
8888
def get_access_token(self) -> str:
8989
"""Returns the access token"""
9090
if self.token_has_expired():
91-
token, expires_in = self.refresh_access_token()
92-
self.access_token = token
93-
self.set_token_expiry_date(expires_in)
91+
self.refresh_and_set_access_token()
9492

9593
return self.access_token
9694

95+
def refresh_and_set_access_token(self) -> None:
96+
"""Force refresh the access token and update internal state.
97+
98+
This method refreshes the access token regardless of whether it has expired,
99+
and updates the internal token and expiry date. Subclasses may override this
100+
to handle additional state updates (e.g., persisting new refresh tokens).
101+
"""
102+
token, expires_in = self.refresh_access_token()
103+
self.access_token = token
104+
self.set_token_expiry_date(expires_in)
105+
97106
def token_has_expired(self) -> bool:
98107
"""Returns True if the token is expired"""
99108
return ab_datetime_now() > self.get_token_expiry_date()

airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -318,20 +318,28 @@ def token_has_expired(self) -> bool:
318318

319319
def get_access_token(self) -> str:
320320
"""Retrieve new access and refresh token if the access token has expired.
321-
The new refresh token is persisted with the set_refresh_token function
321+
322+
The new refresh token is persisted with the set_refresh_token function.
323+
322324
Returns:
323325
str: The current access_token, updated if it was previously expired.
324326
"""
325327
if self.token_has_expired():
326-
new_access_token, access_token_expires_in, new_refresh_token = (
327-
self.refresh_access_token()
328-
)
329-
self.access_token = new_access_token
330-
self.set_refresh_token(new_refresh_token)
331-
self.set_token_expiry_date(access_token_expires_in)
332-
self._emit_control_message()
328+
self.refresh_and_set_access_token()
333329
return self.access_token
334330

331+
def refresh_and_set_access_token(self) -> None:
332+
"""Force refresh the access token and update internal state.
333+
334+
For single-use refresh tokens, this also persists the new refresh token
335+
and emits a control message to update the connector config.
336+
"""
337+
new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token()
338+
self.access_token = new_access_token
339+
self.set_refresh_token(new_refresh_token)
340+
self.set_token_expiry_date(access_token_expires_in)
341+
self._emit_control_message()
342+
335343
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override]
336344
"""
337345
Refreshes the access token by making a handled request and extracting the necessary token information.

unit_tests/sources/streams/http/test_http_client.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -845,12 +845,10 @@ def __init__(self):
845845
self._token_expiry_date = None
846846
self.refresh_called = False
847847

848-
def refresh_access_token(self):
848+
def refresh_and_set_access_token(self):
849849
self.refresh_called = True
850-
return ("new_refreshed_token", "2099-01-01T00:00:00Z")
851-
852-
def set_token_expiry_date(self, value):
853-
self._token_expiry_date = value
850+
self.access_token = "new_refreshed_token"
851+
self._token_expiry_date = "2099-01-01T00:00:00Z"
854852

855853
def __call__(self, request):
856854
request.headers["Authorization"] = f"Bearer {self.access_token}"
@@ -932,7 +930,7 @@ class FailingOAuthAuthenticator:
932930
def __init__(self):
933931
self.access_token = "old_token"
934932

935-
def refresh_access_token(self):
933+
def refresh_and_set_access_token(self):
936934
raise Exception("Token refresh failed")
937935

938936
def __call__(self, request):
@@ -978,11 +976,6 @@ def test_refresh_token_then_retry_action_with_single_use_refresh_token_authentic
978976
)
979977

980978
mock_authenticator = MagicMock(spec=SingleUseRefreshTokenOauth2Authenticator)
981-
mock_authenticator.refresh_access_token.return_value = (
982-
"new_access_token",
983-
"2099-01-01T00:00:00Z",
984-
"new_refresh_token",
985-
)
986979

987980
mocked_session = MagicMock(spec=requests.Session)
988981
mocked_session.auth = mock_authenticator
@@ -1013,10 +1006,7 @@ def test_refresh_token_then_retry_action_with_single_use_refresh_token_authentic
10131006
with pytest.raises(DefaultBackoffException):
10141007
http_client._send(prepared_request, {})
10151008

1016-
mock_authenticator.refresh_access_token.assert_called_once()
1017-
mock_authenticator.set_refresh_token.assert_called_once_with("new_refresh_token")
1018-
mock_authenticator.set_token_expiry_date.assert_called_once_with("2099-01-01T00:00:00Z")
1019-
mock_authenticator._emit_control_message.assert_called_once()
1009+
mock_authenticator.refresh_and_set_access_token.assert_called_once()
10201010

10211011

10221012
@pytest.mark.usefixtures("mock_sleep")

0 commit comments

Comments
 (0)