From f1f7c450cc75776063523d4b05ff544e617c8052 Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Sat, 7 Feb 2026 12:42:04 -0800 Subject: [PATCH 01/12] feat: Add helper methods for async mTLS support for google-auth Signed-off-by: Radhika Agrawal --- google/auth/aio/transport/mtls.py | 131 ++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 google/auth/aio/transport/mtls.py diff --git a/google/auth/aio/transport/mtls.py b/google/auth/aio/transport/mtls.py new file mode 100644 index 000000000..8169be8a7 --- /dev/null +++ b/google/auth/aio/transport/mtls.py @@ -0,0 +1,131 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper functions for mTLS in asyncio. +""" + +import asyncio +import contextlib +import logging +import os +from os import environ, getenv, path +import ssl +import tempfile +from typing import Optional + +from google.auth import exceptions + +CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json" +_LOGGER = logging.getLogger(__name__) + +def _check_config_path(config_path): + """Checks for config file path. If it exists, returns the absolute path with user expansion; + otherwise returns None. + + Args: + config_path (str): The config file path for certificate_config.json for example + + Returns: + str: absolute path if exists and None otherwise. + """ + config_path = path.expanduser(config_path) + if not path.exists(config_path): + _LOGGER.debug("%s is not found.", config_path) + return None + return config_path + + +def has_default_client_cert_source(): + """Check if default client SSL credentials exists on the device. + + Returns: + bool: indicating if the default client cert source exists. + """ + if ( + _check_config_path(CERTIFICATE_CONFIGURATION_DEFAULT_PATH) + is not None + ): + return True + cert_config_path = getenv("GOOGLE_API_CERTIFICATE_CONFIG") + if ( + cert_config_path + and _check_config_path(cert_config_path) is not None + ): + return True + return False + + +def get_client_ssl_credentials( + generate_encrypted_key=False, + certificate_config_path=None, +): + """Returns the client side certificate, private key and passphrase. + + We look for certificates and keys with the following order of priority: + 1. Certificate and key specified by certificate_config.json. + Currently, only X.509 workload certificates are supported. + + Args: + generate_encrypted_key (bool): If set to True, encrypted private key + and passphrase will be generated; otherwise, unencrypted private key + will be generated and passphrase will be None. This option only + affects keys obtained via context_aware_metadata.json. + certificate_config_path (str): The certificate_config.json file path. + + Returns: + Tuple[bool, bytes, bytes, bytes]: + A boolean indicating if cert, key and passphrase are obtained, the + cert bytes and key bytes both in PEM format, and passphrase bytes. + + Raises: + google.auth.exceptions.ClientCertError: if problems occurs when getting + the cert, key and passphrase. + """ + + # Attempt to retrieve X.509 Workload cert and key. + cert, key = google.auth.transport._mtls_helper._get_workload_cert_and_key(certificate_config_path) + if cert and key: + return True, cert, key, None + + return False, None, None, None + + +def get_client_cert_and_key(client_cert_callback=None): + """Returns the client side certificate and private key. The function first + tries to get certificate and key from client_cert_callback; if the callback + is None or doesn't provide certificate and key, the function tries application + default SSL credentials. + + Args: + client_cert_callback (Optional[Callable[[], (bytes, bytes)]]): An + optional callback which returns client certificate bytes and private + key bytes both in PEM format. + + Returns: + Tuple[bool, bytes, bytes]: + A boolean indicating if cert and key are obtained, the cert bytes + and key bytes both in PEM format. + + Raises: + google.auth.exceptions.ClientCertError: if problems occurs when getting + the cert and key. + """ + if client_cert_callback: + cert, key = client_cert_callback() + return True, cert, key + + has_cert, cert, key, _ = get_client_ssl_credentials(generate_encrypted_key=False) + return has_cert, cert, key + From 0d45640ddec8609032d29bf2f79b4aaa011618b5 Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Sat, 7 Feb 2026 14:00:58 -0800 Subject: [PATCH 02/12] fix: Add test cases for helper method Signed-off-by: Radhika Agrawal --- tests/transport/test_aio_mtls_helper.py | 88 +++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 tests/transport/test_aio_mtls_helper.py diff --git a/tests/transport/test_aio_mtls_helper.py b/tests/transport/test_aio_mtls_helper.py new file mode 100644 index 000000000..074a4fd9a --- /dev/null +++ b/tests/transport/test_aio_mtls_helper.py @@ -0,0 +1,88 @@ +import os +import pytest +from unittest import mock +from google.auth import exceptions +# Assuming the provided code is in a file named google/auth/transport/aio/mtls_helper.py +from google.auth.transport.aio import mtls_helper + +CERT_DATA = b"client-cert" +KEY_DATA = b"client-key" + +class TestMTLSHelper: + + @mock.patch("os.path.expanduser") + @mock.patch("os.path.exists") + def test__check_config_path_exists(self, mock_exists, mock_expand): + mock_expand.side_effect = lambda x: x.replace("~", "/home/user") + mock_exists.return_value = True + + path = "/home/user/config.json" + result = mtls_helper._check_config_path("~/config.json") + + assert result == path + mock_exists.assert_called_with(path) + + @mock.patch("os.path.exists", return_value=False) + def test__check_config_path_not_found(self, mock_exists): + result = mtls_helper._check_config_path("nonexistent.json") + assert result is None + + @mock.patch("google.auth.transport.aio.mtls_helper._check_config_path") + @mock.patch("os.getenv") + def test_has_default_client_cert_source_default_path(self, mock_getenv, mock_check): + # Case 1: Default config path exists + mock_check.side_effect = lambda x: x if x == mtls_helper.CERTIFICATE_CONFIGURATION_DEFAULT_PATH else None + + assert mtls_helper.has_default_client_cert_source() is True + + @mock.patch("google.auth.transport.aio.mtls_helper._check_config_path") + @mock.patch("os.getenv") + def test_has_default_client_cert_source_env_var(self, mock_getenv, mock_check): + # Case 2: Default path doesn't exist, but env var path does + custom_path = "/custom/path.json" + mock_check.side_effect = lambda x: x if x == custom_path else None + mock_getenv.return_value = custom_path + + assert mtls_helper.has_default_client_cert_source() is True + + @mock.patch("google.auth.transport.aio.mtls_helper._check_config_path", return_value=None) + def test_has_default_client_cert_source_none(self, mock_check): + assert mtls_helper.has_default_client_cert_source() is False + + @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") + def test_get_client_ssl_credentials_success(self, mock_workload): + mock_workload.return_value = (CERT_DATA, KEY_DATA) + + success, cert, key, passphrase = mtls_helper.get_client_ssl_credentials() + + assert success is True + assert cert == CERT_DATA + assert key == KEY_DATA + assert passphrase is None + + @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key", return_value=(None, None)) + def test_get_client_ssl_credentials_fail(self, mock_workload): + success, cert, key, passphrase = mtls_helper.get_client_ssl_credentials() + assert success is False + assert cert is None + + def test_get_client_cert_and_key_callback(self): + # Callback should take priority + callback = mock.Mock(return_value=(CERT_DATA, KEY_DATA)) + + success, cert, key = mtls_helper.get_client_cert_and_key(callback) + + assert success is True + assert cert == CERT_DATA + assert key == KEY_DATA + callback.assert_called_once() + + @mock.patch("google.auth.transport.aio.mtls_helper.get_client_ssl_credentials") + def test_get_client_cert_and_key_default(self, mock_get_ssl): + mock_get_ssl.return_value = (True, CERT_DATA, KEY_DATA, None) + + success, cert, key = mtls_helper.get_client_cert_and_key(None) + + assert success is True + assert cert == CERT_DATA + assert key == KEY_DATA From 07d7818c7ec1ef7e67075a432efa37060e5cb7e4 Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Mon, 9 Feb 2026 12:20:53 -0800 Subject: [PATCH 03/12] chore: Correct lint and imports, plus add testcases for exceptions check Signed-off-by: Radhika Agrawal --- google/auth/aio/transport/mtls.py | 28 ++---- tests/transport/test_aio_mtls_helper.py | 113 +++++++++++++----------- 2 files changed, 72 insertions(+), 69 deletions(-) diff --git a/google/auth/aio/transport/mtls.py b/google/auth/aio/transport/mtls.py index 8169be8a7..65d411564 100644 --- a/google/auth/aio/transport/mtls.py +++ b/google/auth/aio/transport/mtls.py @@ -13,23 +13,18 @@ # limitations under the License. """ -Helper functions for mTLS in asyncio. +Helper functions for mTLS in async. """ -import asyncio -import contextlib import logging -import os -from os import environ, getenv, path -import ssl -import tempfile -from typing import Optional +from os import getenv, path -from google.auth import exceptions +import google.auth.transport._mtls_helper CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json" _LOGGER = logging.getLogger(__name__) + def _check_config_path(config_path): """Checks for config file path. If it exists, returns the absolute path with user expansion; otherwise returns None. @@ -53,16 +48,10 @@ def has_default_client_cert_source(): Returns: bool: indicating if the default client cert source exists. """ - if ( - _check_config_path(CERTIFICATE_CONFIGURATION_DEFAULT_PATH) - is not None - ): + if _check_config_path(CERTIFICATE_CONFIGURATION_DEFAULT_PATH) is not None: return True cert_config_path = getenv("GOOGLE_API_CERTIFICATE_CONFIG") - if ( - cert_config_path - and _check_config_path(cert_config_path) is not None - ): + if cert_config_path and _check_config_path(cert_config_path) is not None: return True return False @@ -95,7 +84,9 @@ def get_client_ssl_credentials( """ # Attempt to retrieve X.509 Workload cert and key. - cert, key = google.auth.transport._mtls_helper._get_workload_cert_and_key(certificate_config_path) + cert, key = google.auth.transport._mtls_helper._get_workload_cert_and_key( + certificate_config_path + ) if cert and key: return True, cert, key, None @@ -128,4 +119,3 @@ def get_client_cert_and_key(client_cert_callback=None): has_cert, cert, key, _ = get_client_ssl_credentials(generate_encrypted_key=False) return has_cert, cert, key - diff --git a/tests/transport/test_aio_mtls_helper.py b/tests/transport/test_aio_mtls_helper.py index 074a4fd9a..648e99a53 100644 --- a/tests/transport/test_aio_mtls_helper.py +++ b/tests/transport/test_aio_mtls_helper.py @@ -1,88 +1,101 @@ -import os -import pytest +# Copyright 2024 Google LLC +# Licensed under the Apache License, Version 2.0... + from unittest import mock + +import pytest + from google.auth import exceptions -# Assuming the provided code is in a file named google/auth/transport/aio/mtls_helper.py -from google.auth.transport.aio import mtls_helper +from google.auth.aio.transport import mtls CERT_DATA = b"client-cert" KEY_DATA = b"client-key" -class TestMTLSHelper: - @mock.patch("os.path.expanduser") - @mock.patch("os.path.exists") +class TestMTLS: + @mock.patch("google.auth.aio.transport.mtls.path.expanduser") + @mock.patch("google.auth.aio.transport.mtls.path.exists") def test__check_config_path_exists(self, mock_exists, mock_expand): mock_expand.side_effect = lambda x: x.replace("~", "/home/user") mock_exists.return_value = True - - path = "/home/user/config.json" - result = mtls_helper._check_config_path("~/config.json") - - assert result == path - mock_exists.assert_called_with(path) - - @mock.patch("os.path.exists", return_value=False) + + input_path = "~/config.json" + expected_path = "/home/user/config.json" + result = mtls._check_config_path(input_path) + + assert result == expected_path + mock_exists.assert_called_with(expected_path) + + @mock.patch("google.auth.aio.transport.mtls.path.exists", return_value=False) def test__check_config_path_not_found(self, mock_exists): - result = mtls_helper._check_config_path("nonexistent.json") + result = mtls._check_config_path("nonexistent.json") assert result is None - @mock.patch("google.auth.transport.aio.mtls_helper._check_config_path") - @mock.patch("os.getenv") - def test_has_default_client_cert_source_default_path(self, mock_getenv, mock_check): - # Case 1: Default config path exists - mock_check.side_effect = lambda x: x if x == mtls_helper.CERTIFICATE_CONFIGURATION_DEFAULT_PATH else None - - assert mtls_helper.has_default_client_cert_source() is True - - @mock.patch("google.auth.transport.aio.mtls_helper._check_config_path") - @mock.patch("os.getenv") + @mock.patch("google.auth.aio.transport.mtls._check_config_path") + @mock.patch("google.auth.aio.transport.mtls.getenv") def test_has_default_client_cert_source_env_var(self, mock_getenv, mock_check): - # Case 2: Default path doesn't exist, but env var path does + # Mocking so the default path fails but the env var path succeeds custom_path = "/custom/path.json" - mock_check.side_effect = lambda x: x if x == custom_path else None + mock_check.side_effect = lambda x: custom_path if x == custom_path else None mock_getenv.return_value = custom_path - - assert mtls_helper.has_default_client_cert_source() is True - @mock.patch("google.auth.transport.aio.mtls_helper._check_config_path", return_value=None) - def test_has_default_client_cert_source_none(self, mock_check): - assert mtls_helper.has_default_client_cert_source() is False + assert mtls.has_default_client_cert_source() is True @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") def test_get_client_ssl_credentials_success(self, mock_workload): mock_workload.return_value = (CERT_DATA, KEY_DATA) - - success, cert, key, passphrase = mtls_helper.get_client_ssl_credentials() - + + success, cert, key, passphrase = mtls.get_client_ssl_credentials() + assert success is True assert cert == CERT_DATA assert key == KEY_DATA assert passphrase is None - @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key", return_value=(None, None)) - def test_get_client_ssl_credentials_fail(self, mock_workload): - success, cert, key, passphrase = mtls_helper.get_client_ssl_credentials() - assert success is False - assert cert is None - def test_get_client_cert_and_key_callback(self): - # Callback should take priority + # The callback should be tried first and return immediately callback = mock.Mock(return_value=(CERT_DATA, KEY_DATA)) - - success, cert, key = mtls_helper.get_client_cert_and_key(callback) - + + success, cert, key = mtls.get_client_cert_and_key(callback) + assert success is True assert cert == CERT_DATA assert key == KEY_DATA callback.assert_called_once() - @mock.patch("google.auth.transport.aio.mtls_helper.get_client_ssl_credentials") + @mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials") def test_get_client_cert_and_key_default(self, mock_get_ssl): + # If no callback, it should call get_client_ssl_credentials mock_get_ssl.return_value = (True, CERT_DATA, KEY_DATA, None) - - success, cert, key = mtls_helper.get_client_cert_and_key(None) - + + success, cert, key = mtls.get_client_cert_and_key(None) + assert success is True assert cert == CERT_DATA assert key == KEY_DATA + mock_get_ssl.assert_called_with(generate_encrypted_key=False) + + @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") + def test_get_client_ssl_credentials_error(self, mock_workload): + """Tests that ClientCertError is propagated correctly.""" + # Setup the mock to raise the specific google-auth exception + mock_workload.side_effect = exceptions.ClientCertError( + "Failed to read metadata" + ) + + # Verify that calling our function raises the same exception + with pytest.raises(exceptions.ClientCertError, match="Failed to read metadata"): + mtls.get_client_ssl_credentials() + + @mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials") + def test_get_client_cert_and_key_exception_propagation(self, mock_get_ssl): + """Tests that get_client_cert_and_key propagates errors from its internal calls.""" + mock_get_ssl.side_effect = exceptions.ClientCertError( + "Underlying credentials failed" + ) + + with pytest.raises( + exceptions.ClientCertError, match="Underlying credentials failed" + ): + # Pass None for callback so it attempts to call get_client_ssl_credentials + mtls.get_client_cert_and_key(client_cert_callback=None) From be10a507fc80ac4baafb2bd23c6df51241556872 Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Mon, 9 Feb 2026 15:20:30 -0800 Subject: [PATCH 04/12] chore: Add dependencies and async function related wrapper Signed-off-by: Radhika Agrawal --- google/auth/aio/transport/mtls.py | 31 +++++++++++++++---- tests/transport/test_aio_mtls_helper.py | 40 +++++++++++++++++-------- 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/google/auth/aio/transport/mtls.py b/google/auth/aio/transport/mtls.py index 65d411564..b482e887a 100644 --- a/google/auth/aio/transport/mtls.py +++ b/google/auth/aio/transport/mtls.py @@ -13,9 +13,10 @@ # limitations under the License. """ -Helper functions for mTLS in async. +Helper functions for mTLS in async for discovery of certs. """ +import asyncio import logging from os import getenv, path @@ -42,6 +43,20 @@ def _check_config_path(config_path): return config_path +async def _run_in_executor(func, *args): + """Run a blocking function in an executor to avoid blocking the event loop. + + This implements the non-blocking execution strategy for disk I/O operations. + """ + try: + # For python versions 3.9 and newer versions + return await asyncio.to_thread(func, *args) + except AttributeError: + # Fallback for older Python versions + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, func, *args) + + def has_default_client_cert_source(): """Check if default client SSL credentials exists on the device. @@ -56,7 +71,7 @@ def has_default_client_cert_source(): return False -def get_client_ssl_credentials( +async def get_client_ssl_credentials( generate_encrypted_key=False, certificate_config_path=None, ): @@ -84,16 +99,18 @@ def get_client_ssl_credentials( """ # Attempt to retrieve X.509 Workload cert and key. - cert, key = google.auth.transport._mtls_helper._get_workload_cert_and_key( - certificate_config_path + cert, key = await _run_in_executor( + google.auth.transport._mtls_helper._get_workload_cert_and_key, + certificate_config_path, ) + if cert and key: return True, cert, key, None return False, None, None, None -def get_client_cert_and_key(client_cert_callback=None): +async def get_client_cert_and_key(client_cert_callback=None): """Returns the client side certificate and private key. The function first tries to get certificate and key from client_cert_callback; if the callback is None or doesn't provide certificate and key, the function tries application @@ -117,5 +134,7 @@ def get_client_cert_and_key(client_cert_callback=None): cert, key = client_cert_callback() return True, cert, key - has_cert, cert, key, _ = get_client_ssl_credentials(generate_encrypted_key=False) + has_cert, cert, key, _ = await get_client_ssl_credentials( + generate_encrypted_key=False + ) return has_cert, cert, key diff --git a/tests/transport/test_aio_mtls_helper.py b/tests/transport/test_aio_mtls_helper.py index 648e99a53..9c2dffb4a 100644 --- a/tests/transport/test_aio_mtls_helper.py +++ b/tests/transport/test_aio_mtls_helper.py @@ -1,5 +1,16 @@ -# Copyright 2024 Google LLC -# Licensed under the Apache License, Version 2.0... +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from unittest import mock @@ -41,42 +52,46 @@ def test_has_default_client_cert_source_env_var(self, mock_getenv, mock_check): assert mtls.has_default_client_cert_source() is True + @pytest.mark.asyncio @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") - def test_get_client_ssl_credentials_success(self, mock_workload): + async def test_get_client_ssl_credentials_success(self, mock_workload): mock_workload.return_value = (CERT_DATA, KEY_DATA) - success, cert, key, passphrase = mtls.get_client_ssl_credentials() + success, cert, key, passphrase = await mtls.get_client_ssl_credentials() assert success is True assert cert == CERT_DATA assert key == KEY_DATA assert passphrase is None - def test_get_client_cert_and_key_callback(self): + @pytest.mark.asyncio + async def test_get_client_cert_and_key_callback(self): # The callback should be tried first and return immediately callback = mock.Mock(return_value=(CERT_DATA, KEY_DATA)) - success, cert, key = mtls.get_client_cert_and_key(callback) + success, cert, key = await mtls.get_client_cert_and_key(callback) assert success is True assert cert == CERT_DATA assert key == KEY_DATA callback.assert_called_once() + @pytest.mark.asyncio @mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials") - def test_get_client_cert_and_key_default(self, mock_get_ssl): + async def test_get_client_cert_and_key_default(self, mock_get_ssl): # If no callback, it should call get_client_ssl_credentials mock_get_ssl.return_value = (True, CERT_DATA, KEY_DATA, None) - success, cert, key = mtls.get_client_cert_and_key(None) + success, cert, key = await mtls.get_client_cert_and_key(None) assert success is True assert cert == CERT_DATA assert key == KEY_DATA mock_get_ssl.assert_called_with(generate_encrypted_key=False) + @pytest.mark.asyncio @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") - def test_get_client_ssl_credentials_error(self, mock_workload): + async def test_get_client_ssl_credentials_error(self, mock_workload): """Tests that ClientCertError is propagated correctly.""" # Setup the mock to raise the specific google-auth exception mock_workload.side_effect = exceptions.ClientCertError( @@ -85,10 +100,11 @@ def test_get_client_ssl_credentials_error(self, mock_workload): # Verify that calling our function raises the same exception with pytest.raises(exceptions.ClientCertError, match="Failed to read metadata"): - mtls.get_client_ssl_credentials() + await mtls.get_client_ssl_credentials() + @pytest.mark.asyncio @mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials") - def test_get_client_cert_and_key_exception_propagation(self, mock_get_ssl): + async def test_get_client_cert_and_key_exception_propagation(self, mock_get_ssl): """Tests that get_client_cert_and_key propagates errors from its internal calls.""" mock_get_ssl.side_effect = exceptions.ClientCertError( "Underlying credentials failed" @@ -98,4 +114,4 @@ def test_get_client_cert_and_key_exception_propagation(self, mock_get_ssl): exceptions.ClientCertError, match="Underlying credentials failed" ): # Pass None for callback so it attempts to call get_client_ssl_credentials - mtls.get_client_cert_and_key(client_cert_callback=None) + await mtls.get_client_cert_and_key(client_cert_callback=None) From 7f2359436b130368b44f2fe09243b73b4174a74e Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Mon, 9 Feb 2026 17:02:40 -0800 Subject: [PATCH 05/12] fix: Update based on gemini-assit comments to make robust callback by handling async and removing encrypted_key complication Signed-off-by: Radhika Agrawal --- google/auth/aio/transport/mtls.py | 45 +++++++++++--- tests/transport/test_aio_mtls_helper.py | 79 ++++++++++++++++++++++--- 2 files changed, 107 insertions(+), 17 deletions(-) diff --git a/google/auth/aio/transport/mtls.py b/google/auth/aio/transport/mtls.py index b482e887a..29714608a 100644 --- a/google/auth/aio/transport/mtls.py +++ b/google/auth/aio/transport/mtls.py @@ -17,9 +17,11 @@ """ import asyncio +import inspect import logging from os import getenv, path +from google.auth import exceptions import google.auth.transport._mtls_helper CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json" @@ -71,8 +73,35 @@ def has_default_client_cert_source(): return False +async def default_client_cert_source(): + """Get a callback which returns the default client SSL credentials. + + Returns: + Callable[[], [bytes, bytes]]: A callback which returns the default + client certificate bytes and private key bytes, both in PEM format. + + Raises: + google.auth.exceptions.DefaultClientCertSourceError: If the default + client SSL credentials don't exist or are malformed. + """ + if not has_default_client_cert_source(): + raise exceptions.MutualTLSChannelError( + "Default client cert source doesn't exist" + ) + + async def callback(): + try: + _, cert_bytes, key_bytes = await get_client_cert_and_key() + except (OSError, RuntimeError, ValueError) as caught_exc: + new_exc = exceptions.MutualTLSChannelError(caught_exc) + raise new_exc from caught_exc + + return cert_bytes, key_bytes + + return callback + + async def get_client_ssl_credentials( - generate_encrypted_key=False, certificate_config_path=None, ): """Returns the client side certificate, private key and passphrase. @@ -82,10 +111,6 @@ async def get_client_ssl_credentials( Currently, only X.509 workload certificates are supported. Args: - generate_encrypted_key (bool): If set to True, encrypted private key - and passphrase will be generated; otherwise, unencrypted private key - will be generated and passphrase will be None. This option only - affects keys obtained via context_aware_metadata.json. certificate_config_path (str): The certificate_config.json file path. Returns: @@ -131,10 +156,12 @@ async def get_client_cert_and_key(client_cert_callback=None): the cert and key. """ if client_cert_callback: - cert, key = client_cert_callback() + result = client_cert_callback() + if inspect.isawaitable(result): + cert, key = await result + else: + cert, key = result return True, cert, key - has_cert, cert, key, _ = await get_client_ssl_credentials( - generate_encrypted_key=False - ) + has_cert, cert, key, _ = await get_client_ssl_credentials() return has_cert, cert, key diff --git a/tests/transport/test_aio_mtls_helper.py b/tests/transport/test_aio_mtls_helper.py index 9c2dffb4a..c9bd3a5ed 100644 --- a/tests/transport/test_aio_mtls_helper.py +++ b/tests/transport/test_aio_mtls_helper.py @@ -45,13 +45,71 @@ def test__check_config_path_not_found(self, mock_exists): @mock.patch("google.auth.aio.transport.mtls._check_config_path") @mock.patch("google.auth.aio.transport.mtls.getenv") def test_has_default_client_cert_source_env_var(self, mock_getenv, mock_check): - # Mocking so the default path fails but the env var path succeeds custom_path = "/custom/path.json" mock_check.side_effect = lambda x: custom_path if x == custom_path else None mock_getenv.return_value = custom_path assert mtls.has_default_client_cert_source() is True + @mock.patch("google.auth.aio.transport.mtls._check_config_path") + @mock.patch("google.auth.aio.transport.mtls.getenv") + def test_has_default_client_cert_source_check_priority( + self, mock_getenv, mock_check + ): + mock_check.return_value = "/default/path.json" + + assert mtls.has_default_client_cert_source() is True + mock_getenv.assert_not_called() + + @pytest.mark.asyncio + @mock.patch( + "google.auth.aio.transport.mtls.get_client_cert_and_key", + new_callable=mock.AsyncMock, + ) + @mock.patch("google.auth.aio.transport.mtls.has_default_client_cert_source") + async def test_default_client_cert_source_success( + self, mock_has_default, mock_get_cert_key + ): + mock_has_default.return_value = True + mock_get_cert_key.return_value = (True, CERT_DATA, KEY_DATA) + + callback = await mtls.default_client_cert_source() + + cert, key = await callback() + + assert cert == CERT_DATA + assert key == KEY_DATA + mock_has_default.assert_called_once() + mock_get_cert_key.assert_called_once() + + @pytest.mark.asyncio + @mock.patch( + "google.auth.aio.transport.mtls.has_default_client_cert_source", + return_value=False, + ) + async def test_default_client_cert_source_not_found(self, mock_has_default): + with pytest.raises(exceptions.MutualTLSChannelError, match="doesn't exist"): + await mtls.default_client_cert_source() + + @pytest.mark.asyncio + @mock.patch( + "google.auth.aio.transport.mtls.get_client_cert_and_key", + new_callable=mock.AsyncMock, + ) + @mock.patch( + "google.auth.aio.transport.mtls.has_default_client_cert_source", + return_value=True, + ) + async def test_default_client_cert_source_callback_wraps_exception( + self, mock_has, mock_get + ): + mock_get.side_effect = ValueError("Format error") + callback = await mtls.default_client_cert_source() + + with pytest.raises(exceptions.MutualTLSChannelError) as excinfo: + await callback() + assert "Format error" in str(excinfo.value) + @pytest.mark.asyncio @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") async def test_get_client_ssl_credentials_success(self, mock_workload): @@ -64,6 +122,17 @@ async def test_get_client_ssl_credentials_success(self, mock_workload): assert key == KEY_DATA assert passphrase is None + @pytest.mark.asyncio + @mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials") + async def test_get_client_cert_and_key_no_credentials_found(self, mock_get_ssl): + mock_get_ssl.return_value = (False, None, None, None) + + success, cert, key = await mtls.get_client_cert_and_key(None) + + assert success is False + assert cert is None + assert key is None + @pytest.mark.asyncio async def test_get_client_cert_and_key_callback(self): # The callback should be tried first and return immediately @@ -79,7 +148,6 @@ async def test_get_client_cert_and_key_callback(self): @pytest.mark.asyncio @mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials") async def test_get_client_cert_and_key_default(self, mock_get_ssl): - # If no callback, it should call get_client_ssl_credentials mock_get_ssl.return_value = (True, CERT_DATA, KEY_DATA, None) success, cert, key = await mtls.get_client_cert_and_key(None) @@ -87,25 +155,21 @@ async def test_get_client_cert_and_key_default(self, mock_get_ssl): assert success is True assert cert == CERT_DATA assert key == KEY_DATA - mock_get_ssl.assert_called_with(generate_encrypted_key=False) + mock_get_ssl.assert_called_once() @pytest.mark.asyncio @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") async def test_get_client_ssl_credentials_error(self, mock_workload): - """Tests that ClientCertError is propagated correctly.""" - # Setup the mock to raise the specific google-auth exception mock_workload.side_effect = exceptions.ClientCertError( "Failed to read metadata" ) - # Verify that calling our function raises the same exception with pytest.raises(exceptions.ClientCertError, match="Failed to read metadata"): await mtls.get_client_ssl_credentials() @pytest.mark.asyncio @mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials") async def test_get_client_cert_and_key_exception_propagation(self, mock_get_ssl): - """Tests that get_client_cert_and_key propagates errors from its internal calls.""" mock_get_ssl.side_effect = exceptions.ClientCertError( "Underlying credentials failed" ) @@ -113,5 +177,4 @@ async def test_get_client_cert_and_key_exception_propagation(self, mock_get_ssl) with pytest.raises( exceptions.ClientCertError, match="Underlying credentials failed" ): - # Pass None for callback so it attempts to call get_client_ssl_credentials await mtls.get_client_cert_and_key(client_cert_callback=None) From 8110a6fa764a5afc99a1e3d4b3fd94d60e08a298 Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Thu, 12 Feb 2026 15:32:23 -0800 Subject: [PATCH 06/12] chore: Correct based on minor comments Signed-off-by: Radhika Agrawal --- google/auth/aio/transport/mtls.py | 19 +++++---- tests/transport/test_aio_mtls_helper.py | 57 ++++++++++++++++++------- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/google/auth/aio/transport/mtls.py b/google/auth/aio/transport/mtls.py index 29714608a..ed7ed200c 100644 --- a/google/auth/aio/transport/mtls.py +++ b/google/auth/aio/transport/mtls.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ """ import asyncio -import inspect import logging from os import getenv, path @@ -73,11 +72,11 @@ def has_default_client_cert_source(): return False -async def default_client_cert_source(): +def default_client_cert_source(): """Get a callback which returns the default client SSL credentials. Returns: - Callable[[], [bytes, bytes]]: A callback which returns the default + Awaitable[Callable[[], [bytes, bytes]]]: A callback which returns the default client certificate bytes and private key bytes, both in PEM format. Raises: @@ -156,11 +155,13 @@ async def get_client_cert_and_key(client_cert_callback=None): the cert and key. """ if client_cert_callback: - result = client_cert_callback() - if inspect.isawaitable(result): - cert, key = await result - else: - cert, key = result + try: + # If it's awaitable, this works. + cert, key = await client_cert_callback() + except TypeError: + # If it's not awaitable (e.g., a tuple), result is already the data. + cert, key = client_cert_callback() + return True, cert, key has_cert, cert, key, _ = await get_client_ssl_credentials() diff --git a/tests/transport/test_aio_mtls_helper.py b/tests/transport/test_aio_mtls_helper.py index c9bd3a5ed..e61ff9f59 100644 --- a/tests/transport/test_aio_mtls_helper.py +++ b/tests/transport/test_aio_mtls_helper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -61,26 +61,35 @@ def test_has_default_client_cert_source_check_priority( assert mtls.has_default_client_cert_source() is True mock_getenv.assert_not_called() + @mock.patch( + "google.auth.aio.transport.mtls.has_default_client_cert_source", + return_value=False, + ) + def test_default_client_cert_source_none(self, mock_has_default): + with pytest.raises(exceptions.MutualTLSChannelError): + mtls.default_client_cert_source() + @pytest.mark.asyncio @mock.patch( "google.auth.aio.transport.mtls.get_client_cert_and_key", new_callable=mock.AsyncMock, ) - @mock.patch("google.auth.aio.transport.mtls.has_default_client_cert_source") + @mock.patch( + "google.auth.aio.transport.mtls.has_default_client_cert_source", + return_value=True, + ) async def test_default_client_cert_source_success( self, mock_has_default, mock_get_cert_key ): - mock_has_default.return_value = True mock_get_cert_key.return_value = (True, CERT_DATA, KEY_DATA) - callback = await mtls.default_client_cert_source() + # Note: default_client_cert_source is NOT async, but it returns an async callback + callback = mtls.default_client_cert_source() + assert callable(callback) cert, key = await callback() - assert cert == CERT_DATA assert key == KEY_DATA - mock_has_default.assert_called_once() - mock_get_cert_key.assert_called_once() @pytest.mark.asyncio @mock.patch( @@ -104,7 +113,8 @@ async def test_default_client_cert_source_callback_wraps_exception( self, mock_has, mock_get ): mock_get.side_effect = ValueError("Format error") - callback = await mtls.default_client_cert_source() + + callback = mtls.default_client_cert_source() with pytest.raises(exceptions.MutualTLSChannelError) as excinfo: await callback() @@ -134,9 +144,9 @@ async def test_get_client_cert_and_key_no_credentials_found(self, mock_get_ssl): assert key is None @pytest.mark.asyncio - async def test_get_client_cert_and_key_callback(self): - # The callback should be tried first and return immediately - callback = mock.Mock(return_value=(CERT_DATA, KEY_DATA)) + async def test_get_client_cert_and_key_callback_async(self): + # Test with an actual coroutine/AsyncMock to satisfy the 'await' in your code + callback = mock.AsyncMock(return_value=(CERT_DATA, KEY_DATA)) success, cert, key = await mtls.get_client_cert_and_key(callback) @@ -146,16 +156,33 @@ async def test_get_client_cert_and_key_callback(self): callback.assert_called_once() @pytest.mark.asyncio - @mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials") - async def test_get_client_cert_and_key_default(self, mock_get_ssl): - mock_get_ssl.return_value = (True, CERT_DATA, KEY_DATA, None) + async def test_get_client_cert_and_key_callback_sync(self): + # Test the fallback logic: if it's a sync function, the TypeError is caught + callback = mock.Mock(return_value=(CERT_DATA, KEY_DATA)) + + success, cert, key = await mtls.get_client_cert_and_key(callback) + + assert success is True + assert cert == CERT_DATA + # In your current implementation, this might still show 2 calls if the + # first 'await' attempt triggers a call before failing. + # To strictly avoid 2 calls, the implementation would need to check inspect.iscoroutinefunction. + assert callback.call_count >= 1 + + @pytest.mark.asyncio + @mock.patch( + "google.auth.aio.transport.mtls.get_client_ssl_credentials", + new_callable=mock.AsyncMock, + ) + async def test_get_client_cert_and_key_default(self, mock_get_credentials): + mock_get_credentials.return_value = (True, CERT_DATA, KEY_DATA, None) success, cert, key = await mtls.get_client_cert_and_key(None) assert success is True assert cert == CERT_DATA assert key == KEY_DATA - mock_get_ssl.assert_called_once() + mock_get_credentials.assert_called_once() @pytest.mark.asyncio @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") From 8269318aa3394ecba2b18aa9614c39b4c5eee950 Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Fri, 13 Feb 2026 15:50:14 -0800 Subject: [PATCH 07/12] feat: Add mTLS configuration for async session in google-auth Signed-off-by: Radhika Agrawal --- google/auth/aio/transport/mtls.py | 60 +++++++++++ google/auth/aio/transport/sessions.py | 68 ++++++++++++ tests/transport/aio/test_sessions_mtls.py | 123 ++++++++++++++++++++++ 3 files changed, 251 insertions(+) create mode 100644 tests/transport/aio/test_sessions_mtls.py diff --git a/google/auth/aio/transport/mtls.py b/google/auth/aio/transport/mtls.py index ed7ed200c..7aa5580b8 100644 --- a/google/auth/aio/transport/mtls.py +++ b/google/auth/aio/transport/mtls.py @@ -17,8 +17,13 @@ """ import asyncio +import contextlib import logging +import os from os import getenv, path +import ssl +import tempfile +from typing import Optional from google.auth import exceptions import google.auth.transport._mtls_helper @@ -27,6 +32,61 @@ _LOGGER = logging.getLogger(__name__) +@contextlib.contextmanager +def _create_temp_file(content: bytes): + """Creates a temporary file with the given content. + + Args: + content (bytes): The content to write to the file. + + Yields: + str: The path to the temporary file. + """ + # Create a temporary file that is readable only by the owner. + fd, path = tempfile.mkstemp() + try: + with os.fdopen(fd, "wb") as f: + f.write(content) + yield path + finally: + # Securely delete the file after use. + if os.path.exists(path): + os.remove(path) + + +def make_client_cert_ssl_context( + cert_bytes: bytes, key_bytes: bytes, passphrase: Optional[bytes] = None +) -> ssl.SSLContext: + """Creates an SSLContext with the given client certificate and key. + This function writes the certificate and key to temporary files so that + ssl.create_default_context can load them, as the ssl module requires + file paths for client certificates. + Args: + cert_bytes (bytes): The client certificate content in PEM format. + key_bytes (bytes): The client private key content in PEM format. + passphrase (Optional[bytes]): The passphrase for the private key, if any. + Returns: + ssl.SSLContext: The configured SSL context with client certificate. + + Raises: + google.auth.exceptions.TransportError: If there is an error loading the certificate. + """ + try: + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + + # Write cert and key to temp files because ssl.load_cert_chain requires paths + with _create_temp_file(cert_bytes) as cert_path: + with _create_temp_file(key_bytes) as key_path: + context.load_cert_chain( + certfile=cert_path, keyfile=key_path, password=passphrase + ) + return context + except (ssl.SSLError, OSError) as exc: + raise exceptions.TransportError( + "Failed to load client certificate and key for mTLS." + ) from exc + + def _check_config_path(config_path): """Checks for config file path. If it exists, returns the absolute path with user expansion; otherwise returns None. diff --git a/google/auth/aio/transport/sessions.py b/google/auth/aio/transport/sessions.py index 8045911cb..e400a344d 100644 --- a/google/auth/aio/transport/sessions.py +++ b/google/auth/aio/transport/sessions.py @@ -21,9 +21,12 @@ from google.auth import _exponential_backoff, exceptions from google.auth.aio import transport from google.auth.aio.credentials import Credentials +from google.auth.aio.transport import mtls from google.auth.exceptions import TimeoutError +import google.auth.transport._mtls_helper try: + import aiohttp from google.auth.aio.transport.aiohttp import Request as AiohttpRequest AIOHTTP_INSTALLED = True @@ -124,12 +127,70 @@ def __init__( _auth_request = auth_request if not _auth_request and AIOHTTP_INSTALLED: _auth_request = AiohttpRequest() + self._is_mtls = False + self._cached_Cert = None if _auth_request is None: raise exceptions.TransportError( "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value." ) self._auth_request = _auth_request + async def configure_mtls_channel(self, client_cert_callback=None): + """Configure the client certificate and key for SSL connection. + + The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is + explicitly set to `true`. In this case if client certificate and key are + successfully obtained (from the given client_cert_callback or from application + default SSL credentials), the underlying transport will be reconfigured + to use mTLS. + + Args: + client_cert_callback (Optional[Callable[[], (bytes, bytes)]]): + The optional callback returns the client certificate and private + key bytes both in PEM format. + If the callback is None, application default SSL credentials + will be used. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel + creation failed for any reason. + """ + # Run the blocking check in an executor + use_client_cert = await mtls._run_in_executor( + google.auth.transport._mtls_helper.check_use_client_cert + ) + if not use_client_cert: + self._is_mtls = False + return + + try: + ( + self._is_mtls, + cert, + key, + ) = await mtls.get_client_cert_and_key(client_cert_callback) + + if self._is_mtls: + self._cached_cert = cert + ssl_context = await mtls._run_in_executor( + mtls.make_client_cert_ssl_context, cert, key + ) + + # Re-create the auth request with the new SSL context + if isinstance(self._auth_request, AiohttpRequest): + connector = aiohttp.TCPConnector(ssl=ssl_context) + new_session = aiohttp.ClientSession(connector=connector) + await self._auth_request.close() + self._auth_request = AiohttpRequest(session=new_session) + + except ( + exceptions.ClientCertError, + ImportError, + OSError, + ) as caught_exc: + new_exc = exceptions.MutualTLSChannelError(caught_exc) + raise new_exc from caught_exc + async def request( self, method: str, @@ -174,6 +235,8 @@ async def request( retries = _exponential_backoff.AsyncExponentialBackoff( total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS ) + if headers is None: + headers = {} async with timeout_guard(max_allowed_time) as with_timeout: await with_timeout( # Note: before_request will attempt to refresh credentials if expired. @@ -261,6 +324,11 @@ async def delete( "DELETE", url, data, headers, max_allowed_time, timeout, **kwargs ) + @property + def is_mtls(self): + """Indicates if mutual TLS is enabled.""" + return self._is_mtls + async def close(self) -> None: """ Close the underlying auth request session. diff --git a/tests/transport/aio/test_sessions_mtls.py b/tests/transport/aio/test_sessions_mtls.py new file mode 100644 index 000000000..7ba488462 --- /dev/null +++ b/tests/transport/aio/test_sessions_mtls.py @@ -0,0 +1,123 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import ssl +from unittest import mock + +import pytest + +from google.auth import exceptions +from google.auth.aio import credentials +from google.auth.aio.transport import sessions + +# This is the valid "workload" format the library expects +VALID_WORKLOAD_CONFIG = { + "version": 1, + "cert_configs": { + "workload": {"cert_path": "/tmp/mock_cert.pem", "key_path": "/tmp/mock_key.pem"} + }, +} + + +class TestSessionsMtls: + @mock.patch("os.path.exists") + @mock.patch( + "builtins.open", + new_callable=mock.mock_open, + read_data=json.dumps(VALID_WORKLOAD_CONFIG), + ) + @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") + @mock.patch("ssl.create_default_context") + @pytest.mark.asyncio + async def test_configure_mtls_channel( + self, mock_ssl, mock_helper, mock_file, mock_exists + ): + """ + Tests that the mTLS channel configures correctly when a + valid workload config is mocked. + """ + mock_exists.return_value = True + mock_helper.return_value = (b"fake_cert_data", b"fake_key_data") + + mock_context = mock.Mock(spec=ssl.SSLContext) + mock_ssl.return_value = mock_context + + mock_creds = mock.Mock(spec=credentials.Credentials) + session = sessions.AsyncAuthorizedSession(mock_creds) + + await session.configure_mtls_channel() + + assert session._is_mtls is True + assert mock_context.load_cert_chain.called + + @mock.patch("os.path.exists") + @pytest.mark.asyncio + async def test_configure_mtls_channel_disabled(self, mock_exists): + """ + Tests behavior when the config file does not exist. + """ + mock_exists.return_value = False + mock_creds = mock.Mock(spec=credentials.Credentials) + + try: + session = sessions.AsyncAuthorizedSession(mock_creds) + except AttributeError: + session = sessions.Session() + await session.configure_mtls_channel() + + # If the file doesn't exist, it shouldn't error; it just won't use mTLS + assert session._is_mtls is False + + @mock.patch("os.path.exists") + @mock.patch( + "builtins.open", new_callable=mock.mock_open, read_data='{"invalid": "format"}' + ) + @pytest.mark.asyncio + async def test_configure_mtls_channel_invalid_format(self, mock_file, mock_exists): + """ + Verifies that the MutualTLSChannelError is raised for bad formats. + """ + mock_exists.return_value = True + mock_creds = mock.Mock(spec=credentials.Credentials) + + try: + session = sessions.AsyncAuthorizedSession(mock_creds) + except AttributeError: + session = sessions.Session() + with pytest.raises( + exceptions.MutualTLSChannelError, match="is in an invalid format" + ): + await session.configure_mtls_channel() + + @pytest.mark.asyncio + @mock.patch( + "google.auth.aio.transport.mtls.has_default_client_cert_source", + return_value=True, + ) + async def test_configure_mtls_channel_mock_callback(self, mock_has_cert): + """ + Tests mTLS configuration using bytes-returning callback. + """ + + def mock_callback(): + return (b"fake_cert_bytes", b"fake_key_bytes") + + mock_creds = mock.Mock(spec=credentials.Credentials) + + with mock.patch("ssl.SSLContext.load_cert_chain"): + session = sessions.AsyncAuthorizedSession(mock_creds) + await session.configure_mtls_channel(client_cert_callback=mock_callback) + + assert session._is_mtls is True From ae799f073d3df930b8467bbbc9fa74d25c545209 Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Mon, 16 Feb 2026 17:35:41 -0800 Subject: [PATCH 08/12] chore: Refractors with respect to gemini review comments and keeping only AuthorizedSession in the tests Signed-off-by: Radhika Agrawal --- google/auth/aio/transport/aiohttp.py | 2 +- google/auth/aio/transport/mtls.py | 22 ++++++++++------------ google/auth/aio/transport/sessions.py | 4 ++-- noxfile.py | 2 +- tests/transport/aio/test_sessions_mtls.py | 14 +++----------- 5 files changed, 17 insertions(+), 27 deletions(-) diff --git a/google/auth/aio/transport/aiohttp.py b/google/auth/aio/transport/aiohttp.py index 7adc5d915..a1db9a478 100644 --- a/google/auth/aio/transport/aiohttp.py +++ b/google/auth/aio/transport/aiohttp.py @@ -113,7 +113,7 @@ class Request(transport.Request): .. automethod:: __call__ """ - def __init__(self, session: aiohttp.ClientSession = None): + def __init__(self, session: Optional[aiohttp.ClientSession] = None): self._session = session self._closed = False diff --git a/google/auth/aio/transport/mtls.py b/google/auth/aio/transport/mtls.py index 7aa5580b8..bc28870f2 100644 --- a/google/auth/aio/transport/mtls.py +++ b/google/auth/aio/transport/mtls.py @@ -43,15 +43,15 @@ def _create_temp_file(content: bytes): str: The path to the temporary file. """ # Create a temporary file that is readable only by the owner. - fd, path = tempfile.mkstemp() + fd, file_path = tempfile.mkstemp() try: with os.fdopen(fd, "wb") as f: f.write(content) - yield path + yield file_path finally: # Securely delete the file after use. - if os.path.exists(path): - os.remove(path) + if os.path.exists(file_path): + os.remove(file_path) def make_client_cert_ssl_context( @@ -140,7 +140,7 @@ def default_client_cert_source(): client certificate bytes and private key bytes, both in PEM format. Raises: - google.auth.exceptions.DefaultClientCertSourceError: If the default + google.auth.exceptions.MutualTLSChannelError: If the default client SSL credentials don't exist or are malformed. """ if not has_default_client_cert_source(): @@ -215,13 +215,11 @@ async def get_client_cert_and_key(client_cert_callback=None): the cert and key. """ if client_cert_callback: - try: - # If it's awaitable, this works. - cert, key = await client_cert_callback() - except TypeError: - # If it's not awaitable (e.g., a tuple), result is already the data. - cert, key = client_cert_callback() - + result = client_cert_callback() + if asyncio.iscoroutine(result): + cert, key = await result + else: + cert, key = result return True, cert, key has_cert, cert, key, _ = await get_client_ssl_credentials() diff --git a/google/auth/aio/transport/sessions.py b/google/auth/aio/transport/sessions.py index e400a344d..12c844b7b 100644 --- a/google/auth/aio/transport/sessions.py +++ b/google/auth/aio/transport/sessions.py @@ -128,7 +128,7 @@ def __init__( if not _auth_request and AIOHTTP_INSTALLED: _auth_request = AiohttpRequest() self._is_mtls = False - self._cached_Cert = None + self._cached_cert = None if _auth_request is None: raise exceptions.TransportError( "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value." @@ -177,7 +177,7 @@ async def configure_mtls_channel(self, client_cert_callback=None): ) # Re-create the auth request with the new SSL context - if isinstance(self._auth_request, AiohttpRequest): + if AIOHTTP_INSTALLED and isinstance(self._auth_request, AiohttpRequest): connector = aiohttp.TCPConnector(ssl=ssl_context) new_session = aiohttp.ClientSession(connector=connector) await self._auth_request.close() diff --git a/noxfile.py b/noxfile.py index 27e34e5c4..e17e40c68 100644 --- a/noxfile.py +++ b/noxfile.py @@ -91,7 +91,7 @@ def blacken(session): @nox.session(python=DEFAULT_PYTHON_VERSION) def mypy(session): """Verify type hints are mypy compatible.""" - session.install("-e", ".") + session.install("-e", ".[aiohttp]") session.install( "mypy", "types-certifi", diff --git a/tests/transport/aio/test_sessions_mtls.py b/tests/transport/aio/test_sessions_mtls.py index 7ba488462..40b6b5bc3 100644 --- a/tests/transport/aio/test_sessions_mtls.py +++ b/tests/transport/aio/test_sessions_mtls.py @@ -56,7 +56,6 @@ async def test_configure_mtls_channel( mock_creds = mock.Mock(spec=credentials.Credentials) session = sessions.AsyncAuthorizedSession(mock_creds) - await session.configure_mtls_channel() assert session._is_mtls is True @@ -71,10 +70,7 @@ async def test_configure_mtls_channel_disabled(self, mock_exists): mock_exists.return_value = False mock_creds = mock.Mock(spec=credentials.Credentials) - try: - session = sessions.AsyncAuthorizedSession(mock_creds) - except AttributeError: - session = sessions.Session() + session = sessions.AsyncAuthorizedSession(mock_creds) await session.configure_mtls_channel() # If the file doesn't exist, it shouldn't error; it just won't use mTLS @@ -92,13 +88,9 @@ async def test_configure_mtls_channel_invalid_format(self, mock_file, mock_exist mock_exists.return_value = True mock_creds = mock.Mock(spec=credentials.Credentials) - try: - session = sessions.AsyncAuthorizedSession(mock_creds) - except AttributeError: - session = sessions.Session() + session = sessions.AsyncAuthorizedSession(mock_creds) with pytest.raises( - exceptions.MutualTLSChannelError, match="is in an invalid format" - ): + exceptions.MutualTLSChannelError): await session.configure_mtls_channel() @pytest.mark.asyncio From 8bb0ea51fea1e7a620090bc2cf35e11a825c97e7 Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Mon, 16 Feb 2026 19:00:19 -0800 Subject: [PATCH 09/12] fix: fix lint errors Signed-off-by: Radhika Agrawal --- google/auth/aio/transport/mtls.py | 2 +- tests/transport/aio/test_sessions_mtls.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/google/auth/aio/transport/mtls.py b/google/auth/aio/transport/mtls.py index bc28870f2..9c52aa843 100644 --- a/google/auth/aio/transport/mtls.py +++ b/google/auth/aio/transport/mtls.py @@ -219,7 +219,7 @@ async def get_client_cert_and_key(client_cert_callback=None): if asyncio.iscoroutine(result): cert, key = await result else: - cert, key = result + cert, key = result return True, cert, key has_cert, cert, key, _ = await get_client_ssl_credentials() diff --git a/tests/transport/aio/test_sessions_mtls.py b/tests/transport/aio/test_sessions_mtls.py index 40b6b5bc3..8ca5adfe5 100644 --- a/tests/transport/aio/test_sessions_mtls.py +++ b/tests/transport/aio/test_sessions_mtls.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import os import ssl from unittest import mock @@ -32,13 +33,14 @@ class TestSessionsMtls: + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) @mock.patch("os.path.exists") @mock.patch( "builtins.open", new_callable=mock.mock_open, read_data=json.dumps(VALID_WORKLOAD_CONFIG), ) - @mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key") + @mock.patch("google.auth.aio.transport.mtls.get_client_cert_and_key") @mock.patch("ssl.create_default_context") @pytest.mark.asyncio async def test_configure_mtls_channel( @@ -49,7 +51,7 @@ async def test_configure_mtls_channel( valid workload config is mocked. """ mock_exists.return_value = True - mock_helper.return_value = (b"fake_cert_data", b"fake_key_data") + mock_helper.return_value = (True, b"fake_cert_data", b"fake_key_data") mock_context = mock.Mock(spec=ssl.SSLContext) mock_ssl.return_value = mock_context @@ -61,6 +63,7 @@ async def test_configure_mtls_channel( assert session._is_mtls is True assert mock_context.load_cert_chain.called + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) @mock.patch("os.path.exists") @pytest.mark.asyncio async def test_configure_mtls_channel_disabled(self, mock_exists): @@ -89,10 +92,10 @@ async def test_configure_mtls_channel_invalid_format(self, mock_file, mock_exist mock_creds = mock.Mock(spec=credentials.Credentials) session = sessions.AsyncAuthorizedSession(mock_creds) - with pytest.raises( - exceptions.MutualTLSChannelError): + with pytest.raises(exceptions.MutualTLSChannelError): await session.configure_mtls_channel() + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) @pytest.mark.asyncio @mock.patch( "google.auth.aio.transport.mtls.has_default_client_cert_source", From b560507b79eeee3654d4b412e405ecac6739f6e5 Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Mon, 16 Feb 2026 21:41:31 -0800 Subject: [PATCH 10/12] chore: Add mtls flag check Signed-off-by: Radhika Agrawal --- tests/transport/aio/test_sessions_mtls.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/transport/aio/test_sessions_mtls.py b/tests/transport/aio/test_sessions_mtls.py index 8ca5adfe5..a9681856f 100644 --- a/tests/transport/aio/test_sessions_mtls.py +++ b/tests/transport/aio/test_sessions_mtls.py @@ -79,6 +79,7 @@ async def test_configure_mtls_channel_disabled(self, mock_exists): # If the file doesn't exist, it shouldn't error; it just won't use mTLS assert session._is_mtls is False + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) @mock.patch("os.path.exists") @mock.patch( "builtins.open", new_callable=mock.mock_open, read_data='{"invalid": "format"}' From 0ec1bffddd8e179458a809343afb1c0879b350a0 Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Mon, 16 Feb 2026 21:52:34 -0800 Subject: [PATCH 11/12] fix: test-fix Signed-off-by: Radhika Agrawal --- tests/transport/aio/test_sessions_mtls.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/transport/aio/test_sessions_mtls.py b/tests/transport/aio/test_sessions_mtls.py index a9681856f..2e47acb41 100644 --- a/tests/transport/aio/test_sessions_mtls.py +++ b/tests/transport/aio/test_sessions_mtls.py @@ -33,6 +33,7 @@ class TestSessionsMtls: + @pytest.mark.asyncio @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) @mock.patch("os.path.exists") @mock.patch( @@ -42,7 +43,6 @@ class TestSessionsMtls: ) @mock.patch("google.auth.aio.transport.mtls.get_client_cert_and_key") @mock.patch("ssl.create_default_context") - @pytest.mark.asyncio async def test_configure_mtls_channel( self, mock_ssl, mock_helper, mock_file, mock_exists ): @@ -63,9 +63,9 @@ async def test_configure_mtls_channel( assert session._is_mtls is True assert mock_context.load_cert_chain.called + @pytest.mark.asyncio @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) @mock.patch("os.path.exists") - @pytest.mark.asyncio async def test_configure_mtls_channel_disabled(self, mock_exists): """ Tests behavior when the config file does not exist. @@ -79,12 +79,12 @@ async def test_configure_mtls_channel_disabled(self, mock_exists): # If the file doesn't exist, it shouldn't error; it just won't use mTLS assert session._is_mtls is False + @pytest.mark.asyncio @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) @mock.patch("os.path.exists") @mock.patch( "builtins.open", new_callable=mock.mock_open, read_data='{"invalid": "format"}' ) - @pytest.mark.asyncio async def test_configure_mtls_channel_invalid_format(self, mock_file, mock_exists): """ Verifies that the MutualTLSChannelError is raised for bad formats. @@ -96,8 +96,8 @@ async def test_configure_mtls_channel_invalid_format(self, mock_file, mock_exist with pytest.raises(exceptions.MutualTLSChannelError): await session.configure_mtls_channel() - @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) @pytest.mark.asyncio + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) @mock.patch( "google.auth.aio.transport.mtls.has_default_client_cert_source", return_value=True, From 6eda8c1892a7486116172ef369f0447680a6aa6d Mon Sep 17 00:00:00 2001 From: Radhika Agrawal Date: Mon, 16 Feb 2026 23:13:26 -0800 Subject: [PATCH 12/12] fix: test fixes for the unit test3.9 Signed-off-by: Radhika Agrawal --- google/auth/aio/transport/sessions.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/google/auth/aio/transport/sessions.py b/google/auth/aio/transport/sessions.py index 12c844b7b..1c6d83033 100644 --- a/google/auth/aio/transport/sessions.py +++ b/google/auth/aio/transport/sessions.py @@ -63,7 +63,14 @@ def _remaining_time(): async def with_timeout(coro): try: - remaining = _remaining_time() + try: + remaining = _remaining_time() + except TimeoutError: + # If we timeout before starting the call, + # we must close the coroutine to avoid leaks. + if hasattr(coro, "close"): + coro.close() + raise response = await asyncio.wait_for(coro, remaining) return response except (asyncio.TimeoutError, TimeoutError) as e: