|
13 | 13 | from knack.util import CLIError |
14 | 14 | from msal import PublicClientApplication, ConfidentialClientApplication |
15 | 15 |
|
16 | | -# Service principal entry properties |
17 | | -from .msal_authentication import _CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _CLIENT_ASSERTION, \ |
18 | | - _USE_CERT_SN_ISSUER |
19 | | -from .msal_authentication import UserCredential, ServicePrincipalCredential |
| 16 | +from .msal_credentials import UserCredential, ServicePrincipalCredential |
20 | 17 | from .persistence import load_persisted_token_cache, file_extensions, load_secret_store |
21 | 18 | from .util import check_result |
22 | 19 |
|
23 | 20 | AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46' |
24 | 21 |
|
| 22 | +# Service principal entry properties. Names are taken from OAuth 2.0 client credentials flow parameters: |
| 23 | +# https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow |
| 24 | +_TENANT = 'tenant' |
| 25 | +_CLIENT_ID = 'client_id' |
| 26 | +_CLIENT_SECRET = 'client_secret' |
| 27 | +_CERTIFICATE = 'certificate' |
| 28 | +_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer' |
| 29 | +_CLIENT_ASSERTION = 'client_assertion' |
25 | 30 |
|
26 | 31 | # For environment credential |
27 | 32 | AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST" |
@@ -187,10 +192,9 @@ def login_with_service_principal(self, client_id, credential, scopes): |
187 | 192 | `credential` is a dict returned by ServicePrincipalAuth.build_credential |
188 | 193 | """ |
189 | 194 | sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential) |
190 | | - |
191 | | - # This cred means SDK credential object |
192 | | - cred = ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs) |
193 | | - result = cred.acquire_token_for_client(scopes) |
| 195 | + client_credential = sp_auth.get_msal_client_credential() |
| 196 | + cca = ConfidentialClientApplication(client_id, client_credential, **self._msal_app_kwargs) |
| 197 | + result = cca.acquire_token_for_client(scopes) |
194 | 198 | check_result(result) |
195 | 199 |
|
196 | 200 | # Only persist the service principal after a successful login |
@@ -246,32 +250,47 @@ def get_user_credential(self, username): |
246 | 250 |
|
247 | 251 | def get_service_principal_credential(self, client_id): |
248 | 252 | entry = self._service_principal_store.load_entry(client_id, self.tenant_id) |
249 | | - sp_auth = ServicePrincipalAuth(entry) |
250 | | - return ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs) |
| 253 | + client_credential = ServicePrincipalAuth(entry).get_msal_client_credential() |
| 254 | + return ServicePrincipalCredential(client_id, client_credential, **self._msal_app_kwargs) |
251 | 255 |
|
252 | 256 | def get_managed_identity_credential(self, client_id=None): |
253 | 257 | raise NotImplementedError |
254 | 258 |
|
255 | 259 |
|
256 | | -class ServicePrincipalAuth: |
257 | | - |
| 260 | +class ServicePrincipalAuth: # pylint: disable=too-many-instance-attributes |
258 | 261 | def __init__(self, entry): |
| 262 | + # Initialize all attributes first, so that we don't need to call getattr to check their existence |
| 263 | + self.client_id = None |
| 264 | + self.tenant = None |
| 265 | + # secret |
| 266 | + self.client_secret = None |
| 267 | + # certificate |
| 268 | + self.certificate = None |
| 269 | + self.use_cert_sn_issuer = None |
| 270 | + # federated identity credential |
| 271 | + self.client_assertion = None |
| 272 | + |
| 273 | + # Internal attributes for certificate |
| 274 | + # They are computed at runtime and not persisted in the service principal entry. |
| 275 | + self._certificate_string = None |
| 276 | + self._thumbprint = None |
| 277 | + self._public_certificate = None |
| 278 | + |
259 | 279 | self.__dict__.update(entry) |
260 | 280 |
|
261 | | - if _CERTIFICATE in entry: |
| 281 | + if self.certificate: |
262 | 282 | from OpenSSL.crypto import load_certificate, FILETYPE_PEM, Error |
263 | | - self.public_certificate = None |
264 | 283 | try: |
265 | 284 | with open(self.certificate, 'r') as file_reader: |
266 | | - self.certificate_string = file_reader.read() |
267 | | - cert = load_certificate(FILETYPE_PEM, self.certificate_string) |
268 | | - self.thumbprint = cert.digest("sha1").decode().replace(':', '') |
| 285 | + self._certificate_string = file_reader.read() |
| 286 | + cert = load_certificate(FILETYPE_PEM, self._certificate_string) |
| 287 | + self._thumbprint = cert.digest("sha1").decode().replace(':', '') |
269 | 288 | if entry.get(_USE_CERT_SN_ISSUER): |
270 | 289 | # low-tech but safe parsing based on |
271 | 290 | # https://github.com/libressl-portable/openbsd/blob/master/src/lib/libcrypto/pem/pem.h |
272 | 291 | match = re.search(r'-----BEGIN CERTIFICATE-----(?P<cert_value>[^-]+)-----END CERTIFICATE-----', |
273 | | - self.certificate_string, re.I) |
274 | | - self.public_certificate = match.group() |
| 292 | + self._certificate_string, re.I) |
| 293 | + self._public_certificate = match.group() |
275 | 294 | except (UnicodeDecodeError, Error) as ex: |
276 | 295 | raise CLIError('Invalid certificate, please use a valid PEM file. Error detail: {}'.format(ex)) |
277 | 296 |
|
@@ -307,8 +326,42 @@ def build_credential(cls, secret_or_certificate=None, client_assertion=None, use |
307 | 326 | return entry |
308 | 327 |
|
309 | 328 | def get_entry_to_persist(self): |
| 329 | + """Get a service principal entry that can be persisted by ServicePrincipalStore.""" |
310 | 330 | persisted_keys = [_CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _USE_CERT_SN_ISSUER, _CLIENT_ASSERTION] |
311 | | - return {k: v for k, v in self.__dict__.items() if k in persisted_keys} |
| 331 | + # Only persist certain attributes whose values are not None |
| 332 | + return {k: v for k, v in self.__dict__.items() if k in persisted_keys and v} |
| 333 | + |
| 334 | + def get_msal_client_credential(self): |
| 335 | + """Get a client_credential that can be consumed by msal.ConfidentialClientApplication.""" |
| 336 | + client_credential = None |
| 337 | + |
| 338 | + # client_secret |
| 339 | + # "your client secret" |
| 340 | + if self.client_secret: |
| 341 | + client_credential = self.client_secret |
| 342 | + |
| 343 | + # certificate |
| 344 | + # { |
| 345 | + # "private_key": "...-----BEGIN PRIVATE KEY-----... in PEM format", |
| 346 | + # "thumbprint": "A1B2C3D4E5F6...", |
| 347 | + # "public_certificate": "...-----BEGIN CERTIFICATE-----...", |
| 348 | + # } |
| 349 | + if self.certificate: |
| 350 | + client_credential = { |
| 351 | + "private_key": self._certificate_string, |
| 352 | + "thumbprint": self._thumbprint |
| 353 | + } |
| 354 | + if self._public_certificate: |
| 355 | + client_credential['public_certificate'] = self._public_certificate |
| 356 | + |
| 357 | + # client_assertion |
| 358 | + # { |
| 359 | + # "client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..." |
| 360 | + # } |
| 361 | + if self.client_assertion: |
| 362 | + client_credential = {'client_assertion': self.client_assertion} |
| 363 | + |
| 364 | + return client_credential |
312 | 365 |
|
313 | 366 |
|
314 | 367 | class ServicePrincipalStore: |
|
0 commit comments