Skip to content

Commit 791161d

Browse files
committed
Added Cache support for fmi keys
1 parent e0580f5 commit 791161d

7 files changed

Lines changed: 397 additions & 10 deletions

File tree

msal/application.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .mex import send_request as mex_send_request
1616
from .wstrust_request import send_request as wst_send_request
1717
from .wstrust_response import *
18-
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER
18+
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER, _compute_ext_cache_key
1919
import msal.telemetry
2020
from .region import _detect_region
2121
from .throttled_http_client import ThrottledHttpClient
@@ -1571,6 +1571,9 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15711571
key_id = kwargs.get("data", {}).get("key_id")
15721572
if key_id: # Some token types (SSH-certs, POP) are bound to a key
15731573
query["key_id"] = key_id
1574+
ext_cache_key = _compute_ext_cache_key(kwargs.get("data", {}))
1575+
if ext_cache_key: # FMI tokens need cache isolation by path
1576+
query["ext_cache_key"] = ext_cache_key
15741577
now = time.time()
15751578
refresh_reason = msal.telemetry.AT_ABSENT
15761579
for entry in self.token_cache.search( # A generator allows us to

msal/authority.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,9 @@ def __init__(
9292
self._http_client = http_client
9393
self._oidc_authority_url = oidc_authority_url
9494
if oidc_authority_url:
95-
logger.debug("Initializing with OIDC authority: %s", oidc_authority_url)
9695
tenant_discovery_endpoint = self._initialize_oidc_authority(
9796
oidc_authority_url)
9897
else:
99-
logger.debug("Initializing with Entra authority: %s", authority_url)
10098
tenant_discovery_endpoint = self._initialize_entra_authority(
10199
authority_url, validate_authority, instance_discovery)
102100
try:
@@ -117,8 +115,6 @@ def __init__(
117115
.format(authority_url)
118116
) + " Also please double check your tenant name or GUID is correct."
119117
raise ValueError(error_message)
120-
logger.debug(
121-
'openid_config("%s") = %s', tenant_discovery_endpoint, openid_config)
122118
self._issuer = openid_config.get('issuer')
123119
self.authorization_endpoint = openid_config['authorization_endpoint']
124120
self.token_endpoint = openid_config['token_endpoint']

msal/token_cache.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
import json
1+
import base64
2+
import hashlib
3+
import json
24
import threading
35
import time
46
import logging
@@ -12,6 +14,63 @@
1214
logger = logging.getLogger(__name__)
1315
_GRANT_TYPE_BROKER = "broker"
1416

17+
# Fields in the request data dict that should NOT be included in the extended
18+
# cache key hash. Everything else in data IS included, because those are extra
19+
# body parameters going on the wire and must differentiate cached tokens.
20+
#
21+
# Excluded fields and reasons:
22+
# - "key_id" : Already handled as a separate cache lookup field
23+
# - "token_type" : Used for SSH-cert/POP detection; AT entry stores it separately
24+
# - "req_cnf" : Ephemeral proof-of-possession nonce, changes per request
25+
# - "claims" : Handled separately; its presence forces a token refresh
26+
# - "scope" : Already represented as "target" in the AT cache key;
27+
# also added to data only at wire-time, not at cache-lookup time
28+
# - "username" : Standard ROPC grant parameter, not an extra body parameter
29+
# - "password" : Standard ROPC grant parameter, not an extra body parameter
30+
#
31+
# Included fields (examples — anything NOT in this set is included):
32+
# - "fmi_path" : Federated Managed Identity credential path
33+
# - any future extra body parameter that should isolate cache entries
34+
_EXT_CACHE_KEY_EXCLUDED_FIELDS = frozenset({
35+
"key_id",
36+
"token_type",
37+
"req_cnf",
38+
"claims",
39+
"scope",
40+
"username",
41+
"password",
42+
})
43+
44+
45+
def _compute_ext_cache_key(data):
46+
"""Compute an extended cache key hash from extra body parameters in *data*.
47+
48+
All fields in *data* that go on the wire are included in the hash,
49+
EXCEPT those listed in ``_EXT_CACHE_KEY_EXCLUDED_FIELDS``.
50+
This ensures tokens acquired with different parameter values
51+
(e.g., different FMI paths) are cached separately.
52+
53+
Returns an empty string when *data* has no hashable fields.
54+
55+
The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator):
56+
sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded.
57+
"""
58+
if not data:
59+
return ""
60+
cache_components = {
61+
k: str(v) for k, v in data.items()
62+
if k not in _EXT_CACHE_KEY_EXCLUDED_FIELDS and v
63+
}
64+
if not cache_components:
65+
return ""
66+
# Sort keys for consistent hashing (matches Go implementation)
67+
key_str = "".join(
68+
k + cache_components[k] for k in sorted(cache_components.keys())
69+
)
70+
hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest()
71+
return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower()
72+
73+
1574
def is_subdict_of(small, big):
1675
return dict(big, **small) == big
1776

@@ -59,6 +118,7 @@ def __init__(self):
59118
self.CredentialType.ACCESS_TOKEN:
60119
lambda home_account_id=None, environment=None, client_id=None,
61120
realm=None, target=None,
121+
ext_cache_key=None,
62122
# Note: New field(s) can be added here
63123
#key_id=None,
64124
**ignored_payload_from_a_real_token:
@@ -70,7 +130,8 @@ def __init__(self):
70130
realm or "",
71131
target or "",
72132
#key_id or "", # So ATs of different key_id can coexist
73-
]).lower(),
133+
] + ([ext_cache_key] if ext_cache_key else [])
134+
).lower(),
74135
self.CredentialType.ID_TOKEN:
75136
lambda home_account_id=None, environment=None, client_id=None,
76137
realm=None, **ignored_payload_from_a_real_token:
@@ -98,6 +159,7 @@ def __init__(self):
98159
def _get_access_token(
99160
self,
100161
home_account_id, environment, client_id, realm, target, # Together they form a compound key
162+
ext_cache_key=None,
101163
default=None,
102164
): # O(1)
103165
return self._get(
@@ -108,6 +170,7 @@ def _get_access_token(
108170
client_id=client_id,
109171
realm=realm,
110172
target=" ".join(target),
173+
ext_cache_key=ext_cache_key,
111174
),
112175
default=default)
113176

@@ -153,7 +216,8 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n)
153216
): # Special case for O(1) AT lookup
154217
preferred_result = self._get_access_token(
155218
query["home_account_id"], query["environment"],
156-
query["client_id"], query["realm"], target)
219+
query["client_id"], query["realm"], target,
220+
ext_cache_key=query.get("ext_cache_key"))
157221
if preferred_result and self._is_matching(
158222
preferred_result, query,
159223
# Needs no target_set here because it is satisfied by dict key
@@ -179,6 +243,13 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n)
179243
if (entry != preferred_result # Avoid yielding the same entry twice
180244
and self._is_matching(entry, query, target_set=target_set)
181245
):
246+
# Cache isolation for extended cache keys (e.g., FMI path).
247+
# Entries with ext_cache_key must not match queries without one.
248+
if (credential_type == self.CredentialType.ACCESS_TOKEN
249+
and "ext_cache_key" in entry
250+
and "ext_cache_key" not in (query or {})
251+
):
252+
continue
182253
yield entry
183254
for at in expired_access_tokens:
184255
self.remove_at(at)
@@ -278,6 +349,12 @@ def __add(self, event, now=None):
278349
# So that we won't accidentally store a user's password etc.
279350
"key_id", # It happens in SSH-cert or POP scenario
280351
}})
352+
# Compute and store extended cache key for cache isolation
353+
# (e.g., different FMI paths should have separate cache entries)
354+
ext_cache_key = _compute_ext_cache_key(data)
355+
356+
if ext_cache_key:
357+
at["ext_cache_key"] = ext_cache_key
281358
if "refresh_in" in response:
282359
refresh_in = response["refresh_in"] # It is an integer
283360
at["refresh_on"] = str(now + refresh_in) # Schema wants a string

tests/test_application.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,63 @@ def mock_post(url, headers=None, data=None, *args, **kwargs):
805805
self.assertEqual(result2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE,
806806
"Second call should return token from cache")
807807

808+
def test_different_fmi_paths_are_cached_separately(self):
809+
"""Tokens acquired with different fmi_path values must NOT share cache entries."""
810+
app = ConfidentialClientApplication(
811+
"client_id", client_credential="secret",
812+
authority="https://login.microsoftonline.com/my_tenant")
813+
814+
def mock_post_factory(token_value):
815+
def mock_post(url, headers=None, data=None, *args, **kwargs):
816+
return MinimalResponse(
817+
status_code=200, text=json.dumps({
818+
"access_token": token_value,
819+
"expires_in": 3600,
820+
}))
821+
return mock_post
822+
823+
# Acquire token with path A
824+
result_a = app.acquire_token_for_client_with_fmi_path(
825+
["scope"], "PathA/credential", post=mock_post_factory("AT_for_path_A"))
826+
self.assertEqual("AT_for_path_A", result_a["access_token"])
827+
828+
# Acquire token with path B (should NOT get path A's cached token)
829+
result_b = app.acquire_token_for_client_with_fmi_path(
830+
["scope"], "PathB/credential", post=mock_post_factory("AT_for_path_B"))
831+
self.assertEqual("AT_for_path_B", result_b["access_token"])
832+
self.assertEqual(result_b[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP,
833+
"Different FMI path should NOT return a cached token from another path")
834+
835+
# Verify path A still returns its own cached token
836+
result_a2 = app.acquire_token_for_client_with_fmi_path(
837+
["scope"], "PathA/credential", post=mock_post_factory("should_not_be_used"))
838+
self.assertEqual("AT_for_path_A", result_a2["access_token"])
839+
self.assertEqual(result_a2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE,
840+
"Same FMI path should return cached token")
841+
842+
def test_fmi_token_does_not_interfere_with_non_fmi_token(self):
843+
"""FMI-cached tokens must not be returned for non-FMI acquire_token_for_client."""
844+
app = ConfidentialClientApplication(
845+
"client_id", client_credential="secret",
846+
authority="https://login.microsoftonline.com/my_tenant")
847+
848+
# First, cache a token via FMI path
849+
app.acquire_token_for_client_with_fmi_path(
850+
["scope"], "some/fmi/path",
851+
post=lambda url, **kwargs: MinimalResponse(
852+
status_code=200, text=json.dumps({
853+
"access_token": "FMI_AT", "expires_in": 3600})))
854+
855+
# Now call regular acquire_token_for_client — should NOT get FMI token
856+
result = app.acquire_token_for_client(
857+
["scope"],
858+
post=lambda url, **kwargs: MinimalResponse(
859+
status_code=200, text=json.dumps({
860+
"access_token": "regular_AT", "expires_in": 3600})))
861+
self.assertEqual("regular_AT", result["access_token"])
862+
self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP,
863+
"Non-FMI call should not return FMI-cached token")
864+
808865

809866
@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK)
810867
class TestRemoveTokensForClient(unittest.TestCase):

tests/test_ccs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,14 @@ def test_acquire_token_silent(self):
6161
"CSS routing info should be derived from home_account_id")
6262

6363
def test_acquire_token_by_username_password(self):
64+
import warnings
6465
app = msal.ClientApplication("client_id")
6566
username = "johndoe@contoso.com"
6667
with patch.object(app.http_client, "post", return_value=MinimalResponse(
6768
status_code=400, text='{"error": "mock"}')) as mocked_method:
68-
app.acquire_token_by_username_password(username, "password", ["scope"])
69+
with warnings.catch_warnings():
70+
warnings.simplefilter("ignore", DeprecationWarning)
71+
app.acquire_token_by_username_password(username, "password", ["scope"])
6972
self.assertEqual(
7073
"upn:" + username,
7174
mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'),

tests/test_fmi_e2e.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,135 @@ def test_acquire_with_assertion_callback_and_fmi_path(self):
149149
"cache might not be working correctly")
150150

151151

152+
class TestFMICacheIsolation(LabBasedTestCase):
153+
"""Test that tokens acquired with different FMI paths are cached separately.
154+
155+
This verifies the cache key extensibility: two calls with different fmi_path
156+
values should NOT return each other's cached tokens.
157+
"""
158+
159+
def test_different_fmi_paths_are_cached_separately(self):
160+
app = msal.ConfidentialClientApplication(
161+
_FMI_CLIENT_ID,
162+
client_credential=get_client_certificate(),
163+
authority=_AUTHORITY_URL,
164+
http_client=MinimalHttpClient(),
165+
)
166+
scopes = [_FMI_SCOPE]
167+
168+
# Acquire token with path A
169+
result_a = app.acquire_token_for_client_with_fmi_path(
170+
scopes, "PathA/credential")
171+
self.assertIn("access_token", result_a,
172+
"Path A acquisition failed: {}: {}".format(
173+
result_a.get("error"), result_a.get("error_description")))
174+
175+
# Acquire token with path B — should NOT get path A's cached token
176+
result_b = app.acquire_token_for_client_with_fmi_path(
177+
scopes, "PathB/credential")
178+
self.assertIn("access_token", result_b,
179+
"Path B acquisition failed: {}: {}".format(
180+
result_b.get("error"), result_b.get("error_description")))
181+
self.assertNotEqual(
182+
result_b.get("token_source"), "cache",
183+
"Different FMI path should NOT return cached token from another path")
184+
185+
# Verify path A still returns its own cached token
186+
result_a2 = app.acquire_token_for_client_with_fmi_path(
187+
scopes, "PathA/credential")
188+
self.assertIn("access_token", result_a2)
189+
self.assertEqual(
190+
result_a2.get("token_source"), "cache",
191+
"Same FMI path should return cached token")
192+
self.assertEqual(result_a["access_token"], result_a2["access_token"])
193+
194+
def test_fmi_token_does_not_interfere_with_non_fmi_token(self):
195+
app = msal.ConfidentialClientApplication(
196+
_FMI_CLIENT_ID,
197+
client_credential=get_client_certificate(),
198+
authority=_AUTHORITY_URL,
199+
http_client=MinimalHttpClient(),
200+
)
201+
scopes = [_FMI_SCOPE]
202+
203+
# Cache a token via FMI path
204+
fmi_result = app.acquire_token_for_client_with_fmi_path(scopes, _FMI_PATH)
205+
self.assertIn("access_token", fmi_result)
206+
207+
# Regular acquire_token_for_client should NOT get the FMI token
208+
regular_result = app.acquire_token_for_client(scopes)
209+
self.assertIn("access_token", regular_result,
210+
"Regular call failed: {}: {}".format(
211+
regular_result.get("error"), regular_result.get("error_description")))
212+
self.assertNotEqual(
213+
regular_result.get("token_source"), "cache",
214+
"Non-FMI call should not return FMI-cached token")
215+
216+
217+
class TestFMICacheInspection(LabBasedTestCase):
218+
"""Acquire tokens with two different FMI paths and inspect the underlying
219+
cache to verify the entries are correctly isolated."""
220+
221+
def test_two_fmi_paths_produce_separate_cache_entries(self):
222+
app = msal.ConfidentialClientApplication(
223+
_FMI_CLIENT_ID,
224+
client_credential=get_client_certificate(),
225+
authority=_AUTHORITY_URL,
226+
http_client=MinimalHttpClient(),
227+
)
228+
scopes = [_FMI_SCOPE]
229+
path_a = "PathAlpha/Credential"
230+
path_b = "PathBeta/Credential"
231+
232+
# 1. Acquire token with path A
233+
result_a = app.acquire_token_for_client_with_fmi_path(scopes, path_a)
234+
self.assertIn("access_token", result_a,
235+
"Path A acquisition failed: {}: {}".format(
236+
result_a.get("error"), result_a.get("error_description")))
237+
token_a = result_a["access_token"]
238+
239+
# 2. Acquire token with path B
240+
result_b = app.acquire_token_for_client_with_fmi_path(scopes, path_b)
241+
self.assertIn("access_token", result_b,
242+
"Path B acquisition failed: {}: {}".format(
243+
result_b.get("error"), result_b.get("error_description")))
244+
token_b = result_b["access_token"]
245+
246+
# Tokens should be different (different paths go to different resources)
247+
self.assertNotEqual(token_a, token_b,
248+
"Tokens for different FMI paths should differ")
249+
250+
# 3. Inspect cache: there should be exactly 2 AccessToken entries
251+
cache = app.token_cache._cache
252+
at_entries = cache.get("AccessToken", {})
253+
# Filter to our client_id + scope to avoid noise
254+
our_entries = {
255+
k: v for k, v in at_entries.items()
256+
if v.get("client_id") == _FMI_CLIENT_ID
257+
and _FMI_SCOPE.split("/")[0] in v.get("target", "")
258+
}
259+
self.assertEqual(2, len(our_entries),
260+
"Cache should contain exactly 2 AT entries for our client, "
261+
"got {}: {}".format(len(our_entries), list(our_entries.keys())))
262+
263+
# 4. Each entry must have a non-empty ext_cache_key, and they must differ
264+
ext_keys = [v.get("ext_cache_key") for v in our_entries.values()]
265+
for ek in ext_keys:
266+
self.assertTrue(ek, "Each FMI cache entry must have a non-empty ext_cache_key")
267+
self.assertNotEqual(ext_keys[0], ext_keys[1],
268+
"ext_cache_key values for different FMI paths must differ")
269+
270+
# 5. Verify each path still returns its own cached token
271+
cached_a = app.acquire_token_for_client_with_fmi_path(scopes, path_a)
272+
self.assertEqual("cache", cached_a.get("token_source"))
273+
self.assertEqual(token_a, cached_a["access_token"],
274+
"Path A should return its own cached token")
275+
276+
cached_b = app.acquire_token_for_client_with_fmi_path(scopes, path_b)
277+
self.assertEqual("cache", cached_b.get("token_source"))
278+
self.assertEqual(token_b, cached_b["access_token"],
279+
"Path B should return its own cached token")
280+
281+
152282
if __name__ == "__main__":
153283
unittest.main()

0 commit comments

Comments
 (0)