Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
_USER_TYPE = 'type'
_USER = 'user'
_SERVICE_PRINCIPAL = 'servicePrincipal'
_ACCESS_TOKEN_IDENTITY_TYPE = 'accessToken'
_ACCESS_TOKEN_IDENTITY_NAME = 'ACCESS_TOKEN_ACCOUNT'
_SERVICE_PRINCIPAL_CERT_SN_ISSUER_AUTH = 'useCertSNIssuerAuth'
_TOKEN_ENTRY_USER_ID = 'userId'
_TOKEN_ENTRY_TOKEN_TYPE = 'tokenType'
Expand All @@ -62,6 +64,11 @@
_AZ_LOGIN_MESSAGE = "Please run 'az login' to setup account."


_AZURE_CLI_SUBSCRIPTION_ID = 'AZURE_CLI_SUBSCRIPTION_ID'
_AZURE_CLI_TENANT_ID = 'AZURE_CLI_TENANT_ID'
_AZURE_CLI_ACCESS_TOKEN = 'AZURE_CLI_ACCESS_TOKEN'


def load_subscriptions(cli_ctx, all_clouds=False, refresh=False):
profile = Profile(cli_ctx=cli_ctx)
if refresh:
Expand Down Expand Up @@ -306,6 +313,13 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au

account = self.get_subscription(subscription_id)

if account[_USER_ENTITY][_USER_TYPE] == _ACCESS_TOKEN_IDENTITY_TYPE:
from .auth.credentials import AccessTokenCredential
sdk_cred = CredentialAdaptor(AccessTokenCredential(os.environ[_AZURE_CLI_ACCESS_TOKEN]))
return (sdk_cred,
str(account[_SUBSCRIPTION_ID]),
str(account[_TENANT_ID]))

managed_identity_type, managed_identity_id = Profile._parse_managed_identity_account(account)

if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
Expand Down Expand Up @@ -358,6 +372,26 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No

account = self.get_subscription(subscription)

if account[_USER_ENTITY][_USER_TYPE] == _ACCESS_TOKEN_IDENTITY_TYPE:
access_token = os.environ[_AZURE_CLI_ACCESS_TOKEN]
from .auth.util import _now_timestamp
expires_on = _now_timestamp() + 3600
import datetime
expiresOn = datetime.datetime.fromtimestamp(expires_on).strftime("%Y-%m-%d %H:%M:%S.%f")
token_entry = {
'accessToken': os.environ[_AZURE_CLI_ACCESS_TOKEN],
'expires_on': expires_on,
'expiresOn': expiresOn
}

# Build a tuple of (token_type, token, token_entry)
token_tuple = 'Bearer', access_token, token_entry

# Return a tuple of (token_tuple, subscription, tenant)
return (token_tuple,
account[_SUBSCRIPTION_ID],
str(tenant if tenant else account[_TENANT_ID]))

managed_identity_type, managed_identity_id = Profile._parse_managed_identity_account(account)

if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
Expand Down Expand Up @@ -539,6 +573,25 @@ def get_current_account_user(self):
return active_account[_USER_ENTITY][_USER_NAME]

def get_subscription(self, subscription=None): # take id or name
if _env_vars_configured():
subscription_id = subscription if subscription else os.environ.get(_AZURE_CLI_SUBSCRIPTION_ID)
from .auth.credentials import AccessTokenCredential
sdk_cred = AccessTokenCredential(os.environ[_AZURE_CLI_ACCESS_TOKEN])
subscription_finder = SubscriptionFinder(self.cli_ctx)
tenant_id = subscription_finder.find_tenant_for_subscription(subscription_id, sdk_cred)
return {
# Subscription ID is not required for data-plane operations
_SUBSCRIPTION_ID: subscription if subscription else os.environ.get(_AZURE_CLI_SUBSCRIPTION_ID),
# Tenant ID is required by some operations.
# For example, "Vaults - Create Or Update" requires tenantId property.
# https://learn.microsoft.com/en-us/rest/api/keyvault/keyvault/vaults/create-or-update
_TENANT_ID: tenant_id,
_USER_ENTITY: {
_USER_NAME: _ACCESS_TOKEN_IDENTITY_NAME,
_USER_TYPE: _ACCESS_TOKEN_IDENTITY_TYPE
},
}

subscriptions = self.load_cached_subscriptions()
if not subscriptions:
raise CLIError(_AZ_LOGIN_MESSAGE)
Expand Down Expand Up @@ -751,6 +804,12 @@ def __init__(self, cli_ctx):
self._authority = self.cli_ctx.cloud.endpoints.active_directory
self.tenants = []

def find_tenant_for_subscription(self, subscription_id, credential=None):
# pylint: disable=too-many-statements
client = self._create_subscription_client(credential)
subscription = client.subscriptions.get(subscription_id)
return subscription.tenant_id

def find_using_common_tenant(self, username, credential=None):
# pylint: disable=too-many-statements
all_subscriptions = []
Expand Down Expand Up @@ -902,3 +961,15 @@ def _create_identity_instance(cli_ctx, authority, tenant_id=None, client_id=None
use_msal_http_cache=use_msal_http_cache,
enable_broker_on_windows=enable_broker_on_windows,
instance_discovery=instance_discovery)

def _use_msal_managed_identity(cli_ctx):
from azure.cli.core.telemetry import set_use_msal_managed_identity
# Use core.use_msal_managed_identity=false to use the old msrestazure implementation
use_msal_managed_identity = cli_ctx.config.getboolean('core', 'use_msal_managed_identity', fallback=True)
set_use_msal_managed_identity(use_msal_managed_identity)
return use_msal_managed_identity


def _env_vars_configured():
if _AZURE_CLI_ACCESS_TOKEN in os.environ:
return True
26 changes: 26 additions & 0 deletions src/azure-cli-core/azure/cli/core/auth/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

"""
Credentials to acquire tokens from MSAL.
"""

from knack.log import get_logger

logger = get_logger(__name__)


class AccessTokenCredential: # pylint: disable=too-few-public-methods

def __init__(self, access_token):
self.access_token = access_token

def acquire_token(self, scopes, **kwargs):
logger.debug("AccessTokenCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
return {
'access_token': self.access_token,
# The caller is responsible for providing a valid token
'expires_in': 3600
}
2 changes: 1 addition & 1 deletion src/azure-cli-core/azure/cli/core/commands/arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def __call__(self, parser, namespace, value, option_string=None):
sub_id = sub['id']
break
if not sub_id:
logger.warning("Subscription '%s' not recognized.", value)
# logger.warning("Subscription '%s' not recognized.", value)
sub_id = value
namespace._subscription = sub_id # pylint: disable=protected-access

Expand Down
Loading