Skip to content

Commit 16de77d

Browse files
authored
Fix the token expiration logic in SR Oauth (#2177)
* update * fix and update tests
1 parent 0758024 commit 16de77d

4 files changed

Lines changed: 16 additions & 12 deletions

File tree

src/confluent_kafka/schema_registry/_async/schema_registry_client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,8 @@ def token_expired(self) -> bool:
122122
if self.token is None:
123123
raise ValueError("Token is not set")
124124

125-
expiry_window = self.token['expires_in'] * self.token_expiry_threshold
126-
127-
return self.token['expires_at'] < time.time() + expiry_window
125+
refresh_buffer = self.token['expires_in'] * (1 - self.token_expiry_threshold)
126+
return self.token['expires_at'] < time.time() + refresh_buffer
128127

129128
async def get_access_token(self) -> str:
130129
if not self.token or self.token_expired():

src/confluent_kafka/schema_registry/_sync/schema_registry_client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,8 @@ def token_expired(self) -> bool:
122122
if self.token is None:
123123
raise ValueError("Token is not set")
124124

125-
expiry_window = self.token['expires_in'] * self.token_expiry_threshold
126-
127-
return self.token['expires_at'] < time.time() + expiry_window
125+
refresh_buffer = self.token['expires_in'] * (1 - self.token_expiry_threshold)
126+
return self.token['expires_at'] < time.time() + refresh_buffer
128127

129128
def get_access_token(self) -> str:
130129
if not self.token or self.token_expired():

tests/schema_registry/_async/test_bearer_field_provider.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ async def custom_oauth_function(config: dict) -> dict:
5252

5353
def test_expiry():
5454
oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000)
55-
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1}
55+
# Use consistent test data: expires_at and expires_in should match
56+
# Token expires in 2 seconds, with 0.8 threshold, should refresh after 1.6 seconds (when 0.4s remaining)
57+
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 2}
5658
assert not oauth_client.token_expired()
57-
time.sleep(1.5)
59+
time.sleep(1.7) # After 1.7 seconds, only 0.3s remaining (< 0.4s threshold), should be expired
5860
assert oauth_client.token_expired()
5961

6062

@@ -65,7 +67,8 @@ def update_token1():
6567
oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'}
6668

6769
def update_token2():
68-
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1, 'access_token': '1234'}
70+
# Use consistent test data: expires_at and expires_in should match
71+
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 2, 'access_token': '1234'}
6972

7073
oauth_client.generate_access_token = AsyncMock(side_effect=update_token1)
7174
await oauth_client.get_access_token()

tests/schema_registry/_sync/test_bearer_field_provider.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ def custom_oauth_function(config: dict) -> dict:
5252

5353
def test_expiry():
5454
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000)
55-
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1}
55+
# Use consistent test data: expires_at and expires_in should match
56+
# Token expires in 2 seconds, with 0.8 threshold, should refresh after 1.6 seconds (when 0.4s remaining)
57+
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 2}
5658
assert not oauth_client.token_expired()
57-
time.sleep(1.5)
59+
time.sleep(1.7) # After 1.7 seconds, only 0.3s remaining (< 0.4s threshold), should be expired
5860
assert oauth_client.token_expired()
5961

6062

@@ -65,7 +67,8 @@ def update_token1():
6567
oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'}
6668

6769
def update_token2():
68-
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1, 'access_token': '1234'}
70+
# Use consistent test data: expires_at and expires_in should match
71+
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 2, 'access_token': '1234'}
6972

7073
oauth_client.generate_access_token = Mock(side_effect=update_token1)
7174
oauth_client.get_access_token()

0 commit comments

Comments
 (0)