Skip to content

Commit 7d08bba

Browse files
Copilotgladjohn
andcommitted
Add MSI v2 (mTLS PoP) support: core module, attestation, integration, sample, and tests
Co-authored-by: gladjohn <90415114+gladjohn@users.noreply.github.com>
1 parent 7c4652e commit 7d08bba

File tree

6 files changed

+1171
-7
lines changed

6 files changed

+1171
-7
lines changed

msal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
3939
ManagedIdentityClient,
4040
ManagedIdentityError,
41+
MsiV2Error,
4142
ArcPlatformNotSupportedError,
4243
)
4344

msal/managed_identity.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ class ManagedIdentityError(ValueError):
2424
pass
2525

2626

27+
class MsiV2Error(ManagedIdentityError):
28+
"""Raised when the MSI v2 (mTLS PoP) flow fails."""
29+
pass
30+
31+
2732
class ManagedIdentity(UserDict):
2833
"""Feed an instance of this class to :class:`msal.ManagedIdentityClient`
2934
to acquire token for the specified managed identity.
@@ -166,6 +171,7 @@ def __init__(
166171
token_cache=None,
167172
http_cache=None,
168173
client_capabilities: Optional[List[str]] = None,
174+
msi_v2_enabled: Optional[bool] = None,
169175
):
170176
"""Create a managed identity client.
171177
@@ -207,6 +213,17 @@ def __init__(
207213
Client capability in Managed Identity is relayed as-is
208214
via ``xms_cc`` parameter on the wire.
209215
216+
:param bool msi_v2_enabled: (optional)
217+
Enable MSI v2 (mTLS PoP) token acquisition.
218+
When True (or when the ``MSAL_ENABLE_MSI_V2`` environment variable
219+
is set to a truthy value), the client will attempt to acquire tokens
220+
using the MSI v2 flow (IMDS /issuecredential + mTLS PoP).
221+
If the MSI v2 flow fails, it automatically falls back to MSI v1.
222+
MSI v2 only applies to Azure VM (IMDS) environments; it is ignored
223+
in other managed identity environments (App Service, Service Fabric,
224+
Azure Arc, etc.).
225+
Defaults to None (disabled unless the env var is set).
226+
210227
Recipe 1: Hard code a managed identity for your app::
211228
212229
import msal, requests
@@ -253,6 +270,11 @@ def __init__(
253270
)
254271
self._token_cache = token_cache or TokenCache()
255272
self._client_capabilities = client_capabilities
273+
# MSI v2 is enabled by the constructor param or the MSAL_ENABLE_MSI_V2 env var
274+
if msi_v2_enabled is None:
275+
env_val = os.environ.get("MSAL_ENABLE_MSI_V2", "").lower()
276+
msi_v2_enabled = env_val in ("1", "true", "yes")
277+
self._msi_v2_enabled = msi_v2_enabled
256278

257279
def acquire_token_for_client(
258280
self,
@@ -326,13 +348,28 @@ def acquire_token_for_client(
326348
break # With a fallback in hand, we break here to go refresh
327349
return access_token_from_cache # It is still good as new
328350
try:
329-
result = _obtain_token(
330-
self._http_client, self._managed_identity, resource,
331-
access_token_sha256_to_refresh=hashlib.sha256(
332-
access_token_to_refresh.encode("utf-8")).hexdigest()
333-
if access_token_to_refresh else None,
334-
client_capabilities=self._client_capabilities,
335-
)
351+
result = None
352+
if self._msi_v2_enabled:
353+
try:
354+
from .msi_v2 import obtain_token as _obtain_token_v2
355+
result = _obtain_token_v2(
356+
self._http_client, self._managed_identity, resource)
357+
logger.debug("MSI v2 token acquisition succeeded")
358+
except MsiV2Error as exc:
359+
logger.warning(
360+
"MSI v2 flow failed, falling back to MSI v1: %s", exc)
361+
except Exception as exc: # pylint: disable=broad-except
362+
logger.warning(
363+
"MSI v2 encountered unexpected error, "
364+
"falling back to MSI v1: %s", exc)
365+
if result is None:
366+
result = _obtain_token(
367+
self._http_client, self._managed_identity, resource,
368+
access_token_sha256_to_refresh=hashlib.sha256(
369+
access_token_to_refresh.encode("utf-8")).hexdigest()
370+
if access_token_to_refresh else None,
371+
client_capabilities=self._client_capabilities,
372+
)
336373
if "access_token" in result:
337374
expires_in = result.get("expires_in", 3600)
338375
if "refresh_in" not in result and expires_in >= 7200:

0 commit comments

Comments
 (0)