diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index 977fff2d5..e68318cd4 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -2331,6 +2331,7 @@ definitions: - IGNORE - RESET_PAGINATION - RATE_LIMITED + - REFRESH_TOKEN_THEN_RETRY examples: - SUCCESS - FAIL @@ -2338,6 +2339,7 @@ definitions: - IGNORE - RESET_PAGINATION - RATE_LIMITED + - REFRESH_TOKEN_THEN_RETRY failure_type: title: Failure Type description: Failure type of traced exception if a response matches the filter. diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index 3fccc600f..5d2f0521f 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -543,6 +543,7 @@ class Action(Enum): IGNORE = "IGNORE" RESET_PAGINATION = "RESET_PAGINATION" RATE_LIMITED = "RATE_LIMITED" + REFRESH_TOKEN_THEN_RETRY = "REFRESH_TOKEN_THEN_RETRY" class FailureType(Enum): @@ -563,6 +564,7 @@ class HttpResponseFilter(BaseModel): "IGNORE", "RESET_PAGINATION", "RATE_LIMITED", + "REFRESH_TOKEN_THEN_RETRY", ], title="Action", ) diff --git a/airbyte_cdk/sources/streams/http/error_handlers/response_models.py b/airbyte_cdk/sources/streams/http/error_handlers/response_models.py index 7199d1982..082d580d5 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/response_models.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/response_models.py @@ -18,6 +18,7 @@ class ResponseAction(Enum): IGNORE = "IGNORE" RESET_PAGINATION = "RESET_PAGINATION" RATE_LIMITED = "RATE_LIMITED" + REFRESH_TOKEN_THEN_RETRY = "REFRESH_TOKEN_THEN_RETRY" @dataclass diff --git a/airbyte_cdk/sources/streams/http/http_client.py b/airbyte_cdk/sources/streams/http/http_client.py index e9fc5add2..e7a5715ac 100644 --- a/airbyte_cdk/sources/streams/http/http_client.py +++ b/airbyte_cdk/sources/streams/http/http_client.py @@ -102,7 +102,11 @@ def __str__(self) -> str: class HttpClient: _DEFAULT_MAX_RETRY: int = 5 _DEFAULT_MAX_TIME: int = 60 * 10 - _ACTIONS_TO_RETRY_ON = {ResponseAction.RETRY, ResponseAction.RATE_LIMITED} + _ACTIONS_TO_RETRY_ON = { + ResponseAction.RETRY, + ResponseAction.RATE_LIMITED, + ResponseAction.REFRESH_TOKEN_THEN_RETRY, + } def __init__( self, @@ -452,6 +456,31 @@ def _handle_error_resolution( # backoff retry loop. Adding `\n` to the message and ignore 'end' ensure that few messages are printed at the same time. print(f"{message}\n", end="", flush=True) + # Handle REFRESH_TOKEN_THEN_RETRY: Force refresh the OAuth token before retry + # This is useful when the API returns 401 but the stored token expiry hasn't been reached yet + # Only OAuth authenticators have refresh_and_set_access_token method + # Non-OAuth auth types (e.g., BearerAuthenticator) will fall through to normal retry + if error_resolution.response_action == ResponseAction.REFRESH_TOKEN_THEN_RETRY: + if ( + hasattr(self._session, "auth") + and self._session.auth is not None + and hasattr(self._session.auth, "refresh_and_set_access_token") + ): + try: + self._session.auth.refresh_and_set_access_token() # type: ignore[union-attr] + self._logger.info( + "Refreshed OAuth token due to REFRESH_TOKEN_THEN_RETRY response action" + ) + except Exception as refresh_error: + self._logger.warning( + f"Failed to refresh OAuth token: {refresh_error}. Proceeding with retry using existing token." + ) + else: + self._logger.warning( + "REFRESH_TOKEN_THEN_RETRY action received but authenticator does not support token refresh. " + "Proceeding with normal retry." + ) + if error_resolution.response_action == ResponseAction.FAIL: if response is not None: filtered_response_message = filter_secrets( @@ -481,9 +510,10 @@ def _handle_error_resolution( self._logger.info(error_resolution.error_message or log_message) # TODO: Consider dynamic retry count depending on subsequent error codes - elif ( - error_resolution.response_action == ResponseAction.RETRY - or error_resolution.response_action == ResponseAction.RATE_LIMITED + elif error_resolution.response_action in ( + ResponseAction.RETRY, + ResponseAction.RATE_LIMITED, + ResponseAction.REFRESH_TOKEN_THEN_RETRY, ): user_defined_backoff_time = None for backoff_strategy in self._backoff_strategies: diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 3b4aa9844..48c5bba73 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -88,12 +88,21 @@ def get_auth_header(self) -> Mapping[str, Any]: def get_access_token(self) -> str: """Returns the access token""" if self.token_has_expired(): - token, expires_in = self.refresh_access_token() - self.access_token = token - self.set_token_expiry_date(expires_in) + self.refresh_and_set_access_token() return self.access_token + def refresh_and_set_access_token(self) -> None: + """Force refresh the access token and update internal state. + + This method refreshes the access token regardless of whether it has expired, + and updates the internal token and expiry date. Subclasses may override this + to handle additional state updates (e.g., persisting new refresh tokens). + """ + token, expires_in = self.refresh_access_token() + self.access_token = token + self.set_token_expiry_date(expires_in) + def token_has_expired(self) -> bool: """Returns True if the token is expired""" return ab_datetime_now() > self.get_token_expiry_date() diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index a2932c294..cb64eb3e3 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -318,20 +318,28 @@ def token_has_expired(self) -> bool: def get_access_token(self) -> str: """Retrieve new access and refresh token if the access token has expired. - The new refresh token is persisted with the set_refresh_token function + + The new refresh token is persisted with the set_refresh_token function. + Returns: str: The current access_token, updated if it was previously expired. """ if self.token_has_expired(): - new_access_token, access_token_expires_in, new_refresh_token = ( - self.refresh_access_token() - ) - self.access_token = new_access_token - self.set_refresh_token(new_refresh_token) - self.set_token_expiry_date(access_token_expires_in) - self._emit_control_message() + self.refresh_and_set_access_token() return self.access_token + def refresh_and_set_access_token(self) -> None: + """Force refresh the access token and update internal state. + + For single-use refresh tokens, this also persists the new refresh token + and emits a control message to update the connector config. + """ + new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token() + self.access_token = new_access_token + self.set_refresh_token(new_refresh_token) + self.set_token_expiry_date(access_token_expires_in) + self._emit_control_message() + def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override] """ Refreshes the access token by making a handled request and extracting the necessary token information. diff --git a/unit_tests/sources/streams/http/test_http_client.py b/unit_tests/sources/streams/http/test_http_client.py index 3840c70e3..ea245c2fb 100644 --- a/unit_tests/sources/streams/http/test_http_client.py +++ b/unit_tests/sources/streams/http/test_http_client.py @@ -837,3 +837,225 @@ def backoff_time(self, response_or_exception, attempt_count): with pytest.raises(AirbyteTracedException) as e: http_client.send_request(http_method="get", url="https://airbyte.io/", request_kwargs={}) assert e.value.failure_type == expected_failure_type + + +class MockOAuthAuthenticator: + def __init__(self): + self.access_token = "old_token" + self._token_expiry_date = None + self.refresh_called = False + + def refresh_and_set_access_token(self): + self.refresh_called = True + self.access_token = "new_refreshed_token" + self._token_expiry_date = "2099-01-01T00:00:00Z" + + def __call__(self, request): + request.headers["Authorization"] = f"Bearer {self.access_token}" + return request + + +def test_refresh_token_then_retry_action_refreshes_oauth_token(mocker): + mock_authenticator = MockOAuthAuthenticator() + mocked_session = MagicMock(spec=requests.Session) + mocked_session.auth = mock_authenticator + + http_client = HttpClient( + name="test", + logger=MagicMock(), + error_handler=HttpStatusErrorHandler( + logger=MagicMock(), + error_mapping={ + 401: ErrorResolution( + ResponseAction.REFRESH_TOKEN_THEN_RETRY, + FailureType.transient_error, + "Token expired, refreshing", + ) + }, + ), + session=mocked_session, + ) + + prepared_request = requests.PreparedRequest() + mocked_response = MagicMock(spec=requests.Response) + mocked_response.status_code = 401 + mocked_response.headers = {} + mocked_response.ok = False + mocked_session.send.return_value = mocked_response + + with pytest.raises(DefaultBackoffException): + http_client._send(prepared_request, {}) + + assert mock_authenticator.refresh_called + assert mock_authenticator.access_token == "new_refreshed_token" + assert mock_authenticator._token_expiry_date == "2099-01-01T00:00:00Z" + + +def test_refresh_token_then_retry_action_without_oauth_authenticator_proceeds_with_retry(mocker): + mocked_session = MagicMock(spec=requests.Session) + mocked_session.auth = None + + mocked_logger = MagicMock() + http_client = HttpClient( + name="test", + logger=mocked_logger, + error_handler=HttpStatusErrorHandler( + logger=MagicMock(), + error_mapping={ + 401: ErrorResolution( + ResponseAction.REFRESH_TOKEN_THEN_RETRY, + FailureType.transient_error, + "Token expired, refreshing", + ) + }, + ), + session=mocked_session, + ) + + prepared_request = requests.PreparedRequest() + mocked_response = MagicMock(spec=requests.Response) + mocked_response.status_code = 401 + mocked_response.headers = {} + mocked_response.ok = False + mocked_session.send.return_value = mocked_response + + with pytest.raises(DefaultBackoffException): + http_client._send(prepared_request, {}) + + mocked_logger.warning.assert_called() + + +def test_refresh_token_then_retry_action_handles_refresh_failure_gracefully(mocker): + class FailingOAuthAuthenticator: + def __init__(self): + self.access_token = "old_token" + + def refresh_and_set_access_token(self): + raise Exception("Token refresh failed") + + def __call__(self, request): + return request + + mock_authenticator = FailingOAuthAuthenticator() + mocked_session = MagicMock(spec=requests.Session) + mocked_session.auth = mock_authenticator + + mocked_logger = MagicMock() + http_client = HttpClient( + name="test", + logger=mocked_logger, + error_handler=HttpStatusErrorHandler( + logger=MagicMock(), + error_mapping={ + 401: ErrorResolution( + ResponseAction.REFRESH_TOKEN_THEN_RETRY, + FailureType.transient_error, + "Token expired, refreshing", + ) + }, + ), + session=mocked_session, + ) + + prepared_request = requests.PreparedRequest() + mocked_response = MagicMock(spec=requests.Response) + mocked_response.status_code = 401 + mocked_response.headers = {} + mocked_response.ok = False + mocked_session.send.return_value = mocked_response + + with pytest.raises(DefaultBackoffException): + http_client._send(prepared_request, {}) + + mocked_logger.warning.assert_called() + + +def test_refresh_token_then_retry_action_with_single_use_refresh_token_authenticator(mocker): + from airbyte_cdk.sources.streams.http.requests_native_auth import ( + SingleUseRefreshTokenOauth2Authenticator, + ) + + mock_authenticator = MagicMock(spec=SingleUseRefreshTokenOauth2Authenticator) + + mocked_session = MagicMock(spec=requests.Session) + mocked_session.auth = mock_authenticator + + http_client = HttpClient( + name="test", + logger=MagicMock(), + error_handler=HttpStatusErrorHandler( + logger=MagicMock(), + error_mapping={ + 401: ErrorResolution( + ResponseAction.REFRESH_TOKEN_THEN_RETRY, + FailureType.transient_error, + "Token expired, refreshing", + ) + }, + ), + session=mocked_session, + ) + + prepared_request = requests.PreparedRequest() + mocked_response = MagicMock(spec=requests.Response) + mocked_response.status_code = 401 + mocked_response.headers = {} + mocked_response.ok = False + mocked_session.send.return_value = mocked_response + + with pytest.raises(DefaultBackoffException): + http_client._send(prepared_request, {}) + + mock_authenticator.refresh_and_set_access_token.assert_called_once() + + +@pytest.mark.usefixtures("mock_sleep") +def test_refresh_token_then_retry_action_retries_and_succeeds_after_token_refresh(): + mock_authenticator = MockOAuthAuthenticator() + mocked_session = MagicMock(spec=requests.Session) + mocked_session.auth = mock_authenticator + + valid_response = MagicMock(spec=requests.Response) + valid_response.status_code = 200 + valid_response.ok = True + valid_response.headers = {} + + call_count = 0 + + def update_response(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + retry_response = MagicMock(spec=requests.Response) + retry_response.ok = False + retry_response.status_code = 401 + retry_response.headers = {} + return retry_response + else: + return valid_response + + mocked_session.send.side_effect = update_response + + http_client = HttpClient( + name="test", + logger=MagicMock(), + error_handler=HttpStatusErrorHandler( + logger=MagicMock(), + error_mapping={ + 401: ErrorResolution( + ResponseAction.REFRESH_TOKEN_THEN_RETRY, + FailureType.transient_error, + "Token expired, refreshing", + ) + }, + ), + session=mocked_session, + ) + + prepared_request = requests.PreparedRequest() + returned_response = http_client._send_with_retry(prepared_request, request_kwargs={}) + + assert mock_authenticator.refresh_called + assert mock_authenticator.access_token == "new_refreshed_token" + assert returned_response == valid_response + assert call_count == 2