Skip to content

Commit 7c76170

Browse files
authored
fix(credentials): token cache reuse issue for OAuthDeviceCode (DM-3772) (#2605)
1 parent f95510f commit 7c76170

2 files changed

Lines changed: 98 additions & 95 deletions

File tree

cognite/client/credentials.py

Lines changed: 72 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import atexit
44
import inspect
55
import json
6+
import operator
67
import tempfile
78
import threading
89
import time
@@ -463,72 +464,78 @@ def _get_token(self, convert_timestamps: bool = True) -> dict[str, Any]:
463464
return token
464465

465466
def _refresh_access_token(self) -> tuple[str, float]:
466-
# First check if a token cache exists on disk. If yes, find and use:
467-
# - A valid access token.
468-
# - A valid refresh token, and if so, use it automatically to redeem a new access token.
467+
# Token resolution order (cheapest option first):
468+
# 1. Valid access token in cache → use directly, no network call
469+
# 2. Refresh token in cache → exchange for new AT, one network call
470+
# 3. Device code flow → interactive, requires user action
469471
credentials = None
470-
for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.REFRESH_TOKEN):
471-
if "expires_on" in token and token["expires_on"] > time.time():
472-
credentials = token
472+
473+
# 1. Check for a still-valid access token. search() does NOT filter by expiry,
474+
# so we check manually and respect the leeway to avoid handing out a near-expired token.
475+
for token in self.__app.token_cache.search(
476+
self.__app.token_cache.CredentialType.ACCESS_TOKEN,
477+
query={"client_id": self.client_id},
478+
):
479+
expiry = int(token.get("expires_on", 0)) - time.time() - self.token_expiry_leeway_seconds
480+
if expiry > 0:
481+
credentials = {"access_token": token["secret"], "expires_in": expiry}
473482
break
483+
484+
# 2. No valid AT — try to silently redeem a refresh token.
485+
if credentials is None:
486+
rt_entry = None
487+
for token in self.__app.token_cache.search(
488+
self.__app.token_cache.CredentialType.REFRESH_TOKEN,
489+
query={"client_id": self.client_id},
490+
):
491+
rt_entry = token
492+
break # MSAL RTs have no 'expires_on'; use the first found
493+
494+
if rt_entry is not None:
495+
# Pass the full RT cache entry (not just the secret string) so MSAL's
496+
# on_removing_rt callback can properly remove it on invalid_grant.
497+
# Exclude OIDC meta-scopes that are not valid in token-endpoint requests.
498+
oidc_scopes = frozenset({"openid", "profile", "email", "offline_access"})
499+
resp = self.__app.client.obtain_token_by_refresh_token(
500+
rt_entry,
501+
rt_getter=operator.itemgetter("secret"),
502+
scope=" ".join(s for s in self.__scopes if s not in oidc_scopes),
503+
)
504+
if isinstance(resp, dict) and "error" not in resp:
505+
credentials = resp
506+
# else: RT rejected by server, fall through to device code flow
507+
474508
if credentials is not None:
475-
credentials = self.__app.client.obtain_token_by_refresh_token(credentials.get("secret", ""))
476-
else:
477-
for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.ACCESS_TOKEN):
478-
if expiry := int(token.get("expires_on", 0)) - time.time() > 0:
479-
credentials = {
480-
"access_token": token.get("secret"),
481-
"expires_in": expiry,
482-
}
483-
break
484-
# If we're unable to find (or acquire a new) access token, we initiate the device code auth flow.
509+
self._verify_credentials(credentials)
510+
return credentials["access_token"], time.time() + float(credentials["expires_in"])
511+
512+
# 3. If we're unable to find (or acquire a new) access token, we initiate the device code auth flow.
485513
# The msal device_code flow does not support setting the audience, so we need to handle it manually.
486514
# We use the http client instantiated as part of the msal client, as well as the details found
487515
# in oauth discovery.
488-
if credentials is None:
489-
data = {
490-
"scope": self.scope_string(),
491-
"client_id": self.client_id,
492-
}
493-
for key, value in self.__token_custom_args.items():
494-
data[key] = value
495-
496-
device_flow_endpoint = self._get_device_authorization_endpoint()
497-
device_flow_response = self._get_device_code_response(device_flow_endpoint, data)
498-
if "verification_uri" in device_flow_response:
499-
print( # noqa: T201
500-
f"Visit {device_flow_response['verification_uri']} and enter the code: {device_flow_response.get('user_code', 'ERROR')}"
501-
)
502-
elif "message" in device_flow_response:
503-
print( # noqa: T201
504-
f"Device code: {device_flow_response.get('message', device_flow_response.get('user_code', 'ERROR'))}"
505-
)
506-
else:
507-
raise CogniteOAuthError(
508-
device_flow_response.get("error", ""), device_flow_response.get("error_description", "")
509-
)
510-
511-
if "interval" not in device_flow_response:
512-
# Set default interval according to standard
513-
device_flow_response["interval"] = 5
514-
if "expires_in" in device_flow_response:
515-
# msal library uses expires_at instead of the standard expires_in
516-
device_flow_response["expires_at"] = float(device_flow_response["expires_in"]) + time.time()
517-
# Poll for token
518-
credentials = self.__app.client.obtain_token_by_device_flow(
519-
flow=device_flow_response,
520-
data=dict(
521-
data,
522-
code=device_flow_response.get(
523-
"device_code"
524-
), # Hack from msal library to get the code from the device flow, not standard
525-
),
516+
data = {"scope": self.scope_string(), "client_id": self.client_id, **self.__token_custom_args}
517+
device_flow_endpoint = self._get_device_authorization_endpoint()
518+
response = self._get_device_code_response(device_flow_endpoint, data)
519+
if "verification_uri" in response:
520+
print(f"Visit {response['verification_uri']} and enter the code: {response.get('user_code', 'ERROR')}") # noqa: T201
521+
elif "message" in response:
522+
print(f"Device code: {response.get('message', response.get('user_code', 'ERROR'))}") # noqa: T201
523+
else:
524+
raise CogniteAuthError(
525+
f"Error initiating device flow: {response.get('error')} - {response.get('error_description')}"
526526
)
527-
528-
self._verify_credentials(credentials)
529-
self.__app.token_cache.add(
530-
dict(credentials, environment=self.__app.authority.instance),
527+
if "interval" not in response:
528+
response["interval"] = 5 # Set default interval according to standard
529+
if "expires_in" in response:
530+
# msal library uses expires_at instead of the standard expires_in
531+
response["expires_at"] = float(response["expires_in"]) + time.time()
532+
533+
credentials = self.__app.client.obtain_token_by_device_flow(
534+
flow=response,
535+
# Hack from msal library to get the code from the device flow, not standard:
536+
data=dict(data, code=response.get("device_code")),
531537
)
538+
self._verify_credentials(credentials)
532539
return credentials["access_token"], time.time() + float(credentials["expires_in"])
533540

534541
@classmethod
@@ -580,18 +587,16 @@ def default_for_entra_id(
580587
mem_cache_only: bool = False,
581588
) -> OAuthDeviceCode:
582589
"""
583-
Create an OAuthDeviceCode instance for Azure with default URLs and scopes. It uses the pre-configured Cognite
584-
app registration for device code flow. If you need device code flow with another app registration, instantiate
585-
OAuthDeviceCode directly.
590+
Create an OAuthDeviceCode instance for Azure with default URLs and scopes.
586591
587592
The default configuration creates the URLs based on the tenant id and cluster:
588593
589594
* Authority URL: ``https://login.microsoftonline.com/{tenant_id}``
590-
* Scopes: [``https://{cdf_cluster}.cognitedata.com/.default``]
595+
* Scopes: [``https://{cdf_cluster}.cognitedata.com/IDENTITY``, ``https://{cdf_cluster}.cognitedata.com/user_impersonation``, ``profile``, ``openid``, ``offline_access``]
591596
592597
Args:
593598
tenant_id (str): The Azure tenant id
594-
client_id (str): An app registration that allows device code flow.
599+
client_id (str): Your app registration client id. Must have device code flow enabled.
595600
cdf_cluster (str): The CDF cluster where the CDF project is located.
596601
token_cache_path (Path | None): Location to store token cache, defaults to os temp directory/cognitetokencache.{client_id}.bin.
597602
token_expiry_leeway_seconds (int): The token is refreshed at the earliest when this number of seconds is left before expiry. Default: 30 sec
@@ -602,18 +607,18 @@ def default_for_entra_id(
602607
"""
603608
return cls(
604609
authority_url=f"https://login.microsoftonline.com/{tenant_id}",
605-
client_id=client_id, # Default application for CDF API for device code flow
610+
client_id=client_id,
606611
scopes=[
607612
f"https://{cdf_cluster}.cognitedata.com/IDENTITY",
608613
f"https://{cdf_cluster}.cognitedata.com/user_impersonation",
609614
"profile",
610615
"openid",
616+
"offline_access", # required for Azure to issue a refresh token
611617
],
612618
token_cache_path=token_cache_path,
613619
token_expiry_leeway_seconds=token_expiry_leeway_seconds,
614620
clear_cache=clear_cache,
615621
mem_cache_only=mem_cache_only,
616-
audience=f"https://{cdf_cluster}.cognitedata.com",
617622
)
618623

619624
@classmethod
@@ -697,9 +702,8 @@ def scopes(self) -> list[str]:
697702
return self.__scopes
698703

699704
def _refresh_access_token(self) -> tuple[str, float]:
700-
# First check if a token cache exists on disk. If yes, find and use:
701-
# - A valid access token.
702-
# - A valid refresh token, and if so, use it automatically to redeem a new access token.
705+
# Try the in-memory token cache silently (MSAL checks AT first, then RT automatically).
706+
# Falls through to interactive flow if nothing usable is found.
703707
credentials = None
704708
if accounts := self.__app.get_accounts():
705709
credentials = self.__app.acquire_token_silent(scopes=self.__scopes, account=accounts[0])

0 commit comments

Comments
 (0)