Skip to content

Commit dea01cc

Browse files
committed
feat: retry 5xx auth network requests
1 parent f337b94 commit dea01cc

File tree

9 files changed

+439
-47
lines changed

9 files changed

+439
-47
lines changed

openfga_sdk/api/open_fga_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT.
1111
"""
1212

13-
1413
from openfga_sdk.api_client import ApiClient
1514
from openfga_sdk.exceptions import ApiValueError, FgaValidationException
1615
from openfga_sdk.oauth2 import OAuth2Client

openfga_sdk/oauth2.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,47 @@
1010
NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT.
1111
"""
1212

13+
import asyncio
1314
import json
15+
import math
16+
import random
17+
import sys
1418
from datetime import datetime, timedelta
1519

1620
import urllib3
1721

22+
from openfga_sdk.configuration import Configuration
1823
from openfga_sdk.credentials import Credentials
1924
from openfga_sdk.exceptions import AuthenticationError
2025

2126

27+
def jitter(loop_count, min_wait_in_ms):
28+
"""
29+
Generate a random jitter value for exponential backoff
30+
"""
31+
minimum = math.ceil(2**loop_count * min_wait_in_ms)
32+
maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms)
33+
jitter = random.randrange(minimum, maximum) / 1000
34+
35+
# If running in pytest, set jitter to 0 to speed up tests
36+
if "pytest" in sys.modules:
37+
jitter = 0
38+
39+
return jitter
40+
41+
2242
class OAuth2Client:
2343

24-
def __init__(self, credentials: Credentials):
44+
def __init__(self, credentials: Credentials, configuration=None):
2545
self._credentials = credentials
2646
self._access_token = None
2747
self._access_expiry_time = None
2848

49+
if configuration is None:
50+
configuration = Configuration.get_default_copy()
51+
52+
self.configuration = configuration
53+
2954
def _token_valid(self):
3055
"""
3156
Return whether token is valid
@@ -41,37 +66,65 @@ async def _obtain_token(self, client):
4166
Perform OAuth2 and obtain token
4267
"""
4368
configuration = self._credentials.configuration
69+
4470
token_url = f"https://{configuration.api_issuer}/oauth/token"
71+
4572
post_params = {
4673
"client_id": configuration.client_id,
4774
"client_secret": configuration.client_secret,
4875
"audience": configuration.api_audience,
4976
"grant_type": "client_credentials",
5077
}
78+
5179
headers = urllib3.response.HTTPHeaderDict(
5280
{
5381
"Accept": "application/json",
5482
"Content-Type": "application/x-www-form-urlencoded",
5583
"User-Agent": "openfga-sdk (python) 0.4.1",
5684
}
5785
)
58-
raw_response = await client.POST(
59-
token_url, headers=headers, post_params=post_params
86+
87+
max_retry = (
88+
self.configuration.retry_params.max_retry
89+
if (
90+
self.configuration.retry_params is not None
91+
and self.configuration.retry_params.max_retry is not None
92+
)
93+
else 0
6094
)
61-
if 200 <= raw_response.status <= 299:
62-
try:
63-
api_response = json.loads(raw_response.data)
64-
except:
65-
raise AuthenticationError(http_resp=raw_response)
66-
if not api_response.get("expires_in") or not api_response.get(
67-
"access_token"
68-
):
69-
raise AuthenticationError(http_resp=raw_response)
70-
self._access_expiry_time = datetime.now() + timedelta(
71-
seconds=int(api_response.get("expires_in"))
95+
96+
min_wait_in_ms = (
97+
self.configuration.retry_params.min_wait_in_ms
98+
if (
99+
self.configuration.retry_params is not None
100+
and self.configuration.retry_params.min_wait_in_ms is not None
101+
)
102+
else 0
103+
)
104+
105+
for attempt in range(max_retry + 1):
106+
raw_response = await client.POST(
107+
token_url, headers=headers, post_params=post_params
72108
)
73-
self._access_token = api_response.get("access_token")
74-
else:
109+
110+
if 500 <= raw_response.status <= 599 or raw_response.status == 429:
111+
if attempt < max_retry and raw_response.status != 501:
112+
await asyncio.sleep(jitter(attempt, min_wait_in_ms))
113+
continue
114+
115+
if 200 <= raw_response.status <= 299:
116+
try:
117+
api_response = json.loads(raw_response.data)
118+
except:
119+
raise AuthenticationError(http_resp=raw_response)
120+
121+
if api_response.get("expires_in") and api_response.get("access_token"):
122+
self._access_expiry_time = datetime.now() + timedelta(
123+
seconds=int(api_response.get("expires_in"))
124+
)
125+
self._access_token = api_response.get("access_token")
126+
break
127+
75128
raise AuthenticationError(http_resp=raw_response)
76129

77130
async def get_authentication_header(self, client):

openfga_sdk/sync/oauth2.py

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,46 @@
1111
"""
1212

1313
import json
14+
import math
15+
import random
16+
import sys
17+
import time
1418
from datetime import datetime, timedelta
1519

1620
import urllib3
1721

22+
from openfga_sdk.configuration import Configuration
1823
from openfga_sdk.credentials import Credentials
1924
from openfga_sdk.exceptions import AuthenticationError
2025

2126

27+
def jitter(loop_count, min_wait_in_ms):
28+
"""
29+
Generate a random jitter value for exponential backoff
30+
"""
31+
minimum = math.ceil(2**loop_count * min_wait_in_ms)
32+
maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms)
33+
jitter = random.randrange(minimum, maximum) / 1000
34+
35+
# If running in pytest, set jitter to 0 to speed up tests
36+
if "pytest" in sys.modules:
37+
jitter = 0
38+
39+
return jitter
40+
41+
2242
class OAuth2Client:
2343

24-
def __init__(self, credentials: Credentials):
44+
def __init__(self, credentials: Credentials, configuration=None):
2545
self._credentials = credentials
2646
self._access_token = None
2747
self._access_expiry_time = None
2848

49+
if configuration is None:
50+
configuration = Configuration.get_default_copy()
51+
52+
self.configuration = configuration
53+
2954
def _token_valid(self):
3055
"""
3156
Return whether token is valid
@@ -41,35 +66,65 @@ def _obtain_token(self, client):
4166
Perform OAuth2 and obtain token
4267
"""
4368
configuration = self._credentials.configuration
69+
4470
token_url = f"https://{configuration.api_issuer}/oauth/token"
71+
4572
post_params = {
4673
"client_id": configuration.client_id,
4774
"client_secret": configuration.client_secret,
4875
"audience": configuration.api_audience,
4976
"grant_type": "client_credentials",
5077
}
78+
5179
headers = urllib3.response.HTTPHeaderDict(
5280
{
5381
"Accept": "application/json",
5482
"Content-Type": "application/x-www-form-urlencoded",
5583
"User-Agent": "openfga-sdk (python) 0.4.1",
5684
}
5785
)
58-
raw_response = client.POST(token_url, headers=headers, post_params=post_params)
59-
if 200 <= raw_response.status <= 299:
60-
try:
61-
api_response = json.loads(raw_response.data)
62-
except:
63-
raise AuthenticationError(http_resp=raw_response)
64-
if not api_response.get("expires_in") or not api_response.get(
65-
"access_token"
66-
):
67-
raise AuthenticationError(http_resp=raw_response)
68-
self._access_expiry_time = datetime.now() + timedelta(
69-
seconds=int(api_response.get("expires_in"))
86+
87+
max_retry = (
88+
self.configuration.retry_params.max_retry
89+
if (
90+
self.configuration.retry_params is not None
91+
and self.configuration.retry_params.max_retry is not None
92+
)
93+
else 0
94+
)
95+
96+
min_wait_in_ms = (
97+
self.configuration.retry_params.min_wait_in_ms
98+
if (
99+
self.configuration.retry_params is not None
100+
and self.configuration.retry_params.min_wait_in_ms is not None
101+
)
102+
else 0
103+
)
104+
105+
for attempt in range(max_retry + 1):
106+
raw_response = client.POST(
107+
token_url, headers=headers, post_params=post_params
70108
)
71-
self._access_token = api_response.get("access_token")
72-
else:
109+
110+
if 500 <= raw_response.status <= 599 or raw_response.status == 429:
111+
if attempt < max_retry and raw_response.status != 501:
112+
time.sleep(jitter(attempt, min_wait_in_ms))
113+
continue
114+
115+
if 200 <= raw_response.status <= 299:
116+
try:
117+
api_response = json.loads(raw_response.data)
118+
except:
119+
raise AuthenticationError(http_resp=raw_response)
120+
121+
if api_response.get("expires_in") and api_response.get("access_token"):
122+
self._access_expiry_time = datetime.now() + timedelta(
123+
seconds=int(api_response.get("expires_in"))
124+
)
125+
self._access_token = api_response.get("access_token")
126+
break
127+
73128
raise AuthenticationError(http_resp=raw_response)
74129

75130
def get_authentication_header(self, client):

test-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
mock >= 5.1.0, < 6
44
flake8 >= 7.0.0, < 8
5-
pytest-cov >= 4.1.0, < 5
5+
pytest-cov >= 5, < 6
66
griffe >= 0.41.2, < 1

test/test_credentials.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_configuration_client_credentials(self):
9898
configuration=CredentialConfiguration(
9999
client_id="myclientid",
100100
client_secret="mysecret",
101-
api_issuer="www.testme.com",
101+
api_issuer="issuer.fga.example",
102102
api_audience="myaudience",
103103
),
104104
)
@@ -121,7 +121,7 @@ def test_configuration_client_credentials_missing_client_id(self):
121121
method="client_credentials",
122122
configuration=CredentialConfiguration(
123123
client_secret="mysecret",
124-
api_issuer="www.testme.com",
124+
api_issuer="issuer.fga.example",
125125
api_audience="myaudience",
126126
),
127127
)
@@ -136,7 +136,7 @@ def test_configuration_client_credentials_missing_client_secret(self):
136136
method="client_credentials",
137137
configuration=CredentialConfiguration(
138138
client_id="myclientid",
139-
api_issuer="www.testme.com",
139+
api_issuer="issuer.fga.example",
140140
api_audience="myaudience",
141141
),
142142
)
@@ -167,7 +167,7 @@ def test_configuration_client_credentials_missing_api_audience(self):
167167
configuration=CredentialConfiguration(
168168
client_id="myclientid",
169169
client_secret="mysecret",
170-
api_issuer="www.testme.com",
170+
api_issuer="issuer.fga.example",
171171
),
172172
)
173173
with self.assertRaises(openfga_sdk.ApiValueError):

0 commit comments

Comments
 (0)