1- import json
1+ import base64
2+ import hashlib
3+ import json
24import threading
35import time
46import logging
1214logger = logging .getLogger (__name__ )
1315_GRANT_TYPE_BROKER = "broker"
1416
17+ # Fields in the request data dict that should NOT be included in the extended
18+ # cache key hash. Everything else in data IS included, because those are extra
19+ # body parameters going on the wire and must differentiate cached tokens.
20+ #
21+ # Excluded fields and reasons:
22+ # - "client_id" : Standard OAuth2 client identifier, same for every request
23+ # - "grant_type" : It is possible to combine grants to get tokens, e.g. obo + refresh_token, auth_code + refresh_token etc.
24+ # - "scope" : Already represented as "target" in the AT cache key
25+ # - "claims" : Handled separately; its presence forces a token refresh
26+ # - "username" : Standard ROPC grant parameter. Tokens are cached by user ID (subject or oid+tid) instead
27+ # - "password" : Standard ROPC grant parameter. Tokens are tied to credentials.
28+ # - "refresh_token" : Standard refresh grant parameter
29+ # - "code" : Standard authorization code grant parameter
30+ # - "redirect_uri" : Standard authorization code grant parameter
31+ # - "code_verifier" : Standard PKCE parameter
32+ # - "device_code" : Standard device flow parameter
33+ # - "assertion" : Standard OBO/SAML assertion (RFC 7521)
34+ # - "requested_token_use" : OBO indicator ("on_behalf_of"), not an extra param
35+ # - "client_assertion" : Client authentication credential (RFC 7521 §4.2)
36+ # - "client_assertion_type" : Client authentication type (RFC 7521 §4.2)
37+ # - "client_secret" : Client authentication secret
38+ # - "token_type" : Used for SSH-cert/POP detection; AT entry stores separately
39+ # - "req_cnf" : Ephemeral proof-of-possession nonce, changes per request
40+ # - "key_id" : Already handled as a separate cache lookup field
41+ #
42+ # Included fields (examples — anything NOT in this set is included):
43+ # - "fmi_path" : Federated Managed Identity credential path
44+ # - any future non-standard body parameter that should isolate cache entries
45+ _EXT_CACHE_KEY_EXCLUDED_FIELDS = frozenset ({
46+ # Standard OAuth2 body parameters — these appear in every token request
47+ # and must NOT influence the extended cache key.
48+ # Only non-standard fields (e.g. fmi_path) should contribute to the hash.
49+ "client_id" ,
50+ "grant_type" ,
51+ "scope" ,
52+ "claims" ,
53+ "username" ,
54+ "password" ,
55+ "refresh_token" ,
56+ "code" ,
57+ "redirect_uri" ,
58+ "code_verifier" ,
59+ "device_code" ,
60+ "assertion" ,
61+ "requested_token_use" ,
62+ "client_assertion" ,
63+ "client_assertion_type" ,
64+ "client_secret" ,
65+ "token_type" ,
66+ "req_cnf" ,
67+ "key_id" ,
68+ })
69+
70+
71+ def _compute_ext_cache_key (data ):
72+ """Compute an extended cache key hash from extra body parameters in *data*.
73+
74+ All fields in *data* that go on the wire are included in the hash,
75+ EXCEPT those listed in ``_EXT_CACHE_KEY_EXCLUDED_FIELDS``.
76+ This ensures tokens acquired with different parameter values
77+ (e.g., different FMI paths) are cached separately.
78+
79+ Returns an empty string when *data* has no hashable fields.
80+
81+ The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator):
82+ sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded.
83+ """
84+ if not data :
85+ return ""
86+ cache_components = {
87+ k : str (v ) for k , v in data .items ()
88+ if k not in _EXT_CACHE_KEY_EXCLUDED_FIELDS and v
89+ }
90+ if not cache_components :
91+ return ""
92+ # Sort keys for consistent hashing (matches Go implementation)
93+ key_str = "" .join (
94+ k + cache_components [k ] for k in sorted (cache_components .keys ())
95+ )
96+ hash_bytes = hashlib .sha256 (key_str .encode ("utf-8" )).digest ()
97+ return base64 .urlsafe_b64encode (hash_bytes ).rstrip (b"=" ).decode ("ascii" ).lower ()
98+
99+
15100def is_subdict_of (small , big ):
16101 return dict (big , ** small ) == big
17102
@@ -30,6 +115,7 @@ class TokenCache(object):
30115
31116 class CredentialType :
32117 ACCESS_TOKEN = "AccessToken"
118+ ACCESS_TOKEN_EXTENDED = "atext" # Used when ext_cache_key is present (matches Go/dotnet)
33119 REFRESH_TOKEN = "RefreshToken"
34120 ACCOUNT = "Account" # Not exactly a credential type, but we put it here
35121 ID_TOKEN = "IdToken"
@@ -59,18 +145,22 @@ def __init__(self):
59145 self .CredentialType .ACCESS_TOKEN :
60146 lambda home_account_id = None , environment = None , client_id = None ,
61147 realm = None , target = None ,
148+ ext_cache_key = None ,
62149 # Note: New field(s) can be added here
63150 #key_id=None,
64151 ** ignored_payload_from_a_real_token :
65152 "-" .join ([ # Note: Could use a hash here to shorten key length
66153 home_account_id or "" ,
67154 environment or "" ,
68- self .CredentialType .ACCESS_TOKEN ,
155+ # Use "atext" credential type when ext_cache_key is
156+ # present, matching MSAL Go and MSAL .NET behaviour.
157+ "atext" if ext_cache_key else "AccessToken" ,
69158 client_id or "" ,
70159 realm or "" ,
71160 target or "" ,
72161 #key_id or "", # So ATs of different key_id can coexist
73- ]).lower (),
162+ ] + ([ext_cache_key ] if ext_cache_key else [])
163+ ).lower (),
74164 self .CredentialType .ID_TOKEN :
75165 lambda home_account_id = None , environment = None , client_id = None ,
76166 realm = None , ** ignored_payload_from_a_real_token :
@@ -98,6 +188,7 @@ def __init__(self):
98188 def _get_access_token (
99189 self ,
100190 home_account_id , environment , client_id , realm , target , # Together they form a compound key
191+ ext_cache_key = None ,
101192 default = None ,
102193 ): # O(1)
103194 return self ._get (
@@ -108,6 +199,7 @@ def _get_access_token(
108199 client_id = client_id ,
109200 realm = realm ,
110201 target = " " .join (target ),
202+ ext_cache_key = ext_cache_key ,
111203 ),
112204 default = default )
113205
@@ -153,7 +245,8 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n)
153245 ): # Special case for O(1) AT lookup
154246 preferred_result = self ._get_access_token (
155247 query ["home_account_id" ], query ["environment" ],
156- query ["client_id" ], query ["realm" ], target )
248+ query ["client_id" ], query ["realm" ], target ,
249+ ext_cache_key = query .get ("ext_cache_key" ))
157250 if preferred_result and self ._is_matching (
158251 preferred_result , query ,
159252 # Needs no target_set here because it is satisfied by dict key
@@ -179,6 +272,13 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n)
179272 if (entry != preferred_result # Avoid yielding the same entry twice
180273 and self ._is_matching (entry , query , target_set = target_set )
181274 ):
275+ # Cache isolation for extended cache keys (e.g., FMI path).
276+ # Entries with ext_cache_key must not match queries without one.
277+ if (credential_type == self .CredentialType .ACCESS_TOKEN
278+ and "ext_cache_key" in entry
279+ and "ext_cache_key" not in (query or {})
280+ ):
281+ continue
182282 yield entry
183283 for at in expired_access_tokens :
184284 self .remove_at (at )
@@ -278,6 +378,12 @@ def __add(self, event, now=None):
278378 # So that we won't accidentally store a user's password etc.
279379 "key_id" , # It happens in SSH-cert or POP scenario
280380 }})
381+ # Compute and store extended cache key for cache isolation
382+ # (e.g., different FMI paths should have separate cache entries)
383+ ext_cache_key = _compute_ext_cache_key (data )
384+
385+ if ext_cache_key :
386+ at ["ext_cache_key" ] = ext_cache_key
281387 if "refresh_in" in response :
282388 refresh_in = response ["refresh_in" ] # It is an integer
283389 at ["refresh_on" ] = str (now + refresh_in ) # Schema wants a string
0 commit comments