Skip to content

Commit 3b31489

Browse files
authored
fix: add token expiry buffer to prevent expired token usage and lock concurrent refreshes (#283)
* fix: add token expiry buffer to prevent expired token usage * feat: add lock to prevent concurrent token refreshes * fix: ruff check * fix: address comment * fix: use atomic token state to eliminate torn-read race condition Replace three separate mutable fields with a single frozen dataclass assigned atomically, ensuring concurrent threads always see a complete token state snapshot. * fix: ruff? * feat: restructure and extract common class and func
1 parent f07861e commit 3b31489

5 files changed

Lines changed: 304 additions & 97 deletions

File tree

openfga_sdk/oauth2.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,29 @@
11
import asyncio
22
import json
3-
import math
43
import random
5-
import sys
64

75
from datetime import datetime, timedelta
86

97
import urllib3
108

119
from openfga_sdk.configuration import Configuration
12-
from openfga_sdk.constants import USER_AGENT
10+
from openfga_sdk.constants import (
11+
TOKEN_EXPIRY_JITTER_IN_SEC,
12+
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC,
13+
USER_AGENT,
14+
)
1315
from openfga_sdk.credentials import Credentials
1416
from openfga_sdk.exceptions import AuthenticationError
17+
from openfga_sdk.oauth2_common import _TokenState, jitter
1518
from openfga_sdk.telemetry.attributes import TelemetryAttributes
1619
from openfga_sdk.telemetry.telemetry import Telemetry
1720

1821

19-
def jitter(loop_count, min_wait_in_ms):
20-
"""
21-
Generate a random jitter value for exponential backoff
22-
"""
23-
minimum = math.ceil(2**loop_count * min_wait_in_ms)
24-
maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms)
25-
jitter = random.randrange(minimum, maximum) / 1000
26-
27-
# If running in pytest, set jitter to 0 to speed up tests
28-
if "pytest" in sys.modules:
29-
jitter = 0
30-
31-
return jitter
32-
33-
3422
class OAuth2Client:
3523
def __init__(self, credentials: Credentials, configuration=None):
3624
self._credentials = credentials
37-
self._access_token = None
38-
self._access_expiry_time = None
25+
self._token_state: _TokenState | None = None
26+
self._lock = asyncio.Lock()
3927
self._telemetry = Telemetry()
4028

4129
if configuration is None:
@@ -45,13 +33,13 @@ def __init__(self, credentials: Credentials, configuration=None):
4533

4634
def _token_valid(self):
4735
"""
48-
Return whether token is valid
36+
Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens)
4937
"""
50-
if self._access_token is None or self._access_expiry_time is None:
51-
return False
52-
if self._access_expiry_time < datetime.now():
38+
state = self._token_state # atomic snapshot — either old or new, never torn
39+
if state is None:
5340
return False
54-
return True
41+
remaining = (state.expiry_time - datetime.now()).total_seconds()
42+
return remaining > state.expiry_buffer
5543

5644
async def _obtain_token(self, client):
5745
"""
@@ -76,7 +64,9 @@ async def _obtain_token(self, client):
7664
# Add scope parameter if scopes are configured
7765
if configuration.scopes is not None:
7866
if isinstance(configuration.scopes, list):
79-
scope_str = " ".join(s.strip() for s in configuration.scopes if s and s.strip())
67+
scope_str = " ".join(
68+
s.strip() for s in configuration.scopes if s and s.strip()
69+
)
8070
else:
8171
scope_str = (
8272
configuration.scopes.strip()
@@ -136,10 +126,15 @@ async def _obtain_token(self, client):
136126
raise AuthenticationError(http_resp=raw_response)
137127

138128
if api_response.get("expires_in") and api_response.get("access_token"):
139-
self._access_expiry_time = datetime.now() + timedelta(
140-
seconds=int(api_response.get("expires_in"))
129+
self._token_state = _TokenState(
130+
access_token=api_response.get("access_token"),
131+
expiry_time=datetime.now()
132+
+ timedelta(seconds=int(api_response.get("expires_in"))),
133+
expiry_buffer=(
134+
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC
135+
+ random.random() * TOKEN_EXPIRY_JITTER_IN_SEC
136+
),
141137
)
142-
self._access_token = api_response.get("access_token")
143138
self._telemetry.metrics.credentialsRequest(
144139
attributes={
145140
TelemetryAttributes.fga_client_request_client_id: configuration.client_id
@@ -154,8 +149,8 @@ async def get_authentication_header(self, client):
154149
"""
155150
If configured, return the header for authentication
156151
"""
157-
# check to see token is valid
158152
if not self._token_valid():
159-
# In this case, the token is not valid, we need to get the refresh the token
160-
await self._obtain_token(client)
161-
return {"Authorization": f"Bearer {self._access_token}"}
153+
async with self._lock:
154+
if not self._token_valid():
155+
await self._obtain_token(client)
156+
return {"Authorization": f"Bearer {self._token_state.access_token}"}

openfga_sdk/oauth2_common.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import math
2+
import random
3+
import sys
4+
5+
from dataclasses import dataclass
6+
from datetime import datetime
7+
8+
9+
@dataclass(frozen=True)
10+
class _TokenState:
11+
access_token: str
12+
expiry_time: datetime
13+
expiry_buffer: float
14+
15+
16+
def jitter(loop_count, min_wait_in_ms):
17+
"""
18+
Generate a random jitter value for exponential backoff
19+
"""
20+
minimum = math.ceil(2**loop_count * min_wait_in_ms)
21+
maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms)
22+
jitter = random.randrange(minimum, maximum) / 1000
23+
24+
# If running in pytest, set jitter to 0 to speed up tests
25+
if "pytest" in sys.modules:
26+
jitter = 0
27+
28+
return jitter

openfga_sdk/sync/oauth2.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,30 @@
11
import json
2-
import math
32
import random
4-
import sys
3+
import threading
54
import time
65

76
from datetime import datetime, timedelta
87

98
import urllib3
109

1110
from openfga_sdk.configuration import Configuration
12-
from openfga_sdk.constants import USER_AGENT
11+
from openfga_sdk.constants import (
12+
TOKEN_EXPIRY_JITTER_IN_SEC,
13+
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC,
14+
USER_AGENT,
15+
)
1316
from openfga_sdk.credentials import Credentials
1417
from openfga_sdk.exceptions import AuthenticationError
18+
from openfga_sdk.oauth2_common import _TokenState, jitter
1519
from openfga_sdk.telemetry.attributes import TelemetryAttributes
1620
from openfga_sdk.telemetry.telemetry import Telemetry
1721

1822

19-
def jitter(loop_count, min_wait_in_ms):
20-
"""
21-
Generate a random jitter value for exponential backoff
22-
"""
23-
minimum = math.ceil(2**loop_count * min_wait_in_ms)
24-
maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms)
25-
jitter = random.randrange(minimum, maximum) / 1000
26-
27-
# If running in pytest, set jitter to 0 to speed up tests
28-
if "pytest" in sys.modules:
29-
jitter = 0
30-
31-
return jitter
32-
33-
3423
class OAuth2Client:
3524
def __init__(self, credentials: Credentials, configuration=None):
3625
self._credentials = credentials
37-
self._access_token = None
38-
self._access_expiry_time = None
26+
self._token_state: _TokenState | None = None
27+
self._lock = threading.Lock()
3928
self._telemetry = Telemetry()
4029

4130
if configuration is None:
@@ -45,13 +34,13 @@ def __init__(self, credentials: Credentials, configuration=None):
4534

4635
def _token_valid(self):
4736
"""
48-
Return whether token is valid
37+
Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens)
4938
"""
50-
if self._access_token is None or self._access_expiry_time is None:
51-
return False
52-
if self._access_expiry_time < datetime.now():
39+
state = self._token_state # atomic snapshot — either old or new, never torn
40+
if state is None:
5341
return False
54-
return True
42+
remaining = (state.expiry_time - datetime.now()).total_seconds()
43+
return remaining > state.expiry_buffer
5544

5645
def _obtain_token(self, client):
5746
"""
@@ -76,7 +65,9 @@ def _obtain_token(self, client):
7665
# Add scope parameter if scopes are configured
7766
if configuration.scopes is not None:
7867
if isinstance(configuration.scopes, list):
79-
scope_str = " ".join(s.strip() for s in configuration.scopes if s and s.strip())
68+
scope_str = " ".join(
69+
s.strip() for s in configuration.scopes if s and s.strip()
70+
)
8071
else:
8172
scope_str = (
8273
configuration.scopes.strip()
@@ -136,10 +127,15 @@ def _obtain_token(self, client):
136127
raise AuthenticationError(http_resp=raw_response)
137128

138129
if api_response.get("expires_in") and api_response.get("access_token"):
139-
self._access_expiry_time = datetime.now() + timedelta(
140-
seconds=int(api_response.get("expires_in"))
130+
self._token_state = _TokenState(
131+
access_token=api_response.get("access_token"),
132+
expiry_time=datetime.now()
133+
+ timedelta(seconds=int(api_response.get("expires_in"))),
134+
expiry_buffer=(
135+
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC
136+
+ random.random() * TOKEN_EXPIRY_JITTER_IN_SEC
137+
),
141138
)
142-
self._access_token = api_response.get("access_token")
143139
self._telemetry.metrics.credentialsRequest(
144140
attributes={
145141
TelemetryAttributes.fga_client_request_client_id: configuration.client_id
@@ -154,8 +150,8 @@ def get_authentication_header(self, client):
154150
"""
155151
If configured, return the header for authentication
156152
"""
157-
# check to see token is valid
158153
if not self._token_valid():
159-
# In this case, the token is not valid, we need to get the refresh the token
160-
self._obtain_token(client)
161-
return {"Authorization": f"Bearer {self._access_token}"}
154+
with self._lock:
155+
if not self._token_valid():
156+
self._obtain_token(client)
157+
return {"Authorization": f"Bearer {self._token_state.access_token}"}

0 commit comments

Comments
 (0)