Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
VSCODE_CREDENTIALS_SECTION = "VS Code Azure"
DEFAULT_REFRESH_OFFSET = 300
DEFAULT_TOKEN_REFRESH_RETRY_DELAY = 30
DEFAULT_TOKEN_LOCK_TIMEOUT = 10
DEFAULT_TOKEN_LOCK_TIMEOUT_VARIANCE = 0.25

CACHE_NON_CAE_SUFFIX = ".nocae" # cspell:disable-line
CACHE_CAE_SUFFIX = ".cae"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,29 @@
# Licensed under the MIT License.
# ------------------------------------
import abc
import asyncio # pylint: disable=do-not-import-asyncio
import logging
import random
import threading
import time
from typing import Any, Optional
import weakref
from typing import Any, Dict, Optional, Tuple

from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
from ..._internal import within_credential_chain, get_refresh_status
from ..._enums import TokenRefreshStatus
from ..._constants import DEFAULT_TOKEN_LOCK_TIMEOUT, DEFAULT_TOKEN_LOCK_TIMEOUT_VARIANCE

_LOGGER = logging.getLogger(__name__)


class GetTokenMixin(abc.ABC):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self._last_request_time = 0
# Per-loop, per-scope locks. Weak references ensure cleanup on loop/lock GC.
self._locks: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, weakref.WeakValueDictionary[Tuple, asyncio.Lock]
] = weakref.WeakKeyDictionary()

# https://github.com/python/mypy/issues/5887
super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore
Expand Down Expand Up @@ -45,6 +54,29 @@ async def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo:
:rtype: ~azure.core.credentials.AccessTokenInfo
"""

def _get_request_lock(
self,
scopes: Tuple[str, ...],
claims: Optional[str],
tenant_id: Optional[str],
enable_cae: bool,
) -> Optional[asyncio.Lock]:
# Only use locking in asyncio contexts. If we can't get a running loop
# (e.g., trio), fall through to existing behavior without locking.
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return None
key = (scopes, claims, tenant_id, enable_cae)
if loop not in self._locks:
self._locks[loop] = weakref.WeakValueDictionary()
loop_locks = self._locks[loop]
lock = loop_locks.get(key)
if lock is None:
lock = asyncio.Lock()
loop_locks[key] = lock
return lock

async def get_token(
self,
*scopes: str,
Expand Down Expand Up @@ -118,6 +150,38 @@ async def _get_token_base(
tenant_id = options.get("tenant_id")
enable_cae = options.get("enable_cae", False)

# First check the cache without acquiring the lock.
token = await self._acquire_token_silently(
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
)
refresh_status = get_refresh_status(token, self._last_request_time)
if refresh_status == TokenRefreshStatus.NOT_NEEDED:
_LOGGER.log(
logging.DEBUG if within_credential_chain.get() else logging.INFO,
"%s.%s succeeded",
self.__class__.__name__,
base_method_name,
)
return token # type: ignore[return-value]

# A refresh is needed — acquire the per-scope lock to prevent duplicate network calls.
lock = self._get_request_lock(tuple(sorted(scopes)), claims, tenant_id, enable_cae)
lock_acquired = False

if lock is not None:
jitter = DEFAULT_TOKEN_LOCK_TIMEOUT * DEFAULT_TOKEN_LOCK_TIMEOUT_VARIANCE
timeout = max(0.0, random.uniform(DEFAULT_TOKEN_LOCK_TIMEOUT - jitter, DEFAULT_TOKEN_LOCK_TIMEOUT + jitter))
try:
await asyncio.wait_for(lock.acquire(), timeout=timeout)
lock_acquired = True
except asyncio.TimeoutError:
_LOGGER.warning(
"%s.%s lock acquisition timed out after %s seconds; proceeding with token request",
self.__class__.__name__,
base_method_name,
timeout,
)

try:
token = await self._acquire_token_silently(
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
Expand Down Expand Up @@ -154,3 +218,17 @@ async def _get_token_base(
exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
)
raise

finally:
if lock is not None and lock_acquired:
lock.release()

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
# asyncio.Lock and threading.Lock are not picklable; exclude them.
state.pop("_locks", None)
return state

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state) # type: ignore
self._locks = weakref.WeakKeyDictionary()
Loading
Loading