diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index ff3bf7c0dde..7909bb342cf 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -49,6 +49,8 @@ _USER_TYPE = 'type' _USER = 'user' _SERVICE_PRINCIPAL = 'servicePrincipal' +_ACCESS_TOKEN_IDENTITY_TYPE = 'accessToken' +_ACCESS_TOKEN_IDENTITY_NAME = 'ACCESS_TOKEN_ACCOUNT' _SERVICE_PRINCIPAL_CERT_SN_ISSUER_AUTH = 'useCertSNIssuerAuth' _TOKEN_ENTRY_USER_ID = 'userId' _TOKEN_ENTRY_TOKEN_TYPE = 'tokenType' @@ -62,6 +64,11 @@ _AZ_LOGIN_MESSAGE = "Please run 'az login' to setup account." +_AZURE_CLI_SUBSCRIPTION_ID = 'AZURE_CLI_SUBSCRIPTION_ID' +_AZURE_CLI_TENANT_ID = 'AZURE_CLI_TENANT_ID' +_AZURE_CLI_ACCESS_TOKEN = 'AZURE_CLI_ACCESS_TOKEN' + + def load_subscriptions(cli_ctx, all_clouds=False, refresh=False): profile = Profile(cli_ctx=cli_ctx) if refresh: @@ -306,6 +313,13 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au account = self.get_subscription(subscription_id) + if account[_USER_ENTITY][_USER_TYPE] == _ACCESS_TOKEN_IDENTITY_TYPE: + from .auth.credentials import AccessTokenCredential + sdk_cred = CredentialAdaptor(AccessTokenCredential(os.environ[_AZURE_CLI_ACCESS_TOKEN])) + return (sdk_cred, + str(account[_SUBSCRIPTION_ID]), + str(account[_TENANT_ID])) + managed_identity_type, managed_identity_id = Profile._parse_managed_identity_account(account) if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): @@ -358,6 +372,26 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No account = self.get_subscription(subscription) + if account[_USER_ENTITY][_USER_TYPE] == _ACCESS_TOKEN_IDENTITY_TYPE: + access_token = os.environ[_AZURE_CLI_ACCESS_TOKEN] + from .auth.util import _now_timestamp + expires_on = _now_timestamp() + 3600 + import datetime + expiresOn = datetime.datetime.fromtimestamp(expires_on).strftime("%Y-%m-%d %H:%M:%S.%f") + token_entry = { + 'accessToken': os.environ[_AZURE_CLI_ACCESS_TOKEN], + 'expires_on': expires_on, + 'expiresOn': expiresOn + } + + # Build a tuple of (token_type, token, token_entry) + token_tuple = 'Bearer', access_token, token_entry + + # Return a tuple of (token_tuple, subscription, tenant) + return (token_tuple, + account[_SUBSCRIPTION_ID], + str(tenant if tenant else account[_TENANT_ID])) + managed_identity_type, managed_identity_id = Profile._parse_managed_identity_account(account) if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): @@ -539,6 +573,25 @@ def get_current_account_user(self): return active_account[_USER_ENTITY][_USER_NAME] def get_subscription(self, subscription=None): # take id or name + if _env_vars_configured(): + subscription_id = subscription if subscription else os.environ.get(_AZURE_CLI_SUBSCRIPTION_ID) + from .auth.credentials import AccessTokenCredential + sdk_cred = AccessTokenCredential(os.environ[_AZURE_CLI_ACCESS_TOKEN]) + subscription_finder = SubscriptionFinder(self.cli_ctx) + tenant_id = subscription_finder.find_tenant_for_subscription(subscription_id, sdk_cred) + return { + # Subscription ID is not required for data-plane operations + _SUBSCRIPTION_ID: subscription if subscription else os.environ.get(_AZURE_CLI_SUBSCRIPTION_ID), + # Tenant ID is required by some operations. + # For example, "Vaults - Create Or Update" requires tenantId property. + # https://learn.microsoft.com/en-us/rest/api/keyvault/keyvault/vaults/create-or-update + _TENANT_ID: tenant_id, + _USER_ENTITY: { + _USER_NAME: _ACCESS_TOKEN_IDENTITY_NAME, + _USER_TYPE: _ACCESS_TOKEN_IDENTITY_TYPE + }, + } + subscriptions = self.load_cached_subscriptions() if not subscriptions: raise CLIError(_AZ_LOGIN_MESSAGE) @@ -751,6 +804,12 @@ def __init__(self, cli_ctx): self._authority = self.cli_ctx.cloud.endpoints.active_directory self.tenants = [] + def find_tenant_for_subscription(self, subscription_id, credential=None): + # pylint: disable=too-many-statements + client = self._create_subscription_client(credential) + subscription = client.subscriptions.get(subscription_id) + return subscription.tenant_id + def find_using_common_tenant(self, username, credential=None): # pylint: disable=too-many-statements all_subscriptions = [] @@ -902,3 +961,15 @@ def _create_identity_instance(cli_ctx, authority, tenant_id=None, client_id=None use_msal_http_cache=use_msal_http_cache, enable_broker_on_windows=enable_broker_on_windows, instance_discovery=instance_discovery) + +def _use_msal_managed_identity(cli_ctx): + from azure.cli.core.telemetry import set_use_msal_managed_identity + # Use core.use_msal_managed_identity=false to use the old msrestazure implementation + use_msal_managed_identity = cli_ctx.config.getboolean('core', 'use_msal_managed_identity', fallback=True) + set_use_msal_managed_identity(use_msal_managed_identity) + return use_msal_managed_identity + + +def _env_vars_configured(): + if _AZURE_CLI_ACCESS_TOKEN in os.environ: + return True diff --git a/src/azure-cli-core/azure/cli/core/auth/credentials.py b/src/azure-cli-core/azure/cli/core/auth/credentials.py new file mode 100644 index 00000000000..d08d29a4590 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/credentials.py @@ -0,0 +1,26 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +""" +Credentials to acquire tokens from MSAL. +""" + +from knack.log import get_logger + +logger = get_logger(__name__) + + +class AccessTokenCredential: # pylint: disable=too-few-public-methods + + def __init__(self, access_token): + self.access_token = access_token + + def acquire_token(self, scopes, **kwargs): + logger.debug("AccessTokenCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs) + return { + 'access_token': self.access_token, + # The caller is responsible for providing a valid token + 'expires_in': 3600 + } diff --git a/src/azure-cli-core/azure/cli/core/commands/arm.py b/src/azure-cli-core/azure/cli/core/commands/arm.py index 0769e4759dc..3ff654f319b 100644 --- a/src/azure-cli-core/azure/cli/core/commands/arm.py +++ b/src/azure-cli-core/azure/cli/core/commands/arm.py @@ -366,7 +366,7 @@ def __call__(self, parser, namespace, value, option_string=None): sub_id = sub['id'] break if not sub_id: - logger.warning("Subscription '%s' not recognized.", value) + # logger.warning("Subscription '%s' not recognized.", value) sub_id = value namespace._subscription = sub_id # pylint: disable=protected-access