Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.

Commit dbd40d0

Browse files
feat: mTLS configuration via x.509 for asynchronous session in google-auth
Signed-off-by: Radhika Agrawal <agrawalradhika@google.com>
1 parent fce3c71 commit dbd40d0

5 files changed

Lines changed: 248 additions & 3 deletions

File tree

google/auth/aio/transport/aiohttp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class Request(transport.Request):
113113
.. automethod:: __call__
114114
"""
115115

116-
def __init__(self, session: aiohttp.ClientSession = None):
116+
def __init__(self, session: Optional[aiohttp.ClientSession] = None):
117117
self._session = session
118118
self._closed = False
119119

google/auth/aio/transport/mtls.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,72 @@
1717
"""
1818

1919
import asyncio
20+
import contextlib
2021
import logging
22+
import os
23+
from os import getenv, path
24+
import ssl
25+
import tempfile
26+
from typing import Optional
2127

2228
from google.auth import exceptions
2329
import google.auth.transport._mtls_helper
24-
import google.auth.transport.mtls
2530

2631
_LOGGER = logging.getLogger(__name__)
2732

33+
@contextlib.contextmanager
34+
def _create_temp_file(content: bytes):
35+
"""Creates a temporary file with the given content.
36+
37+
Args:
38+
content (bytes): The content to write to the file.
39+
40+
Yields:
41+
str: The path to the temporary file.
42+
"""
43+
# Create a temporary file that is readable only by the owner.
44+
fd, file_path = tempfile.mkstemp()
45+
try:
46+
with os.fdopen(fd, "wb") as f:
47+
f.write(content)
48+
yield file_path
49+
finally:
50+
# Securely delete the file after use.
51+
if os.path.exists(file_path):
52+
os.remove(file_path)
53+
54+
55+
def make_client_cert_ssl_context(
56+
cert_bytes: bytes, key_bytes: bytes, passphrase: Optional[bytes] = None
57+
) -> ssl.SSLContext:
58+
"""Creates an SSLContext with the given client certificate and key.
59+
This function writes the certificate and key to temporary files so that
60+
ssl.create_default_context can load them, as the ssl module requires
61+
file paths for client certificates.
62+
Args:
63+
cert_bytes (bytes): The client certificate content in PEM format.
64+
key_bytes (bytes): The client private key content in PEM format.
65+
passphrase (Optional[bytes]): The passphrase for the private key, if any.
66+
Returns:
67+
ssl.SSLContext: The configured SSL context with client certificate.
68+
69+
Raises:
70+
google.auth.exceptions.TransportError: If there is an error loading the certificate.
71+
"""
72+
try:
73+
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
74+
75+
# Write cert and key to temp files because ssl.load_cert_chain requires paths
76+
with _create_temp_file(cert_bytes) as cert_path:
77+
with _create_temp_file(key_bytes) as key_path:
78+
context.load_cert_chain(
79+
certfile=cert_path, keyfile=key_path, password=passphrase
80+
)
81+
return context
82+
except (ssl.SSLError, OSError) as exc:
83+
raise exceptions.TransportError(
84+
"Failed to load client certificate and key for mTLS."
85+
) from exc
2886

2987
async def _run_in_executor(func, *args):
3088
"""Run a blocking function in an executor to avoid blocking the event loop.

google/auth/aio/transport/sessions.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
from google.auth import _exponential_backoff, exceptions
2222
from google.auth.aio import transport
2323
from google.auth.aio.credentials import Credentials
24+
from google.auth.aio.transport import mtls
2425
from google.auth.exceptions import TimeoutError
26+
import google.auth.transport._mtls_helper
2527

2628
try:
29+
import aiohttp
2730
from google.auth.aio.transport.aiohttp import Request as AiohttpRequest
2831

2932
AIOHTTP_INSTALLED = True
@@ -124,12 +127,70 @@ def __init__(
124127
_auth_request = auth_request
125128
if not _auth_request and AIOHTTP_INSTALLED:
126129
_auth_request = AiohttpRequest()
130+
self._is_mtls = False
131+
self._cached_cert = None
127132
if _auth_request is None:
128133
raise exceptions.TransportError(
129134
"`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value."
130135
)
131136
self._auth_request = _auth_request
132137

138+
async def configure_mtls_channel(self, client_cert_callback=None):
139+
"""Configure the client certificate and key for SSL connection.
140+
141+
The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is
142+
explicitly set to `true`. In this case if client certificate and key are
143+
successfully obtained (from the given client_cert_callback or from application
144+
default SSL credentials), the underlying transport will be reconfigured
145+
to use mTLS.
146+
147+
Args:
148+
client_cert_callback (Optional[Callable[[], (bytes, bytes)]]):
149+
The optional callback returns the client certificate and private
150+
key bytes both in PEM format.
151+
If the callback is None, application default SSL credentials
152+
will be used.
153+
154+
Raises:
155+
google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
156+
creation failed for any reason.
157+
"""
158+
# Run the blocking check in an executor
159+
use_client_cert = await mtls._run_in_executor(
160+
google.auth.transport._mtls_helper.check_use_client_cert
161+
)
162+
if not use_client_cert:
163+
self._is_mtls = False
164+
return
165+
166+
try:
167+
(
168+
self._is_mtls,
169+
cert,
170+
key,
171+
) = await mtls.get_client_cert_and_key(client_cert_callback)
172+
173+
if self._is_mtls:
174+
self._cached_cert = cert
175+
ssl_context = await mtls._run_in_executor(
176+
mtls.make_client_cert_ssl_context, cert, key
177+
)
178+
179+
# Re-create the auth request with the new SSL context
180+
if AIOHTTP_INSTALLED and isinstance(self._auth_request, AiohttpRequest):
181+
connector = aiohttp.TCPConnector(ssl=ssl_context)
182+
new_session = aiohttp.ClientSession(connector=connector)
183+
await self._auth_request.close()
184+
self._auth_request = AiohttpRequest(session=new_session)
185+
186+
except (
187+
exceptions.ClientCertError,
188+
ImportError,
189+
OSError,
190+
) as caught_exc:
191+
new_exc = exceptions.MutualTLSChannelError(caught_exc)
192+
raise new_exc from caught_exc
193+
133194
async def request(
134195
self,
135196
method: str,
@@ -174,6 +235,8 @@ async def request(
174235
retries = _exponential_backoff.AsyncExponentialBackoff(
175236
total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS
176237
)
238+
if headers is None:
239+
headers = {}
177240
async with timeout_guard(max_allowed_time) as with_timeout:
178241
await with_timeout(
179242
# Note: before_request will attempt to refresh credentials if expired.
@@ -261,6 +324,11 @@ async def delete(
261324
"DELETE", url, data, headers, max_allowed_time, timeout, **kwargs
262325
)
263326

327+
@property
328+
def is_mtls(self):
329+
"""Indicates if mutual TLS is enabled."""
330+
return self._is_mtls
331+
264332
async def close(self) -> None:
265333
"""
266334
Close the underlying auth request session.

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def blacken(session):
9191
@nox.session(python=DEFAULT_PYTHON_VERSION)
9292
def mypy(session):
9393
"""Verify type hints are mypy compatible."""
94-
session.install("-e", ".")
94+
session.install("-e", ".[aiohttp]")
9595
session.install(
9696
"mypy",
9797
"types-certifi",
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import os
17+
import ssl
18+
from unittest import mock
19+
20+
import pytest
21+
22+
from google.auth import exceptions
23+
from google.auth.aio import credentials
24+
from google.auth.aio.transport import sessions
25+
26+
# This is the valid "workload" format the library expects
27+
VALID_WORKLOAD_CONFIG = {
28+
"version": 1,
29+
"cert_configs": {
30+
"workload": {"cert_path": "/tmp/mock_cert.pem", "key_path": "/tmp/mock_key.pem"}
31+
},
32+
}
33+
34+
35+
class TestSessionsMtls:
36+
@pytest.mark.asyncio
37+
@mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"})
38+
@mock.patch("os.path.exists")
39+
@mock.patch(
40+
"builtins.open",
41+
new_callable=mock.mock_open,
42+
read_data=json.dumps(VALID_WORKLOAD_CONFIG),
43+
)
44+
@mock.patch("google.auth.aio.transport.mtls.get_client_cert_and_key")
45+
@mock.patch("ssl.create_default_context")
46+
async def test_configure_mtls_channel(
47+
self, mock_ssl, mock_helper, mock_file, mock_exists
48+
):
49+
"""
50+
Tests that the mTLS channel configures correctly when a
51+
valid workload config is mocked.
52+
"""
53+
mock_exists.return_value = True
54+
mock_helper.return_value = (True, b"fake_cert_data", b"fake_key_data")
55+
56+
mock_context = mock.Mock(spec=ssl.SSLContext)
57+
mock_ssl.return_value = mock_context
58+
59+
mock_creds = mock.Mock(spec=credentials.Credentials)
60+
session = sessions.AsyncAuthorizedSession(mock_creds)
61+
await session.configure_mtls_channel()
62+
63+
assert session._is_mtls is True
64+
assert mock_context.load_cert_chain.called
65+
66+
@pytest.mark.asyncio
67+
@mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"})
68+
@mock.patch("os.path.exists")
69+
async def test_configure_mtls_channel_disabled(self, mock_exists):
70+
"""
71+
Tests behavior when the config file does not exist.
72+
"""
73+
mock_exists.return_value = False
74+
mock_creds = mock.Mock(spec=credentials.Credentials)
75+
76+
session = sessions.AsyncAuthorizedSession(mock_creds)
77+
await session.configure_mtls_channel()
78+
79+
# If the file doesn't exist, it shouldn't error; it just won't use mTLS
80+
assert session._is_mtls is False
81+
82+
@pytest.mark.asyncio
83+
@mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"})
84+
@mock.patch("os.path.exists")
85+
@mock.patch(
86+
"builtins.open", new_callable=mock.mock_open, read_data='{"invalid": "format"}'
87+
)
88+
async def test_configure_mtls_channel_invalid_format(self, mock_file, mock_exists):
89+
"""
90+
Verifies that the MutualTLSChannelError is raised for bad formats.
91+
"""
92+
mock_exists.return_value = True
93+
mock_creds = mock.Mock(spec=credentials.Credentials)
94+
95+
session = sessions.AsyncAuthorizedSession(mock_creds)
96+
with pytest.raises(exceptions.MutualTLSChannelError):
97+
await session.configure_mtls_channel()
98+
99+
@pytest.mark.asyncio
100+
@mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"})
101+
@mock.patch(
102+
"google.auth.aio.transport.mtls.has_default_client_cert_source",
103+
return_value=True,
104+
)
105+
async def test_configure_mtls_channel_mock_callback(self, mock_has_cert):
106+
"""
107+
Tests mTLS configuration using bytes-returning callback.
108+
"""
109+
110+
def mock_callback():
111+
return (b"fake_cert_bytes", b"fake_key_bytes")
112+
113+
mock_creds = mock.Mock(spec=credentials.Credentials)
114+
115+
with mock.patch("ssl.SSLContext.load_cert_chain"):
116+
session = sessions.AsyncAuthorizedSession(mock_creds)
117+
await session.configure_mtls_channel(client_cert_callback=mock_callback)
118+
119+
assert session._is_mtls is True

0 commit comments

Comments
 (0)