diff --git a/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py index 43caa4c0668..915cf0fcd27 100644 --- a/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py +++ b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py @@ -43,16 +43,23 @@ def __init__(self, client_id, username, **kwargs): self._account = accounts[0] - def acquire_token(self, scopes, claims_challenge=None, **kwargs): + def acquire_token(self, scopes, claims_challenge=None, data=None, **kwargs): # scopes must be a list. # For acquiring SSH certificate, scopes is ['https://pas.windows.net/CheckMyAccess/Linux/.default'] + # data is only used for acquiring VM SSH certificate. DO NOT use it for other purposes. # kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL - logger.debug("UserCredential.acquire_token: scopes=%r, claims_challenge=%r, kwargs=%r", - scopes, claims_challenge, kwargs) + logger.debug("UserCredential.acquire_token: scopes=%r, claims_challenge=%r, data=%r, kwargs=%r", + scopes, claims_challenge, data, kwargs) if claims_challenge: logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s', self._msal_app.authority.tenant, claims_challenge) + + # Only pass data to MSAL if it is set. Passing data=None will cause failure in MSAL: + # AttributeError: 'NoneType' object has no attribute 'get' + if data is not None: + kwargs['data'] = data + result = self._msal_app.acquire_token_silent_with_error( scopes, self._account, claims_challenge=claims_challenge, **kwargs) @@ -105,8 +112,13 @@ def __init__(self, client_id, client_credential, **kwargs): """ self._msal_app = ConfidentialClientApplication(client_id, client_credential=client_credential, **kwargs) - def acquire_token(self, scopes, **kwargs): - logger.debug("ServicePrincipalCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs) + def acquire_token(self, scopes, data=None, **kwargs): + logger.debug("ServicePrincipalCredential.acquire_token: scopes=%r, data=%r, kwargs=%r", + scopes, data, kwargs) + + if data is not None: + kwargs['data'] = data + result = self._msal_app.acquire_token_for_client(scopes, **kwargs) check_result(result) return result @@ -126,8 +138,13 @@ def __init__(self): # token_cache=... ) - def acquire_token(self, scopes, **kwargs): - logger.debug("CloudShellCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs) + def acquire_token(self, scopes, data=None, **kwargs): + logger.debug("CloudShellCredential.acquire_token: scopes=%r, data=%r, kwargs=%r", + scopes, data, kwargs) + + if data is not None: + kwargs['data'] = data + result = self._msal_app.acquire_token_interactive(scopes, prompt="none", **kwargs) check_result(result, scopes=scopes) return result @@ -147,8 +164,13 @@ def __init__(self, client_id=None, resource_id=None, object_id=None): managed_identity = SystemAssignedManagedIdentity() self._msal_client = ManagedIdentityClient(managed_identity, http_client=requests.Session()) - def acquire_token(self, scopes, **kwargs): - logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs) + def acquire_token(self, scopes, data=None, **kwargs): + logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, data=%r, kwargs=%r", + scopes, data, kwargs) + + if data is not None: + from azure.cli.core.azclierror import AuthenticationError + raise AuthenticationError("VM SSH currently doesn't support managed identity.") from .util import scopes_to_resource result = self._msal_client.acquire_token_for_client(resource=scopes_to_resource(scopes)) diff --git a/src/azure-cli-core/azure/cli/core/auth/tests/test_msal_credentials.py b/src/azure-cli-core/azure/cli/core/auth/tests/test_msal_credentials.py new file mode 100644 index 00000000000..ae168d6b873 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/tests/test_msal_credentials.py @@ -0,0 +1,115 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + + +import unittest +from unittest import mock + +from ..msal_credentials import UserCredential + +MOCK_ACCOUNT = { + 'account_source': 'authorization_code', + 'authority_type': 'MSSTS', + 'environment': 'login.microsoftonline.com', + # random GUID generated by uuid.uuid4() + 'home_account_id': '9d486bfc-8d91-4a65-a23e-33e1f01a1718.e4e8e73b-5f99-4bd5-bdac-60b916a7343b', + 'local_account_id': '9d486bfc-8d91-4a65-a23e-33e1f01a1718', + 'realm': 'e4e8e73b-5f99-4bd5-bdac-60b916a7343b', + 'username': 'test@microsoft.com' +} + +MOCK_SCOPES = ['https://management.core.windows.net//.default'] + +MOCK_ACCESS_TOKEN = "mock_access_token" +MOCK_MSAL_TOKEN = { + 'access_token': MOCK_ACCESS_TOKEN, + 'token_type': 'Bearer', + 'expires_in': 1800, + 'token_source': 'cache' +} + +MOCK_CLAIMS = {"test_claims": "value2"} + +MOCK_DATA = { + 'key_id': 'test', + 'req_cnf': 'test', + 'token_type': 'ssh-cert' +} +MOCK_CERTIFICATE= "mock_certificate" +MOCK_MSAL_CERTIFICATE = { + 'access_token': MOCK_CERTIFICATE, + 'client_info': 'test', + 'expires_in': 3599, + 'ext_expires_in': 3599, + 'foci': '1', + 'id_token': 'test', + 'id_token_claims': { + 'preferred_username': 'test@microsoft.com', + 'tid': 'e4e8e73b-5f99-4bd5-bdac-60b916a7343b' + }, + 'refresh_token': 'test', + 'scope': 'https://pas.windows.net/CheckMyAccess/Linux/user_impersonation https://pas.windows.net/CheckMyAccess/Linux/.default', + 'token_source': 'identity_provider', + 'token_type': 'ssh-cert' +} + + +class AuthorityStub: + def __init__(self): + self.tenant = 'e4e8e73b-5f99-4bd5-bdac-60b916a7343b' + +class PublicClientApplicationStub: + + def __init__(self, client_id, **kwargs): + self.client_id = client_id + self.authority = AuthorityStub() + self.kwargs = kwargs + self.acquire_token_silent_with_error_scopes = None + self.acquire_token_silent_with_error_claims_challenge = None + self.acquire_token_silent_with_error_kwargs = None + super().__init__() + + def get_accounts(self, username): + return [MOCK_ACCOUNT] + + def acquire_token_silent_with_error(self, scopes, account, **kwargs): + self.acquire_token_silent_with_error_scopes = scopes + self.acquire_token_silent_with_error_claims_challenge = scopes + self.acquire_token_silent_with_error_kwargs = kwargs + if 'data' in kwargs: + return MOCK_MSAL_CERTIFICATE + return MOCK_MSAL_TOKEN + + +class TestUserCredential(unittest.TestCase): + + @mock.patch('azure.cli.core.auth.msal_credentials.PublicClientApplication') + def test_get_token(self, public_client_application_mock): + public_client_application_mock.side_effect = PublicClientApplicationStub + + msal_credential = UserCredential('test_client_id', 'test_username') + msal_app = msal_credential._msal_app + assert msal_credential._account == MOCK_ACCOUNT + + result = msal_credential.acquire_token(MOCK_SCOPES) + assert result == MOCK_MSAL_TOKEN + assert msal_app.acquire_token_silent_with_error_scopes == MOCK_SCOPES + # Make sure data is not passed to MSAL + assert 'data' not in msal_app.acquire_token_silent_with_error_kwargs + + result = msal_credential.acquire_token(MOCK_SCOPES, claims_challenge=MOCK_CLAIMS) + assert result == MOCK_MSAL_TOKEN + assert msal_app.acquire_token_silent_with_error_scopes == MOCK_SCOPES + assert msal_app.acquire_token_silent_with_error_kwargs['claims_challenge'] == MOCK_CLAIMS + + result = msal_credential.acquire_token(['https://pas.windows.net/CheckMyAccess/Linux/.default'], + data=MOCK_DATA) + assert result == MOCK_MSAL_CERTIFICATE + assert msal_app.acquire_token_silent_with_error_scopes == ['https://pas.windows.net/CheckMyAccess/Linux/.default'] + assert msal_app.acquire_token_silent_with_error_kwargs['data'] == MOCK_DATA + + +if __name__ == '__main__': + unittest.main()