Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from azure.cli.core._session import ACCOUNT
from azure.cli.core.azclierror import AuthenticationError
from azure.cli.core.cloud import get_active_cloud, set_cloud_subscription
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
from azure.cli.core.util import in_cloud_console, can_launch_browser, is_github_codespaces
from knack.log import get_logger
from knack.util import CLIError
Expand Down Expand Up @@ -313,9 +314,10 @@ def login_with_managed_identity_azure_arc(self, identity_id=None, allow_no_subsc
import jwt
identity_type = MsiAccountTypes.system_assigned
from .auth.msal_credentials import ManagedIdentityCredential
from .auth.constants import ACCESS_TOKEN

cred = ManagedIdentityCredential()
token = cred.get_token(*self._arm_scope).token
token = cred.acquire_token(self._arm_scope)[ACCESS_TOKEN]
logger.info('Managed identity: token was retrieved. Now trying to initialize local accounts...')
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
tenant = decode['tid']
Expand All @@ -339,9 +341,10 @@ def login_with_managed_identity_azure_arc(self, identity_id=None, allow_no_subsc
def login_in_cloud_shell(self):
import jwt
from .auth.msal_credentials import CloudShellCredential
from .auth.constants import ACCESS_TOKEN

cred = CloudShellCredential()
token = cred.get_token(*self._arm_scope).token
token = cred.acquire_token(self._arm_scope)[ACCESS_TOKEN]
logger.info('Cloud Shell token was retrieved. Now trying to initialize local accounts...')
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
tenant = decode['tid']
Expand Down Expand Up @@ -397,21 +400,19 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
# Cloud Shell
from .auth.msal_credentials import CloudShellCredential
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
cred = CredentialAdaptor(CloudShellCredential())
sdk_cred = CredentialAdaptor(CloudShellCredential())

elif managed_identity_type:
# managed identity
if _on_azure_arc():
from .auth.msal_credentials import ManagedIdentityCredential
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
cred = CredentialAdaptor(ManagedIdentityCredential())
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
else:
# The resource is merely used by msrestazure to get the first access token.
# It is not actually used in an API invocation.
cred = MsiAccountTypes.msi_auth_factory(
sdk_cred = MsiAccountTypes.msi_auth_factory(
managed_identity_type, managed_identity_id,
self.cli_ctx.cloud.endpoints.active_directory_resource_id)

Expand All @@ -431,10 +432,9 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
external_credentials = []
for external_tenant in external_tenants:
external_credentials.append(self._create_credential(account, tenant_id=external_tenant))
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
cred = CredentialAdaptor(credential, auxiliary_credentials=external_credentials)
sdk_cred = CredentialAdaptor(credential, auxiliary_credentials=external_credentials)

return (cred,
return (sdk_cred,
str(account[_SUBSCRIPTION_ID]),
str(account[_TENANT_ID]))

Expand All @@ -460,24 +460,24 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
if tenant:
raise CLIError("Tenant shouldn't be specified for Cloud Shell account")
from .auth.msal_credentials import CloudShellCredential
cred = CloudShellCredential()
sdk_cred = CredentialAdaptor(CloudShellCredential())

elif managed_identity_type:
# managed identity
if tenant:
raise CLIError("Tenant shouldn't be specified for managed identity account")
if _on_azure_arc():
from .auth.msal_credentials import ManagedIdentityCredential
cred = ManagedIdentityCredential()
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
else:
from .auth.util import scopes_to_resource
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
scopes_to_resource(scopes))
sdk_cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
scopes_to_resource(scopes))

else:
cred = self._create_credential(account, tenant_id=tenant)
sdk_cred = CredentialAdaptor(self._create_credential(account, tenant_id=tenant))

sdk_token = cred.get_token(*scopes)
sdk_token = sdk_cred.get_token(*scopes)
# Convert epoch int 'expires_on' to datetime string 'expiresOn' for backward compatibility
# WARNING: expiresOn is deprecated and will be removed in future release.
import datetime
Expand Down Expand Up @@ -856,7 +856,6 @@ def find_using_common_tenant(self, username, credential=None):
specific_tenant_credential = identity.get_user_credential(username)

try:

subscriptions = self.find_using_specific_tenant(tenant_id, specific_tenant_credential,
tenant_id_description=t)
except AuthenticationError as ex:
Expand Down Expand Up @@ -927,9 +926,9 @@ def _create_subscription_client(self, credential):
raise CLIInternalError("Unable to get '{}' in profile '{}'"
.format(ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS, self.cli_ctx.cloud.profile))
api_version = get_api_version(self.cli_ctx, ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS)
client_kwargs = _prepare_mgmt_client_kwargs_track2(self.cli_ctx, credential)

client = client_type(credential, api_version=api_version,
sdk_cred = CredentialAdaptor(credential)
client_kwargs = _prepare_mgmt_client_kwargs_track2(self.cli_ctx, sdk_cred)
client = client_type(sdk_cred, api_version=api_version,
base_url=self.cli_ctx.cloud.endpoints.resource_manager,
**client_kwargs)
return client
Expand Down
3 changes: 3 additions & 0 deletions src/azure-cli-core/azure/cli/core/auth/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
# --------------------------------------------------------------------------------------------

AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'

ACCESS_TOKEN = 'access_token'
EXPIRES_IN = "expires_in"
15 changes: 8 additions & 7 deletions src/azure-cli-core/azure/cli/core/auth/credential_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
# --------------------------------------------------------------------------------------------

from knack.log import get_logger
from .util import build_sdk_access_token

logger = get_logger(__name__)


class CredentialAdaptor:
def __init__(self, credential, auxiliary_credentials=None):
"""Cross-tenant credential adaptor. It takes a main credential and auxiliary credentials.
"""Credential adaptor between MSAL credential and SDK credential.
It implements Track 2 SDK's azure.core.credentials.TokenCredential by exposing get_token.
:param credential: Main credential from .msal_authentication
:param auxiliary_credentials: Credentials from .msal_authentication for cross tenant authentication.
Details about cross tenant authentication:
:param credential: MSAL credential from ._msal_credentials
:param auxiliary_credentials: MSAL credentials for cross-tenant authentication.
Details about cross-tenant authentication:
https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/authenticate-multi-tenant
"""

Expand All @@ -32,11 +32,12 @@ def get_token(self, *scopes, **kwargs):
if 'data' in kwargs:
filtered_kwargs['data'] = kwargs['data']

return self._credential.get_token(*scopes, **filtered_kwargs)
return build_sdk_access_token(self._credential.acquire_token(list(scopes), **filtered_kwargs))

def get_auxiliary_tokens(self, *scopes, **kwargs):
"""Get access tokens from auxiliary credentials."""
# To test cross-tenant authentication, see https://github.com/Azure/azure-cli/issues/16691
if self._auxiliary_credentials:
return [cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials]
return [build_sdk_access_token(cred.acquire_token(list(scopes), **kwargs))
for cred in self._auxiliary_credentials]
return None
5 changes: 2 additions & 3 deletions src/azure-cli-core/azure/cli/core/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,8 @@ def login_with_service_principal(self, client_id, credential, scopes):
"""
sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential)
client_credential = sp_auth.get_msal_client_credential()
cca = ConfidentialClientApplication(client_id, client_credential=client_credential, **self._msal_app_kwargs)
result = cca.acquire_token_for_client(scopes)
check_result(result)
cred = ServicePrincipalCredential(client_id, client_credential, **self._msal_app_kwargs)
cred.acquire_token(scopes)

# Only persist the service principal after a successful login
entry = sp_auth.get_entry_to_persist()
Expand Down
57 changes: 24 additions & 33 deletions src/azure-cli-core/azure/cli/core/auth/msal_credentials.py
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am considering defining an abstract base classes MsalCredential and let all MSAL credentials inherit from it.

However, Azure Identity only defines TokenCredential as a Protocol

https://github.com/Azure/azure-sdk-for-python/blob/87d4a3273cc36f2e6cf2c8797a59060a5e6d275b/sdk/core/azure-core/azure/core/credentials.py#L72

class TokenCredential(Protocol):
    """Protocol for classes able to provide OAuth tokens."""

    def get_token(
        self,
        *scopes: str,
        claims: Optional[str] = None,
        tenant_id: Optional[str] = None,
        enable_cae: bool = False,
        **kwargs: Any,
    ) -> AccessToken:

A real credential looks like:

https://github.com/Azure/azure-sdk-for-python/blob/87d4a3273cc36f2e6cf2c8797a59060a5e6d275b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py#L39

class AzureCliCredential:

There is no hard constraints that it must implement get_token method.

Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,7 @@
# --------------------------------------------------------------------------------------------

"""
Credentials defined in this module are alternative implementations of credentials provided by Azure Identity.

These credentials implement azure.core.credentials.TokenCredential by exposing `get_token` method for Track 2
SDK invocation.

If you want to implement your own credential, the credential must also expose `get_token` method.

`get_token` method takes `scopes` as positional arguments and other optional `kwargs`, such as `claims`, `data`.
The return value should be a named tuple containing two elements: token (str), expires_on (int). You may simply use
azure.cli.core.auth.util.AccessToken to build the return value. See below credentials as examples.
Credentials to acquire tokens from MSAL.
"""

from knack.log import get_logger
Expand All @@ -22,15 +13,15 @@
ManagedIdentityClient, SystemAssignedManagedIdentity)

from .constants import AZURE_CLI_CLIENT_ID
from .util import check_result, build_sdk_access_token
from .util import check_result

logger = get_logger(__name__)


class UserCredential: # pylint: disable=too-few-public-methods

def __init__(self, client_id, username, **kwargs):
"""User credential implementing get_token interface.
"""User credential wrapping msal.application.PublicClientApplication

:param client_id: Client ID of the CLI.
:param username: The username for user credential.
Expand All @@ -52,14 +43,16 @@ def __init__(self, client_id, username, **kwargs):

self._account = accounts[0]

def get_token(self, *scopes, claims=None, **kwargs):
# scopes = ['https://pas.windows.net/CheckMyAccess/Linux/.default']
logger.debug("UserCredential.get_token: scopes=%r, claims=%r, kwargs=%r", scopes, claims, kwargs)
def acquire_token(self, scopes, claims=None, **kwargs):
# scopes must be a list.
# For acquiring SSH certificate, scopes is ['https://pas.windows.net/CheckMyAccess/Linux/.default']
# kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL
logger.debug("UserCredential.acquire_token: scopes=%r, claims=%r, kwargs=%r", scopes, claims, kwargs)

if claims:
logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s',
self._msal_app.authority.tenant, claims)
result = self._msal_app.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims,
result = self._msal_app.acquire_token_silent_with_error(scopes, self._account, claims_challenge=claims,
**kwargs)

from azure.cli.core.azclierror import AuthenticationError
Expand All @@ -82,7 +75,7 @@ def get_token(self, *scopes, claims=None, **kwargs):
success_template, error_template = read_response_templates()

result = self._msal_app.acquire_token_interactive(
list(scopes), login_hint=self._account['username'],
scopes, login_hint=self._account['username'],
port=8400 if self._msal_app.authority.is_adfs else None,
success_template=success_template, error_template=error_template, **kwargs)
check_result(result)
Expand All @@ -91,25 +84,24 @@ def get_token(self, *scopes, claims=None, **kwargs):
# launch browser, but show the error message and `az login` command instead.
else:
raise
return build_sdk_access_token(result)
return result


class ServicePrincipalCredential: # pylint: disable=too-few-public-methods

def __init__(self, client_id, client_credential, **kwargs):
"""Service principal credential implementing get_token interface.
"""Service principal credential wrapping msal.application.ConfidentialClientApplication.

:param client_id: The service principal's client ID.
:param client_credential: client_credential that will be passed to MSAL.
"""
self._msal_app = ConfidentialClientApplication(client_id, client_credential, **kwargs)

def get_token(self, *scopes, **kwargs):
logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
self._msal_app = ConfidentialClientApplication(client_id, client_credential=client_credential, **kwargs)

result = self._msal_app.acquire_token_for_client(list(scopes), **kwargs)
def acquire_token(self, scopes, **kwargs):
logger.debug("ServicePrincipalCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
result = self._msal_app.acquire_token_for_client(scopes, **kwargs)
check_result(result)
return build_sdk_access_token(result)
return result


class CloudShellCredential: # pylint: disable=too-few-public-methods
Expand All @@ -126,12 +118,11 @@ def __init__(self):
# token_cache=...
)

def get_token(self, *scopes, **kwargs):
logger.debug("CloudShellCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
# kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL
result = self._msal_app.acquire_token_interactive(list(scopes), prompt="none", **kwargs)
def acquire_token(self, scopes, **kwargs):
logger.debug("CloudShellCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
result = self._msal_app.acquire_token_interactive(scopes, prompt="none", **kwargs)
check_result(result, scopes=scopes)
return build_sdk_access_token(result)
return result


class ManagedIdentityCredential: # pylint: disable=too-few-public-methods
Expand All @@ -143,10 +134,10 @@ def __init__(self):
import requests
self._msal_client = ManagedIdentityClient(SystemAssignedManagedIdentity(), http_client=requests.Session())

def get_token(self, *scopes, **kwargs):
logger.debug("ManagedIdentityCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
def acquire_token(self, scopes, **kwargs):
logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)

from .util import scopes_to_resource
result = self._msal_client.acquire_token_for_client(resource=scopes_to_resource(scopes))
check_result(result)
return build_sdk_access_token(result)
return result
11 changes: 7 additions & 4 deletions src/azure-cli-core/azure/cli/core/auth/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ def check_result(result, **kwargs):


def build_sdk_access_token(token_entry):
import time
request_time = int(time.time())

# MSAL token entry sample:
# {
# 'access_token': 'eyJ0eXAiOiJKV...',
Expand All @@ -153,7 +150,8 @@ def build_sdk_access_token(token_entry):
# Importing azure.core.credentials.AccessToken is expensive.
# This can slow down commands that doesn't need azure.core, like `az account get-access-token`.
# So We define our own AccessToken.
return AccessToken(token_entry["access_token"], request_time + token_entry["expires_in"])
from .constants import ACCESS_TOKEN, EXPIRES_IN
return AccessToken(token_entry[ACCESS_TOKEN], _now_timestamp() + token_entry[EXPIRES_IN])


def decode_access_token(access_token):
Expand All @@ -177,3 +175,8 @@ def read_response_templates():
error_template = f.read()

return success_template, error_template


def _now_timestamp():
import time
return int(time.time())
Loading