Skip to content
2 changes: 1 addition & 1 deletion airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
def _has_access_token_been_initialized(self) -> bool:
return self._access_token is not None

def set_token_expiry_date(self, value: Union[str, int]) -> None:
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
Comment thread
dbgold17 marked this conversation as resolved.
self._token_expiry_date = self._parse_token_expiration_date(value)
Comment thread
dbgold17 marked this conversation as resolved.
Outdated

def get_assertion_name(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
headers = self.get_refresh_request_headers()
return headers if headers else None

def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
def refresh_access_token(self) -> Tuple[str, Optional[str]]:
"""
Returns the refresh token and its expiration datetime

Expand All @@ -147,6 +147,15 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
# ----------------
# PRIVATE METHODS
# ----------------

def _default_token_expiry_date(self) -> str:
"""
Returns the default token expiry date
"""
if self.token_expiry_is_time_of_expiration:
return str(ab_datetime_now() + timedelta(hours=1))
else:
return "3600"

def _wrap_refresh_token_exception(
self, exception: requests.exceptions.RequestException
Expand Down Expand Up @@ -316,9 +325,15 @@ def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any:
response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.

Returns:
str: The extracted token_expiry_date.
The extracted token_expiry_date or None if not found.
"""
return self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
expires_in = self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
# If the access token expires in is None, we do not know when the token will expire
# 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
if expires_in is None:
return self._default_token_expiry_date()
else:
return expires_in

def _find_and_get_value_from_response(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def get_access_token(self) -> str:
self._emit_control_message()
return self.access_token

def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override]
def refresh_access_token(self) -> Tuple[str, Optional[str], str]: # type: ignore[override]
"""
Refreshes the access token by making a handled request and extracting the necessary token information.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import json
import logging
from datetime import timedelta, timezone
Comment thread
dbgold17 marked this conversation as resolved.
from datetime import timedelta
from typing import Optional, Union
from unittest.mock import Mock
from unittest.mock import Mock, PropertyMock

import freezegun
import pytest
Expand Down Expand Up @@ -236,6 +236,7 @@ def test_refresh_request_body_with_keys_override(self):
}
assert body == expected

@freezegun.freeze_time("2022-01-01")
def test_refresh_access_token(self, mocker):
oauth = Oauth2Authenticator(
token_refresh_endpoint="https://refresh_endpoint.com",
Expand Down Expand Up @@ -281,6 +282,15 @@ def test_refresh_access_token(self, mocker):
assert isinstance(expires_in, str)
assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in)

# Test with no expires_in
mocker.patch.object(
resp,
"json",
return_value={"access_token": "access_token"},
)
token, expires_in = oauth.refresh_access_token()
assert expires_in == "3600"

# Test with nested access_token and expires_in as str(int)
mocker.patch.object(
resp,
Expand Down Expand Up @@ -393,8 +403,10 @@ def test_refresh_access_token_when_headers_provided(self, mocker):
"YYYY-MM-DDTHH:mm:ss.SSSSSSZ",
AirbyteDateTime(year=2022, month=2, day=12),
),
(None, None, AirbyteDateTime(year=2022, month=1, day=1, hour=1)),
(None, "YYYY-MM-DD", AirbyteDateTime(year=2022, month=1, day=1, hour=1)),
],
ids=["seconds", "string_of_seconds", "simple_date", "simple_datetime"],
ids=["seconds", "string_of_seconds", "simple_date", "simple_datetime", "default_behavior", "default_behavior_with_format"],
)
@freezegun.freeze_time("2022-01-01")
def test_parse_refresh_token_lifespan(
Expand Down
Loading