Skip to content

Commit ecf515a

Browse files
4gustbgavrilMS
andauthored
Added withFmi method for cca app (#876)
* Added withFmi method for cca app * Added Cache support for fmi keys * updated the cache key ext * updated cache key excluded * Update msal/token_cache.py Co-authored-by: Bogdan Gavril <bogavril@microsoft.com> * updated API for fmi --------- Co-authored-by: Bogdan Gavril <bogavril@microsoft.com>
1 parent eb78068 commit ecf515a

File tree

7 files changed

+823
-10
lines changed

7 files changed

+823
-10
lines changed

msal/application.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .mex import send_request as mex_send_request
2121
from .wstrust_request import send_request as wst_send_request
2222
from .wstrust_response import *
23-
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER
23+
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER, _compute_ext_cache_key
2424
import msal.telemetry
2525
from .region import _detect_region
2626
from .throttled_http_client import ThrottledHttpClient
@@ -1583,6 +1583,9 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15831583
key_id = kwargs.get("data", {}).get("key_id")
15841584
if key_id: # Some token types (SSH-certs, POP) are bound to a key
15851585
query["key_id"] = key_id
1586+
ext_cache_key = _compute_ext_cache_key(kwargs.get("data", {}))
1587+
if ext_cache_key: # FMI tokens need cache isolation by path
1588+
query["ext_cache_key"] = ext_cache_key
15861589
now = time.time()
15871590
refresh_reason = msal.telemetry.AT_ABSENT
15881591
for entry in self.token_cache.search( # A generator allows us to
@@ -2436,7 +2439,7 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app
24362439
except that ``allow_broker`` parameter shall remain ``None``.
24372440
"""
24382441

2439-
def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
2442+
def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, **kwargs):
24402443
"""Acquires token for the current confidential client, not for an end user.
24412444
24422445
Since MSAL Python 1.23, it will automatically look for token from cache,
@@ -2449,7 +2452,17 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
24492452
in the form of a claims_challenge directive in the www-authenticate header to be
24502453
returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token.
24512454
It is a string of a JSON object which contains lists of claims being requested from these locations.
2452-
2455+
:param str fmi_path:
2456+
Optional. The Federated Managed Identity (FMI) credential path.
2457+
When provided, it is sent as the ``fmi_path`` parameter in the
2458+
token request body, and the resulting token is cached separately
2459+
so that different FMI paths do not share cached tokens.
2460+
Example usage::
2461+
2462+
result = cca.acquire_token_for_client(
2463+
scopes=["api://resource/.default"],
2464+
fmi_path="SomeFmiPath/FmiCredentialPath",
2465+
)
24532466
:return: A dict representing the json response from Microsoft Entra:
24542467
24552468
- A successful response would contain "access_token" key,
@@ -2459,6 +2472,12 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
24592472
raise ValueError( # We choose to disallow force_refresh
24602473
"Historically, this method does not support force_refresh behavior. "
24612474
)
2475+
if fmi_path is not None:
2476+
if not isinstance(fmi_path, str):
2477+
raise ValueError(
2478+
"fmi_path must be a string, got {}".format(type(fmi_path).__name__))
2479+
kwargs["data"] = kwargs.get("data", {})
2480+
kwargs["data"]["fmi_path"] = fmi_path
24622481
return _clean_up(self._acquire_token_silent_with_error(
24632482
scopes, None, claims_challenge=claims_challenge, **kwargs))
24642483

msal/authority.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def __init__(
8989
self._http_client = http_client
9090
self._oidc_authority_url = oidc_authority_url
9191
if oidc_authority_url:
92-
logger.debug("Initializing with OIDC authority: %s", oidc_authority_url)
9392
tenant_discovery_endpoint = self._initialize_oidc_authority(
9493
oidc_authority_url)
9594
else:

msal/token_cache.py

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
import json
1+
import base64
2+
import hashlib
3+
import json
24
import threading
35
import time
46
import logging
@@ -12,6 +14,89 @@
1214
logger = logging.getLogger(__name__)
1315
_GRANT_TYPE_BROKER = "broker"
1416

17+
# Fields in the request data dict that should NOT be included in the extended
18+
# cache key hash. Everything else in data IS included, because those are extra
19+
# body parameters going on the wire and must differentiate cached tokens.
20+
#
21+
# Excluded fields and reasons:
22+
# - "client_id" : Standard OAuth2 client identifier, same for every request
23+
# - "grant_type" : It is possible to combine grants to get tokens, e.g. obo + refresh_token, auth_code + refresh_token etc.
24+
# - "scope" : Already represented as "target" in the AT cache key
25+
# - "claims" : Handled separately; its presence forces a token refresh
26+
# - "username" : Standard ROPC grant parameter. Tokens are cached by user ID (subject or oid+tid) instead
27+
# - "password" : Standard ROPC grant parameter. Tokens are tied to credentials.
28+
# - "refresh_token" : Standard refresh grant parameter
29+
# - "code" : Standard authorization code grant parameter
30+
# - "redirect_uri" : Standard authorization code grant parameter
31+
# - "code_verifier" : Standard PKCE parameter
32+
# - "device_code" : Standard device flow parameter
33+
# - "assertion" : Standard OBO/SAML assertion (RFC 7521)
34+
# - "requested_token_use" : OBO indicator ("on_behalf_of"), not an extra param
35+
# - "client_assertion" : Client authentication credential (RFC 7521 §4.2)
36+
# - "client_assertion_type" : Client authentication type (RFC 7521 §4.2)
37+
# - "client_secret" : Client authentication secret
38+
# - "token_type" : Used for SSH-cert/POP detection; AT entry stores separately
39+
# - "req_cnf" : Ephemeral proof-of-possession nonce, changes per request
40+
# - "key_id" : Already handled as a separate cache lookup field
41+
#
42+
# Included fields (examples — anything NOT in this set is included):
43+
# - "fmi_path" : Federated Managed Identity credential path
44+
# - any future non-standard body parameter that should isolate cache entries
45+
_EXT_CACHE_KEY_EXCLUDED_FIELDS = frozenset({
46+
# Standard OAuth2 body parameters — these appear in every token request
47+
# and must NOT influence the extended cache key.
48+
# Only non-standard fields (e.g. fmi_path) should contribute to the hash.
49+
"client_id",
50+
"grant_type",
51+
"scope",
52+
"claims",
53+
"username",
54+
"password",
55+
"refresh_token",
56+
"code",
57+
"redirect_uri",
58+
"code_verifier",
59+
"device_code",
60+
"assertion",
61+
"requested_token_use",
62+
"client_assertion",
63+
"client_assertion_type",
64+
"client_secret",
65+
"token_type",
66+
"req_cnf",
67+
"key_id",
68+
})
69+
70+
71+
def _compute_ext_cache_key(data):
72+
"""Compute an extended cache key hash from extra body parameters in *data*.
73+
74+
All fields in *data* that go on the wire are included in the hash,
75+
EXCEPT those listed in ``_EXT_CACHE_KEY_EXCLUDED_FIELDS``.
76+
This ensures tokens acquired with different parameter values
77+
(e.g., different FMI paths) are cached separately.
78+
79+
Returns an empty string when *data* has no hashable fields.
80+
81+
The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator):
82+
sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded.
83+
"""
84+
if not data:
85+
return ""
86+
cache_components = {
87+
k: str(v) for k, v in data.items()
88+
if k not in _EXT_CACHE_KEY_EXCLUDED_FIELDS and v
89+
}
90+
if not cache_components:
91+
return ""
92+
# Sort keys for consistent hashing (matches Go implementation)
93+
key_str = "".join(
94+
k + cache_components[k] for k in sorted(cache_components.keys())
95+
)
96+
hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest()
97+
return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower()
98+
99+
15100
def is_subdict_of(small, big):
16101
return dict(big, **small) == big
17102

@@ -30,6 +115,7 @@ class TokenCache(object):
30115

31116
class CredentialType:
32117
ACCESS_TOKEN = "AccessToken"
118+
ACCESS_TOKEN_EXTENDED = "atext" # Used when ext_cache_key is present (matches Go/dotnet)
33119
REFRESH_TOKEN = "RefreshToken"
34120
ACCOUNT = "Account" # Not exactly a credential type, but we put it here
35121
ID_TOKEN = "IdToken"
@@ -59,18 +145,22 @@ def __init__(self):
59145
self.CredentialType.ACCESS_TOKEN:
60146
lambda home_account_id=None, environment=None, client_id=None,
61147
realm=None, target=None,
148+
ext_cache_key=None,
62149
# Note: New field(s) can be added here
63150
#key_id=None,
64151
**ignored_payload_from_a_real_token:
65152
"-".join([ # Note: Could use a hash here to shorten key length
66153
home_account_id or "",
67154
environment or "",
68-
self.CredentialType.ACCESS_TOKEN,
155+
# Use "atext" credential type when ext_cache_key is
156+
# present, matching MSAL Go and MSAL .NET behaviour.
157+
"atext" if ext_cache_key else "AccessToken",
69158
client_id or "",
70159
realm or "",
71160
target or "",
72161
#key_id or "", # So ATs of different key_id can coexist
73-
]).lower(),
162+
] + ([ext_cache_key] if ext_cache_key else [])
163+
).lower(),
74164
self.CredentialType.ID_TOKEN:
75165
lambda home_account_id=None, environment=None, client_id=None,
76166
realm=None, **ignored_payload_from_a_real_token:
@@ -98,6 +188,7 @@ def __init__(self):
98188
def _get_access_token(
99189
self,
100190
home_account_id, environment, client_id, realm, target, # Together they form a compound key
191+
ext_cache_key=None,
101192
default=None,
102193
): # O(1)
103194
return self._get(
@@ -108,6 +199,7 @@ def _get_access_token(
108199
client_id=client_id,
109200
realm=realm,
110201
target=" ".join(target),
202+
ext_cache_key=ext_cache_key,
111203
),
112204
default=default)
113205

@@ -153,7 +245,8 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n)
153245
): # Special case for O(1) AT lookup
154246
preferred_result = self._get_access_token(
155247
query["home_account_id"], query["environment"],
156-
query["client_id"], query["realm"], target)
248+
query["client_id"], query["realm"], target,
249+
ext_cache_key=query.get("ext_cache_key"))
157250
if preferred_result and self._is_matching(
158251
preferred_result, query,
159252
# Needs no target_set here because it is satisfied by dict key
@@ -179,6 +272,13 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n)
179272
if (entry != preferred_result # Avoid yielding the same entry twice
180273
and self._is_matching(entry, query, target_set=target_set)
181274
):
275+
# Cache isolation for extended cache keys (e.g., FMI path).
276+
# Entries with ext_cache_key must not match queries without one.
277+
if (credential_type == self.CredentialType.ACCESS_TOKEN
278+
and "ext_cache_key" in entry
279+
and "ext_cache_key" not in (query or {})
280+
):
281+
continue
182282
yield entry
183283
for at in expired_access_tokens:
184284
self.remove_at(at)
@@ -278,6 +378,12 @@ def __add(self, event, now=None):
278378
# So that we won't accidentally store a user's password etc.
279379
"key_id", # It happens in SSH-cert or POP scenario
280380
}})
381+
# Compute and store extended cache key for cache isolation
382+
# (e.g., different FMI paths should have separate cache entries)
383+
ext_cache_key = _compute_ext_cache_key(data)
384+
385+
if ext_cache_key:
386+
at["ext_cache_key"] = ext_cache_key
281387
if "refresh_in" in response:
282388
refresh_in = response["refresh_in"] # It is an integer
283389
at["refresh_on"] = str(now + refresh_in) # Schema wants a string

0 commit comments

Comments
 (0)