Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.

Commit af21156

Browse files
chore: Add support for other error types in async/mtls and take care of race conditions
Signed-off-by: Radhika Agrawal <agrawalradhika@google.com>
1 parent 806c329 commit af21156

4 files changed

Lines changed: 90 additions & 21 deletions

File tree

google/auth/aio/transport/mtls.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,19 @@ def make_client_cert_ssl_context(
7070
Raises:
7171
google.auth.exceptions.TransportError: If there is an error loading the certificate.
7272
"""
73-
try:
74-
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
75-
76-
# Write cert and key to temp files because ssl.load_cert_chain requires paths
77-
with _create_temp_file(cert_bytes) as cert_path:
78-
with _create_temp_file(key_bytes) as key_path:
79-
context.load_cert_chain(
80-
certfile=cert_path, keyfile=key_path, password=passphrase
81-
)
82-
return context
83-
except (ssl.SSLError, OSError) as exc:
84-
raise exceptions.TransportError(
85-
"Failed to load client certificate and key for mTLS."
86-
) from exc
73+
with _create_temp_file(cert_bytes) as cert_path, _create_temp_file(
74+
key_bytes
75+
) as key_path:
76+
try:
77+
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
78+
context.load_cert_chain(
79+
certfile=cert_path, keyfile=key_path, password=passphrase
80+
)
81+
return context
82+
except (ssl.SSLError, OSError, IOError, ValueError, RuntimeError) as exc:
83+
raise exceptions.TransportError(
84+
"Failed to load client certificate and key for mTLS."
85+
) from exc
8786

8887

8988
async def _run_in_executor(func, *args):
@@ -104,7 +103,7 @@ def default_client_cert_source():
104103
"""Get a callback which returns the default client SSL credentials.
105104
106105
Returns:
107-
Awaitable[Callable[[], [bytes, bytes]]]: A callback which returns the default
106+
Awaitable[Callable[[], Tuple[bytes, bytes]]]: A callback which returns the default
108107
client certificate bytes and private key bytes, both in PEM format.
109108
110109
Raises:

google/auth/aio/transport/sessions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ async def configure_mtls_channel(self, client_cert_callback=None):
143143
successfully obtained (from the given client_cert_callback or from application
144144
default SSL credentials), the underlying transport will be reconfigured
145145
to use mTLS.
146+
Note: This function does nothing if the `aiohttp` library is not
147+
installed.
146148
147149
Args:
148150
client_cert_callback (Optional[Callable[[], (bytes, bytes)]]):
@@ -180,8 +182,9 @@ async def configure_mtls_channel(self, client_cert_callback=None):
180182
if AIOHTTP_INSTALLED and isinstance(self._auth_request, AiohttpRequest):
181183
connector = aiohttp.TCPConnector(ssl=ssl_context)
182184
new_session = aiohttp.ClientSession(connector=connector)
183-
await self._auth_request.close()
185+
old_auth_request = self._auth_request
184186
self._auth_request = AiohttpRequest(session=new_session)
187+
await old_auth_request.close()
185188

186189
except (
187190
exceptions.ClientCertError,

tests/transport/aio/test_sessions_mtls.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,23 @@ async def test_configure_mtls_channel(self):
4646
), mock.patch(
4747
"google.auth.aio.transport.mtls.get_client_cert_and_key"
4848
) as mock_helper, mock.patch(
49-
"ssl.create_default_context"
50-
) as mock_ssl:
49+
"google.auth.aio.transport.mtls.make_client_cert_ssl_context"
50+
) as mock_make_context:
5151
mock_exists.return_value = True
5252
mock_helper.return_value = (True, b"fake_cert_data", b"fake_key_data")
5353

5454
mock_context = mock.Mock(spec=ssl.SSLContext)
55-
mock_ssl.return_value = mock_context
55+
mock_make_context.return_value = mock_context
5656

57-
# Use AsyncMock for credentials to avoid "coroutine never awaited" warnings
5857
mock_creds = mock.AsyncMock(spec=credentials.Credentials)
5958
session = sessions.AsyncAuthorizedSession(mock_creds)
6059

6160
await session.configure_mtls_channel()
6261

6362
assert session._is_mtls is True
64-
assert mock_context.load_cert_chain.called
63+
mock_make_context.assert_called_once_with(
64+
b"fake_cert_data", b"fake_key_data"
65+
)
6566

6667
@pytest.mark.asyncio
6768
async def test_configure_mtls_channel_disabled(self):

tests/transport/test_aio_mtls_helper.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
import ssl
1517
from unittest import mock
1618

1719
import pytest
@@ -24,6 +26,70 @@
2426

2527

2628
class TestMTLS:
29+
@pytest.mark.asyncio
30+
async def test__create_temp_file(self):
31+
"""Tests that _create_temp_file creates a file with correct content and deletes it."""
32+
content = b"test cert data"
33+
34+
# Test file creation and content
35+
with mtls._create_temp_file(content) as file_path:
36+
assert os.path.exists(file_path)
37+
# Verify file is not readable by others (mkstemp default)
38+
if os.name == "posix":
39+
assert (os.stat(file_path).st_mode & 0o777) == 0o600
40+
41+
with open(file_path, "rb") as f:
42+
assert f.read() == content
43+
44+
# Test file deletion after context exit
45+
assert not os.path.exists(file_path)
46+
47+
@pytest.mark.asyncio
48+
async def test_make_client_cert_ssl_context_success(self):
49+
"""Tests successful creation of an SSLContext with client certificates."""
50+
cert_bytes = b"cert_data"
51+
key_bytes = b"key_data"
52+
passphrase = b"password"
53+
54+
mock_context = mock.Mock(spec=ssl.SSLContext)
55+
56+
with mock.patch(
57+
"ssl.create_default_context", return_value=mock_context
58+
) as mock_create:
59+
context = mtls.make_client_cert_ssl_context(
60+
cert_bytes, key_bytes, passphrase=passphrase
61+
)
62+
63+
assert context == mock_context
64+
mock_create.assert_called_once_with(ssl.Purpose.SERVER_AUTH)
65+
66+
# Verify load_cert_chain was called
67+
assert mock_context.load_cert_chain.called
68+
kwargs = mock_context.load_cert_chain.call_args.kwargs
69+
assert "certfile" in kwargs
70+
assert "keyfile" in kwargs
71+
assert kwargs["password"] == passphrase
72+
73+
assert not os.path.exists(kwargs["certfile"])
74+
assert not os.path.exists(kwargs["keyfile"])
75+
76+
@pytest.mark.asyncio
77+
async def test_make_client_cert_ssl_context_error(self):
78+
"""Verifies that TransportError is raised when SSL loading fails."""
79+
cert_bytes = b"cert_data"
80+
key_bytes = b"key_data"
81+
82+
mock_context = mock.Mock(spec=ssl.SSLContext)
83+
# Mocking an SSLError to trigger the exception handler in make_client_cert_ssl_context
84+
mock_context.load_cert_chain.side_effect = ssl.SSLError("Mock SSL Error")
85+
86+
with mock.patch("ssl.create_default_context", return_value=mock_context):
87+
with pytest.raises(exceptions.TransportError) as exc_info:
88+
mtls.make_client_cert_ssl_context(cert_bytes, key_bytes)
89+
90+
assert "Failed to load client certificate" in str(exc_info.value)
91+
assert isinstance(exc_info.value.__cause__, ssl.SSLError)
92+
2793
@pytest.mark.asyncio
2894
@mock.patch(
2995
"google.auth.transport.mtls.has_default_client_cert_source", return_value=False

0 commit comments

Comments
 (0)