Skip to content

Commit d3c9419

Browse files
feat(cdk): add REFRESH_TOKEN_THEN_RETRY response action for OAuth token refresh (#886)
Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
1 parent efad73e commit d3c9419

File tree

7 files changed

+289
-15
lines changed

7 files changed

+289
-15
lines changed

airbyte_cdk/sources/declarative/declarative_component_schema.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,13 +2331,15 @@ definitions:
23312331
- IGNORE
23322332
- RESET_PAGINATION
23332333
- RATE_LIMITED
2334+
- REFRESH_TOKEN_THEN_RETRY
23342335
examples:
23352336
- SUCCESS
23362337
- FAIL
23372338
- RETRY
23382339
- IGNORE
23392340
- RESET_PAGINATION
23402341
- RATE_LIMITED
2342+
- REFRESH_TOKEN_THEN_RETRY
23412343
failure_type:
23422344
title: Failure Type
23432345
description: Failure type of traced exception if a response matches the filter.

airbyte_cdk/sources/declarative/models/declarative_component_schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ class Action(Enum):
543543
IGNORE = "IGNORE"
544544
RESET_PAGINATION = "RESET_PAGINATION"
545545
RATE_LIMITED = "RATE_LIMITED"
546+
REFRESH_TOKEN_THEN_RETRY = "REFRESH_TOKEN_THEN_RETRY"
546547

547548

548549
class FailureType(Enum):
@@ -563,6 +564,7 @@ class HttpResponseFilter(BaseModel):
563564
"IGNORE",
564565
"RESET_PAGINATION",
565566
"RATE_LIMITED",
567+
"REFRESH_TOKEN_THEN_RETRY",
566568
],
567569
title="Action",
568570
)

airbyte_cdk/sources/streams/http/error_handlers/response_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class ResponseAction(Enum):
1818
IGNORE = "IGNORE"
1919
RESET_PAGINATION = "RESET_PAGINATION"
2020
RATE_LIMITED = "RATE_LIMITED"
21+
REFRESH_TOKEN_THEN_RETRY = "REFRESH_TOKEN_THEN_RETRY"
2122

2223

2324
@dataclass

airbyte_cdk/sources/streams/http/http_client.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ def __str__(self) -> str:
102102
class HttpClient:
103103
_DEFAULT_MAX_RETRY: int = 5
104104
_DEFAULT_MAX_TIME: int = 60 * 10
105-
_ACTIONS_TO_RETRY_ON = {ResponseAction.RETRY, ResponseAction.RATE_LIMITED}
105+
_ACTIONS_TO_RETRY_ON = {
106+
ResponseAction.RETRY,
107+
ResponseAction.RATE_LIMITED,
108+
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
109+
}
106110

107111
def __init__(
108112
self,
@@ -452,6 +456,31 @@ def _handle_error_resolution(
452456
# backoff retry loop. Adding `\n` to the message and ignore 'end' ensure that few messages are printed at the same time.
453457
print(f"{message}\n", end="", flush=True)
454458

459+
# Handle REFRESH_TOKEN_THEN_RETRY: Force refresh the OAuth token before retry
460+
# This is useful when the API returns 401 but the stored token expiry hasn't been reached yet
461+
# Only OAuth authenticators have refresh_and_set_access_token method
462+
# Non-OAuth auth types (e.g., BearerAuthenticator) will fall through to normal retry
463+
if error_resolution.response_action == ResponseAction.REFRESH_TOKEN_THEN_RETRY:
464+
if (
465+
hasattr(self._session, "auth")
466+
and self._session.auth is not None
467+
and hasattr(self._session.auth, "refresh_and_set_access_token")
468+
):
469+
try:
470+
self._session.auth.refresh_and_set_access_token() # type: ignore[union-attr]
471+
self._logger.info(
472+
"Refreshed OAuth token due to REFRESH_TOKEN_THEN_RETRY response action"
473+
)
474+
except Exception as refresh_error:
475+
self._logger.warning(
476+
f"Failed to refresh OAuth token: {refresh_error}. Proceeding with retry using existing token."
477+
)
478+
else:
479+
self._logger.warning(
480+
"REFRESH_TOKEN_THEN_RETRY action received but authenticator does not support token refresh. "
481+
"Proceeding with normal retry."
482+
)
483+
455484
if error_resolution.response_action == ResponseAction.FAIL:
456485
if response is not None:
457486
filtered_response_message = filter_secrets(
@@ -481,9 +510,10 @@ def _handle_error_resolution(
481510
self._logger.info(error_resolution.error_message or log_message)
482511

483512
# TODO: Consider dynamic retry count depending on subsequent error codes
484-
elif (
485-
error_resolution.response_action == ResponseAction.RETRY
486-
or error_resolution.response_action == ResponseAction.RATE_LIMITED
513+
elif error_resolution.response_action in (
514+
ResponseAction.RETRY,
515+
ResponseAction.RATE_LIMITED,
516+
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
487517
):
488518
user_defined_backoff_time = None
489519
for backoff_strategy in self._backoff_strategies:

airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,21 @@ def get_auth_header(self) -> Mapping[str, Any]:
8888
def get_access_token(self) -> str:
8989
"""Returns the access token"""
9090
if self.token_has_expired():
91-
token, expires_in = self.refresh_access_token()
92-
self.access_token = token
93-
self.set_token_expiry_date(expires_in)
91+
self.refresh_and_set_access_token()
9492

9593
return self.access_token
9694

95+
def refresh_and_set_access_token(self) -> None:
96+
"""Force refresh the access token and update internal state.
97+
98+
This method refreshes the access token regardless of whether it has expired,
99+
and updates the internal token and expiry date. Subclasses may override this
100+
to handle additional state updates (e.g., persisting new refresh tokens).
101+
"""
102+
token, expires_in = self.refresh_access_token()
103+
self.access_token = token
104+
self.set_token_expiry_date(expires_in)
105+
97106
def token_has_expired(self) -> bool:
98107
"""Returns True if the token is expired"""
99108
return ab_datetime_now() > self.get_token_expiry_date()

airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -318,20 +318,28 @@ def token_has_expired(self) -> bool:
318318

319319
def get_access_token(self) -> str:
320320
"""Retrieve new access and refresh token if the access token has expired.
321-
The new refresh token is persisted with the set_refresh_token function
321+
322+
The new refresh token is persisted with the set_refresh_token function.
323+
322324
Returns:
323325
str: The current access_token, updated if it was previously expired.
324326
"""
325327
if self.token_has_expired():
326-
new_access_token, access_token_expires_in, new_refresh_token = (
327-
self.refresh_access_token()
328-
)
329-
self.access_token = new_access_token
330-
self.set_refresh_token(new_refresh_token)
331-
self.set_token_expiry_date(access_token_expires_in)
332-
self._emit_control_message()
328+
self.refresh_and_set_access_token()
333329
return self.access_token
334330

331+
def refresh_and_set_access_token(self) -> None:
332+
"""Force refresh the access token and update internal state.
333+
334+
For single-use refresh tokens, this also persists the new refresh token
335+
and emits a control message to update the connector config.
336+
"""
337+
new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token()
338+
self.access_token = new_access_token
339+
self.set_refresh_token(new_refresh_token)
340+
self.set_token_expiry_date(access_token_expires_in)
341+
self._emit_control_message()
342+
335343
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override]
336344
"""
337345
Refreshes the access token by making a handled request and extracting the necessary token information.

unit_tests/sources/streams/http/test_http_client.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,3 +837,225 @@ def backoff_time(self, response_or_exception, attempt_count):
837837
with pytest.raises(AirbyteTracedException) as e:
838838
http_client.send_request(http_method="get", url="https://airbyte.io/", request_kwargs={})
839839
assert e.value.failure_type == expected_failure_type
840+
841+
842+
class MockOAuthAuthenticator:
843+
def __init__(self):
844+
self.access_token = "old_token"
845+
self._token_expiry_date = None
846+
self.refresh_called = False
847+
848+
def refresh_and_set_access_token(self):
849+
self.refresh_called = True
850+
self.access_token = "new_refreshed_token"
851+
self._token_expiry_date = "2099-01-01T00:00:00Z"
852+
853+
def __call__(self, request):
854+
request.headers["Authorization"] = f"Bearer {self.access_token}"
855+
return request
856+
857+
858+
def test_refresh_token_then_retry_action_refreshes_oauth_token(mocker):
859+
mock_authenticator = MockOAuthAuthenticator()
860+
mocked_session = MagicMock(spec=requests.Session)
861+
mocked_session.auth = mock_authenticator
862+
863+
http_client = HttpClient(
864+
name="test",
865+
logger=MagicMock(),
866+
error_handler=HttpStatusErrorHandler(
867+
logger=MagicMock(),
868+
error_mapping={
869+
401: ErrorResolution(
870+
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
871+
FailureType.transient_error,
872+
"Token expired, refreshing",
873+
)
874+
},
875+
),
876+
session=mocked_session,
877+
)
878+
879+
prepared_request = requests.PreparedRequest()
880+
mocked_response = MagicMock(spec=requests.Response)
881+
mocked_response.status_code = 401
882+
mocked_response.headers = {}
883+
mocked_response.ok = False
884+
mocked_session.send.return_value = mocked_response
885+
886+
with pytest.raises(DefaultBackoffException):
887+
http_client._send(prepared_request, {})
888+
889+
assert mock_authenticator.refresh_called
890+
assert mock_authenticator.access_token == "new_refreshed_token"
891+
assert mock_authenticator._token_expiry_date == "2099-01-01T00:00:00Z"
892+
893+
894+
def test_refresh_token_then_retry_action_without_oauth_authenticator_proceeds_with_retry(mocker):
895+
mocked_session = MagicMock(spec=requests.Session)
896+
mocked_session.auth = None
897+
898+
mocked_logger = MagicMock()
899+
http_client = HttpClient(
900+
name="test",
901+
logger=mocked_logger,
902+
error_handler=HttpStatusErrorHandler(
903+
logger=MagicMock(),
904+
error_mapping={
905+
401: ErrorResolution(
906+
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
907+
FailureType.transient_error,
908+
"Token expired, refreshing",
909+
)
910+
},
911+
),
912+
session=mocked_session,
913+
)
914+
915+
prepared_request = requests.PreparedRequest()
916+
mocked_response = MagicMock(spec=requests.Response)
917+
mocked_response.status_code = 401
918+
mocked_response.headers = {}
919+
mocked_response.ok = False
920+
mocked_session.send.return_value = mocked_response
921+
922+
with pytest.raises(DefaultBackoffException):
923+
http_client._send(prepared_request, {})
924+
925+
mocked_logger.warning.assert_called()
926+
927+
928+
def test_refresh_token_then_retry_action_handles_refresh_failure_gracefully(mocker):
929+
class FailingOAuthAuthenticator:
930+
def __init__(self):
931+
self.access_token = "old_token"
932+
933+
def refresh_and_set_access_token(self):
934+
raise Exception("Token refresh failed")
935+
936+
def __call__(self, request):
937+
return request
938+
939+
mock_authenticator = FailingOAuthAuthenticator()
940+
mocked_session = MagicMock(spec=requests.Session)
941+
mocked_session.auth = mock_authenticator
942+
943+
mocked_logger = MagicMock()
944+
http_client = HttpClient(
945+
name="test",
946+
logger=mocked_logger,
947+
error_handler=HttpStatusErrorHandler(
948+
logger=MagicMock(),
949+
error_mapping={
950+
401: ErrorResolution(
951+
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
952+
FailureType.transient_error,
953+
"Token expired, refreshing",
954+
)
955+
},
956+
),
957+
session=mocked_session,
958+
)
959+
960+
prepared_request = requests.PreparedRequest()
961+
mocked_response = MagicMock(spec=requests.Response)
962+
mocked_response.status_code = 401
963+
mocked_response.headers = {}
964+
mocked_response.ok = False
965+
mocked_session.send.return_value = mocked_response
966+
967+
with pytest.raises(DefaultBackoffException):
968+
http_client._send(prepared_request, {})
969+
970+
mocked_logger.warning.assert_called()
971+
972+
973+
def test_refresh_token_then_retry_action_with_single_use_refresh_token_authenticator(mocker):
974+
from airbyte_cdk.sources.streams.http.requests_native_auth import (
975+
SingleUseRefreshTokenOauth2Authenticator,
976+
)
977+
978+
mock_authenticator = MagicMock(spec=SingleUseRefreshTokenOauth2Authenticator)
979+
980+
mocked_session = MagicMock(spec=requests.Session)
981+
mocked_session.auth = mock_authenticator
982+
983+
http_client = HttpClient(
984+
name="test",
985+
logger=MagicMock(),
986+
error_handler=HttpStatusErrorHandler(
987+
logger=MagicMock(),
988+
error_mapping={
989+
401: ErrorResolution(
990+
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
991+
FailureType.transient_error,
992+
"Token expired, refreshing",
993+
)
994+
},
995+
),
996+
session=mocked_session,
997+
)
998+
999+
prepared_request = requests.PreparedRequest()
1000+
mocked_response = MagicMock(spec=requests.Response)
1001+
mocked_response.status_code = 401
1002+
mocked_response.headers = {}
1003+
mocked_response.ok = False
1004+
mocked_session.send.return_value = mocked_response
1005+
1006+
with pytest.raises(DefaultBackoffException):
1007+
http_client._send(prepared_request, {})
1008+
1009+
mock_authenticator.refresh_and_set_access_token.assert_called_once()
1010+
1011+
1012+
@pytest.mark.usefixtures("mock_sleep")
1013+
def test_refresh_token_then_retry_action_retries_and_succeeds_after_token_refresh():
1014+
mock_authenticator = MockOAuthAuthenticator()
1015+
mocked_session = MagicMock(spec=requests.Session)
1016+
mocked_session.auth = mock_authenticator
1017+
1018+
valid_response = MagicMock(spec=requests.Response)
1019+
valid_response.status_code = 200
1020+
valid_response.ok = True
1021+
valid_response.headers = {}
1022+
1023+
call_count = 0
1024+
1025+
def update_response(*args, **kwargs):
1026+
nonlocal call_count
1027+
call_count += 1
1028+
if call_count == 1:
1029+
retry_response = MagicMock(spec=requests.Response)
1030+
retry_response.ok = False
1031+
retry_response.status_code = 401
1032+
retry_response.headers = {}
1033+
return retry_response
1034+
else:
1035+
return valid_response
1036+
1037+
mocked_session.send.side_effect = update_response
1038+
1039+
http_client = HttpClient(
1040+
name="test",
1041+
logger=MagicMock(),
1042+
error_handler=HttpStatusErrorHandler(
1043+
logger=MagicMock(),
1044+
error_mapping={
1045+
401: ErrorResolution(
1046+
ResponseAction.REFRESH_TOKEN_THEN_RETRY,
1047+
FailureType.transient_error,
1048+
"Token expired, refreshing",
1049+
)
1050+
},
1051+
),
1052+
session=mocked_session,
1053+
)
1054+
1055+
prepared_request = requests.PreparedRequest()
1056+
returned_response = http_client._send_with_retry(prepared_request, request_kwargs={})
1057+
1058+
assert mock_authenticator.refresh_called
1059+
assert mock_authenticator.access_token == "new_refreshed_token"
1060+
assert returned_response == valid_response
1061+
assert call_count == 2

0 commit comments

Comments
 (0)