diff --git a/src/azure-cli-core/azure/cli/core/_debug.py b/src/azure-cli-core/azure/cli/core/_debug.py index b8028791ce9..65703261bff 100644 --- a/src/azure-cli-core/azure/cli/core/_debug.py +++ b/src/azure-cli-core/azure/cli/core/_debug.py @@ -45,3 +45,32 @@ def change_ssl_cert_verification_track2(): logger.debug("Using CA bundle file at '%s'.", ca_bundle_file) client_kwargs['connection_verify'] = ca_bundle_file return client_kwargs + + +def get_msal_http_client(): + """ + Create an HTTP client (requests.Session) for MSAL that respects certificate verification settings. + + This ensures MSAL applications use the same certificate verification settings as the rest of Azure CLI, + including custom CA bundles specified via REQUESTS_CA_BUNDLE environment variable. + + Returns: + requests.Session: A configured Session object with appropriate certificate verification settings. + """ + import requests + + session = requests.Session() + + if should_disable_connection_verify(): + logger.warning("Connection verification disabled by environment variable %s", + DISABLE_VERIFY_VARIABLE_NAME) + os.environ[ADAL_PYTHON_SSL_NO_VERIFY] = '1' + session.verify = False + elif REQUESTS_CA_BUNDLE in os.environ: + ca_bundle_file = os.environ[REQUESTS_CA_BUNDLE] + if not os.path.isfile(ca_bundle_file): + raise CLIError('REQUESTS_CA_BUNDLE environment variable is specified with an invalid file path') + logger.debug("MSAL: Using CA bundle file at '%s'.", ca_bundle_file) + session.verify = ca_bundle_file + + return session diff --git a/src/azure-cli-core/azure/cli/core/auth/identity.py b/src/azure-cli-core/azure/cli/core/auth/identity.py index 91629e89441..d83e687a4bb 100644 --- a/src/azure-cli-core/azure/cli/core/auth/identity.py +++ b/src/azure-cli-core/azure/cli/core/auth/identity.py @@ -99,10 +99,14 @@ def _msal_app_kwargs(self): if self._use_msal_http_cache and not Identity._msal_http_cache: Identity._msal_http_cache = self._load_msal_http_cache() + # Import here to avoid circular dependency + from azure.cli.core._debug import get_msal_http_client + return { "authority": self._msal_authority, "token_cache": Identity._msal_token_cache, "http_cache": Identity._msal_http_cache, + "http_client": get_msal_http_client(), "instance_discovery": self._instance_discovery, # CP1 means we can handle claims challenges (CAE) "client_capabilities": None if "AZURE_IDENTITY_DISABLE_CP1" in os.environ else ["CP1"] diff --git a/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py index 00da593c8ae..9d930da16c6 100644 --- a/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py +++ b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py @@ -139,13 +139,14 @@ class ManagedIdentityCredential: # pylint: disable=too-few-public-methods """ def __init__(self, client_id=None, resource_id=None, object_id=None): - import requests + # Use the configured HTTP client that respects certificate settings + from azure.cli.core._debug import get_msal_http_client if client_id or resource_id or object_id: managed_identity = UserAssignedManagedIdentity( client_id=client_id, resource_id=resource_id, object_id=object_id) else: managed_identity = SystemAssignedManagedIdentity() - self._msal_client = ManagedIdentityClient(managed_identity, http_client=requests.Session()) + self._msal_client = ManagedIdentityClient(managed_identity, http_client=get_msal_http_client()) def acquire_token(self, scopes, **kwargs): logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs) diff --git a/src/azure-cli-core/azure/cli/core/tests/test_connection_verify.py b/src/azure-cli-core/azure/cli/core/tests/test_connection_verify.py index c301a9d5b0e..72de4fe85ec 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_connection_verify.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_connection_verify.py @@ -6,7 +6,8 @@ import os import logging import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch +import tempfile import azure.cli.core._debug as _debug import azure.cli.core.util as cli_util @@ -35,6 +36,52 @@ def test_verify_client_connection(self): clientMock = _debug.change_ssl_cert_verification(clientMock) self.assertFalse(clientMock.config.connection.verify) + def test_get_msal_http_client_respects_ca_bundle(self): + """Test that get_msal_http_client() respects REQUESTS_CA_BUNDLE environment variable.""" + # Save original environment + original_ca_bundle = os.environ.get(_debug.REQUESTS_CA_BUNDLE) + original_disable_verify = os.environ.get(cli_util.DISABLE_VERIFY_VARIABLE_NAME) + + try: + # Create a temporary file to act as a CA bundle + with tempfile.NamedTemporaryFile(delete=False, suffix='.pem') as tmp_file: + tmp_file.write(b'# Test CA Bundle') + tmp_file_path = tmp_file.name + + # Test 1: With REQUESTS_CA_BUNDLE set + os.environ[_debug.REQUESTS_CA_BUNDLE] = tmp_file_path + if cli_util.DISABLE_VERIFY_VARIABLE_NAME in os.environ: + del os.environ[cli_util.DISABLE_VERIFY_VARIABLE_NAME] + + session = _debug.get_msal_http_client() + self.assertEqual(session.verify, tmp_file_path) + + # Test 2: With connection verification disabled + del os.environ[_debug.REQUESTS_CA_BUNDLE] + os.environ[cli_util.DISABLE_VERIFY_VARIABLE_NAME] = "1" + + session = _debug.get_msal_http_client() + self.assertFalse(session.verify) + + # Test 3: With neither set (default behavior) + del os.environ[cli_util.DISABLE_VERIFY_VARIABLE_NAME] + + session = _debug.get_msal_http_client() + self.assertTrue(session.verify) # Default is True + + finally: + # Cleanup + os.unlink(tmp_file_path) + # Restore original environment + if original_ca_bundle: + os.environ[_debug.REQUESTS_CA_BUNDLE] = original_ca_bundle + elif _debug.REQUESTS_CA_BUNDLE in os.environ: + del os.environ[_debug.REQUESTS_CA_BUNDLE] + if original_disable_verify: + os.environ[cli_util.DISABLE_VERIFY_VARIABLE_NAME] = original_disable_verify + elif cli_util.DISABLE_VERIFY_VARIABLE_NAME in os.environ: + del os.environ[cli_util.DISABLE_VERIFY_VARIABLE_NAME] + if __name__ == '__main__': unittest.main()