-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathaccess_token_manager.py
More file actions
90 lines (68 loc) · 3.76 KB
/
Copy pathaccess_token_manager.py
File metadata and controls
90 lines (68 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from datetime import datetime, timedelta, timezone
from threading import Lock
from azure.core.credentials import AccessToken, TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
import durabletask.internal.shared as shared
# By default, when there's 10minutes left before the token expires, refresh the token
class AccessTokenManager:
_token: AccessToken | None
def __init__(self, token_credential: TokenCredential | None, refresh_interval_seconds: int = 600):
self._scope = "https://durabletask.io/.default"
self._refresh_interval_seconds = refresh_interval_seconds
self._logger = shared.get_logger("token_manager")
self._credential = token_credential
self._refresh_lock = Lock()
if self._credential is not None:
self._token = self._credential.get_token(self._scope)
self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc)
else:
self._token = None
self.expiry_time = None
def get_access_token(self) -> AccessToken | None:
if self._token is None or self.is_token_expired():
with self._refresh_lock:
if self._token is None or self.is_token_expired():
self.refresh_token()
return self._token
# Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds.
# For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes,
# We will grab a new token when there're 30minutes left on the lifespan of the token
def is_token_expired(self) -> bool:
if self.expiry_time is None:
return True
return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self._refresh_interval_seconds))
def refresh_token(self):
if self._credential is not None:
self._token = self._credential.get_token(self._scope)
# Convert UNIX timestamp to timezone-aware datetime
self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc)
self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}")
class AsyncAccessTokenManager:
"""Async version of AccessTokenManager that uses AsyncTokenCredential.
This avoids blocking the event loop when acquiring or refreshing tokens."""
_token: AccessToken | None
def __init__(self, token_credential: AsyncTokenCredential | None,
refresh_interval_seconds: int = 600):
self._scope = "https://durabletask.io/.default"
self._refresh_interval_seconds = refresh_interval_seconds
self._logger = shared.get_logger("async_token_manager")
self._credential = token_credential
self._token = None
self.expiry_time = None
async def get_access_token(self) -> AccessToken | None:
if self._token is None or self.is_token_expired():
await self.refresh_token()
return self._token
def is_token_expired(self) -> bool:
if self.expiry_time is None:
return True
return datetime.now(timezone.utc) >= (
self.expiry_time - timedelta(seconds=self._refresh_interval_seconds))
async def refresh_token(self):
if self._credential is not None:
self._token = await self._credential.get_token(self._scope)
# Convert UNIX timestamp to timezone-aware datetime
self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc)
self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}")