Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/auth/aio/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class Request(transport.Request):
.. automethod:: __call__
"""

def __init__(self, session: aiohttp.ClientSession = None):
def __init__(self, session: Optional[aiohttp.ClientSession] = None):
self._session = session
self._closed = False

Expand Down
226 changes: 226 additions & 0 deletions google/auth/aio/transport/mtls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Helper functions for mTLS in async for discovery of certs.
"""

import asyncio
import contextlib
import logging
import os
from os import getenv, path
import ssl
import tempfile
from typing import Optional

from google.auth import exceptions
import google.auth.transport._mtls_helper

CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json"
_LOGGER = logging.getLogger(__name__)


@contextlib.contextmanager
def _create_temp_file(content: bytes):
"""Creates a temporary file with the given content.

Args:
content (bytes): The content to write to the file.

Yields:
str: The path to the temporary file.
"""
# Create a temporary file that is readable only by the owner.
fd, file_path = tempfile.mkstemp()
try:
with os.fdopen(fd, "wb") as f:
f.write(content)
yield file_path
finally:
# Securely delete the file after use.
if os.path.exists(file_path):
os.remove(file_path)


def make_client_cert_ssl_context(
cert_bytes: bytes, key_bytes: bytes, passphrase: Optional[bytes] = None
) -> ssl.SSLContext:
"""Creates an SSLContext with the given client certificate and key.
This function writes the certificate and key to temporary files so that
ssl.create_default_context can load them, as the ssl module requires
file paths for client certificates.
Args:
cert_bytes (bytes): The client certificate content in PEM format.
key_bytes (bytes): The client private key content in PEM format.
passphrase (Optional[bytes]): The passphrase for the private key, if any.
Returns:
ssl.SSLContext: The configured SSL context with client certificate.

Raises:
google.auth.exceptions.TransportError: If there is an error loading the certificate.
"""
try:
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)

# Write cert and key to temp files because ssl.load_cert_chain requires paths
with _create_temp_file(cert_bytes) as cert_path:
with _create_temp_file(key_bytes) as key_path:
context.load_cert_chain(
certfile=cert_path, keyfile=key_path, password=passphrase
)
return context
except (ssl.SSLError, OSError) as exc:
raise exceptions.TransportError(
"Failed to load client certificate and key for mTLS."
) from exc


def _check_config_path(config_path):
"""Checks for config file path. If it exists, returns the absolute path with user expansion;
otherwise returns None.

Args:
config_path (str): The config file path for certificate_config.json for example

Returns:
str: absolute path if exists and None otherwise.
"""
config_path = path.expanduser(config_path)
if not path.exists(config_path):
_LOGGER.debug("%s is not found.", config_path)
return None
return config_path


async def _run_in_executor(func, *args):
"""Run a blocking function in an executor to avoid blocking the event loop.

This implements the non-blocking execution strategy for disk I/O operations.
"""
try:
# For python versions 3.9 and newer versions
return await asyncio.to_thread(func, *args)
except AttributeError:
# Fallback for older Python versions
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, func, *args)


def has_default_client_cert_source():
"""Check if default client SSL credentials exists on the device.

Returns:
bool: indicating if the default client cert source exists.
"""
if _check_config_path(CERTIFICATE_CONFIGURATION_DEFAULT_PATH) is not None:
return True
cert_config_path = getenv("GOOGLE_API_CERTIFICATE_CONFIG")
if cert_config_path and _check_config_path(cert_config_path) is not None:
return True
return False


def default_client_cert_source():
"""Get a callback which returns the default client SSL credentials.

Returns:
Awaitable[Callable[[], [bytes, bytes]]]: A callback which returns the default
client certificate bytes and private key bytes, both in PEM format.

Raises:
google.auth.exceptions.MutualTLSChannelError: If the default
client SSL credentials don't exist or are malformed.
"""
if not has_default_client_cert_source():
raise exceptions.MutualTLSChannelError(
"Default client cert source doesn't exist"
)

async def callback():
try:
_, cert_bytes, key_bytes = await get_client_cert_and_key()
except (OSError, RuntimeError, ValueError) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc

return cert_bytes, key_bytes

return callback


async def get_client_ssl_credentials(
certificate_config_path=None,
):
"""Returns the client side certificate, private key and passphrase.

We look for certificates and keys with the following order of priority:
1. Certificate and key specified by certificate_config.json.
Currently, only X.509 workload certificates are supported.

Args:
certificate_config_path (str): The certificate_config.json file path.

Returns:
Tuple[bool, bytes, bytes, bytes]:
A boolean indicating if cert, key and passphrase are obtained, the
cert bytes and key bytes both in PEM format, and passphrase bytes.

Raises:
google.auth.exceptions.ClientCertError: if problems occurs when getting
the cert, key and passphrase.
"""

# Attempt to retrieve X.509 Workload cert and key.
cert, key = await _run_in_executor(
google.auth.transport._mtls_helper._get_workload_cert_and_key,
certificate_config_path,
)

if cert and key:
return True, cert, key, None

return False, None, None, None


async def get_client_cert_and_key(client_cert_callback=None):
"""Returns the client side certificate and private key. The function first
tries to get certificate and key from client_cert_callback; if the callback
is None or doesn't provide certificate and key, the function tries application
default SSL credentials.

Args:
client_cert_callback (Optional[Callable[[], (bytes, bytes)]]): An
optional callback which returns client certificate bytes and private
key bytes both in PEM format.

Returns:
Tuple[bool, bytes, bytes]:
A boolean indicating if cert and key are obtained, the cert bytes
and key bytes both in PEM format.

Raises:
google.auth.exceptions.ClientCertError: if problems occurs when getting
the cert and key.
"""
if client_cert_callback:
result = client_cert_callback()
if asyncio.iscoroutine(result):
cert, key = await result
else:
cert, key = result
return True, cert, key

has_cert, cert, key, _ = await get_client_ssl_credentials()
return has_cert, cert, key
77 changes: 76 additions & 1 deletion google/auth/aio/transport/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
from google.auth import _exponential_backoff, exceptions
from google.auth.aio import transport
from google.auth.aio.credentials import Credentials
from google.auth.aio.transport import mtls
from google.auth.exceptions import TimeoutError
import google.auth.transport._mtls_helper

try:
import aiohttp
from google.auth.aio.transport.aiohttp import Request as AiohttpRequest

AIOHTTP_INSTALLED = True
Expand Down Expand Up @@ -60,7 +63,14 @@ def _remaining_time():

async def with_timeout(coro):
try:
remaining = _remaining_time()
try:
remaining = _remaining_time()
except TimeoutError:
# If we timeout before starting the call,
# we must close the coroutine to avoid leaks.
if hasattr(coro, "close"):
coro.close()
raise
response = await asyncio.wait_for(coro, remaining)
return response
except (asyncio.TimeoutError, TimeoutError) as e:
Expand Down Expand Up @@ -124,12 +134,70 @@ def __init__(
_auth_request = auth_request
if not _auth_request and AIOHTTP_INSTALLED:
_auth_request = AiohttpRequest()
self._is_mtls = False
self._cached_cert = None
if _auth_request is None:
raise exceptions.TransportError(
"`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value."
)
self._auth_request = _auth_request

async def configure_mtls_channel(self, client_cert_callback=None):
"""Configure the client certificate and key for SSL connection.

The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is
explicitly set to `true`. In this case if client certificate and key are
successfully obtained (from the given client_cert_callback or from application
default SSL credentials), the underlying transport will be reconfigured
to use mTLS.

Args:
client_cert_callback (Optional[Callable[[], (bytes, bytes)]]):
The optional callback returns the client certificate and private
key bytes both in PEM format.
If the callback is None, application default SSL credentials
will be used.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
creation failed for any reason.
"""
# Run the blocking check in an executor
use_client_cert = await mtls._run_in_executor(
google.auth.transport._mtls_helper.check_use_client_cert
)
if not use_client_cert:
self._is_mtls = False
return

try:
(
self._is_mtls,
cert,
key,
) = await mtls.get_client_cert_and_key(client_cert_callback)

if self._is_mtls:
self._cached_cert = cert
ssl_context = await mtls._run_in_executor(
mtls.make_client_cert_ssl_context, cert, key
)

# Re-create the auth request with the new SSL context
if AIOHTTP_INSTALLED and isinstance(self._auth_request, AiohttpRequest):
connector = aiohttp.TCPConnector(ssl=ssl_context)
new_session = aiohttp.ClientSession(connector=connector)
await self._auth_request.close()
self._auth_request = AiohttpRequest(session=new_session)
Comment thread
agrawalradhika-cell marked this conversation as resolved.

except (
exceptions.ClientCertError,
ImportError,
OSError,
) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc

async def request(
self,
method: str,
Expand Down Expand Up @@ -174,6 +242,8 @@ async def request(
retries = _exponential_backoff.AsyncExponentialBackoff(
total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS
)
if headers is None:
headers = {}
async with timeout_guard(max_allowed_time) as with_timeout:
await with_timeout(
# Note: before_request will attempt to refresh credentials if expired.
Expand Down Expand Up @@ -261,6 +331,11 @@ async def delete(
"DELETE", url, data, headers, max_allowed_time, timeout, **kwargs
)

@property
def is_mtls(self):
"""Indicates if mutual TLS is enabled."""
return self._is_mtls

async def close(self) -> None:
"""
Close the underlying auth request session.
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def blacken(session):
@nox.session(python=DEFAULT_PYTHON_VERSION)
def mypy(session):
"""Verify type hints are mypy compatible."""
session.install("-e", ".")
session.install("-e", ".[aiohttp]")
session.install(
"mypy",
"types-certifi",
Expand Down
Loading
Loading