Skip to content

Commit 69805ef

Browse files
Copilotgladjohn
andcommitted
Add mtls_proof_of_possession and with_attestation_support APIs to acquire_token_for_client
Co-authored-by: gladjohn <90415114+gladjohn@users.noreply.github.com>
1 parent 8e37366 commit 69805ef

File tree

4 files changed

+161
-20
lines changed

4 files changed

+161
-20
lines changed

msal/managed_identity.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,8 @@ def acquire_token_for_client(
281281
*,
282282
resource: str, # If/when we support scope, resource will become optional
283283
claims_challenge: Optional[str] = None,
284+
mtls_proof_of_possession: bool = False,
285+
with_attestation_support: bool = False,
284286
):
285287
"""Acquire token for the managed identity.
286288
@@ -300,6 +302,22 @@ def acquire_token_for_client(
300302
even if the app developer did not opt in for the "CP1" client capability.
301303
Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token.
302304
305+
:param bool mtls_proof_of_possession: (optional)
306+
When True, use the MSI v2 (mTLS Proof-of-Possession) flow to acquire an
307+
``mtls_pop`` token bound to a short-lived mTLS certificate issued by the
308+
IMDS ``/issuecredential`` endpoint.
309+
Without this flag the legacy IMDS v1 flow is used.
310+
Defaults to False.
311+
312+
This takes precedence over the ``msi_v2_enabled`` constructor parameter.
313+
314+
:param bool with_attestation_support: (optional)
315+
When True (and ``mtls_proof_of_possession`` is also True), attempt
316+
KeyGuard / platform attestation before credential issuance.
317+
On Windows this leverages ``AttestationClientLib.dll`` when available;
318+
on other platforms the parameter is silently ignored.
319+
Defaults to False.
320+
303321
.. note::
304322
305323
Known issue: When an Azure VM has only one user-assigned managed identity,
@@ -349,11 +367,15 @@ def acquire_token_for_client(
349367
return access_token_from_cache # It is still good as new
350368
try:
351369
result = None
352-
if self._msi_v2_enabled:
370+
# Per-call mtls_proof_of_possession takes precedence over the constructor
371+
# default (msi_v2_enabled / MSAL_ENABLE_MSI_V2 env var).
372+
use_msi_v2 = mtls_proof_of_possession or self._msi_v2_enabled
373+
if use_msi_v2:
353374
try:
354375
from .msi_v2 import obtain_token as _obtain_token_v2
355376
result = _obtain_token_v2(
356-
self._http_client, self._managed_identity, resource)
377+
self._http_client, self._managed_identity, resource,
378+
attestation_enabled=with_attestation_support)
357379
logger.debug("MSI v2 token acquisition succeeded")
358380
except MsiV2Error as exc:
359381
logger.warning(

msal/msi_v2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,16 @@ def obtain_token(
282282
http_client,
283283
managed_identity,
284284
resource: str,
285+
attestation_enabled: bool = False,
285286
) -> Dict[str, Any]:
286287
"""Acquire a token using the MSI v2 (mTLS PoP) flow.
287288
288289
:param http_client: HTTP client for IMDS requests.
289290
:param managed_identity: ManagedIdentity configuration dict.
290291
:param resource: Resource URL for token acquisition.
292+
:param attestation_enabled: When True, attempt KeyGuard / platform attestation
293+
before issuing credentials (Windows only; silently skipped on other platforms).
294+
Defaults to False.
291295
:returns: OAuth2 token response dict with access_token on success,
292296
or error dict on failure.
293297
:raises MsiV2Error: If the flow fails at a non-recoverable step.
@@ -311,9 +315,9 @@ def obtain_token(
311315
# 3. Build PKCS#10 CSR with cuId OID extension
312316
csr_der = _build_csr(private_key, client_id, cu_id)
313317

314-
# 4. Attempt attestation (Windows only; falls back to None on other platforms)
318+
# 4. Attempt attestation only when explicitly requested by the caller
315319
attestation_jwt = None
316-
if attestation_endpoint:
320+
if attestation_enabled and attestation_endpoint:
317321
try:
318322
from .msi_v2_attestation import get_attestation_jwt
319323
attestation_jwt = get_attestation_jwt(

sample/msi_v2_sample.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313
- Set RESOURCE environment variable to the target resource URL, e.g.
1414
export RESOURCE=https://management.azure.com/
1515
16-
To enable MSI v2 (required):
17-
export MSAL_ENABLE_MSI_V2=true
18-
or pass msi_v2_enabled=True to ManagedIdentityClient.
19-
2016
Usage:
2117
python msi_v2_sample.py
2218
"""
@@ -42,13 +38,16 @@
4238
msal.SystemAssignedManagedIdentity(),
4339
http_client=requests.Session(),
4440
token_cache=global_token_cache,
45-
msi_v2_enabled=True, # Enable MSI v2 (mTLS PoP) flow
4641
)
4742

4843

4944
def acquire_and_use_token():
5045
"""Acquire an mtls_pop token via MSI v2 and optionally call an API."""
51-
result = client.acquire_token_for_client(resource=RESOURCE)
46+
result = client.acquire_token_for_client(
47+
resource=RESOURCE,
48+
mtls_proof_of_possession=True, # Use MSI v2 (mTLS PoP) flow
49+
with_attestation_support=True, # Enable KeyGuard attestation (Windows)
50+
)
5251

5352
if "access_token" in result:
5453
print("Token acquired successfully")

tests/test_msi_v2.py

Lines changed: 126 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def test_client_msi_v2_disabled_via_env_var(self):
435435
class TestMsiV2TokenAcquisitionIntegration(unittest.TestCase):
436436
"""Integration tests for MSI v2 token acquisition flow with mocked IMDS."""
437437

438-
def _make_client(self, msi_v2_enabled=True):
438+
def _make_client(self, msi_v2_enabled=False):
439439
import requests
440440
return msal.ManagedIdentityClient(
441441
msal.SystemAssignedManagedIdentity(),
@@ -466,7 +466,48 @@ def _make_mock_responses(self, client_id, cu_id, cert_pem, token_endpoint,
466466

467467
@patch("msal.msi_v2._acquire_token_via_mtls")
468468
def test_msi_v2_happy_path(self, mock_mtls):
469-
"""MSI v2 succeeds end-to-end (mTLS call is mocked)."""
469+
"""MSI v2 succeeds end-to-end via mtls_proof_of_possession=True."""
470+
import requests
471+
472+
key = _generate_rsa_key()
473+
cert_pem = _make_self_signed_cert(key, "test-client-id")
474+
access_token = "MSI_V2_ACCESS_TOKEN"
475+
expires_in = 3600
476+
token_endpoint = "https://login.microsoftonline.com/tenant/oauth2/token"
477+
478+
platform_metadata, credential, token_response = self._make_mock_responses(
479+
"test-client-id", "test-cu-id", cert_pem, token_endpoint,
480+
access_token, expires_in)
481+
482+
mock_mtls.return_value = token_response
483+
484+
client = self._make_client()
485+
486+
def _mock_get(url, **kwargs):
487+
if "getplatformmetadata" in url:
488+
return MinimalResponse(
489+
status_code=200, text=json.dumps(platform_metadata))
490+
raise ValueError("Unexpected GET: {}".format(url))
491+
492+
def _mock_post(url, **kwargs):
493+
if "issuecredential" in url:
494+
return MinimalResponse(
495+
status_code=200, text=json.dumps(credential))
496+
raise ValueError("Unexpected POST: {}".format(url))
497+
498+
with patch.object(client._http_client, "get", side_effect=_mock_get), \
499+
patch.object(client._http_client, "post", side_effect=_mock_post):
500+
result = client.acquire_token_for_client(
501+
resource="https://management.azure.com/",
502+
mtls_proof_of_possession=True)
503+
504+
self.assertEqual(result["access_token"], access_token)
505+
self.assertEqual(result["token_type"], "mtls_pop")
506+
self.assertEqual(result["token_source"], "identity_provider")
507+
508+
@patch("msal.msi_v2._acquire_token_via_mtls")
509+
def test_msi_v2_happy_path_via_constructor_flag(self, mock_mtls):
510+
"""MSI v2 also works when enabled via the msi_v2_enabled constructor param."""
470511
import requests
471512

472513
key = _generate_rsa_key()
@@ -497,18 +538,18 @@ def _mock_post(url, **kwargs):
497538

498539
with patch.object(client._http_client, "get", side_effect=_mock_get), \
499540
patch.object(client._http_client, "post", side_effect=_mock_post):
541+
# No mtls_proof_of_possession kwarg; relies on constructor flag
500542
result = client.acquire_token_for_client(
501543
resource="https://management.azure.com/")
502544

503545
self.assertEqual(result["access_token"], access_token)
504546
self.assertEqual(result["token_type"], "mtls_pop")
505-
self.assertEqual(result["token_source"], "identity_provider")
506547

507548
@patch("msal.msi_v2._acquire_token_via_mtls")
508549
def test_msi_v2_fallback_to_v1_on_metadata_failure(self, mock_mtls):
509550
"""MSI v2 falls back to MSI v1 if IMDS metadata call fails."""
510551
import requests
511-
client = self._make_client(msi_v2_enabled=True)
552+
client = self._make_client()
512553

513554
def _mock_get(url, **kwargs):
514555
if "getplatformmetadata" in url:
@@ -523,18 +564,19 @@ def _mock_get(url, **kwargs):
523564
raise ValueError("Unexpected GET: {}".format(url))
524565

525566
with patch.object(client._http_client, "get", side_effect=_mock_get):
526-
result = client.acquire_token_for_client(resource="R")
567+
result = client.acquire_token_for_client(
568+
resource="R", mtls_proof_of_possession=True)
527569

528570
# Should have fallen back to MSI v1
529571
self.assertEqual(result["access_token"], "V1_TOKEN")
530572
mock_mtls.assert_not_called()
531573

532574
@patch("msal.msi_v2._acquire_token_via_mtls")
533-
def test_msi_v2_not_attempted_when_disabled(self, mock_mtls):
534-
"""MSI v2 is not attempted when msi_v2_enabled=False."""
575+
def test_msi_v2_not_attempted_when_not_requested(self, mock_mtls):
576+
"""MSI v2 is not attempted when mtls_proof_of_possession=False (default)."""
535577
import requests
536578

537-
client = self._make_client(msi_v2_enabled=False)
579+
client = self._make_client()
538580

539581
def _mock_get(url, **kwargs):
540582
return MinimalResponse(status_code=200, text=json.dumps({
@@ -544,6 +586,7 @@ def _mock_get(url, **kwargs):
544586
}))
545587

546588
with patch.object(client._http_client, "get", side_effect=_mock_get):
589+
# No mtls_proof_of_possession — uses v1 by default
547590
result = client.acquire_token_for_client(resource="R")
548591

549592
mock_mtls.assert_not_called()
@@ -553,7 +596,7 @@ def _mock_get(url, **kwargs):
553596
def test_msi_v2_fallback_on_unexpected_error(self, mock_mtls):
554597
"""MSI v2 falls back to MSI v1 on unexpected errors."""
555598
import requests
556-
client = self._make_client(msi_v2_enabled=True)
599+
client = self._make_client()
557600

558601
platform_metadata = {
559602
"clientId": "client-id",
@@ -583,12 +626,85 @@ def _mock_post(url, **kwargs):
583626

584627
with patch.object(client._http_client, "get", side_effect=_mock_get), \
585628
patch.object(client._http_client, "post", side_effect=_mock_post):
586-
result = client.acquire_token_for_client(resource="R")
629+
result = client.acquire_token_for_client(
630+
resource="R", mtls_proof_of_possession=True)
587631

588632
# Should fall back to MSI v1
589633
self.assertEqual(result["access_token"], "V1_FALLBACK")
590634
mock_mtls.assert_not_called()
591635

636+
@patch("msal.msi_v2_attestation.get_attestation_jwt")
637+
@patch("msal.msi_v2._acquire_token_via_mtls")
638+
def test_with_attestation_support_triggers_attestation(
639+
self, mock_mtls, mock_attest
640+
):
641+
"""with_attestation_support=True calls attestation; False skips it."""
642+
import requests
643+
644+
key = _generate_rsa_key()
645+
cert_pem = _make_self_signed_cert(key, "test-client-id")
646+
token_endpoint = "https://login.microsoftonline.com/tenant/oauth2/token"
647+
access_token = "MSI_V2_ATTEST_TOKEN"
648+
expires_in = 3600
649+
650+
platform_metadata = {
651+
"clientId": "test-client-id",
652+
"tenantId": "tenant-id",
653+
"cuId": "test-cu-id",
654+
"attestationEndpoint": "https://attest.example.com",
655+
}
656+
credential = {
657+
"certificate": cert_pem,
658+
"tokenEndpoint": token_endpoint,
659+
}
660+
token_response = {
661+
"access_token": access_token,
662+
"expires_in": str(expires_in),
663+
"token_type": "mtls_pop",
664+
}
665+
666+
mock_attest.return_value = "fake.attestation.jwt"
667+
mock_mtls.return_value = token_response
668+
669+
client = self._make_client()
670+
671+
def _mock_get(url, **kwargs):
672+
if "getplatformmetadata" in url:
673+
return MinimalResponse(
674+
status_code=200, text=json.dumps(platform_metadata))
675+
raise ValueError("Unexpected GET: {}".format(url))
676+
677+
def _mock_post(url, **kwargs):
678+
if "issuecredential" in url:
679+
return MinimalResponse(
680+
status_code=200, text=json.dumps(credential))
681+
raise ValueError("Unexpected POST: {}".format(url))
682+
683+
# --- with_attestation_support=True: attestation should be called ---
684+
with patch.object(client._http_client, "get", side_effect=_mock_get), \
685+
patch.object(client._http_client, "post", side_effect=_mock_post):
686+
result = client.acquire_token_for_client(
687+
resource="https://management.azure.com/",
688+
mtls_proof_of_possession=True,
689+
with_attestation_support=True,
690+
)
691+
mock_attest.assert_called_once()
692+
self.assertEqual(result["access_token"], access_token)
693+
694+
mock_attest.reset_mock()
695+
mock_mtls.reset_mock()
696+
697+
# --- with_attestation_support=False (default): attestation NOT called ---
698+
with patch.object(client._http_client, "get", side_effect=_mock_get), \
699+
patch.object(client._http_client, "post", side_effect=_mock_post):
700+
result = client.acquire_token_for_client(
701+
resource="https://management.azure.com/",
702+
mtls_proof_of_possession=True,
703+
with_attestation_support=False,
704+
)
705+
mock_attest.assert_not_called()
706+
self.assertEqual(result["access_token"], access_token)
707+
592708

593709
# ---------------------------------------------------------------------------
594710
# Attestation module

0 commit comments

Comments
 (0)