Skip to content

Commit 8a56bc1

Browse files
Gladwin JohnsonCopilot
authored andcommitted
Fix exception-contract and cache-key regressions vs #882
Blocker 1 — Exception contract: RuntimeError from msal-key-attestation (DLL load, attestation call) now gets caught and wrapped as MsiV2Error at both the provider call site in msi_v2.py and the outer boundary in managed_identity.py. Only MsiV2Error (or its subclasses) can escape to the caller. Blocker 2 — Stable attestation cache key: The provider callback signature is expanded from (endpoint, key_handle, client_id) to (endpoint, key_handle, client_id, cache_key). MSAL now passes the stable per-boot key name as cache_key, which get_attestation_jwt() uses for its MAA token cache instead of falling back to the less cache-friendly numeric handle. Tests: 59 passed (44 core + 15 attestation), including new tests for RuntimeError wrapping and cache_key forwarding. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 639917f commit 8a56bc1

File tree

5 files changed

+77
-20
lines changed

5 files changed

+77
-20
lines changed

msal-key-attestation/msal_key_attestation/attestation.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,17 +338,21 @@ def get_attestation_jwt(
338338

339339
# ---------------------------------------------------------------------------
340340
# Public factory — matches the callback signature MSAL expects:
341-
# (endpoint: str, key_handle: int, client_id: str) -> str
341+
# (endpoint: str, key_handle: int, client_id: str, cache_key: str) -> str
342342
# ---------------------------------------------------------------------------
343343

344-
def create_attestation_provider() -> Callable[[str, int, str], str]:
344+
def create_attestation_provider() -> Callable[[str, int, str, str], str]:
345345
"""
346346
Create an attestation token provider callable for MSAL MSI v2.
347347
348348
The returned callable has signature::
349349
350350
provider(attestation_endpoint: str, key_handle: int,
351-
client_id: str) -> str
351+
client_id: str, cache_key: str) -> str
352+
353+
``cache_key`` should be the stable per-boot key name. Using the key
354+
name (rather than the numeric handle) maximizes MAA-token cache hits
355+
across key re-opens.
352356
353357
It wraps :func:`get_attestation_jwt` with caching support.
354358
@@ -372,10 +376,12 @@ def _provider(
372376
attestation_endpoint: str,
373377
key_handle: int,
374378
client_id: str,
379+
cache_key: str = "",
375380
) -> str:
376381
return get_attestation_jwt(
377382
attestation_endpoint=attestation_endpoint,
378383
client_id=client_id,
379384
key_handle=key_handle,
385+
cache_key=cache_key or None,
380386
)
381387
return _provider

msal-key-attestation/tests/test_attestation.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,26 @@ def test_returns_callable(self):
109109
def test_provider_calls_get_attestation_jwt(self, mock_get):
110110
mock_get.return_value = "fake.attestation.jwt"
111111
provider = create_attestation_provider()
112-
result = provider("https://attest.example.com", 12345, "client-id")
112+
result = provider(
113+
"https://attest.example.com", 12345, "client-id", "my-key-name")
113114
self.assertEqual(result, "fake.attestation.jwt")
114115
mock_get.assert_called_once_with(
115116
attestation_endpoint="https://attest.example.com",
116117
client_id="client-id",
117118
key_handle=12345,
119+
cache_key="my-key-name",
120+
)
121+
122+
@patch("msal_key_attestation.attestation.get_attestation_jwt")
123+
def test_provider_forwards_empty_cache_key_as_none(self, mock_get):
124+
mock_get.return_value = "fake.jwt"
125+
provider = create_attestation_provider()
126+
provider("https://ep", 1, "cid", "")
127+
mock_get.assert_called_once_with(
128+
attestation_endpoint="https://ep",
129+
client_id="cid",
130+
key_handle=1,
131+
cache_key=None,
118132
)
119133

120134

msal/managed_identity.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -343,11 +343,17 @@ def acquire_token_for_client(
343343
"Install it with: pip install msal-key-attestation")
344344

345345
from .msi_v2 import obtain_token as _obtain_token_v2
346-
result = _obtain_token_v2(
347-
self._http_client, self._managed_identity, resource,
348-
attestation_enabled=True,
349-
attestation_token_provider=attestation_token_provider,
350-
)
346+
try:
347+
result = _obtain_token_v2(
348+
self._http_client, self._managed_identity, resource,
349+
attestation_enabled=True,
350+
attestation_token_provider=attestation_token_provider,
351+
)
352+
except MsiV2Error:
353+
raise
354+
except Exception as exc:
355+
raise MsiV2Error(
356+
f"[msi_v2] Unexpected failure: {exc}") from exc
351357
if "access_token" in result and "error" not in result:
352358
result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
353359
return result

msal/msi_v2.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,9 +1232,10 @@ def _acquire_token_mtls_schannel(
12321232
# Public API
12331233
# ---------------------------------------------------------------------------
12341234

1235-
# Type alias for attestation provider callback matching .NET delegate:
1236-
# (endpoint: str, key_handle: int, client_id: str) -> str (JWT)
1237-
AttestationTokenProvider = Callable[[str, int, str], str]
1235+
# Type alias for attestation provider callback.
1236+
# Signature: (endpoint, key_handle, client_id, cache_key) -> JWT string.
1237+
# cache_key is the stable per-boot key name for optimal caching.
1238+
AttestationTokenProvider = Callable[[str, int, str, str], str]
12381239

12391240

12401241
def obtain_token(
@@ -1263,8 +1264,10 @@ def obtain_token(
12631264
resource: Resource URI for token acquisition
12641265
attestation_enabled: Whether attestation is enabled
12651266
attestation_token_provider: Callback (endpoint, key_handle,
1266-
client_id) -> JWT string. Provided by msal-key-attestation
1267-
package. None means non-attested flow.
1267+
client_id, cache_key) -> JWT string. Provided by
1268+
msal-key-attestation package. cache_key is the stable
1269+
per-boot key name for optimal caching. None means
1270+
non-attested flow.
12681271
12691272
Returns:
12701273
Token response dict with access_token, expires_in, token_type,
@@ -1335,10 +1338,18 @@ def obtain_token(
13351338
if not attestation_endpoint:
13361339
raise MsiV2Error(
13371340
"[msi_v2] attestationEndpoint missing from metadata.")
1338-
att_jwt = attestation_token_provider(
1339-
str(attestation_endpoint),
1340-
int(key.value),
1341-
str(client_id))
1341+
try:
1342+
att_jwt = attestation_token_provider(
1343+
str(attestation_endpoint),
1344+
int(key.value),
1345+
str(client_id),
1346+
str(key_name))
1347+
except MsiV2Error:
1348+
raise
1349+
except Exception as exc:
1350+
raise MsiV2Error(
1351+
f"[msi_v2] Attestation provider failed: {exc}"
1352+
) from exc
13421353
if not att_jwt or not str(att_jwt).strip():
13431354
raise MsiV2Error(
13441355
"[msi_v2] Attestation provider returned empty JWT.")

tests/test_msi_v2.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def test_v2_called_when_both_flags_true(self, mock_v2, _):
403403
with patch.dict("sys.modules", {
404404
"msal_key_attestation": MagicMock(
405405
create_attestation_provider=MagicMock(
406-
return_value=lambda ep, kh, ci: "fake.jwt"))
406+
return_value=lambda ep, kh, ci, ck="": "fake.jwt"))
407407
}):
408408
res = client.acquire_token_for_client(
409409
resource="https://mtlstb.graph.microsoft.com",
@@ -425,7 +425,7 @@ def test_strict_v2_failure_raises_no_v1_fallback(
425425
with patch.dict("sys.modules", {
426426
"msal_key_attestation": MagicMock(
427427
create_attestation_provider=MagicMock(
428-
return_value=lambda ep, kh, ci: "fake.jwt"))
428+
return_value=lambda ep, kh, ci, ck="": "fake.jwt"))
429429
}):
430430
with self.assertRaises(MsiV2Error):
431431
client.acquire_token_for_client(
@@ -434,6 +434,26 @@ def test_strict_v2_failure_raises_no_v1_fallback(
434434
with_attestation_support=True)
435435
mock_v1.assert_not_called()
436436

437+
@patch("msal.msi_v2.obtain_token",
438+
side_effect=RuntimeError("DLL load failed"))
439+
@patch("msal.managed_identity._obtain_token")
440+
def test_runtime_error_wrapped_as_msi_v2_error(
441+
self, mock_v1, mock_v2):
442+
"""RuntimeError from provider/DLL must surface as MsiV2Error."""
443+
client = self._make_client()
444+
with patch.dict("sys.modules", {
445+
"msal_key_attestation": MagicMock(
446+
create_attestation_provider=MagicMock(
447+
return_value=lambda ep, kh, ci, ck="": "fake.jwt"))
448+
}):
449+
with self.assertRaises(MsiV2Error) as ctx:
450+
client.acquire_token_for_client(
451+
resource="R",
452+
mtls_proof_of_possession=True,
453+
with_attestation_support=True)
454+
self.assertIn("DLL load failed", str(ctx.exception))
455+
mock_v1.assert_not_called()
456+
437457
def test_missing_attestation_package_raises_clear_error(self):
438458
client = self._make_client()
439459
with patch.dict("sys.modules", {"msal_key_attestation": None}):

0 commit comments

Comments
 (0)