@@ -837,3 +837,225 @@ def backoff_time(self, response_or_exception, attempt_count):
837837 with pytest .raises (AirbyteTracedException ) as e :
838838 http_client .send_request (http_method = "get" , url = "https://airbyte.io/" , request_kwargs = {})
839839 assert e .value .failure_type == expected_failure_type
840+
841+
842+ class MockOAuthAuthenticator :
843+ def __init__ (self ):
844+ self .access_token = "old_token"
845+ self ._token_expiry_date = None
846+ self .refresh_called = False
847+
848+ def refresh_and_set_access_token (self ):
849+ self .refresh_called = True
850+ self .access_token = "new_refreshed_token"
851+ self ._token_expiry_date = "2099-01-01T00:00:00Z"
852+
853+ def __call__ (self , request ):
854+ request .headers ["Authorization" ] = f"Bearer { self .access_token } "
855+ return request
856+
857+
858+ def test_refresh_token_then_retry_action_refreshes_oauth_token (mocker ):
859+ mock_authenticator = MockOAuthAuthenticator ()
860+ mocked_session = MagicMock (spec = requests .Session )
861+ mocked_session .auth = mock_authenticator
862+
863+ http_client = HttpClient (
864+ name = "test" ,
865+ logger = MagicMock (),
866+ error_handler = HttpStatusErrorHandler (
867+ logger = MagicMock (),
868+ error_mapping = {
869+ 401 : ErrorResolution (
870+ ResponseAction .REFRESH_TOKEN_THEN_RETRY ,
871+ FailureType .transient_error ,
872+ "Token expired, refreshing" ,
873+ )
874+ },
875+ ),
876+ session = mocked_session ,
877+ )
878+
879+ prepared_request = requests .PreparedRequest ()
880+ mocked_response = MagicMock (spec = requests .Response )
881+ mocked_response .status_code = 401
882+ mocked_response .headers = {}
883+ mocked_response .ok = False
884+ mocked_session .send .return_value = mocked_response
885+
886+ with pytest .raises (DefaultBackoffException ):
887+ http_client ._send (prepared_request , {})
888+
889+ assert mock_authenticator .refresh_called
890+ assert mock_authenticator .access_token == "new_refreshed_token"
891+ assert mock_authenticator ._token_expiry_date == "2099-01-01T00:00:00Z"
892+
893+
894+ def test_refresh_token_then_retry_action_without_oauth_authenticator_proceeds_with_retry (mocker ):
895+ mocked_session = MagicMock (spec = requests .Session )
896+ mocked_session .auth = None
897+
898+ mocked_logger = MagicMock ()
899+ http_client = HttpClient (
900+ name = "test" ,
901+ logger = mocked_logger ,
902+ error_handler = HttpStatusErrorHandler (
903+ logger = MagicMock (),
904+ error_mapping = {
905+ 401 : ErrorResolution (
906+ ResponseAction .REFRESH_TOKEN_THEN_RETRY ,
907+ FailureType .transient_error ,
908+ "Token expired, refreshing" ,
909+ )
910+ },
911+ ),
912+ session = mocked_session ,
913+ )
914+
915+ prepared_request = requests .PreparedRequest ()
916+ mocked_response = MagicMock (spec = requests .Response )
917+ mocked_response .status_code = 401
918+ mocked_response .headers = {}
919+ mocked_response .ok = False
920+ mocked_session .send .return_value = mocked_response
921+
922+ with pytest .raises (DefaultBackoffException ):
923+ http_client ._send (prepared_request , {})
924+
925+ mocked_logger .warning .assert_called ()
926+
927+
928+ def test_refresh_token_then_retry_action_handles_refresh_failure_gracefully (mocker ):
929+ class FailingOAuthAuthenticator :
930+ def __init__ (self ):
931+ self .access_token = "old_token"
932+
933+ def refresh_and_set_access_token (self ):
934+ raise Exception ("Token refresh failed" )
935+
936+ def __call__ (self , request ):
937+ return request
938+
939+ mock_authenticator = FailingOAuthAuthenticator ()
940+ mocked_session = MagicMock (spec = requests .Session )
941+ mocked_session .auth = mock_authenticator
942+
943+ mocked_logger = MagicMock ()
944+ http_client = HttpClient (
945+ name = "test" ,
946+ logger = mocked_logger ,
947+ error_handler = HttpStatusErrorHandler (
948+ logger = MagicMock (),
949+ error_mapping = {
950+ 401 : ErrorResolution (
951+ ResponseAction .REFRESH_TOKEN_THEN_RETRY ,
952+ FailureType .transient_error ,
953+ "Token expired, refreshing" ,
954+ )
955+ },
956+ ),
957+ session = mocked_session ,
958+ )
959+
960+ prepared_request = requests .PreparedRequest ()
961+ mocked_response = MagicMock (spec = requests .Response )
962+ mocked_response .status_code = 401
963+ mocked_response .headers = {}
964+ mocked_response .ok = False
965+ mocked_session .send .return_value = mocked_response
966+
967+ with pytest .raises (DefaultBackoffException ):
968+ http_client ._send (prepared_request , {})
969+
970+ mocked_logger .warning .assert_called ()
971+
972+
973+ def test_refresh_token_then_retry_action_with_single_use_refresh_token_authenticator (mocker ):
974+ from airbyte_cdk .sources .streams .http .requests_native_auth import (
975+ SingleUseRefreshTokenOauth2Authenticator ,
976+ )
977+
978+ mock_authenticator = MagicMock (spec = SingleUseRefreshTokenOauth2Authenticator )
979+
980+ mocked_session = MagicMock (spec = requests .Session )
981+ mocked_session .auth = mock_authenticator
982+
983+ http_client = HttpClient (
984+ name = "test" ,
985+ logger = MagicMock (),
986+ error_handler = HttpStatusErrorHandler (
987+ logger = MagicMock (),
988+ error_mapping = {
989+ 401 : ErrorResolution (
990+ ResponseAction .REFRESH_TOKEN_THEN_RETRY ,
991+ FailureType .transient_error ,
992+ "Token expired, refreshing" ,
993+ )
994+ },
995+ ),
996+ session = mocked_session ,
997+ )
998+
999+ prepared_request = requests .PreparedRequest ()
1000+ mocked_response = MagicMock (spec = requests .Response )
1001+ mocked_response .status_code = 401
1002+ mocked_response .headers = {}
1003+ mocked_response .ok = False
1004+ mocked_session .send .return_value = mocked_response
1005+
1006+ with pytest .raises (DefaultBackoffException ):
1007+ http_client ._send (prepared_request , {})
1008+
1009+ mock_authenticator .refresh_and_set_access_token .assert_called_once ()
1010+
1011+
1012+ @pytest .mark .usefixtures ("mock_sleep" )
1013+ def test_refresh_token_then_retry_action_retries_and_succeeds_after_token_refresh ():
1014+ mock_authenticator = MockOAuthAuthenticator ()
1015+ mocked_session = MagicMock (spec = requests .Session )
1016+ mocked_session .auth = mock_authenticator
1017+
1018+ valid_response = MagicMock (spec = requests .Response )
1019+ valid_response .status_code = 200
1020+ valid_response .ok = True
1021+ valid_response .headers = {}
1022+
1023+ call_count = 0
1024+
1025+ def update_response (* args , ** kwargs ):
1026+ nonlocal call_count
1027+ call_count += 1
1028+ if call_count == 1 :
1029+ retry_response = MagicMock (spec = requests .Response )
1030+ retry_response .ok = False
1031+ retry_response .status_code = 401
1032+ retry_response .headers = {}
1033+ return retry_response
1034+ else :
1035+ return valid_response
1036+
1037+ mocked_session .send .side_effect = update_response
1038+
1039+ http_client = HttpClient (
1040+ name = "test" ,
1041+ logger = MagicMock (),
1042+ error_handler = HttpStatusErrorHandler (
1043+ logger = MagicMock (),
1044+ error_mapping = {
1045+ 401 : ErrorResolution (
1046+ ResponseAction .REFRESH_TOKEN_THEN_RETRY ,
1047+ FailureType .transient_error ,
1048+ "Token expired, refreshing" ,
1049+ )
1050+ },
1051+ ),
1052+ session = mocked_session ,
1053+ )
1054+
1055+ prepared_request = requests .PreparedRequest ()
1056+ returned_response = http_client ._send_with_retry (prepared_request , request_kwargs = {})
1057+
1058+ assert mock_authenticator .refresh_called
1059+ assert mock_authenticator .access_token == "new_refreshed_token"
1060+ assert returned_response == valid_response
1061+ assert call_count == 2
0 commit comments