|
7 | 7 | import logging |
8 | 8 | from copy import deepcopy |
9 | 9 | from datetime import timedelta, timezone |
10 | | -from unittest.mock import Mock |
| 10 | +from unittest.mock import Mock, patch |
11 | 11 |
|
12 | 12 | import freezegun |
13 | 13 | import pytest |
14 | 14 | import requests |
15 | 15 | from requests import Response |
16 | 16 |
|
| 17 | +from airbyte_cdk.models import FailureType |
17 | 18 | from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator |
18 | 19 | from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator |
19 | 20 | from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse |
| 21 | +from airbyte_cdk.utils import AirbyteTracedException |
20 | 22 | from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets |
21 | 23 | from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse |
22 | 24 |
|
@@ -645,3 +647,75 @@ def mock_request(method, url, data, headers): |
645 | 647 | raise Exception( |
646 | 648 | f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}" |
647 | 649 | ) |
| 650 | + |
| 651 | + |
| 652 | +class TestOauth2AuthenticatorTransientErrorHandling: |
| 653 | + """Tests for transient network error handling during OAuth token refresh.""" |
| 654 | + |
| 655 | + def _create_authenticator(self): |
| 656 | + return DeclarativeOauth2Authenticator( |
| 657 | + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", |
| 658 | + client_id="{{ config['client_id'] }}", |
| 659 | + client_secret="{{ config['client_secret'] }}", |
| 660 | + refresh_token="{{ parameters['refresh_token'] }}", |
| 661 | + config=config, |
| 662 | + token_expiry_date="{{ config['token_expiry_date'] }}", |
| 663 | + parameters=parameters, |
| 664 | + ) |
| 665 | + |
| 666 | + @pytest.mark.parametrize( |
| 667 | + "exception_class", |
| 668 | + [ |
| 669 | + requests.exceptions.ConnectionError, |
| 670 | + requests.exceptions.ConnectTimeout, |
| 671 | + requests.exceptions.ReadTimeout, |
| 672 | + ], |
| 673 | + ids=["ConnectionError", "ConnectTimeout", "ReadTimeout"], |
| 674 | + ) |
| 675 | + def test_transient_network_error_wrapped_as_transient_error(self, exception_class): |
| 676 | + """Transient network errors during OAuth refresh are wrapped in AirbyteTracedException with transient_error.""" |
| 677 | + oauth = self._create_authenticator() |
| 678 | + with patch.object( |
| 679 | + oauth, "_make_handled_request", side_effect=exception_class("connection reset") |
| 680 | + ): |
| 681 | + with pytest.raises(AirbyteTracedException) as exc_info: |
| 682 | + oauth.refresh_access_token() |
| 683 | + |
| 684 | + assert exc_info.value.failure_type == FailureType.transient_error |
| 685 | + assert "network error" in exc_info.value.message.lower() |
| 686 | + |
| 687 | + def test_connection_error_is_retried_before_raising(self, mocker): |
| 688 | + """ConnectionError triggers backoff retries in _make_handled_request before propagating.""" |
| 689 | + oauth = self._create_authenticator() |
| 690 | + |
| 691 | + call_count = 0 |
| 692 | + |
| 693 | + def request_side_effect(**kwargs): |
| 694 | + nonlocal call_count |
| 695 | + call_count += 1 |
| 696 | + if call_count < 3: |
| 697 | + raise requests.exceptions.ConnectionError("connection reset by peer") |
| 698 | + mock_response = Mock(spec=requests.Response) |
| 699 | + mock_response.ok = True |
| 700 | + mock_response.json.return_value = {"access_token": "token_value", "expires_in": 3600} |
| 701 | + return mock_response |
| 702 | + |
| 703 | + mocker.patch("requests.request", side_effect=request_side_effect) |
| 704 | + # Patch backoff to avoid actual delays in tests |
| 705 | + mocker.patch("time.sleep") |
| 706 | + |
| 707 | + token, _ = oauth.refresh_access_token() |
| 708 | + assert token == "token_value" |
| 709 | + assert call_count == 3 |
| 710 | + |
| 711 | + def test_generic_exception_wrapped_as_system_error(self, mocker): |
| 712 | + """Generic exceptions during OAuth refresh are wrapped in AirbyteTracedException with system_error.""" |
| 713 | + oauth = self._create_authenticator() |
| 714 | + mocker.patch("requests.request", side_effect=ValueError("unexpected parsing error")) |
| 715 | + mocker.patch("time.sleep") |
| 716 | + |
| 717 | + with pytest.raises(AirbyteTracedException) as exc_info: |
| 718 | + oauth.refresh_access_token() |
| 719 | + |
| 720 | + assert exc_info.value.failure_type == FailureType.system_error |
| 721 | + assert "OAuth access token refresh request failed" in exc_info.value.message |
0 commit comments