From ccac289be451e87ab7f2b57abe744ddc655b4fc5 Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Thu, 10 Apr 2025 17:00:40 -0700 Subject: [PATCH 01/13] minimal change to insert default expiration --- airbyte_cdk/sources/declarative/auth/oauth.py | 2 +- .../requests_native_auth/abstract_oauth.py | 21 ++++++++++++++++--- .../http/requests_native_auth/oauth.py | 2 +- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index bc609e42e..1bebd41a4 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -239,7 +239,7 @@ def get_token_expiry_date(self) -> AirbyteDateTime: def _has_access_token_been_initialized(self) -> bool: return self._access_token is not None - def set_token_expiry_date(self, value: Union[str, int]) -> None: + def set_token_expiry_date(self, value: AirbyteDateTime) -> None: self._token_expiry_date = self._parse_token_expiration_date(value) def get_assertion_name(self) -> str: diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index b0afeca6e..338b47cb2 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -130,7 +130,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None: headers = self.get_refresh_request_headers() return headers if headers else None - def refresh_access_token(self) -> Tuple[str, Union[str, int]]: + def refresh_access_token(self) -> Tuple[str, Optional[str]]: """ Returns the refresh token and its expiration datetime @@ -147,6 +147,15 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]: # ---------------- # PRIVATE METHODS # ---------------- + + def _default_token_expiry_date(self) -> str: + """ + Returns the default token expiry date + """ + if self.token_expiry_is_time_of_expiration: + return str(ab_datetime_now() + timedelta(hours=1)) + else: + return "3600" def _wrap_refresh_token_exception( self, exception: requests.exceptions.RequestException @@ -316,9 +325,15 @@ def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any: response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. Returns: - str: The extracted token_expiry_date. + The extracted token_expiry_date or None if not found. """ - return self._find_and_get_value_from_response(response_data, self.get_expires_in_name()) + expires_in = self._find_and_get_value_from_response(response_data, self.get_expires_in_name()) + # If the access token expires in is None, we do not know when the token will expire + # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration + if expires_in is None: + return self._default_token_expiry_date() + else: + return expires_in def _find_and_get_value_from_response( self, diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 2ff2f60e9..6162fd2e2 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -355,7 +355,7 @@ def get_access_token(self) -> str: self._emit_control_message() return self.access_token - def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override] + def refresh_access_token(self) -> Tuple[str, Optional[str], str]: # type: ignore[override] """ Refreshes the access token by making a handled request and extracting the necessary token information. From 78ada5872c1790ed2c61a7e3f5cfb3f89156fc93 Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Thu, 10 Apr 2025 17:04:04 -0700 Subject: [PATCH 02/13] update tests --- .../test_requests_native_auth.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index d756931c8..138a25eec 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -4,9 +4,9 @@ import json import logging -from datetime import timedelta, timezone +from datetime import timedelta from typing import Optional, Union -from unittest.mock import Mock +from unittest.mock import Mock, PropertyMock import freezegun import pytest @@ -236,6 +236,7 @@ def test_refresh_request_body_with_keys_override(self): } assert body == expected + @freezegun.freeze_time("2022-01-01") def test_refresh_access_token(self, mocker): oauth = Oauth2Authenticator( token_refresh_endpoint="https://refresh_endpoint.com", @@ -281,6 +282,15 @@ def test_refresh_access_token(self, mocker): assert isinstance(expires_in, str) assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in) + # Test with no expires_in + mocker.patch.object( + resp, + "json", + return_value={"access_token": "access_token"}, + ) + token, expires_in = oauth.refresh_access_token() + assert expires_in == "3600" + # Test with nested access_token and expires_in as str(int) mocker.patch.object( resp, @@ -393,8 +403,10 @@ def test_refresh_access_token_when_headers_provided(self, mocker): "YYYY-MM-DDTHH:mm:ss.SSSSSSZ", AirbyteDateTime(year=2022, month=2, day=12), ), + (None, None, AirbyteDateTime(year=2022, month=1, day=1, hour=1)), + (None, "YYYY-MM-DD", AirbyteDateTime(year=2022, month=1, day=1, hour=1)), ], - ids=["seconds", "string_of_seconds", "simple_date", "simple_datetime"], + ids=["seconds", "string_of_seconds", "simple_date", "simple_datetime", "default_behavior", "default_behavior_with_format"], ) @freezegun.freeze_time("2022-01-01") def test_parse_refresh_token_lifespan( From aa1be428cf51c284c70971239c5733c1eb0a6778 Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Thu, 10 Apr 2025 17:24:15 -0700 Subject: [PATCH 03/13] remove unused import --- .../http/requests_native_auth/test_requests_native_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index 138a25eec..a69db6d00 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -6,7 +6,7 @@ import logging from datetime import timedelta from typing import Optional, Union -from unittest.mock import Mock, PropertyMock +from unittest.mock import Mock import freezegun import pytest From d8c0c64bbc85e323b476f04d514e8d16bc654188 Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Fri, 11 Apr 2025 11:34:22 -0700 Subject: [PATCH 04/13] undo function signature change --- airbyte_cdk/sources/declarative/auth/oauth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index 1bebd41a4..bc609e42e 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -239,7 +239,7 @@ def get_token_expiry_date(self) -> AirbyteDateTime: def _has_access_token_been_initialized(self) -> bool: return self._access_token is not None - def set_token_expiry_date(self, value: AirbyteDateTime) -> None: + def set_token_expiry_date(self, value: Union[str, int]) -> None: self._token_expiry_date = self._parse_token_expiration_date(value) def get_assertion_name(self) -> str: From 511c32f882ab4c59b70a348a9d45c96de53fafbe Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Fri, 11 Apr 2025 11:46:03 -0700 Subject: [PATCH 05/13] undo typing change --- .../sources/streams/http/requests_native_auth/abstract_oauth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 338b47cb2..5acd4c14c 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -130,7 +130,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None: headers = self.get_refresh_request_headers() return headers if headers else None - def refresh_access_token(self) -> Tuple[str, Optional[str]]: + def refresh_access_token(self) -> Tuple[str, Union[str, int]]: """ Returns the refresh token and its expiration datetime From 19389d8345893e07c5a8f37688a7ded3f479246f Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Fri, 11 Apr 2025 11:49:15 -0700 Subject: [PATCH 06/13] mypy --- airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 6162fd2e2..2ff2f60e9 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -355,7 +355,7 @@ def get_access_token(self) -> str: self._emit_control_message() return self.access_token - def refresh_access_token(self) -> Tuple[str, Optional[str], str]: # type: ignore[override] + def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override] """ Refreshes the access token by making a handled request and extracting the necessary token information. From 2494d93b42d768f6e16160a62ac848fef15cc612 Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Fri, 11 Apr 2025 11:50:40 -0700 Subject: [PATCH 07/13] ruff formatting --- .../streams/http/requests_native_auth/abstract_oauth.py | 6 ++++-- .../requests_native_auth/test_requests_native_auth.py | 9 ++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 5acd4c14c..e22b7e83a 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -147,7 +147,7 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]: # ---------------- # PRIVATE METHODS # ---------------- - + def _default_token_expiry_date(self) -> str: """ Returns the default token expiry date @@ -327,7 +327,9 @@ def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any: Returns: The extracted token_expiry_date or None if not found. """ - expires_in = self._find_and_get_value_from_response(response_data, self.get_expires_in_name()) + expires_in = self._find_and_get_value_from_response( + response_data, self.get_expires_in_name() + ) # If the access token expires in is None, we do not know when the token will expire # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration if expires_in is None: diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index a69db6d00..d264f8a19 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -406,7 +406,14 @@ def test_refresh_access_token_when_headers_provided(self, mocker): (None, None, AirbyteDateTime(year=2022, month=1, day=1, hour=1)), (None, "YYYY-MM-DD", AirbyteDateTime(year=2022, month=1, day=1, hour=1)), ], - ids=["seconds", "string_of_seconds", "simple_date", "simple_datetime", "default_behavior", "default_behavior_with_format"], + ids=[ + "seconds", + "string_of_seconds", + "simple_date", + "simple_datetime", + "default_behavior", + "default_behavior_with_format", + ], ) @freezegun.freeze_time("2022-01-01") def test_parse_refresh_token_lifespan( From c973e026fca042a3852511dba4fce2a32d4b4f65 Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Fri, 11 Apr 2025 15:12:15 -0700 Subject: [PATCH 08/13] attempt to rework classes to simplify expiration date handling and type casting --- airbyte_cdk/sources/declarative/auth/oauth.py | 4 +-- .../requests_native_auth/abstract_oauth.py | 32 +++++++++---------- .../http/requests_native_auth/oauth.py | 31 +++--------------- 3 files changed, 22 insertions(+), 45 deletions(-) diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index bc609e42e..15b5cfbce 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -239,8 +239,8 @@ def get_token_expiry_date(self) -> AirbyteDateTime: def _has_access_token_been_initialized(self) -> bool: return self._access_token is not None - def set_token_expiry_date(self, value: Union[str, int]) -> None: - self._token_expiry_date = self._parse_token_expiration_date(value) + def set_token_expiry_date(self, value: AirbyteDateTime) -> None: + self._token_expiry_date = value def get_assertion_name(self) -> str: return self.assertion_name diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index e22b7e83a..67023f6cf 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -130,7 +130,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None: headers = self.get_refresh_request_headers() return headers if headers else None - def refresh_access_token(self) -> Tuple[str, Union[str, int]]: + def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: """ Returns the refresh token and its expiration datetime @@ -148,14 +148,13 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]: # PRIVATE METHODS # ---------------- - def _default_token_expiry_date(self) -> str: + def _default_token_expiry_date(self) -> AirbyteDateTime: """ Returns the default token expiry date """ - if self.token_expiry_is_time_of_expiration: - return str(ab_datetime_now() + timedelta(hours=1)) - else: - return "3600" + # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration + default_token_expiry_duration_hours = 1 # 1 hour + return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) def _wrap_refresh_token_exception( self, exception: requests.exceptions.RequestException @@ -266,14 +265,10 @@ def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: """ - Return the expiration datetime of the refresh token + Parse a string or integer token expiration date into a datetime object :return: expiration datetime """ - if not value and not self.token_has_expired(): - # No expiry token was provided but the previous one is not expired so it's fine - return self.get_token_expiry_date() - if self.token_expiry_is_time_of_expiration: if not self.token_expiry_date_format: raise ValueError( @@ -317,9 +312,11 @@ def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: """ return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) - def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any: + def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: """ Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. + + If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. Args: response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. @@ -331,11 +328,14 @@ def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any: response_data, self.get_expires_in_name() ) # If the access token expires in is None, we do not know when the token will expire - # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration if expires_in is None: - return self._default_token_expiry_date() + # If the token expiry date is set and the token is not expired, continue using the existing token expiry date + if self.get_token_expiry_date() and not self.token_has_expired(): + return self.get_token_expiry_date() + else: + return self._default_token_expiry_date() else: - return expires_in + return self._parse_token_expiration_date(expires_in) def _find_and_get_value_from_response( self, @@ -458,7 +458,7 @@ def get_token_expiry_date(self) -> AirbyteDateTime: """Expiration date of the access token""" @abstractmethod - def set_token_expiry_date(self, value: Union[str, int]) -> None: + def set_token_expiry_date(self, value: AirbyteDateTime) -> None: """Setter for access token expiration date""" @abstractmethod diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 2ff2f60e9..0ca6f6b3a 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -120,8 +120,8 @@ def get_grant_type(self) -> str: def get_token_expiry_date(self) -> AirbyteDateTime: return self._token_expiry_date - def set_token_expiry_date(self, value: Union[str, int]) -> None: - self._token_expiry_date = self._parse_token_expiration_date(value) + def set_token_expiry_date(self, value: AirbyteDateTime) -> None: + self._token_expiry_date = value @property def token_expiry_is_time_of_expiration(self) -> bool: @@ -316,26 +316,6 @@ def token_has_expired(self) -> bool: """Returns True if the token is expired""" return ab_datetime_now() > self.get_token_expiry_date() - @staticmethod - def get_new_token_expiry_date( - access_token_expires_in: str, - token_expiry_date_format: str | None = None, - ) -> AirbyteDateTime: - """ - Calculate the new token expiry date based on the provided expiration duration or format. - - Args: - access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format. - token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None. - - Returns: - AirbyteDateTime: The calculated expiry date of the access token. - """ - if token_expiry_date_format: - return ab_datetime_parse(access_token_expires_in) - else: - return ab_datetime_now() + timedelta(seconds=int(access_token_expires_in)) - def get_access_token(self) -> str: """Retrieve new access and refresh token if the access token has expired. The new refresh token is persisted with the set_refresh_token function @@ -346,16 +326,13 @@ def get_access_token(self) -> str: new_access_token, access_token_expires_in, new_refresh_token = ( self.refresh_access_token() ) - new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date( - access_token_expires_in, self._token_expiry_date_format - ) self.access_token = new_access_token self.set_refresh_token(new_refresh_token) - self.set_token_expiry_date(new_token_expiry_date) + self.set_token_expiry_date(access_token_expires_in) self._emit_control_message() return self.access_token - def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override] + def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override] """ Refreshes the access token by making a handled request and extracting the necessary token information. From 6290c358e87400c89b14d31bd366d624cd9dfc0e Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Mon, 14 Apr 2025 21:09:16 -0700 Subject: [PATCH 09/13] begin updating tests --- .../requests_native_auth/abstract_oauth.py | 2 +- .../sources/declarative/auth/test_oauth.py | 63 ++++++++- .../test_requests_native_auth.py | 128 ++++++++++++------ 3 files changed, 142 insertions(+), 51 deletions(-) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 67023f6cf..651fd1d10 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -361,7 +361,7 @@ def _find_and_get_value_from_response( """ if current_depth > max_depth: # this is needed to avoid an inf loop, possible with a very deep nesting observed. - message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response." + message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." raise ResponseKeysMaxRecurtionReached( internal_message=message, message=message, failure_type=FailureType.config_error ) diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index c54b9982f..f8cc77c85 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -203,6 +203,7 @@ def test_error_on_refresh_token_grant_without_refresh_token(self): grant_type="refresh_token", ) + @freezegun.freeze_time("2022-01-01") def test_refresh_access_token(self, mocker): oauth = DeclarativeOauth2Authenticator( token_refresh_endpoint="{{ config['refresh_endpoint'] }}", @@ -225,13 +226,15 @@ def test_refresh_access_token(self, mocker): resp, "json", return_value={"access_token": "access_token", "expires_in": 1000} ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) - token = oauth.refresh_access_token() + access_token, token_expiry_date = oauth.refresh_access_token() - assert ("access_token", 1000) == token + assert access_token == "access_token" + assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000) filtered = filter_secrets("access_token") assert filtered == "****" + @freezegun.freeze_time("2022-01-01") def test_refresh_access_token_when_headers_provided(self, mocker): expected_headers = { "Authorization": "Bearer some_access_token", @@ -256,9 +259,10 @@ def test_refresh_access_token_when_headers_provided(self, mocker): mocked_request = mocker.patch.object( requests, "request", side_effect=mock_request, autospec=True ) - token = oauth.refresh_access_token() + access_token, token_expiry_date = oauth.refresh_access_token() - assert ("access_token", 1000) == token + assert access_token == "access_token" + assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000) assert mocked_request.call_args.kwargs["headers"] == expected_headers @@ -314,6 +318,7 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp( assert isinstance(oauth._token_expiry_date, AirbyteDateTime) assert oauth.get_token_expiry_date() == ab_datetime_parse(expected_date) + @freezegun.freeze_time("2022-01-01") def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_fetch_access_token( self, ) -> None: @@ -335,12 +340,55 @@ def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_ url="https://refresh_endpoint.com/", body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token", ), - HttpResponse(body=json.dumps({"access_token": "new_access_token"})), + HttpResponse(body=json.dumps({"access_token": "new_access_token", "expires_in": 1000})), ) oauth.get_access_token() assert oauth.access_token == "new_access_token" - assert oauth._token_expiry_date == expiry_date + assert oauth._token_expiry_date == ab_datetime_now() + timedelta(seconds=1000) + + @freezegun.freeze_time("2022-01-01") + @pytest.mark.parametrize( + "initial_expiry_date_delta, expected_new_expiry_date_delta, expected_access_token", + [ + (timedelta(days=1), timedelta(days=1), "some_access_token"), + (timedelta(days=-1), timedelta(hours=1), "new_access_token"), + (None, timedelta(hours=1), "new_access_token"), + ], + ids=["initial_expiry_date_in_future", "initial_expiry_date_in_past", "no_initial_expiry_date"], + ) + def test_no_expiry_date_provided_by_auth_server( + self, + initial_expiry_date_delta, + expected_new_expiry_date_delta, + expected_access_token, + ) -> None: + initial_expiry_date = ab_datetime_now().add(initial_expiry_date_delta).isoformat() if initial_expiry_date_delta else None + expected_new_expiry_date = ab_datetime_now().add(expected_new_expiry_date_delta) + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="https://refresh_endpoint.com/", + client_id="some_client_id", + client_secret="some_client_secret", + token_expiry_date=initial_expiry_date, + access_token_value="some_access_token", + refresh_token="some_refresh_token", + config={}, + parameters={}, + grant_type="client", + ) + + with HttpMocker() as http_mocker: + http_mocker.post( + HttpRequest( + url="https://refresh_endpoint.com/", + body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token", + ), + HttpResponse(body=json.dumps({"access_token": "new_access_token"})), + ) + oauth.get_access_token() + + assert oauth.access_token == expected_access_token + assert oauth._token_expiry_date == expected_new_expiry_date @pytest.mark.parametrize( "expires_in_response, token_expiry_date_format", @@ -443,6 +491,7 @@ def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next assert "access_token" == token assert oauth.get_token_expiry_date() == ab_datetime_parse(next_day) + @freezegun.freeze_time("2022-01-01") def test_profile_assertion(self, mocker): with HttpMocker() as http_mocker: jwt = JwtAuthenticator( @@ -477,7 +526,7 @@ def test_profile_assertion(self, mocker): token = oauth.refresh_access_token() - assert ("access_token", 1000) == token + assert ("access_token", ab_datetime_now().add(timedelta(seconds=1000))) == token filtered = filter_secrets("access_token") assert filtered == "****" diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index d264f8a19..dacb48f4b 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -83,6 +83,7 @@ def test_multiple_token_authenticator(): assert {"Authorization": "Bearer token1"} == header3 +@freezegun.freeze_time("2022-01-01") class TestOauth2Authenticator: """ Test class for OAuth2Authenticator. @@ -104,8 +105,9 @@ def test_get_auth_header_fresh(self, mocker): refresh_token=TestOauth2Authenticator.refresh_token, ) + expires_in = ab_datetime_now().add(timedelta(seconds=1000)) mocker.patch.object( - Oauth2Authenticator, "refresh_access_token", return_value=("access_token", 1000) + Oauth2Authenticator, "refresh_access_token", return_value=("access_token", expires_in) ) header = oauth.get_auth_header() assert {"Authorization": "Bearer access_token"} == header @@ -121,15 +123,15 @@ def test_get_auth_header_expired(self, mocker): refresh_token=TestOauth2Authenticator.refresh_token, ) - expire_immediately = 0 + already_expired = ab_datetime_now() - timedelta(seconds=100) mocker.patch.object( Oauth2Authenticator, "refresh_access_token", - return_value=("access_token_1", expire_immediately), + return_value=("access_token_1", already_expired), ) oauth.get_auth_header() # Set the first expired token. - valid_100_secs = 100 + valid_100_secs = ab_datetime_now() + timedelta(seconds=100) mocker.patch.object( Oauth2Authenticator, "refresh_access_token", @@ -236,7 +238,6 @@ def test_refresh_request_body_with_keys_override(self): } assert body == expected - @freezegun.freeze_time("2022-01-01") def test_refresh_access_token(self, mocker): oauth = Oauth2Authenticator( token_refresh_endpoint="https://refresh_endpoint.com", @@ -251,6 +252,21 @@ def test_refresh_access_token(self, mocker): "scopes": ["no_override"], }, ) + + oauth_with_expired_token= Oauth2Authenticator( + token_refresh_endpoint="https://refresh_endpoint.com", + client_id="some_client_id", + client_secret="some_client_secret", + refresh_token="some_refresh_token", + scopes=["scope1", "scope2"], + token_expiry_date=ab_datetime_now() - timedelta(days=3), + refresh_request_body={ + "custom_field": "in_outbound_request", + "another_field": "exists_in_body", + "scopes": ["no_override"], + }, + ) + resp.status_code = 200 mocker.patch.object( @@ -259,8 +275,9 @@ def test_refresh_access_token(self, mocker): mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, int) - assert ("access_token", 1000) == (token, expires_in) + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(seconds=1000)) + assert token == "access_token" # Test with expires_in as str(int) mocker.patch.object( @@ -268,28 +285,40 @@ def test_refresh_access_token(self, mocker): ) token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, str) - assert ("access_token", "2000") == (token, expires_in) - + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(seconds=2000)) + assert token == "access_token" + # Test with expires_in as datetime(str) mocker.patch.object( resp, "json", return_value={"access_token": "access_token", "expires_in": "2022-04-24T00:00:00Z"}, ) - token, expires_in = oauth.refresh_access_token() - - assert isinstance(expires_in, str) - assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in) - - # Test with no expires_in + # This should raise a ValueError because the token_expiry_is_time_of_expiration is False by default + with pytest.raises(ValueError): + token, expires_in = oauth.refresh_access_token() + + # Test with no expires_in mocker.patch.object( resp, "json", return_value={"access_token": "access_token"}, ) + + # Since the initialized token is not expired (now + 3 days), we don't expect the expiration date to be updated token, expires_in = oauth.refresh_access_token() - assert expires_in == "3600" + + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(days=3)) + assert token == "access_token" + + # Since the initialized token is expired (now - 3 days), we expect the expiration date to be updated to the default value (now + 1 hour) + token, expires_in = oauth_with_expired_token.refresh_access_token() + + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(hours=1)) + assert token == "access_token" # Test with nested access_token and expires_in as str(int) mocker.patch.object( @@ -299,8 +328,9 @@ def test_refresh_access_token(self, mocker): ) token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, str) - assert ("access_token_nested", "2001") == (token, expires_in) + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(seconds=2001)) + assert token == "access_token_nested" # Test with multiple nested levels access_token and expires_in as str(int) mocker.patch.object( @@ -326,9 +356,10 @@ def test_refresh_access_token(self, mocker): }, ) token, expires_in = oauth.refresh_access_token() - - assert isinstance(expires_in, str) - assert ("access_token_deeply_nested", "2002") == (token, expires_in) + + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(seconds=2002)) + assert token == "access_token_deeply_nested" # Test with max nested levels access_token and expires_in as str(int) mocker.patch.object( @@ -358,7 +389,7 @@ def test_refresh_access_token(self, mocker): ) with pytest.raises(ResponseKeysMaxRecurtionReached) as exc_info: oauth.refresh_access_token() - error_message = "The maximum level of recursion is reached. Couldn't find the speficied `access_token` in the response." + error_message = "The maximum level of recursion is reached. Couldn't find the specified `access_token` in the response." assert exc_info.value.internal_message == error_message assert exc_info.value.message == error_message assert exc_info.value.failure_type == FailureType.config_error @@ -387,8 +418,9 @@ def test_refresh_access_token_when_headers_provided(self, mocker): ) token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, int) - assert ("access_token", 1000) == (token, expires_in) + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(seconds=1000)) + assert token == "access_token" assert mocked_request.call_args.kwargs["headers"] == expected_headers @@ -415,7 +447,6 @@ def test_refresh_access_token_when_headers_provided(self, mocker): "default_behavior_with_format", ], ) - @freezegun.freeze_time("2022-01-01") def test_parse_refresh_token_lifespan( self, mocker, @@ -446,14 +477,11 @@ def test_parse_refresh_token_lifespan( return_value={"access_token": "access_token", "expires_in": expires_in_response}, ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) - token, expire_in = oauth.refresh_access_token() - expires_datetime = oauth._parse_token_expiration_date(expire_in) + token, expires_datetime = oauth.refresh_access_token() assert isinstance(expires_datetime, AirbyteDateTime) - assert ("access_token", expected_token_expiry_date) == ( - token, - expires_datetime, - ) + assert expires_datetime == expected_token_expiry_date + assert token == "access_token" @pytest.mark.usefixtures("mock_sleep") @pytest.mark.parametrize("error_code", (429, 500, 502, 504)) @@ -473,8 +501,9 @@ def test_refresh_access_token_retry(self, error_code, requests_mock): ], ) token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, int) - assert (token, expires_in) == ("token", 10) + assert isinstance(expires_in, AirbyteDateTime) + assert token == "token" + assert expires_in == ab_datetime_now().add(timedelta(seconds=10)) assert requests_mock.call_count == 3 def test_auth_call_method(self, mocker): @@ -485,8 +514,9 @@ def test_auth_call_method(self, mocker): refresh_token=TestOauth2Authenticator.refresh_token, ) + expires_in = ab_datetime_now().add(timedelta(seconds=1000)) mocker.patch.object( - Oauth2Authenticator, "refresh_access_token", return_value=("access_token", 1000) + Oauth2Authenticator, "refresh_access_token", return_value=("access_token", expires_in) ) prepared_request = requests.PreparedRequest() prepared_request.headers = {} @@ -549,7 +579,7 @@ def test_refresh_access_token_wrapped( assert exc_info.value.message == error_message assert exc_info.value.failure_type == FailureType.config_error - +@freezegun.freeze_time("2022-12-31") class TestSingleUseRefreshTokenOauth2Authenticator: @pytest.fixture def connector_config(self): @@ -570,7 +600,7 @@ def invalid_connector_config(self): def test_init(self, connector_config): authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, - token_refresh_endpoint="foobar", + token_refresh_endpoint="https://refresh_endpoint.com", client_id=connector_config["credentials"]["client_id"], client_secret=connector_config["credentials"]["client_secret"], ) @@ -580,7 +610,6 @@ def test_init(self, connector_config): connector_config["credentials"]["token_expiry_date"] ) - @freezegun.freeze_time("2022-12-31") @pytest.mark.parametrize( "test_name, expires_in_value, expiry_date_format, expected_expiry_date", [ @@ -601,14 +630,20 @@ def test_given_no_message_repository_get_access_token( ): authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, - token_refresh_endpoint="foobar", + token_refresh_endpoint="https://refresh_endpoint.com", client_id=connector_config["credentials"]["client_id"], client_secret=connector_config["credentials"]["client_secret"], token_expiry_date_format=expiry_date_format, + token_expiry_is_time_of_expiration=bool(expiry_date_format), ) - authenticator.refresh_access_token = mocker.Mock( - return_value=("new_access_token", expires_in_value, "new_refresh_token") + + # Mock the response from the refresh token endpoint + resp.status_code = 200 + mocker.patch.object( + resp, "json", return_value={"access_token": "new_access_token", "expires_in": expires_in_value, "refresh_token": "new_refresh_token"} ) + mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) + authenticator.token_has_expired = mocker.Mock(return_value=True) access_token = authenticator.get_access_token() captured = capsys.readouterr() @@ -639,9 +674,16 @@ def test_given_message_repository_when_get_access_token_then_emit_message( token_expiry_date_format="YYYY-MM-DD", message_repository=message_repository, ) - authenticator.refresh_access_token = mocker.Mock( - return_value=("new_access_token", "2023-04-04", "new_refresh_token") + # Mock the response from the refresh token endpoint + resp.status_code = 200 + mocker.patch.object( + resp, "json", return_value={"access_token": "new_access_token", "expires_in": "2023-04-04", "refresh_token": "new_refresh_token"} ) + mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) + + # authenticator.refresh_access_token = mocker.Mock( + # return_value=("new_access_token", "2023-04-04", "new_refresh_token") + # ) authenticator.token_has_expired = mocker.Mock(return_value=True) authenticator.get_access_token() From 393dd6fdb7c57e0ecfc914c2262b9b1fbd6f4782 Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Tue, 15 Apr 2025 14:09:05 -0700 Subject: [PATCH 10/13] finish fixing tests --- .../test_requests_native_auth.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index dacb48f4b..a8ac568f4 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -668,9 +668,10 @@ def test_given_message_repository_when_get_access_token_then_emit_message( message_repository = Mock() authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, - token_refresh_endpoint="foobar", + token_refresh_endpoint="https://refresh_endpoint.com", client_id=connector_config["credentials"]["client_id"], client_secret=connector_config["credentials"]["client_secret"], + token_expiry_is_time_of_expiration=True, token_expiry_date_format="YYYY-MM-DD", message_repository=message_repository, ) @@ -681,9 +682,6 @@ def test_given_message_repository_when_get_access_token_then_emit_message( ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) - # authenticator.refresh_access_token = mocker.Mock( - # return_value=("new_access_token", "2023-04-04", "new_refresh_token") - # ) authenticator.token_has_expired = mocker.Mock(return_value=True) authenticator.get_access_token() @@ -744,21 +742,26 @@ def test_given_message_repository_when_get_access_token_then_log_request( def test_refresh_access_token(self, mocker, connector_config): authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, - token_refresh_endpoint="foobar", + token_refresh_endpoint="https://refresh_endpoint.com", client_id=connector_config["credentials"]["client_id"], client_secret=connector_config["credentials"]["client_secret"], ) - - authenticator._make_handled_request = mocker.Mock( + + # Mock the response from the refresh token endpoint + resp.status_code = 200 + mocker.patch.object( + resp, "json", return_value={ authenticator.get_access_token_name(): "new_access_token", authenticator.get_expires_in_name(): "42", authenticator.get_refresh_token_name(): "new_refresh_token", } ) + mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) + assert authenticator.refresh_access_token() == ( "new_access_token", - "42", + ab_datetime_now().add(timedelta(seconds=42)), "new_refresh_token", ) From c7d7b56e4a1f5e4033b8378dff2bc6d26497ded0 Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Tue, 15 Apr 2025 14:14:06 -0700 Subject: [PATCH 11/13] ruff format --- .../requests_native_auth/abstract_oauth.py | 4 +- .../sources/declarative/auth/test_oauth.py | 18 +++++-- .../test_requests_native_auth.py | 51 ++++++++++++------- 3 files changed, 48 insertions(+), 25 deletions(-) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 651fd1d10..f2a698e1d 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -153,7 +153,7 @@ def _default_token_expiry_date(self) -> AirbyteDateTime: Returns the default token expiry date """ # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration - default_token_expiry_duration_hours = 1 # 1 hour + default_token_expiry_duration_hours = 1 # 1 hour return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) def _wrap_refresh_token_exception( @@ -315,7 +315,7 @@ def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: """ Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. - + If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. Args: diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index f8cc77c85..077aa4573 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -340,13 +340,15 @@ def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_ url="https://refresh_endpoint.com/", body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token", ), - HttpResponse(body=json.dumps({"access_token": "new_access_token", "expires_in": 1000})), + HttpResponse( + body=json.dumps({"access_token": "new_access_token", "expires_in": 1000}) + ), ) oauth.get_access_token() assert oauth.access_token == "new_access_token" assert oauth._token_expiry_date == ab_datetime_now() + timedelta(seconds=1000) - + @freezegun.freeze_time("2022-01-01") @pytest.mark.parametrize( "initial_expiry_date_delta, expected_new_expiry_date_delta, expected_access_token", @@ -355,7 +357,11 @@ def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_ (timedelta(days=-1), timedelta(hours=1), "new_access_token"), (None, timedelta(hours=1), "new_access_token"), ], - ids=["initial_expiry_date_in_future", "initial_expiry_date_in_past", "no_initial_expiry_date"], + ids=[ + "initial_expiry_date_in_future", + "initial_expiry_date_in_past", + "no_initial_expiry_date", + ], ) def test_no_expiry_date_provided_by_auth_server( self, @@ -363,7 +369,11 @@ def test_no_expiry_date_provided_by_auth_server( expected_new_expiry_date_delta, expected_access_token, ) -> None: - initial_expiry_date = ab_datetime_now().add(initial_expiry_date_delta).isoformat() if initial_expiry_date_delta else None + initial_expiry_date = ( + ab_datetime_now().add(initial_expiry_date_delta).isoformat() + if initial_expiry_date_delta + else None + ) expected_new_expiry_date = ab_datetime_now().add(expected_new_expiry_date_delta) oauth = DeclarativeOauth2Authenticator( token_refresh_endpoint="https://refresh_endpoint.com/", diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index a8ac568f4..dbfc0ac86 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -252,8 +252,8 @@ def test_refresh_access_token(self, mocker): "scopes": ["no_override"], }, ) - - oauth_with_expired_token= Oauth2Authenticator( + + oauth_with_expired_token = Oauth2Authenticator( token_refresh_endpoint="https://refresh_endpoint.com", client_id="some_client_id", client_secret="some_client_secret", @@ -266,7 +266,6 @@ def test_refresh_access_token(self, mocker): "scopes": ["no_override"], }, ) - resp.status_code = 200 mocker.patch.object( @@ -288,7 +287,7 @@ def test_refresh_access_token(self, mocker): assert isinstance(expires_in, AirbyteDateTime) assert expires_in == ab_datetime_now().add(timedelta(seconds=2000)) assert token == "access_token" - + # Test with expires_in as datetime(str) mocker.patch.object( resp, @@ -298,24 +297,24 @@ def test_refresh_access_token(self, mocker): # This should raise a ValueError because the token_expiry_is_time_of_expiration is False by default with pytest.raises(ValueError): token, expires_in = oauth.refresh_access_token() - - # Test with no expires_in + + # Test with no expires_in mocker.patch.object( resp, "json", return_value={"access_token": "access_token"}, ) - + # Since the initialized token is not expired (now + 3 days), we don't expect the expiration date to be updated token, expires_in = oauth.refresh_access_token() - + assert isinstance(expires_in, AirbyteDateTime) assert expires_in == ab_datetime_now().add(timedelta(days=3)) assert token == "access_token" - + # Since the initialized token is expired (now - 3 days), we expect the expiration date to be updated to the default value (now + 1 hour) token, expires_in = oauth_with_expired_token.refresh_access_token() - + assert isinstance(expires_in, AirbyteDateTime) assert expires_in == ab_datetime_now().add(timedelta(hours=1)) assert token == "access_token" @@ -356,7 +355,7 @@ def test_refresh_access_token(self, mocker): }, ) token, expires_in = oauth.refresh_access_token() - + assert isinstance(expires_in, AirbyteDateTime) assert expires_in == ab_datetime_now().add(timedelta(seconds=2002)) assert token == "access_token_deeply_nested" @@ -579,6 +578,7 @@ def test_refresh_access_token_wrapped( assert exc_info.value.message == error_message assert exc_info.value.failure_type == FailureType.config_error + @freezegun.freeze_time("2022-12-31") class TestSingleUseRefreshTokenOauth2Authenticator: @pytest.fixture @@ -636,14 +636,20 @@ def test_given_no_message_repository_get_access_token( token_expiry_date_format=expiry_date_format, token_expiry_is_time_of_expiration=bool(expiry_date_format), ) - + # Mock the response from the refresh token endpoint resp.status_code = 200 mocker.patch.object( - resp, "json", return_value={"access_token": "new_access_token", "expires_in": expires_in_value, "refresh_token": "new_refresh_token"} + resp, + "json", + return_value={ + authenticator.get_access_token_name(): "new_access_token", + authenticator.get_expires_in_name(): expires_in_value, + authenticator.get_refresh_token_name(): "new_refresh_token", + }, ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) - + authenticator.token_has_expired = mocker.Mock(return_value=True) access_token = authenticator.get_access_token() captured = capsys.readouterr() @@ -678,10 +684,16 @@ def test_given_message_repository_when_get_access_token_then_emit_message( # Mock the response from the refresh token endpoint resp.status_code = 200 mocker.patch.object( - resp, "json", return_value={"access_token": "new_access_token", "expires_in": "2023-04-04", "refresh_token": "new_refresh_token"} + resp, + "json", + return_value={ + authenticator.get_access_token_name(): "new_access_token", + authenticator.get_expires_in_name(): "2023-04-04", + authenticator.get_refresh_token_name(): "new_refresh_token", + }, ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) - + authenticator.token_has_expired = mocker.Mock(return_value=True) authenticator.get_access_token() @@ -746,16 +758,17 @@ def test_refresh_access_token(self, mocker, connector_config): client_id=connector_config["credentials"]["client_id"], client_secret=connector_config["credentials"]["client_secret"], ) - + # Mock the response from the refresh token endpoint resp.status_code = 200 mocker.patch.object( - resp, "json", + resp, + "json", return_value={ authenticator.get_access_token_name(): "new_access_token", authenticator.get_expires_in_name(): "42", authenticator.get_refresh_token_name(): "new_refresh_token", - } + }, ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) From 0dc56621068c4256121323f8a96923f905006e93 Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Tue, 15 Apr 2025 16:03:55 -0700 Subject: [PATCH 12/13] simplify nested if statements --- .../http/requests_native_auth/abstract_oauth.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index f2a698e1d..1e85c3b5f 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -327,15 +327,15 @@ def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Airbyt expires_in = self._find_and_get_value_from_response( response_data, self.get_expires_in_name() ) - # If the access token expires in is None, we do not know when the token will expire - if expires_in is None: - # If the token expiry date is set and the token is not expired, continue using the existing token expiry date - if self.get_token_expiry_date() and not self.token_has_expired(): - return self.get_token_expiry_date() - else: - return self._default_token_expiry_date() - else: + if expires_in is not None: return self._parse_token_expiration_date(expires_in) + + # expires_in is None + existing_expiry_date = self.get_token_expiry_date() + if existing_expiry_date and not self.token_has_expired(): + return existing_expiry_date + + return self._default_token_expiry_date() def _find_and_get_value_from_response( self, From 42807bf3c616a05371faf69e1ab7874b9ec7ec23 Mon Sep 17 00:00:00 2001 From: David Gold <32782137+dbgold17@users.noreply.github.com> Date: Wed, 16 Apr 2025 16:03:04 -0700 Subject: [PATCH 13/13] ruffing --- .../streams/http/requests_native_auth/abstract_oauth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 1e85c3b5f..108055f1d 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -329,12 +329,12 @@ def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Airbyt ) if expires_in is not None: return self._parse_token_expiration_date(expires_in) - + # expires_in is None existing_expiry_date = self.get_token_expiry_date() if existing_expiry_date and not self.token_has_expired(): return existing_expiry_date - + return self._default_token_expiry_date() def _find_and_get_value_from_response(