33import atexit
44import inspect
55import json
6+ import operator
67import tempfile
78import threading
89import time
@@ -463,72 +464,78 @@ def _get_token(self, convert_timestamps: bool = True) -> dict[str, Any]:
463464 return token
464465
465466 def _refresh_access_token (self ) -> tuple [str , float ]:
466- # First check if a token cache exists on disk. If yes, find and use:
467- # - A valid access token.
468- # - A valid refresh token, and if so, use it automatically to redeem a new access token.
467+ # Token resolution order (cheapest option first):
468+ # 1. Valid access token in cache → use directly, no network call
469+ # 2. Refresh token in cache → exchange for new AT, one network call
470+ # 3. Device code flow → interactive, requires user action
469471 credentials = None
470- for token in self .__app .token_cache .search (self .__app .token_cache .CredentialType .REFRESH_TOKEN ):
471- if "expires_on" in token and token ["expires_on" ] > time .time ():
472- credentials = token
472+
473+ # 1. Check for a still-valid access token. search() does NOT filter by expiry,
474+ # so we check manually and respect the leeway to avoid handing out a near-expired token.
475+ for token in self .__app .token_cache .search (
476+ self .__app .token_cache .CredentialType .ACCESS_TOKEN ,
477+ query = {"client_id" : self .client_id },
478+ ):
479+ expiry = int (token .get ("expires_on" , 0 )) - time .time () - self .token_expiry_leeway_seconds
480+ if expiry > 0 :
481+ credentials = {"access_token" : token ["secret" ], "expires_in" : expiry }
473482 break
483+
484+ # 2. No valid AT — try to silently redeem a refresh token.
485+ if credentials is None :
486+ rt_entry = None
487+ for token in self .__app .token_cache .search (
488+ self .__app .token_cache .CredentialType .REFRESH_TOKEN ,
489+ query = {"client_id" : self .client_id },
490+ ):
491+ rt_entry = token
492+ break # MSAL RTs have no 'expires_on'; use the first found
493+
494+ if rt_entry is not None :
495+ # Pass the full RT cache entry (not just the secret string) so MSAL's
496+ # on_removing_rt callback can properly remove it on invalid_grant.
497+ # Exclude OIDC meta-scopes that are not valid in token-endpoint requests.
498+ oidc_scopes = frozenset ({"openid" , "profile" , "email" , "offline_access" })
499+ resp = self .__app .client .obtain_token_by_refresh_token (
500+ rt_entry ,
501+ rt_getter = operator .itemgetter ("secret" ),
502+ scope = " " .join (s for s in self .__scopes if s not in oidc_scopes ),
503+ )
504+ if isinstance (resp , dict ) and "error" not in resp :
505+ credentials = resp
506+ # else: RT rejected by server, fall through to device code flow
507+
474508 if credentials is not None :
475- credentials = self .__app .client .obtain_token_by_refresh_token (credentials .get ("secret" , "" ))
476- else :
477- for token in self .__app .token_cache .search (self .__app .token_cache .CredentialType .ACCESS_TOKEN ):
478- if expiry := int (token .get ("expires_on" , 0 )) - time .time () > 0 :
479- credentials = {
480- "access_token" : token .get ("secret" ),
481- "expires_in" : expiry ,
482- }
483- break
484- # If we're unable to find (or acquire a new) access token, we initiate the device code auth flow.
509+ self ._verify_credentials (credentials )
510+ return credentials ["access_token" ], time .time () + float (credentials ["expires_in" ])
511+
512+ # 3. If we're unable to find (or acquire a new) access token, we initiate the device code auth flow.
485513 # The msal device_code flow does not support setting the audience, so we need to handle it manually.
486514 # We use the http client instantiated as part of the msal client, as well as the details found
487515 # in oauth discovery.
488- if credentials is None :
489- data = {
490- "scope" : self .scope_string (),
491- "client_id" : self .client_id ,
492- }
493- for key , value in self .__token_custom_args .items ():
494- data [key ] = value
495-
496- device_flow_endpoint = self ._get_device_authorization_endpoint ()
497- device_flow_response = self ._get_device_code_response (device_flow_endpoint , data )
498- if "verification_uri" in device_flow_response :
499- print ( # noqa: T201
500- f"Visit { device_flow_response ['verification_uri' ]} and enter the code: { device_flow_response .get ('user_code' , 'ERROR' )} "
501- )
502- elif "message" in device_flow_response :
503- print ( # noqa: T201
504- f"Device code: { device_flow_response .get ('message' , device_flow_response .get ('user_code' , 'ERROR' ))} "
505- )
506- else :
507- raise CogniteOAuthError (
508- device_flow_response .get ("error" , "" ), device_flow_response .get ("error_description" , "" )
509- )
510-
511- if "interval" not in device_flow_response :
512- # Set default interval according to standard
513- device_flow_response ["interval" ] = 5
514- if "expires_in" in device_flow_response :
515- # msal library uses expires_at instead of the standard expires_in
516- device_flow_response ["expires_at" ] = float (device_flow_response ["expires_in" ]) + time .time ()
517- # Poll for token
518- credentials = self .__app .client .obtain_token_by_device_flow (
519- flow = device_flow_response ,
520- data = dict (
521- data ,
522- code = device_flow_response .get (
523- "device_code"
524- ), # Hack from msal library to get the code from the device flow, not standard
525- ),
516+ data = {"scope" : self .scope_string (), "client_id" : self .client_id , ** self .__token_custom_args }
517+ device_flow_endpoint = self ._get_device_authorization_endpoint ()
518+ response = self ._get_device_code_response (device_flow_endpoint , data )
519+ if "verification_uri" in response :
520+ print (f"Visit { response ['verification_uri' ]} and enter the code: { response .get ('user_code' , 'ERROR' )} " ) # noqa: T201
521+ elif "message" in response :
522+ print (f"Device code: { response .get ('message' , response .get ('user_code' , 'ERROR' ))} " ) # noqa: T201
523+ else :
524+ raise CogniteAuthError (
525+ f"Error initiating device flow: { response .get ('error' )} - { response .get ('error_description' )} "
526526 )
527-
528- self ._verify_credentials (credentials )
529- self .__app .token_cache .add (
530- dict (credentials , environment = self .__app .authority .instance ),
527+ if "interval" not in response :
528+ response ["interval" ] = 5 # Set default interval according to standard
529+ if "expires_in" in response :
530+ # msal library uses expires_at instead of the standard expires_in
531+ response ["expires_at" ] = float (response ["expires_in" ]) + time .time ()
532+
533+ credentials = self .__app .client .obtain_token_by_device_flow (
534+ flow = response ,
535+ # Hack from msal library to get the code from the device flow, not standard:
536+ data = dict (data , code = response .get ("device_code" )),
531537 )
538+ self ._verify_credentials (credentials )
532539 return credentials ["access_token" ], time .time () + float (credentials ["expires_in" ])
533540
534541 @classmethod
@@ -580,18 +587,16 @@ def default_for_entra_id(
580587 mem_cache_only : bool = False ,
581588 ) -> OAuthDeviceCode :
582589 """
583- Create an OAuthDeviceCode instance for Azure with default URLs and scopes. It uses the pre-configured Cognite
584- app registration for device code flow. If you need device code flow with another app registration, instantiate
585- OAuthDeviceCode directly.
590+ Create an OAuthDeviceCode instance for Azure with default URLs and scopes.
586591
587592 The default configuration creates the URLs based on the tenant id and cluster:
588593
589594 * Authority URL: ``https://login.microsoftonline.com/{tenant_id}``
590- * Scopes: [``https://{cdf_cluster}.cognitedata.com/.default ``]
595+ * Scopes: [``https://{cdf_cluster}.cognitedata.com/IDENTITY``, ``https://{cdf_cluster}.cognitedata.com/user_impersonation``, ``profile``, ``openid``, ``offline_access ``]
591596
592597 Args:
593598 tenant_id (str): The Azure tenant id
594- client_id (str): An app registration that allows device code flow.
599+ client_id (str): Your app registration client id. Must have device code flow enabled .
595600 cdf_cluster (str): The CDF cluster where the CDF project is located.
596601 token_cache_path (Path | None): Location to store token cache, defaults to os temp directory/cognitetokencache.{client_id}.bin.
597602 token_expiry_leeway_seconds (int): The token is refreshed at the earliest when this number of seconds is left before expiry. Default: 30 sec
@@ -602,18 +607,18 @@ def default_for_entra_id(
602607 """
603608 return cls (
604609 authority_url = f"https://login.microsoftonline.com/{ tenant_id } " ,
605- client_id = client_id , # Default application for CDF API for device code flow
610+ client_id = client_id ,
606611 scopes = [
607612 f"https://{ cdf_cluster } .cognitedata.com/IDENTITY" ,
608613 f"https://{ cdf_cluster } .cognitedata.com/user_impersonation" ,
609614 "profile" ,
610615 "openid" ,
616+ "offline_access" , # required for Azure to issue a refresh token
611617 ],
612618 token_cache_path = token_cache_path ,
613619 token_expiry_leeway_seconds = token_expiry_leeway_seconds ,
614620 clear_cache = clear_cache ,
615621 mem_cache_only = mem_cache_only ,
616- audience = f"https://{ cdf_cluster } .cognitedata.com" ,
617622 )
618623
619624 @classmethod
@@ -697,9 +702,8 @@ def scopes(self) -> list[str]:
697702 return self .__scopes
698703
699704 def _refresh_access_token (self ) -> tuple [str , float ]:
700- # First check if a token cache exists on disk. If yes, find and use:
701- # - A valid access token.
702- # - A valid refresh token, and if so, use it automatically to redeem a new access token.
705+ # Try the in-memory token cache silently (MSAL checks AT first, then RT automatically).
706+ # Falls through to interactive flow if nothing usable is found.
703707 credentials = None
704708 if accounts := self .__app .get_accounts ():
705709 credentials = self .__app .acquire_token_silent (scopes = self .__scopes , account = accounts [0 ])
0 commit comments