Skip to content

Commit 83b9611

Browse files
committed
mi-msal
1 parent a6798ae commit 83b9611

File tree

3 files changed

+432
-21
lines changed

3 files changed

+432
-21
lines changed

src/azure-cli-core/azure/cli/core/_profile.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,13 @@ def login(self,
222222

223223
def login_with_managed_identity(self, identity_id=None, client_id=None, object_id=None, resource_id=None,
224224
allow_no_subscriptions=None):
225-
if _on_azure_arc():
226-
return self.login_with_managed_identity_azure_arc(
227-
identity_id=identity_id, allow_no_subscriptions=allow_no_subscriptions)
225+
if _use_msal_managed_identity(self.cli_ctx):
226+
if identity_id:
227+
raise CLIError('--username is not supported by MSAL managed identity. '
228+
'Use --client-id, --object-id or --resource-id instead.')
229+
return self.login_with_managed_identity_msal(
230+
client_id=client_id, object_id=object_id, resource_id=resource_id,
231+
allow_no_subscriptions=allow_no_subscriptions)
228232

229233
import jwt
230234
from azure.mgmt.core.tools import is_valid_resource_id
@@ -304,22 +308,23 @@ def login_with_managed_identity(self, identity_id=None, client_id=None, object_i
304308
self._set_subscriptions(consolidated)
305309
return deepcopy(consolidated)
306310

307-
def login_with_managed_identity_azure_arc(self, identity_id=None, allow_no_subscriptions=None):
311+
def login_with_managed_identity_msal(self, client_id=None, object_id=None, resource_id=None,
312+
allow_no_subscriptions=None):
308313
import jwt
309-
identity_type = MsiAccountTypes.system_assigned
310-
from .auth.msal_credentials import ManagedIdentityCredential
311314
from .auth.constants import ACCESS_TOKEN
312315

313-
cred = ManagedIdentityCredential()
316+
identity_id_type, identity_id_value = MsiAccountTypes.parse_ids(
317+
client_id=client_id, object_id=object_id, resource_id=resource_id)
318+
cred = MsiAccountTypes.msal_credential_factory(identity_id_type, identity_id_value)
314319
token = cred.acquire_token(self._arm_scope)[ACCESS_TOKEN]
315320
logger.info('Managed identity: token was retrieved. Now trying to initialize local accounts...')
316321
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
317322
tenant = decode['tid']
318323

319324
subscription_finder = SubscriptionFinder(self.cli_ctx)
320325
subscriptions = subscription_finder.find_using_specific_tenant(tenant, cred)
321-
base_name = ('{}-{}'.format(identity_type, identity_id) if identity_id else identity_type)
322-
user = _USER_ASSIGNED_IDENTITY if identity_id else _SYSTEM_ASSIGNED_IDENTITY
326+
base_name = ('{}-{}'.format(identity_id_type, identity_id_value) if identity_id_value else identity_id_type)
327+
user = _USER_ASSIGNED_IDENTITY if identity_id_value else _SYSTEM_ASSIGNED_IDENTITY
323328
if not subscriptions:
324329
if allow_no_subscriptions:
325330
subscriptions = self._build_tenant_level_accounts([tenant])
@@ -399,10 +404,10 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
399404

400405
elif managed_identity_type:
401406
# managed identity
402-
if _on_azure_arc():
403-
from .auth.msal_credentials import ManagedIdentityCredential
407+
if _use_msal_managed_identity(self.cli_ctx):
404408
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
405-
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
409+
cred = MsiAccountTypes.msal_credential_factory(managed_identity_type, managed_identity_id)
410+
sdk_cred = CredentialAdaptor(cred)
406411
else:
407412
# The resource is merely used by msrestazure to get the first access token.
408413
# It is not actually used in an API invocation.
@@ -432,7 +437,8 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
432437
str(account[_SUBSCRIPTION_ID]),
433438
str(account[_TENANT_ID]))
434439

435-
def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=None):
440+
def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=None, credential_out=None):
441+
# credential_out is only used by unit tests to inspect the credential. Do not use it!
436442
# Convert resource to scopes
437443
if resource and not scopes:
438444
from .auth.util import resource_to_scopes
@@ -460,9 +466,11 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
460466
# managed identity
461467
if tenant:
462468
raise CLIError("Tenant shouldn't be specified for managed identity account")
463-
if _on_azure_arc():
464-
from .auth.msal_credentials import ManagedIdentityCredential
465-
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
469+
if _use_msal_managed_identity(self.cli_ctx):
470+
cred = MsiAccountTypes.msal_credential_factory(managed_identity_type, managed_identity_id)
471+
if credential_out:
472+
credential_out['credential'] = cred
473+
sdk_cred = CredentialAdaptor(cred)
466474
else:
467475
from .auth.util import scopes_to_resource
468476
sdk_cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
@@ -810,6 +818,41 @@ def msi_auth_factory(cli_account_name, identity, resource):
810818
return MSIAuthenticationWrapper(resource=resource, msi_res_id=identity)
811819
raise ValueError("unrecognized msi account name '{}'".format(cli_account_name))
812820

821+
@staticmethod
822+
def parse_ids(client_id=None, object_id=None, resource_id=None):
823+
id_arg_count = len([arg for arg in (client_id, object_id, resource_id) if arg])
824+
if id_arg_count > 1:
825+
raise CLIError('Usage error: Provide only one of --client-id, --object-id, --resource-id.')
826+
827+
id_type = None
828+
id_value = None
829+
if id_arg_count == 0:
830+
id_type = MsiAccountTypes.system_assigned
831+
id_value = None
832+
elif client_id:
833+
id_type = MsiAccountTypes.user_assigned_client_id
834+
id_value = client_id
835+
elif object_id:
836+
id_type = MsiAccountTypes.user_assigned_object_id
837+
id_value = object_id
838+
elif resource_id:
839+
id_type = MsiAccountTypes.user_assigned_resource_id
840+
id_value = resource_id
841+
return id_type, id_value
842+
843+
@staticmethod
844+
def msal_credential_factory(id_type, id_value):
845+
from azure.cli.core.auth.msal_credentials import ManagedIdentityCredential
846+
if id_type == MsiAccountTypes.system_assigned:
847+
return ManagedIdentityCredential()
848+
if id_type == MsiAccountTypes.user_assigned_client_id:
849+
return ManagedIdentityCredential(client_id=id_value)
850+
if id_type == MsiAccountTypes.user_assigned_object_id:
851+
return ManagedIdentityCredential(object_id=id_value)
852+
if id_type == MsiAccountTypes.user_assigned_resource_id:
853+
return ManagedIdentityCredential(resource_id=id_value)
854+
raise ValueError("Unrecognized managed identity ID type '{}'".format(id_type))
855+
813856

814857
class SubscriptionFinder:
815858
# An ARM client. It finds subscriptions for a user or service principal. It shouldn't do any
@@ -976,7 +1019,9 @@ def _create_identity_instance(cli_ctx, authority, tenant_id=None, client_id=None
9761019
instance_discovery=instance_discovery)
9771020

9781021

979-
def _on_azure_arc():
1022+
def _use_msal_managed_identity(cli_ctx):
9801023
# This indicates an Azure Arc-enabled server
9811024
from msal.managed_identity import get_managed_identity_source, AZURE_ARC
982-
return get_managed_identity_source() == AZURE_ARC
1025+
# PREVIEW: Use core.use_msal_managed_identity=true to enable managed identity authentication with MSAL
1026+
use_msal_managed_identity = cli_ctx.config.getboolean('core', 'use_msal_managed_identity', fallback=False)
1027+
return use_msal_managed_identity or get_managed_identity_source() == AZURE_ARC

src/azure-cli-core/azure/cli/core/auth/msal_credentials.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from knack.log import get_logger
1111
from knack.util import CLIError
1212
from msal import (PublicClientApplication, ConfidentialClientApplication,
13-
ManagedIdentityClient, SystemAssignedManagedIdentity)
13+
ManagedIdentityClient, SystemAssignedManagedIdentity, UserAssignedManagedIdentity)
1414

1515
from .constants import AZURE_CLI_CLIENT_ID
1616
from .util import check_result
@@ -131,9 +131,14 @@ class ManagedIdentityCredential: # pylint: disable=too-few-public-methods
131131
Currently, only Azure Arc's system-assigned managed identity is supported.
132132
"""
133133

134-
def __init__(self):
134+
def __init__(self, client_id=None, resource_id=None, object_id=None):
135135
import requests
136-
self._msal_client = ManagedIdentityClient(SystemAssignedManagedIdentity(), http_client=requests.Session())
136+
if client_id or resource_id or object_id:
137+
managed_identity = UserAssignedManagedIdentity(
138+
client_id=client_id, resource_id=resource_id, object_id=object_id)
139+
else:
140+
managed_identity = SystemAssignedManagedIdentity()
141+
self._msal_client = ManagedIdentityClient(managed_identity, http_client=requests.Session())
137142

138143
def acquire_token(self, scopes, **kwargs):
139144
logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)

0 commit comments

Comments
 (0)