diff --git a/google/auth/aio/transport/aiohttp.py b/google/auth/aio/transport/aiohttp.py index 7adc5d915..a1db9a478 100644 --- a/google/auth/aio/transport/aiohttp.py +++ b/google/auth/aio/transport/aiohttp.py @@ -113,7 +113,7 @@ class Request(transport.Request): .. automethod:: __call__ """ - def __init__(self, session: aiohttp.ClientSession = None): + def __init__(self, session: Optional[aiohttp.ClientSession] = None): self._session = session self._closed = False diff --git a/google/auth/aio/transport/mtls.py b/google/auth/aio/transport/mtls.py new file mode 100644 index 000000000..9c52aa843 --- /dev/null +++ b/google/auth/aio/transport/mtls.py @@ -0,0 +1,226 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper functions for mTLS in async for discovery of certs. +""" + +import asyncio +import contextlib +import logging +import os +from os import getenv, path +import ssl +import tempfile +from typing import Optional + +from google.auth import exceptions +import google.auth.transport._mtls_helper + +CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json" +_LOGGER = logging.getLogger(__name__) + + +@contextlib.contextmanager +def _create_temp_file(content: bytes): + """Creates a temporary file with the given content. + + Args: + content (bytes): The content to write to the file. + + Yields: + str: The path to the temporary file. + """ + # Create a temporary file that is readable only by the owner. + fd, file_path = tempfile.mkstemp() + try: + with os.fdopen(fd, "wb") as f: + f.write(content) + yield file_path + finally: + # Securely delete the file after use. + if os.path.exists(file_path): + os.remove(file_path) + + +def make_client_cert_ssl_context( + cert_bytes: bytes, key_bytes: bytes, passphrase: Optional[bytes] = None +) -> ssl.SSLContext: + """Creates an SSLContext with the given client certificate and key. + This function writes the certificate and key to temporary files so that + ssl.create_default_context can load them, as the ssl module requires + file paths for client certificates. + Args: + cert_bytes (bytes): The client certificate content in PEM format. + key_bytes (bytes): The client private key content in PEM format. + passphrase (Optional[bytes]): The passphrase for the private key, if any. + Returns: + ssl.SSLContext: The configured SSL context with client certificate. + + Raises: + google.auth.exceptions.TransportError: If there is an error loading the certificate. + """ + try: + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + + # Write cert and key to temp files because ssl.load_cert_chain requires paths + with _create_temp_file(cert_bytes) as cert_path: + with _create_temp_file(key_bytes) as key_path: + context.load_cert_chain( + certfile=cert_path, keyfile=key_path, password=passphrase + ) + return context + except (ssl.SSLError, OSError) as exc: + raise exceptions.TransportError( + "Failed to load client certificate and key for mTLS." + ) from exc + + +def _check_config_path(config_path): + """Checks for config file path. If it exists, returns the absolute path with user expansion; + otherwise returns None. + + Args: + config_path (str): The config file path for certificate_config.json for example + + Returns: + str: absolute path if exists and None otherwise. + """ + config_path = path.expanduser(config_path) + if not path.exists(config_path): + _LOGGER.debug("%s is not found.", config_path) + return None + return config_path + + +async def _run_in_executor(func, *args): + """Run a blocking function in an executor to avoid blocking the event loop. + + This implements the non-blocking execution strategy for disk I/O operations. + """ + try: + # For python versions 3.9 and newer versions + return await asyncio.to_thread(func, *args) + except AttributeError: + # Fallback for older Python versions + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, func, *args) + + +def has_default_client_cert_source(): + """Check if default client SSL credentials exists on the device. + + Returns: + bool: indicating if the default client cert source exists. + """ + if _check_config_path(CERTIFICATE_CONFIGURATION_DEFAULT_PATH) is not None: + return True + cert_config_path = getenv("GOOGLE_API_CERTIFICATE_CONFIG") + if cert_config_path and _check_config_path(cert_config_path) is not None: + return True + return False + + +def default_client_cert_source(): + """Get a callback which returns the default client SSL credentials. + + Returns: + Awaitable[Callable[[], [bytes, bytes]]]: A callback which returns the default + client certificate bytes and private key bytes, both in PEM format. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If the default + client SSL credentials don't exist or are malformed. + """ + if not has_default_client_cert_source(): + raise exceptions.MutualTLSChannelError( + "Default client cert source doesn't exist" + ) + + async def callback(): + try: + _, cert_bytes, key_bytes = await get_client_cert_and_key() + except (OSError, RuntimeError, ValueError) as caught_exc: + new_exc = exceptions.MutualTLSChannelError(caught_exc) + raise new_exc from caught_exc + + return cert_bytes, key_bytes + + return callback + + +async def get_client_ssl_credentials( + certificate_config_path=None, +): + """Returns the client side certificate, private key and passphrase. + + We look for certificates and keys with the following order of priority: + 1. Certificate and key specified by certificate_config.json. + Currently, only X.509 workload certificates are supported. + + Args: + certificate_config_path (str): The certificate_config.json file path. + + Returns: + Tuple[bool, bytes, bytes, bytes]: + A boolean indicating if cert, key and passphrase are obtained, the + cert bytes and key bytes both in PEM format, and passphrase bytes. + + Raises: + google.auth.exceptions.ClientCertError: if problems occurs when getting + the cert, key and passphrase. + """ + + # Attempt to retrieve X.509 Workload cert and key. + cert, key = await _run_in_executor( + google.auth.transport._mtls_helper._get_workload_cert_and_key, + certificate_config_path, + ) + + if cert and key: + return True, cert, key, None + + return False, None, None, None + + +async def get_client_cert_and_key(client_cert_callback=None): + """Returns the client side certificate and private key. The function first + tries to get certificate and key from client_cert_callback; if the callback + is None or doesn't provide certificate and key, the function tries application + default SSL credentials. + + Args: + client_cert_callback (Optional[Callable[[], (bytes, bytes)]]): An + optional callback which returns client certificate bytes and private + key bytes both in PEM format. + + Returns: + Tuple[bool, bytes, bytes]: + A boolean indicating if cert and key are obtained, the cert bytes + and key bytes both in PEM format. + + Raises: + google.auth.exceptions.ClientCertError: if problems occurs when getting + the cert and key. + """ + if client_cert_callback: + result = client_cert_callback() + if asyncio.iscoroutine(result): + cert, key = await result + else: + cert, key = result + return True, cert, key + + has_cert, cert, key, _ = await get_client_ssl_credentials() + return has_cert, cert, key diff --git a/google/auth/aio/transport/sessions.py b/google/auth/aio/transport/sessions.py index 8045911cb..1c6d83033 100644 --- a/google/auth/aio/transport/sessions.py +++ b/google/auth/aio/transport/sessions.py @@ -21,9 +21,12 @@ from google.auth import _exponential_backoff, exceptions from google.auth.aio import transport from google.auth.aio.credentials import Credentials +from google.auth.aio.transport import mtls from google.auth.exceptions import TimeoutError +import google.auth.transport._mtls_helper try: + import aiohttp from google.auth.aio.transport.aiohttp import Request as AiohttpRequest AIOHTTP_INSTALLED = True @@ -60,7 +63,14 @@ def _remaining_time(): async def with_timeout(coro): try: - remaining = _remaining_time() + try: + remaining = _remaining_time() + except TimeoutError: + # If we timeout before starting the call, + # we must close the coroutine to avoid leaks. + if hasattr(coro, "close"): + coro.close() + raise response = await asyncio.wait_for(coro, remaining) return response except (asyncio.TimeoutError, TimeoutError) as e: @@ -124,12 +134,70 @@ def __init__( _auth_request = auth_request if not _auth_request and AIOHTTP_INSTALLED: _auth_request = AiohttpRequest() + self._is_mtls = False + self._cached_cert = None if _auth_request is None: raise exceptions.TransportError( "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value." ) self._auth_request = _auth_request + async def configure_mtls_channel(self, client_cert_callback=None): + """Configure the client certificate and key for SSL connection. + + The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is + explicitly set to `true`. In this case if client certificate and key are + successfully obtained (from the given client_cert_callback or from application + default SSL credentials), the underlying transport will be reconfigured + to use mTLS. + + Args: + client_cert_callback (Optional[Callable[[], (bytes, bytes)]]): + The optional callback returns the client certificate and private + key bytes both in PEM format. + If the callback is None, application default SSL credentials + will be used. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel + creation failed for any reason. + """ + # Run the blocking check in an executor + use_client_cert = await mtls._run_in_executor( + google.auth.transport._mtls_helper.check_use_client_cert + ) + if not use_client_cert: + self._is_mtls = False + return + + try: + ( + self._is_mtls, + cert, + key, + ) = await mtls.get_client_cert_and_key(client_cert_callback) + + if self._is_mtls: + self._cached_cert = cert + ssl_context = await mtls._run_in_executor( + mtls.make_client_cert_ssl_context, cert, key + ) + + # Re-create the auth request with the new SSL context + if AIOHTTP_INSTALLED and isinstance(self._auth_request, AiohttpRequest): + connector = aiohttp.TCPConnector(ssl=ssl_context) + new_session = aiohttp.ClientSession(connector=connector) + await self._auth_request.close() + self._auth_request = AiohttpRequest(session=new_session) + + except ( + exceptions.ClientCertError, + ImportError, + OSError, + ) as caught_exc: + new_exc = exceptions.MutualTLSChannelError(caught_exc) + raise new_exc from caught_exc + async def request( self, method: str, @@ -174,6 +242,8 @@ async def request( retries = _exponential_backoff.AsyncExponentialBackoff( total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS ) + if headers is None: + headers = {} async with timeout_guard(max_allowed_time) as with_timeout: await with_timeout( # Note: before_request will attempt to refresh credentials if expired. @@ -261,6 +331,11 @@ async def delete( "DELETE", url, data, headers, max_allowed_time, timeout, **kwargs ) + @property + def is_mtls(self): + """Indicates if mutual TLS is enabled.""" + return self._is_mtls + async def close(self) -> None: """ Close the underlying auth request session. diff --git a/noxfile.py b/noxfile.py index 27e34e5c4..e17e40c68 100644 --- a/noxfile.py +++ b/noxfile.py @@ -91,7 +91,7 @@ def blacken(session): @nox.session(python=DEFAULT_PYTHON_VERSION) def mypy(session): """Verify type hints are mypy compatible.""" - session.install("-e", ".") + session.install("-e", ".[aiohttp]") session.install( "mypy", "types-certifi", diff --git a/tests/transport/aio/test_sessions_mtls.py b/tests/transport/aio/test_sessions_mtls.py new file mode 100644 index 000000000..2e47acb41 --- /dev/null +++ b/tests/transport/aio/test_sessions_mtls.py @@ -0,0 +1,119 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import ssl +from unittest import mock + +import pytest + +from google.auth import exceptions +from google.auth.aio import credentials +from google.auth.aio.transport import sessions + +# This is the valid "workload" format the library expects +VALID_WORKLOAD_CONFIG = { + "version": 1, + "cert_configs": { + "workload": {"cert_path": "/tmp/mock_cert.pem", "key_path": "/tmp/mock_key.pem"} + }, +} + + +class TestSessionsMtls: + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + @mock.patch("os.path.exists") + @mock.patch( + "builtins.open", + new_callable=mock.mock_open, + read_data=json.dumps(VALID_WORKLOAD_CONFIG), + ) + @mock.patch("google.auth.aio.transport.mtls.get_client_cert_and_key") + @mock.patch("ssl.create_default_context") + async def test_configure_mtls_channel( + self, mock_ssl, mock_helper, mock_file, mock_exists + ): + """ + Tests that the mTLS channel configures correctly when a + valid workload config is mocked. + """ + mock_exists.return_value = True + mock_helper.return_value = (True, b"fake_cert_data", b"fake_key_data") + + mock_context = mock.Mock(spec=ssl.SSLContext) + mock_ssl.return_value = mock_context + + mock_creds = mock.Mock(spec=credentials.Credentials) + session = sessions.AsyncAuthorizedSession(mock_creds) + await session.configure_mtls_channel() + + assert session._is_mtls is True + assert mock_context.load_cert_chain.called + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + @mock.patch("os.path.exists") + async def test_configure_mtls_channel_disabled(self, mock_exists): + """ + Tests behavior when the config file does not exist. + """ + mock_exists.return_value = False + mock_creds = mock.Mock(spec=credentials.Credentials) + + session = sessions.AsyncAuthorizedSession(mock_creds) + await session.configure_mtls_channel() + + # If the file doesn't exist, it shouldn't error; it just won't use mTLS + assert session._is_mtls is False + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + @mock.patch("os.path.exists") + @mock.patch( + "builtins.open", new_callable=mock.mock_open, read_data='{"invalid": "format"}' + ) + async def test_configure_mtls_channel_invalid_format(self, mock_file, mock_exists): + """ + Verifies that the MutualTLSChannelError is raised for bad formats. + """ + mock_exists.return_value = True + mock_creds = mock.Mock(spec=credentials.Credentials) + + session = sessions.AsyncAuthorizedSession(mock_creds) + with pytest.raises(exceptions.MutualTLSChannelError): + await session.configure_mtls_channel() + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + @mock.patch( + "google.auth.aio.transport.mtls.has_default_client_cert_source", + return_value=True, + ) + async def test_configure_mtls_channel_mock_callback(self, mock_has_cert): + """ + Tests mTLS configuration using bytes-returning callback. + """ + + def mock_callback(): + return (b"fake_cert_bytes", b"fake_key_bytes") + + mock_creds = mock.Mock(spec=credentials.Credentials) + + with mock.patch("ssl.SSLContext.load_cert_chain"): + session = sessions.AsyncAuthorizedSession(mock_creds) + await session.configure_mtls_channel(client_cert_callback=mock_callback) + + assert session._is_mtls is True diff --git a/tests/transport/test_aio_mtls_helper.py b/tests/transport/test_aio_mtls_helper.py new file mode 100644 index 000000000..e61ff9f59 --- /dev/null +++ b/tests/transport/test_aio_mtls_helper.py @@ -0,0 +1,207 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import pytest + +from google.auth import exceptions +from google.auth.aio.transport import mtls + +CERT_DATA = b"client-cert" +KEY_DATA = b"client-key" + + +class TestMTLS: + @mock.patch("google.auth.aio.transport.mtls.path.expanduser") + @mock.patch("google.auth.aio.transport.mtls.path.exists") + def test__check_config_path_exists(self, mock_exists, mock_expand): + mock_expand.side_effect = lambda x: x.replace("~", "/home/user") + mock_exists.return_value = True + + input_path = "~/config.json" + expected_path = "/home/user/config.json" + result = mtls._check_config_path(input_path) + + assert result == expected_path + mock_exists.assert_called_with(expected_path) + + @mock.patch("google.auth.aio.transport.mtls.path.exists", return_value=False) + def test__check_config_path_not_found(self, mock_exists): + result = mtls._check_config_path("nonexistent.json") + assert result is None + + @mock.patch("google.auth.aio.transport.mtls._check_config_path") + @mock.patch("google.auth.aio.transport.mtls.getenv") + def test_has_default_client_cert_source_env_var(self, mock_getenv, mock_check): + custom_path = "/custom/path.json" + mock_check.side_effect = lambda x: custom_path if x == custom_path else None + mock_getenv.return_value = custom_path + + assert mtls.has_default_client_cert_source() is True + + @mock.patch("google.auth.aio.transport.mtls._check_config_path") + @mock.patch("google.auth.aio.transport.mtls.getenv") + def test_has_default_client_cert_source_check_priority( + self, mock_getenv, mock_check + ): + mock_check.return_value = "/default/path.json" + + assert mtls.has_default_client_cert_source() is True + mock_getenv.assert_not_called() + + @mock.patch( + "google.auth.aio.transport.mtls.has_default_client_cert_source", + return_value=False, + ) + def test_default_client_cert_source_none(self, mock_has_default): + with pytest.raises(exceptions.MutualTLSChannelError): + mtls.default_client_cert_source() + + @pytest.mark.asyncio + @mock.patch( + "google.auth.aio.transport.mtls.get_client_cert_and_key", + new_callable=mock.AsyncMock, + ) + @mock.patch( + "google.auth.aio.transport.mtls.has_default_client_cert_source", + return_value=True, + ) + async def test_default_client_cert_source_success( + self, mock_has_default, mock_get_cert_key + ): + mock_get_cert_key.return_value = (True, CERT_DATA, KEY_DATA) + + # Note: default_client_cert_source is NOT async, but it returns an async callback + callback = mtls.default_client_cert_source() + assert callable(callback) + + cert, key = await callback() + assert cert == CERT_DATA + assert key == KEY_DATA + + @pytest.mark.asyncio + @mock.patch( + "google.auth.aio.transport.mtls.has_default_client_cert_source", + return_value=False, + ) + async def test_default_client_cert_source_not_found(self, mock_has_default): + with pytest.raises(exceptions.MutualTLSChannelError, match="doesn't exist"): + await mtls.default_client_cert_source() + + @pytest.mark.asyncio + @mock.patch( + "google.auth.aio.transport.mtls.get_client_cert_and_key", + new_callable=mock.AsyncMock, + ) + @mock.patch( + "google.auth.aio.transport.mtls.has_default_client_cert_source", + return_value=True, + ) + async def test_default_client_cert_source_callback_wraps_exception( + self, mock_has, mock_get + ): + mock_get.side_effect = ValueError("Format error") + + callback = mtls.default_client_cert_source() + + with pytest.raises(exceptions.MutualTLSChannelError) as excinfo: + await callback() + assert "Format error" in str(excinfo.value) + + @pytest.mark.asyncio + @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") + async def test_get_client_ssl_credentials_success(self, mock_workload): + mock_workload.return_value = (CERT_DATA, KEY_DATA) + + success, cert, key, passphrase = await mtls.get_client_ssl_credentials() + + assert success is True + assert cert == CERT_DATA + assert key == KEY_DATA + assert passphrase is None + + @pytest.mark.asyncio + @mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials") + async def test_get_client_cert_and_key_no_credentials_found(self, mock_get_ssl): + mock_get_ssl.return_value = (False, None, None, None) + + success, cert, key = await mtls.get_client_cert_and_key(None) + + assert success is False + assert cert is None + assert key is None + + @pytest.mark.asyncio + async def test_get_client_cert_and_key_callback_async(self): + # Test with an actual coroutine/AsyncMock to satisfy the 'await' in your code + callback = mock.AsyncMock(return_value=(CERT_DATA, KEY_DATA)) + + success, cert, key = await mtls.get_client_cert_and_key(callback) + + assert success is True + assert cert == CERT_DATA + assert key == KEY_DATA + callback.assert_called_once() + + @pytest.mark.asyncio + async def test_get_client_cert_and_key_callback_sync(self): + # Test the fallback logic: if it's a sync function, the TypeError is caught + callback = mock.Mock(return_value=(CERT_DATA, KEY_DATA)) + + success, cert, key = await mtls.get_client_cert_and_key(callback) + + assert success is True + assert cert == CERT_DATA + # In your current implementation, this might still show 2 calls if the + # first 'await' attempt triggers a call before failing. + # To strictly avoid 2 calls, the implementation would need to check inspect.iscoroutinefunction. + assert callback.call_count >= 1 + + @pytest.mark.asyncio + @mock.patch( + "google.auth.aio.transport.mtls.get_client_ssl_credentials", + new_callable=mock.AsyncMock, + ) + async def test_get_client_cert_and_key_default(self, mock_get_credentials): + mock_get_credentials.return_value = (True, CERT_DATA, KEY_DATA, None) + + success, cert, key = await mtls.get_client_cert_and_key(None) + + assert success is True + assert cert == CERT_DATA + assert key == KEY_DATA + mock_get_credentials.assert_called_once() + + @pytest.mark.asyncio + @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") + async def test_get_client_ssl_credentials_error(self, mock_workload): + mock_workload.side_effect = exceptions.ClientCertError( + "Failed to read metadata" + ) + + with pytest.raises(exceptions.ClientCertError, match="Failed to read metadata"): + await mtls.get_client_ssl_credentials() + + @pytest.mark.asyncio + @mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials") + async def test_get_client_cert_and_key_exception_propagation(self, mock_get_ssl): + mock_get_ssl.side_effect = exceptions.ClientCertError( + "Underlying credentials failed" + ) + + with pytest.raises( + exceptions.ClientCertError, match="Underlying credentials failed" + ): + await mtls.get_client_cert_and_key(client_cert_callback=None)