Skip to content

Commit 3218d1a

Browse files
Copilotgladjohn
andcommitted
Remove v1 fallback when mtls_proof_of_possession=True; keep fallback only for legacy msi_v2_enabled path
Co-authored-by: gladjohn <90415114+gladjohn@users.noreply.github.com>
1 parent 69805ef commit 3218d1a

File tree

2 files changed

+53
-36
lines changed

2 files changed

+53
-36
lines changed

msal/managed_identity.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -371,19 +371,28 @@ def acquire_token_for_client(
371371
# default (msi_v2_enabled / MSAL_ENABLE_MSI_V2 env var).
372372
use_msi_v2 = mtls_proof_of_possession or self._msi_v2_enabled
373373
if use_msi_v2:
374-
try:
374+
if mtls_proof_of_possession:
375+
# Explicit per-call request: errors are raised, no fallback to v1
375376
from .msi_v2 import obtain_token as _obtain_token_v2
376377
result = _obtain_token_v2(
377378
self._http_client, self._managed_identity, resource,
378379
attestation_enabled=with_attestation_support)
379380
logger.debug("MSI v2 token acquisition succeeded")
380-
except MsiV2Error as exc:
381-
logger.warning(
382-
"MSI v2 flow failed, falling back to MSI v1: %s", exc)
383-
except Exception as exc: # pylint: disable=broad-except
384-
logger.warning(
385-
"MSI v2 encountered unexpected error, "
386-
"falling back to MSI v1: %s", exc)
381+
else:
382+
# Legacy constructor flag: swallow errors and fall back to v1
383+
try:
384+
from .msi_v2 import obtain_token as _obtain_token_v2
385+
result = _obtain_token_v2(
386+
self._http_client, self._managed_identity, resource,
387+
attestation_enabled=with_attestation_support)
388+
logger.debug("MSI v2 token acquisition succeeded")
389+
except MsiV2Error as exc:
390+
logger.warning(
391+
"MSI v2 flow failed, falling back to MSI v1: %s", exc)
392+
except Exception as exc: # pylint: disable=broad-except
393+
logger.warning(
394+
"MSI v2 encountered unexpected error, "
395+
"falling back to MSI v1: %s", exc)
387396
if result is None:
388397
result = _obtain_token(
389398
self._http_client, self._managed_identity, resource,

tests/test_msi_v2.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -546,29 +546,21 @@ def _mock_post(url, **kwargs):
546546
self.assertEqual(result["token_type"], "mtls_pop")
547547

548548
@patch("msal.msi_v2._acquire_token_via_mtls")
549-
def test_msi_v2_fallback_to_v1_on_metadata_failure(self, mock_mtls):
550-
"""MSI v2 falls back to MSI v1 if IMDS metadata call fails."""
549+
def test_msi_v2_raises_on_metadata_failure_when_pop_requested(self, mock_mtls):
550+
"""When mtls_proof_of_possession=True, errors are raised (no v1 fallback)."""
551551
import requests
552552
client = self._make_client()
553553

554554
def _mock_get(url, **kwargs):
555555
if "getplatformmetadata" in url:
556556
return MinimalResponse(status_code=404, text="Not Found")
557-
# MSI v1 fallback (VM endpoint)
558-
if "oauth2/token" in url:
559-
return MinimalResponse(status_code=200, text=json.dumps({
560-
"access_token": "V1_TOKEN",
561-
"expires_in": "3600",
562-
"resource": "R",
563-
}))
564557
raise ValueError("Unexpected GET: {}".format(url))
565558

566559
with patch.object(client._http_client, "get", side_effect=_mock_get):
567-
result = client.acquire_token_for_client(
568-
resource="R", mtls_proof_of_possession=True)
560+
with self.assertRaises(MsiV2Error):
561+
client.acquire_token_for_client(
562+
resource="R", mtls_proof_of_possession=True)
569563

570-
# Should have fallen back to MSI v1
571-
self.assertEqual(result["access_token"], "V1_TOKEN")
572564
mock_mtls.assert_not_called()
573565

574566
@patch("msal.msi_v2._acquire_token_via_mtls")
@@ -593,8 +585,8 @@ def _mock_get(url, **kwargs):
593585
self.assertEqual(result["access_token"], "V1_TOKEN")
594586

595587
@patch("msal.msi_v2._acquire_token_via_mtls")
596-
def test_msi_v2_fallback_on_unexpected_error(self, mock_mtls):
597-
"""MSI v2 falls back to MSI v1 on unexpected errors."""
588+
def test_msi_v2_raises_on_unexpected_error_when_pop_requested(self, mock_mtls):
589+
"""When mtls_proof_of_possession=True, unexpected errors are raised (no v1 fallback)."""
598590
import requests
599591
client = self._make_client()
600592

@@ -605,18 +597,10 @@ def test_msi_v2_fallback_on_unexpected_error(self, mock_mtls):
605597
"attestationEndpoint": None,
606598
}
607599

608-
call_count = [0]
609-
610600
def _mock_get(url, **kwargs):
611-
call_count[0] += 1
612601
if "getplatformmetadata" in url:
613602
return MinimalResponse(status_code=200, text=json.dumps(platform_metadata))
614-
# MSI v1 fallback
615-
return MinimalResponse(status_code=200, text=json.dumps({
616-
"access_token": "V1_FALLBACK",
617-
"expires_in": "3600",
618-
"resource": "R",
619-
}))
603+
raise ValueError("Unexpected GET: {}".format(url))
620604

621605
def _mock_post(url, **kwargs):
622606
if "issuecredential" in url:
@@ -626,11 +610,35 @@ def _mock_post(url, **kwargs):
626610

627611
with patch.object(client._http_client, "get", side_effect=_mock_get), \
628612
patch.object(client._http_client, "post", side_effect=_mock_post):
629-
result = client.acquire_token_for_client(
630-
resource="R", mtls_proof_of_possession=True)
613+
with self.assertRaises(MsiV2Error):
614+
client.acquire_token_for_client(
615+
resource="R", mtls_proof_of_possession=True)
616+
617+
mock_mtls.assert_not_called()
618+
619+
@patch("msal.msi_v2._acquire_token_via_mtls")
620+
def test_msi_v2_fallback_to_v1_via_constructor_flag_on_failure(self, mock_mtls):
621+
"""Legacy msi_v2_enabled constructor path still falls back to MSI v1 on error."""
622+
import requests
623+
client = self._make_client(msi_v2_enabled=True)
631624

632-
# Should fall back to MSI v1
633-
self.assertEqual(result["access_token"], "V1_FALLBACK")
625+
def _mock_get(url, **kwargs):
626+
if "getplatformmetadata" in url:
627+
return MinimalResponse(status_code=404, text="Not Found")
628+
# MSI v1 fallback (VM endpoint)
629+
if "oauth2/token" in url:
630+
return MinimalResponse(status_code=200, text=json.dumps({
631+
"access_token": "V1_TOKEN",
632+
"expires_in": "3600",
633+
"resource": "R",
634+
}))
635+
raise ValueError("Unexpected GET: {}".format(url))
636+
637+
with patch.object(client._http_client, "get", side_effect=_mock_get):
638+
result = client.acquire_token_for_client(resource="R")
639+
640+
# Legacy path: falls back to v1
641+
self.assertEqual(result["access_token"], "V1_TOKEN")
634642
mock_mtls.assert_not_called()
635643

636644
@patch("msal.msi_v2_attestation.get_attestation_jwt")

0 commit comments

Comments
 (0)