Skip to content

Commit 15542de

Browse files
fix(oauth): add thread-safe token refresh with double-checked locking (#883)
Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: gl_anatolii.yatsuk@airbyte.io <gl_anatolii.yatsuk@airbyte.io>
1 parent d3c9419 commit 15542de

File tree

3 files changed

+160
-3
lines changed

3 files changed

+160
-3
lines changed

airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44

55
import logging
6+
import threading
67
from abc import abstractmethod
78
from datetime import timedelta
89
from json import JSONDecodeError
@@ -41,6 +42,13 @@ class AbstractOauth2Authenticator(AuthBase):
4142

4243
_NO_STREAM_NAME = None
4344

45+
# Class-level lock to prevent concurrent token refresh across multiple authenticator instances.
46+
# This is necessary because multiple streams may share the same OAuth credentials (refresh token)
47+
# through the connector config. Without this lock, concurrent refresh attempts can cause race
48+
# conditions where one stream successfully refreshes the token while others fail because the
49+
# refresh token has been invalidated (especially for single-use refresh tokens).
50+
_token_refresh_lock: threading.Lock = threading.Lock()
51+
4452
def __init__(
4553
self,
4654
refresh_token_error_status_codes: Tuple[int, ...] = (),
@@ -86,9 +94,19 @@ def get_auth_header(self) -> Mapping[str, Any]:
8694
return {"Authorization": f"Bearer {token}"}
8795

8896
def get_access_token(self) -> str:
89-
"""Returns the access token"""
97+
"""
98+
Returns the access token.
99+
100+
This method uses double-checked locking to ensure thread-safe token refresh.
101+
When multiple threads (streams) detect an expired token simultaneously, only one
102+
will perform the refresh while others wait. After acquiring the lock, the token
103+
expiry is re-checked to avoid redundant refresh attempts.
104+
"""
90105
if self.token_has_expired():
91-
self.refresh_and_set_access_token()
106+
with self._token_refresh_lock:
107+
# Double-check after acquiring lock - another thread may have already refreshed
108+
if self.token_has_expired():
109+
self.refresh_and_set_access_token()
92110

93111
return self.access_token
94112

airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,13 +319,21 @@ def token_has_expired(self) -> bool:
319319
def get_access_token(self) -> str:
320320
"""Retrieve new access and refresh token if the access token has expired.
321321
322+
This method uses double-checked locking to ensure thread-safe token refresh.
323+
This is especially critical for single-use refresh tokens where concurrent
324+
refresh attempts would cause failures as the refresh token is invalidated
325+
after first use.
326+
322327
The new refresh token is persisted with the set_refresh_token function.
323328
324329
Returns:
325330
str: The current access_token, updated if it was previously expired.
326331
"""
327332
if self.token_has_expired():
328-
self.refresh_and_set_access_token()
333+
with self._token_refresh_lock:
334+
# Double-check after acquiring lock - another thread may have already refreshed
335+
if self.token_has_expired():
336+
self.refresh_and_set_access_token()
329337
return self.access_token
330338

331339
def refresh_and_set_access_token(self) -> None:

unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import json
66
import logging
7+
import threading
8+
import time
79
from datetime import timedelta
810
from typing import Optional, Union
911
from unittest.mock import Mock
@@ -785,3 +787,132 @@ def mock_request(method, url, data, headers):
785787
raise Exception(
786788
f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}"
787789
)
790+
791+
792+
class TestConcurrentTokenRefresh:
793+
"""
794+
Test class for verifying thread-safe token refresh behavior.
795+
796+
These tests ensure that when multiple threads (streams) attempt to refresh
797+
an expired token simultaneously, only one refresh actually occurs and
798+
others wait and use the refreshed token.
799+
"""
800+
801+
def test_concurrent_token_refresh_only_refreshes_once(self, mocker):
802+
"""
803+
When multiple threads detect an expired token and try to refresh simultaneously,
804+
only one thread should actually perform the refresh. Others should wait and
805+
use the newly refreshed token.
806+
"""
807+
refresh_call_count = 0
808+
refresh_call_lock = threading.Lock()
809+
810+
def mock_refresh_access_token(self):
811+
nonlocal refresh_call_count
812+
with refresh_call_lock:
813+
refresh_call_count += 1
814+
time.sleep(0.1)
815+
return ("new_access_token", ab_datetime_now() + timedelta(hours=1))
816+
817+
mocker.patch.object(
818+
Oauth2Authenticator,
819+
"refresh_access_token",
820+
mock_refresh_access_token,
821+
)
822+
823+
oauth = Oauth2Authenticator(
824+
token_refresh_endpoint="https://refresh_endpoint.com",
825+
client_id="client_id",
826+
client_secret="client_secret",
827+
refresh_token="refresh_token",
828+
token_expiry_date=ab_datetime_now() - timedelta(hours=1),
829+
)
830+
831+
results = []
832+
errors = []
833+
834+
def get_token():
835+
try:
836+
token = oauth.get_access_token()
837+
results.append(token)
838+
except Exception as e:
839+
errors.append(e)
840+
841+
threads = [threading.Thread(target=get_token) for _ in range(5)]
842+
for t in threads:
843+
t.start()
844+
for t in threads:
845+
t.join()
846+
847+
assert len(errors) == 0, f"Unexpected errors: {errors}"
848+
assert len(results) == 5
849+
assert all(token == "new_access_token" for token in results)
850+
assert refresh_call_count == 1, f"Expected 1 refresh call, got {refresh_call_count}"
851+
852+
def test_single_use_refresh_token_concurrent_refresh_only_refreshes_once(self, mocker):
853+
"""
854+
For SingleUseRefreshTokenOauth2Authenticator, concurrent refresh attempts
855+
should also only result in one actual refresh to prevent invalidating
856+
the single-use refresh token.
857+
"""
858+
refresh_call_count = 0
859+
refresh_call_lock = threading.Lock()
860+
861+
connector_config = {
862+
"credentials": {
863+
"client_id": "client_id",
864+
"client_secret": "client_secret",
865+
"refresh_token": "refresh_token",
866+
"access_token": "old_access_token",
867+
"token_expiry_date": str(ab_datetime_now() - timedelta(hours=1)),
868+
}
869+
}
870+
871+
def mock_refresh_access_token(self):
872+
nonlocal refresh_call_count
873+
with refresh_call_lock:
874+
refresh_call_count += 1
875+
time.sleep(0.1)
876+
return (
877+
"new_access_token",
878+
ab_datetime_now() + timedelta(hours=1),
879+
"new_refresh_token",
880+
)
881+
882+
mocker.patch.object(
883+
SingleUseRefreshTokenOauth2Authenticator,
884+
"refresh_access_token",
885+
mock_refresh_access_token,
886+
)
887+
888+
mocker.patch.object(
889+
SingleUseRefreshTokenOauth2Authenticator,
890+
"_emit_control_message",
891+
lambda self: None,
892+
)
893+
894+
oauth = SingleUseRefreshTokenOauth2Authenticator(
895+
connector_config=connector_config,
896+
token_refresh_endpoint="https://refresh_endpoint.com",
897+
)
898+
899+
results = []
900+
errors = []
901+
902+
def get_token():
903+
try:
904+
token = oauth.get_access_token()
905+
results.append(token)
906+
except Exception as e:
907+
errors.append(e)
908+
909+
threads = [threading.Thread(target=get_token) for _ in range(5)]
910+
for t in threads:
911+
t.start()
912+
for t in threads:
913+
t.join()
914+
915+
assert len(errors) == 0, f"Unexpected errors: {errors}"
916+
assert len(results) == 5
917+
assert all(token == "new_access_token" for token in results)
918+
assert refresh_call_count == 1, f"Expected 1 refresh call, got {refresh_call_count}"

0 commit comments

Comments
 (0)