@@ -435,7 +435,7 @@ def test_client_msi_v2_disabled_via_env_var(self):
435435class TestMsiV2TokenAcquisitionIntegration (unittest .TestCase ):
436436 """Integration tests for MSI v2 token acquisition flow with mocked IMDS."""
437437
438- def _make_client (self , msi_v2_enabled = True ):
438+ def _make_client (self , msi_v2_enabled = False ):
439439 import requests
440440 return msal .ManagedIdentityClient (
441441 msal .SystemAssignedManagedIdentity (),
@@ -466,7 +466,48 @@ def _make_mock_responses(self, client_id, cu_id, cert_pem, token_endpoint,
466466
467467 @patch ("msal.msi_v2._acquire_token_via_mtls" )
468468 def test_msi_v2_happy_path (self , mock_mtls ):
469- """MSI v2 succeeds end-to-end (mTLS call is mocked)."""
469+ """MSI v2 succeeds end-to-end via mtls_proof_of_possession=True."""
470+ import requests
471+
472+ key = _generate_rsa_key ()
473+ cert_pem = _make_self_signed_cert (key , "test-client-id" )
474+ access_token = "MSI_V2_ACCESS_TOKEN"
475+ expires_in = 3600
476+ token_endpoint = "https://login.microsoftonline.com/tenant/oauth2/token"
477+
478+ platform_metadata , credential , token_response = self ._make_mock_responses (
479+ "test-client-id" , "test-cu-id" , cert_pem , token_endpoint ,
480+ access_token , expires_in )
481+
482+ mock_mtls .return_value = token_response
483+
484+ client = self ._make_client ()
485+
486+ def _mock_get (url , ** kwargs ):
487+ if "getplatformmetadata" in url :
488+ return MinimalResponse (
489+ status_code = 200 , text = json .dumps (platform_metadata ))
490+ raise ValueError ("Unexpected GET: {}" .format (url ))
491+
492+ def _mock_post (url , ** kwargs ):
493+ if "issuecredential" in url :
494+ return MinimalResponse (
495+ status_code = 200 , text = json .dumps (credential ))
496+ raise ValueError ("Unexpected POST: {}" .format (url ))
497+
498+ with patch .object (client ._http_client , "get" , side_effect = _mock_get ), \
499+ patch .object (client ._http_client , "post" , side_effect = _mock_post ):
500+ result = client .acquire_token_for_client (
501+ resource = "https://management.azure.com/" ,
502+ mtls_proof_of_possession = True )
503+
504+ self .assertEqual (result ["access_token" ], access_token )
505+ self .assertEqual (result ["token_type" ], "mtls_pop" )
506+ self .assertEqual (result ["token_source" ], "identity_provider" )
507+
508+ @patch ("msal.msi_v2._acquire_token_via_mtls" )
509+ def test_msi_v2_happy_path_via_constructor_flag (self , mock_mtls ):
510+ """MSI v2 also works when enabled via the msi_v2_enabled constructor param."""
470511 import requests
471512
472513 key = _generate_rsa_key ()
@@ -497,18 +538,18 @@ def _mock_post(url, **kwargs):
497538
498539 with patch .object (client ._http_client , "get" , side_effect = _mock_get ), \
499540 patch .object (client ._http_client , "post" , side_effect = _mock_post ):
541+ # No mtls_proof_of_possession kwarg; relies on constructor flag
500542 result = client .acquire_token_for_client (
501543 resource = "https://management.azure.com/" )
502544
503545 self .assertEqual (result ["access_token" ], access_token )
504546 self .assertEqual (result ["token_type" ], "mtls_pop" )
505- self .assertEqual (result ["token_source" ], "identity_provider" )
506547
507548 @patch ("msal.msi_v2._acquire_token_via_mtls" )
508549 def test_msi_v2_fallback_to_v1_on_metadata_failure (self , mock_mtls ):
509550 """MSI v2 falls back to MSI v1 if IMDS metadata call fails."""
510551 import requests
511- client = self ._make_client (msi_v2_enabled = True )
552+ client = self ._make_client ()
512553
513554 def _mock_get (url , ** kwargs ):
514555 if "getplatformmetadata" in url :
@@ -523,18 +564,19 @@ def _mock_get(url, **kwargs):
523564 raise ValueError ("Unexpected GET: {}" .format (url ))
524565
525566 with patch .object (client ._http_client , "get" , side_effect = _mock_get ):
526- result = client .acquire_token_for_client (resource = "R" )
567+ result = client .acquire_token_for_client (
568+ resource = "R" , mtls_proof_of_possession = True )
527569
528570 # Should have fallen back to MSI v1
529571 self .assertEqual (result ["access_token" ], "V1_TOKEN" )
530572 mock_mtls .assert_not_called ()
531573
532574 @patch ("msal.msi_v2._acquire_token_via_mtls" )
533- def test_msi_v2_not_attempted_when_disabled (self , mock_mtls ):
534- """MSI v2 is not attempted when msi_v2_enabled =False."""
575+ def test_msi_v2_not_attempted_when_not_requested (self , mock_mtls ):
576+ """MSI v2 is not attempted when mtls_proof_of_possession =False (default) ."""
535577 import requests
536578
537- client = self ._make_client (msi_v2_enabled = False )
579+ client = self ._make_client ()
538580
539581 def _mock_get (url , ** kwargs ):
540582 return MinimalResponse (status_code = 200 , text = json .dumps ({
@@ -544,6 +586,7 @@ def _mock_get(url, **kwargs):
544586 }))
545587
546588 with patch .object (client ._http_client , "get" , side_effect = _mock_get ):
589+ # No mtls_proof_of_possession — uses v1 by default
547590 result = client .acquire_token_for_client (resource = "R" )
548591
549592 mock_mtls .assert_not_called ()
@@ -553,7 +596,7 @@ def _mock_get(url, **kwargs):
553596 def test_msi_v2_fallback_on_unexpected_error (self , mock_mtls ):
554597 """MSI v2 falls back to MSI v1 on unexpected errors."""
555598 import requests
556- client = self ._make_client (msi_v2_enabled = True )
599+ client = self ._make_client ()
557600
558601 platform_metadata = {
559602 "clientId" : "client-id" ,
@@ -583,12 +626,85 @@ def _mock_post(url, **kwargs):
583626
584627 with patch .object (client ._http_client , "get" , side_effect = _mock_get ), \
585628 patch .object (client ._http_client , "post" , side_effect = _mock_post ):
586- result = client .acquire_token_for_client (resource = "R" )
629+ result = client .acquire_token_for_client (
630+ resource = "R" , mtls_proof_of_possession = True )
587631
588632 # Should fall back to MSI v1
589633 self .assertEqual (result ["access_token" ], "V1_FALLBACK" )
590634 mock_mtls .assert_not_called ()
591635
636+ @patch ("msal.msi_v2_attestation.get_attestation_jwt" )
637+ @patch ("msal.msi_v2._acquire_token_via_mtls" )
638+ def test_with_attestation_support_triggers_attestation (
639+ self , mock_mtls , mock_attest
640+ ):
641+ """with_attestation_support=True calls attestation; False skips it."""
642+ import requests
643+
644+ key = _generate_rsa_key ()
645+ cert_pem = _make_self_signed_cert (key , "test-client-id" )
646+ token_endpoint = "https://login.microsoftonline.com/tenant/oauth2/token"
647+ access_token = "MSI_V2_ATTEST_TOKEN"
648+ expires_in = 3600
649+
650+ platform_metadata = {
651+ "clientId" : "test-client-id" ,
652+ "tenantId" : "tenant-id" ,
653+ "cuId" : "test-cu-id" ,
654+ "attestationEndpoint" : "https://attest.example.com" ,
655+ }
656+ credential = {
657+ "certificate" : cert_pem ,
658+ "tokenEndpoint" : token_endpoint ,
659+ }
660+ token_response = {
661+ "access_token" : access_token ,
662+ "expires_in" : str (expires_in ),
663+ "token_type" : "mtls_pop" ,
664+ }
665+
666+ mock_attest .return_value = "fake.attestation.jwt"
667+ mock_mtls .return_value = token_response
668+
669+ client = self ._make_client ()
670+
671+ def _mock_get (url , ** kwargs ):
672+ if "getplatformmetadata" in url :
673+ return MinimalResponse (
674+ status_code = 200 , text = json .dumps (platform_metadata ))
675+ raise ValueError ("Unexpected GET: {}" .format (url ))
676+
677+ def _mock_post (url , ** kwargs ):
678+ if "issuecredential" in url :
679+ return MinimalResponse (
680+ status_code = 200 , text = json .dumps (credential ))
681+ raise ValueError ("Unexpected POST: {}" .format (url ))
682+
683+ # --- with_attestation_support=True: attestation should be called ---
684+ with patch .object (client ._http_client , "get" , side_effect = _mock_get ), \
685+ patch .object (client ._http_client , "post" , side_effect = _mock_post ):
686+ result = client .acquire_token_for_client (
687+ resource = "https://management.azure.com/" ,
688+ mtls_proof_of_possession = True ,
689+ with_attestation_support = True ,
690+ )
691+ mock_attest .assert_called_once ()
692+ self .assertEqual (result ["access_token" ], access_token )
693+
694+ mock_attest .reset_mock ()
695+ mock_mtls .reset_mock ()
696+
697+ # --- with_attestation_support=False (default): attestation NOT called ---
698+ with patch .object (client ._http_client , "get" , side_effect = _mock_get ), \
699+ patch .object (client ._http_client , "post" , side_effect = _mock_post ):
700+ result = client .acquire_token_for_client (
701+ resource = "https://management.azure.com/" ,
702+ mtls_proof_of_possession = True ,
703+ with_attestation_support = False ,
704+ )
705+ mock_attest .assert_not_called ()
706+ self .assertEqual (result ["access_token" ], access_token )
707+
592708
593709# ---------------------------------------------------------------------------
594710# Attestation module
0 commit comments