Skip to content

Commit e7a866f

Browse files
committed
auth
1 parent 8616453 commit e7a866f

File tree

4 files changed

+39
-45
lines changed

4 files changed

+39
-45
lines changed

src/azure-cli-core/azure/cli/core/_profile.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from azure.cli.core._session import ACCOUNT
1212
from azure.cli.core.azclierror import AuthenticationError
1313
from azure.cli.core.cloud import get_active_cloud, set_cloud_subscription
14+
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
1415
from azure.cli.core.util import in_cloud_console, can_launch_browser, is_github_codespaces
1516
from knack.log import get_logger
1617
from knack.util import CLIError
@@ -477,7 +478,7 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
477478
else:
478479
cred = self._create_credential(account, tenant_id=tenant)
479480

480-
sdk_token = cred.get_token(*scopes)
481+
sdk_token = CredentialAdaptor(cred).get_token(*scopes)
481482
# Convert epoch int 'expires_on' to datetime string 'expiresOn' for backward compatibility
482483
# WARNING: expiresOn is deprecated and will be removed in future release.
483484
import datetime
@@ -856,7 +857,6 @@ def find_using_common_tenant(self, username, credential=None):
856857
specific_tenant_credential = identity.get_user_credential(username)
857858

858859
try:
859-
860860
subscriptions = self.find_using_specific_tenant(tenant_id, specific_tenant_credential,
861861
tenant_id_description=t)
862862
except AuthenticationError as ex:
@@ -927,9 +927,11 @@ def _create_subscription_client(self, credential):
927927
raise CLIInternalError("Unable to get '{}' in profile '{}'"
928928
.format(ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS, self.cli_ctx.cloud.profile))
929929
api_version = get_api_version(self.cli_ctx, ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS)
930-
client_kwargs = _prepare_mgmt_client_kwargs_track2(self.cli_ctx, credential)
931930

932-
client = client_type(credential, api_version=api_version,
931+
sdk_credential = CredentialAdaptor(credential)
932+
client_kwargs = _prepare_mgmt_client_kwargs_track2(self.cli_ctx, sdk_credential)
933+
934+
client = client_type(sdk_credential, api_version=api_version,
933935
base_url=self.cli_ctx.cloud.endpoints.resource_manager,
934936
**client_kwargs)
935937
return client

src/azure-cli-core/azure/cli/core/auth/credential_adaptor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44
# --------------------------------------------------------------------------------------------
55

66
from knack.log import get_logger
7+
from .util import build_sdk_access_token
78

89
logger = get_logger(__name__)
910

1011

1112
class CredentialAdaptor:
1213
def __init__(self, credential, auxiliary_credentials=None):
13-
"""Cross-tenant credential adaptor. It takes a main credential and auxiliary credentials.
14-
14+
"""Credential adaptor between MSAL credential and SDK credential.
1515
It implements Track 2 SDK's azure.core.credentials.TokenCredential by exposing get_token.
1616
17-
:param credential: Main credential from .msal_authentication
18-
:param auxiliary_credentials: Credentials from .msal_authentication for cross tenant authentication.
19-
Details about cross tenant authentication:
17+
:param credential: MSAL credential from ._msal_credentials
18+
:param auxiliary_credentials: MSAL credentials for cross-tenant authentication.
19+
Details about cross-tenant authentication:
2020
https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/authenticate-multi-tenant
2121
"""
2222

@@ -32,11 +32,12 @@ def get_token(self, *scopes, **kwargs):
3232
if 'data' in kwargs:
3333
filtered_kwargs['data'] = kwargs['data']
3434

35-
return self._credential.get_token(*scopes, **filtered_kwargs)
35+
return build_sdk_access_token(self._credential.acquire_token(list(scopes), **filtered_kwargs))
3636

3737
def get_auxiliary_tokens(self, *scopes, **kwargs):
3838
"""Get access tokens from auxiliary credentials."""
3939
# To test cross-tenant authentication, see https://github.com/Azure/azure-cli/issues/16691
4040
if self._auxiliary_credentials:
41-
return [cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials]
41+
return [build_sdk_access_token(cred.acquire_token(list(scopes), **kwargs))
42+
for cred in self._auxiliary_credentials]
4243
return None

src/azure-cli-core/azure/cli/core/auth/identity.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,8 @@ def login_with_service_principal(self, client_id, credential, scopes):
192192
"""
193193
sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential)
194194
client_credential = sp_auth.get_msal_client_credential()
195-
cca = ConfidentialClientApplication(client_id, client_credential=client_credential, **self._msal_app_kwargs)
196-
result = cca.acquire_token_for_client(scopes)
197-
check_result(result)
195+
cred = ServicePrincipalCredential(client_id, client_credential, **self._msal_app_kwargs)
196+
cred.acquire_token(scopes)
198197

199198
# Only persist the service principal after a successful login
200199
entry = sp_auth.get_entry_to_persist()

src/azure-cli-core/azure/cli/core/auth/msal_credentials.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,7 @@
44
# --------------------------------------------------------------------------------------------
55

66
"""
7-
Credentials defined in this module are alternative implementations of credentials provided by Azure Identity.
8-
9-
These credentials implement azure.core.credentials.TokenCredential by exposing `get_token` method for Track 2
10-
SDK invocation.
11-
12-
If you want to implement your own credential, the credential must also expose `get_token` method.
13-
14-
`get_token` method takes `scopes` as positional arguments and other optional `kwargs`, such as `claims`, `data`.
15-
The return value should be a named tuple containing two elements: token (str), expires_on (int). You may simply use
16-
azure.cli.core.auth.util.AccessToken to build the return value. See below credentials as examples.
7+
Credentials to acquire tokens from MSAL.
178
"""
189

1910
from knack.log import get_logger
@@ -22,15 +13,15 @@
2213
ManagedIdentityClient, SystemAssignedManagedIdentity)
2314

2415
from .constants import AZURE_CLI_CLIENT_ID
25-
from .util import check_result, build_sdk_access_token
16+
from .util import check_result
2617

2718
logger = get_logger(__name__)
2819

2920

3021
class UserCredential: # pylint: disable=too-few-public-methods
3122

3223
def __init__(self, client_id, username, **kwargs):
33-
"""User credential implementing get_token interface.
24+
"""User credential wrapping msal.application.PublicClientApplication
3425
3526
:param client_id: Client ID of the CLI.
3627
:param username: The username for user credential.
@@ -52,14 +43,15 @@ def __init__(self, client_id, username, **kwargs):
5243

5344
self._account = accounts[0]
5445

55-
def get_token(self, *scopes, claims=None, **kwargs):
56-
# scopes = ['https://pas.windows.net/CheckMyAccess/Linux/.default']
57-
logger.debug("UserCredential.get_token: scopes=%r, claims=%r, kwargs=%r", scopes, claims, kwargs)
46+
def acquire_token(self, scopes, claims=None, **kwargs):
47+
# scopes must be a list.
48+
# For acquiring SSH certificate, scopes is ['https://pas.windows.net/CheckMyAccess/Linux/.default']
49+
logger.debug("UserCredential.acquire_token: scopes=%r, claims=%r, kwargs=%r", scopes, claims, kwargs)
5850

5951
if claims:
6052
logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s',
6153
self._msal_app.authority.tenant, claims)
62-
result = self._msal_app.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims,
54+
result = self._msal_app.acquire_token_silent_with_error(scopes, self._account, claims_challenge=claims,
6355
**kwargs)
6456

6557
from azure.cli.core.azclierror import AuthenticationError
@@ -82,7 +74,7 @@ def get_token(self, *scopes, claims=None, **kwargs):
8274
success_template, error_template = read_response_templates()
8375

8476
result = self._msal_app.acquire_token_interactive(
85-
list(scopes), login_hint=self._account['username'],
77+
scopes, login_hint=self._account['username'],
8678
port=8400 if self._msal_app.authority.is_adfs else None,
8779
success_template=success_template, error_template=error_template, **kwargs)
8880
check_result(result)
@@ -91,25 +83,25 @@ def get_token(self, *scopes, claims=None, **kwargs):
9183
# launch browser, but show the error message and `az login` command instead.
9284
else:
9385
raise
94-
return build_sdk_access_token(result)
86+
return result
9587

9688

9789
class ServicePrincipalCredential: # pylint: disable=too-few-public-methods
9890

9991
def __init__(self, client_id, client_credential, **kwargs):
100-
"""Service principal credential implementing get_token interface.
92+
"""Service principal credential wrapping msal.application.ConfidentialClientApplication.
10193
10294
:param client_id: The service principal's client ID.
10395
:param client_credential: client_credential that will be passed to MSAL.
10496
"""
10597
self._msal_app = ConfidentialClientApplication(client_id, client_credential, **kwargs)
10698

107-
def get_token(self, *scopes, **kwargs):
108-
logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
109-
110-
result = self._msal_app.acquire_token_for_client(list(scopes), **kwargs)
99+
def acquire_token(self, scopes, **kwargs):
100+
# scopes must be a list
101+
logger.debug("ServicePrincipalCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
102+
result = self._msal_app.acquire_token_for_client(scopes, **kwargs)
111103
check_result(result)
112-
return build_sdk_access_token(result)
104+
return result
113105

114106

115107
class CloudShellCredential: # pylint: disable=too-few-public-methods
@@ -126,12 +118,12 @@ def __init__(self):
126118
# token_cache=...
127119
)
128120

129-
def get_token(self, *scopes, **kwargs):
130-
logger.debug("CloudShellCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
121+
def acquire_token(self, scopes, **kwargs):
122+
logger.debug("CloudShellCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
131123
# kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL
132-
result = self._msal_app.acquire_token_interactive(list(scopes), prompt="none", **kwargs)
124+
result = self._msal_app.acquire_token_interactive(scopes, prompt="none", **kwargs)
133125
check_result(result, scopes=scopes)
134-
return build_sdk_access_token(result)
126+
return result
135127

136128

137129
class ManagedIdentityCredential: # pylint: disable=too-few-public-methods
@@ -143,10 +135,10 @@ def __init__(self):
143135
import requests
144136
self._msal_client = ManagedIdentityClient(SystemAssignedManagedIdentity(), http_client=requests.Session())
145137

146-
def get_token(self, *scopes, **kwargs):
147-
logger.debug("ManagedIdentityCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
138+
def acquire_token(self, scopes, **kwargs):
139+
logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
148140

149141
from .util import scopes_to_resource
150142
result = self._msal_client.acquire_token_for_client(resource=scopes_to_resource(scopes))
151143
check_result(result)
152-
return build_sdk_access_token(result)
144+
return result

0 commit comments

Comments
 (0)