Skip to content

Commit 0a35d5f

Browse files
authored
{Core} Refactor code for MSAL authentication (#29439)
1 parent 074cb62 commit 0a35d5f

File tree

3 files changed

+229
-147
lines changed

3 files changed

+229
-147
lines changed

src/azure-cli-core/azure/cli/core/auth/identity.py

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,20 @@
1313
from knack.util import CLIError
1414
from msal import PublicClientApplication, ConfidentialClientApplication
1515

16-
# Service principal entry properties
17-
from .msal_authentication import _CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _CLIENT_ASSERTION, \
18-
_USE_CERT_SN_ISSUER
19-
from .msal_authentication import UserCredential, ServicePrincipalCredential
16+
from .msal_credentials import UserCredential, ServicePrincipalCredential
2017
from .persistence import load_persisted_token_cache, file_extensions, load_secret_store
2118
from .util import check_result
2219

2320
AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'
2421

22+
# Service principal entry properties. Names are taken from OAuth 2.0 client credentials flow parameters:
23+
# https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow
24+
_TENANT = 'tenant'
25+
_CLIENT_ID = 'client_id'
26+
_CLIENT_SECRET = 'client_secret'
27+
_CERTIFICATE = 'certificate'
28+
_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer'
29+
_CLIENT_ASSERTION = 'client_assertion'
2530

2631
# For environment credential
2732
AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST"
@@ -187,10 +192,9 @@ def login_with_service_principal(self, client_id, credential, scopes):
187192
`credential` is a dict returned by ServicePrincipalAuth.build_credential
188193
"""
189194
sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential)
190-
191-
# This cred means SDK credential object
192-
cred = ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs)
193-
result = cred.acquire_token_for_client(scopes)
195+
client_credential = sp_auth.get_msal_client_credential()
196+
cca = ConfidentialClientApplication(client_id, client_credential, **self._msal_app_kwargs)
197+
result = cca.acquire_token_for_client(scopes)
194198
check_result(result)
195199

196200
# Only persist the service principal after a successful login
@@ -246,32 +250,47 @@ def get_user_credential(self, username):
246250

247251
def get_service_principal_credential(self, client_id):
248252
entry = self._service_principal_store.load_entry(client_id, self.tenant_id)
249-
sp_auth = ServicePrincipalAuth(entry)
250-
return ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs)
253+
client_credential = ServicePrincipalAuth(entry).get_msal_client_credential()
254+
return ServicePrincipalCredential(client_id, client_credential, **self._msal_app_kwargs)
251255

252256
def get_managed_identity_credential(self, client_id=None):
253257
raise NotImplementedError
254258

255259

256-
class ServicePrincipalAuth:
257-
260+
class ServicePrincipalAuth: # pylint: disable=too-many-instance-attributes
258261
def __init__(self, entry):
262+
# Initialize all attributes first, so that we don't need to call getattr to check their existence
263+
self.client_id = None
264+
self.tenant = None
265+
# secret
266+
self.client_secret = None
267+
# certificate
268+
self.certificate = None
269+
self.use_cert_sn_issuer = None
270+
# federated identity credential
271+
self.client_assertion = None
272+
273+
# Internal attributes for certificate
274+
# They are computed at runtime and not persisted in the service principal entry.
275+
self._certificate_string = None
276+
self._thumbprint = None
277+
self._public_certificate = None
278+
259279
self.__dict__.update(entry)
260280

261-
if _CERTIFICATE in entry:
281+
if self.certificate:
262282
from OpenSSL.crypto import load_certificate, FILETYPE_PEM, Error
263-
self.public_certificate = None
264283
try:
265284
with open(self.certificate, 'r') as file_reader:
266-
self.certificate_string = file_reader.read()
267-
cert = load_certificate(FILETYPE_PEM, self.certificate_string)
268-
self.thumbprint = cert.digest("sha1").decode().replace(':', '')
285+
self._certificate_string = file_reader.read()
286+
cert = load_certificate(FILETYPE_PEM, self._certificate_string)
287+
self._thumbprint = cert.digest("sha1").decode().replace(':', '')
269288
if entry.get(_USE_CERT_SN_ISSUER):
270289
# low-tech but safe parsing based on
271290
# https://github.com/libressl-portable/openbsd/blob/master/src/lib/libcrypto/pem/pem.h
272291
match = re.search(r'-----BEGIN CERTIFICATE-----(?P<cert_value>[^-]+)-----END CERTIFICATE-----',
273-
self.certificate_string, re.I)
274-
self.public_certificate = match.group()
292+
self._certificate_string, re.I)
293+
self._public_certificate = match.group()
275294
except (UnicodeDecodeError, Error) as ex:
276295
raise CLIError('Invalid certificate, please use a valid PEM file. Error detail: {}'.format(ex))
277296

@@ -307,8 +326,42 @@ def build_credential(cls, secret_or_certificate=None, client_assertion=None, use
307326
return entry
308327

309328
def get_entry_to_persist(self):
329+
"""Get a service principal entry that can be persisted by ServicePrincipalStore."""
310330
persisted_keys = [_CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _USE_CERT_SN_ISSUER, _CLIENT_ASSERTION]
311-
return {k: v for k, v in self.__dict__.items() if k in persisted_keys}
331+
# Only persist certain attributes whose values are not None
332+
return {k: v for k, v in self.__dict__.items() if k in persisted_keys and v}
333+
334+
def get_msal_client_credential(self):
335+
"""Get a client_credential that can be consumed by msal.ConfidentialClientApplication."""
336+
client_credential = None
337+
338+
# client_secret
339+
# "your client secret"
340+
if self.client_secret:
341+
client_credential = self.client_secret
342+
343+
# certificate
344+
# {
345+
# "private_key": "...-----BEGIN PRIVATE KEY-----... in PEM format",
346+
# "thumbprint": "A1B2C3D4E5F6...",
347+
# "public_certificate": "...-----BEGIN CERTIFICATE-----...",
348+
# }
349+
if self.certificate:
350+
client_credential = {
351+
"private_key": self._certificate_string,
352+
"thumbprint": self._thumbprint
353+
}
354+
if self._public_certificate:
355+
client_credential['public_certificate'] = self._public_certificate
356+
357+
# client_assertion
358+
# {
359+
# "client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..."
360+
# }
361+
if self.client_assertion:
362+
client_credential = {'client_assertion': self.client_assertion}
363+
364+
return client_credential
312365

313366

314367
class ServicePrincipalStore:

src/azure-cli-core/azure/cli/core/auth/msal_authentication.py renamed to src/azure-cli-core/azure/cli/core/auth/msal_credentials.py

Lines changed: 16 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,23 @@
2222

2323
from .util import check_result, build_sdk_access_token
2424

25-
# OAuth 2.0 client credentials flow parameter
26-
# https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow
27-
_TENANT = 'tenant'
28-
_CLIENT_ID = 'client_id'
29-
_CLIENT_SECRET = 'client_secret'
30-
_CERTIFICATE = 'certificate'
31-
_CLIENT_ASSERTION = 'client_assertion'
32-
_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer'
33-
3425
logger = get_logger(__name__)
3526

3627

37-
class UserCredential(PublicClientApplication):
28+
class UserCredential: # pylint: disable=too-few-public-methods
3829

3930
def __init__(self, client_id, username, **kwargs):
4031
"""User credential implementing get_token interface.
4132
4233
:param client_id: Client ID of the CLI.
4334
:param username: The username for user credential.
4435
"""
45-
super().__init__(client_id, **kwargs)
36+
self._msal_app = PublicClientApplication(client_id, **kwargs)
4637

4738
# Make sure username is specified, otherwise MSAL returns all accounts
4839
assert username, "username must be specified, got {!r}".format(username)
4940

50-
accounts = self.get_accounts(username)
41+
accounts = self._msal_app.get_accounts(username)
5142

5243
# Usernames are usually unique. We are collecting corner cases to better understand its behavior.
5344
if len(accounts) > 1:
@@ -65,8 +56,9 @@ def get_token(self, *scopes, claims=None, **kwargs):
6556

6657
if claims:
6758
logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s',
68-
self.authority.tenant, claims)
69-
result = self.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims, **kwargs)
59+
self._msal_app.authority.tenant, claims)
60+
result = self._msal_app.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims,
61+
**kwargs)
7062

7163
from azure.cli.core.azclierror import AuthenticationError
7264
try:
@@ -82,13 +74,14 @@ def get_token(self, *scopes, claims=None, **kwargs):
8274
logger.warning(ex)
8375
logger.warning("\nThe default web browser has been opened at %s for scope '%s'. "
8476
"Please continue the login in the web browser.",
85-
self.authority.authorization_endpoint, ' '.join(scopes))
77+
self._msal_app.authority.authorization_endpoint, ' '.join(scopes))
8678

8779
from .util import read_response_templates
8880
success_template, error_template = read_response_templates()
8981

90-
result = self.acquire_token_interactive(
91-
list(scopes), login_hint=self._account['username'], port=8400 if self.authority.is_adfs else None,
82+
result = self._msal_app.acquire_token_interactive(
83+
list(scopes), login_hint=self._account['username'],
84+
port=8400 if self._msal_app.authority.is_adfs else None,
9285
success_template=success_template, error_template=error_template, **kwargs)
9386
check_result(result)
9487

@@ -99,42 +92,19 @@ def get_token(self, *scopes, claims=None, **kwargs):
9992
return build_sdk_access_token(result)
10093

10194

102-
class ServicePrincipalCredential(ConfidentialClientApplication):
95+
class ServicePrincipalCredential: # pylint: disable=too-few-public-methods
10396

104-
def __init__(self, service_principal_auth, **kwargs):
97+
def __init__(self, client_id, client_credential, **kwargs):
10598
"""Service principal credential implementing get_token interface.
10699
107-
:param service_principal_auth: An instance of ServicePrincipalAuth.
100+
:param client_id: The service principal's client ID.
101+
:param client_credential: client_credential that will be passed to MSAL.
108102
"""
109-
client_credential = None
110-
111-
# client_secret
112-
client_secret = getattr(service_principal_auth, _CLIENT_SECRET, None)
113-
if client_secret:
114-
client_credential = client_secret
115-
116-
# certificate
117-
certificate = getattr(service_principal_auth, _CERTIFICATE, None)
118-
if certificate:
119-
client_credential = {
120-
"private_key": getattr(service_principal_auth, 'certificate_string'),
121-
"thumbprint": getattr(service_principal_auth, 'thumbprint')
122-
}
123-
public_certificate = getattr(service_principal_auth, 'public_certificate', None)
124-
if public_certificate:
125-
client_credential['public_certificate'] = public_certificate
126-
127-
# client_assertion
128-
client_assertion = getattr(service_principal_auth, _CLIENT_ASSERTION, None)
129-
if client_assertion:
130-
client_credential = {'client_assertion': client_assertion}
131-
132-
super().__init__(service_principal_auth.client_id, client_credential=client_credential, **kwargs)
103+
self._msal_app = ConfidentialClientApplication(client_id, client_credential, **kwargs)
133104

134105
def get_token(self, *scopes, **kwargs):
135106
logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
136107

137-
scopes = list(scopes)
138-
result = self.acquire_token_for_client(scopes, **kwargs)
108+
result = self._msal_app.acquire_token_for_client(list(scopes), **kwargs)
139109
check_result(result)
140110
return build_sdk_access_token(result)

0 commit comments

Comments
 (0)