Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.
Merged
168 changes: 168 additions & 0 deletions google/auth/aio/transport/mtls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Helper functions for mTLS in async for discovery of certs.
"""
Comment thread
daniel-sanche marked this conversation as resolved.

import asyncio
import logging
from os import getenv, path

from google.auth import exceptions
import google.auth.transport._mtls_helper

CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json"
Comment thread
agrawalradhika-cell marked this conversation as resolved.
Outdated
_LOGGER = logging.getLogger(__name__)


def _check_config_path(config_path):
"""Checks for config file path. If it exists, returns the absolute path with user expansion;
otherwise returns None.

Args:
config_path (str): The config file path for certificate_config.json for example

Returns:
str: absolute path if exists and None otherwise.
"""
config_path = path.expanduser(config_path)
if not path.exists(config_path):
_LOGGER.debug("%s is not found.", config_path)
return None
return config_path


async def _run_in_executor(func, *args):
"""Run a blocking function in an executor to avoid blocking the event loop.

This implements the non-blocking execution strategy for disk I/O operations.
"""
try:
# For python versions 3.9 and newer versions
return await asyncio.to_thread(func, *args)
except AttributeError:
# Fallback for older Python versions
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, func, *args)


def has_default_client_cert_source():
Comment thread
agrawalradhika-cell marked this conversation as resolved.
Outdated
"""Check if default client SSL credentials exists on the device.

Returns:
bool: indicating if the default client cert source exists.
"""
if _check_config_path(CERTIFICATE_CONFIGURATION_DEFAULT_PATH) is not None:
return True
cert_config_path = getenv("GOOGLE_API_CERTIFICATE_CONFIG")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we use the value from env var if set and fallback to default?

Copy link
Copy Markdown
Contributor Author

@agrawalradhika-cell agrawalradhika-cell Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on AIP-4114 - Users should enable Device Certificate Authentication through ADC instead of manual configuration via client options.
Thus, ~/.config/gcloud/certificate_config.json checks if the certs are present at standard location, and
if the certs are not present at default location, then certs are checked for specific custom config.

Same logic in synchronous operations - see url

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that's right? That section uses the word should, not must.

If we put the default check above the env var check, that completely removes the ability to manually override, which doesn't seem like the intended approach to me, even if automatic detection is the best practice

But I just skimmed the doc, so let me know if I'm missing something

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the same helper for sync.

I do acknowledge concerns regarding environment variable precedence, the synchronous implementation intentionally uses a 'defaults-first' approach per AIP-4114 to encourage managed device identities via Application Default Credentials (ADC).
Keeping async logic consistent with the synchronous helper avoids subtle behavioral discrepancies.

if cert_config_path and _check_config_path(cert_config_path) is not None:
return True
return False
Comment thread
agrawalradhika-cell marked this conversation as resolved.
Outdated


def default_client_cert_source():
"""Get a callback which returns the default client SSL credentials.

Returns:
Awaitable[Callable[[], [bytes, bytes]]]: A callback which returns the default
client certificate bytes and private key bytes, both in PEM format.

Raises:
google.auth.exceptions.DefaultClientCertSourceError: If the default
client SSL credentials don't exist or are malformed.
"""
if not has_default_client_cert_source():
raise exceptions.MutualTLSChannelError(
"Default client cert source doesn't exist"
)

async def callback():
try:
_, cert_bytes, key_bytes = await get_client_cert_and_key()
Comment thread
agrawalradhika-cell marked this conversation as resolved.
except (OSError, RuntimeError, ValueError) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc

return cert_bytes, key_bytes

return callback


async def get_client_ssl_credentials(
certificate_config_path=None,
):
"""Returns the client side certificate, private key and passphrase.

We look for certificates and keys with the following order of priority:
1. Certificate and key specified by certificate_config.json.
Currently, only X.509 workload certificates are supported.
Comment thread
agrawalradhika-cell marked this conversation as resolved.

Args:
certificate_config_path (str): The certificate_config.json file path.

Returns:
Tuple[bool, bytes, bytes, bytes]:
A boolean indicating if cert, key and passphrase are obtained, the
cert bytes and key bytes both in PEM format, and passphrase bytes.

Raises:
google.auth.exceptions.ClientCertError: if problems occurs when getting
the cert, key and passphrase.
"""

# Attempt to retrieve X.509 Workload cert and key.
cert, key = await _run_in_executor(
google.auth.transport._mtls_helper._get_workload_cert_and_key,
certificate_config_path,
)
Comment thread
agrawalradhika-cell marked this conversation as resolved.
Comment thread
daniel-sanche marked this conversation as resolved.

if cert and key:
return True, cert, key, None

return False, None, None, None


async def get_client_cert_and_key(client_cert_callback=None):
"""Returns the client side certificate and private key. The function first
Comment thread
daniel-sanche marked this conversation as resolved.
tries to get certificate and key from client_cert_callback; if the callback
is None or doesn't provide certificate and key, the function tries application
default SSL credentials.

Args:
client_cert_callback (Optional[Callable[[], (bytes, bytes)]]): An
optional callback which returns client certificate bytes and private
key bytes both in PEM format.

Returns:
Tuple[bool, bytes, bytes]:
A boolean indicating if cert and key are obtained, the cert bytes
and key bytes both in PEM format.

Raises:
google.auth.exceptions.ClientCertError: if problems occurs when getting
the cert and key.
"""
if client_cert_callback:
try:
# If it's awaitable, this works.
cert, key = await client_cert_callback()
except TypeError:
# If it's not awaitable (e.g., a tuple), result is already the data.
cert, key = client_cert_callback()
Comment thread
agrawalradhika-cell marked this conversation as resolved.
Outdated

return True, cert, key

has_cert, cert, key, _ = await get_client_ssl_credentials()
return has_cert, cert, key
207 changes: 207 additions & 0 deletions tests/transport/test_aio_mtls_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pytest

from google.auth import exceptions
from google.auth.aio.transport import mtls

CERT_DATA = b"client-cert"
KEY_DATA = b"client-key"


class TestMTLS:
@mock.patch("google.auth.aio.transport.mtls.path.expanduser")
@mock.patch("google.auth.aio.transport.mtls.path.exists")
def test__check_config_path_exists(self, mock_exists, mock_expand):
mock_expand.side_effect = lambda x: x.replace("~", "/home/user")
mock_exists.return_value = True

input_path = "~/config.json"
expected_path = "/home/user/config.json"
result = mtls._check_config_path(input_path)

assert result == expected_path
mock_exists.assert_called_with(expected_path)

@mock.patch("google.auth.aio.transport.mtls.path.exists", return_value=False)
def test__check_config_path_not_found(self, mock_exists):
result = mtls._check_config_path("nonexistent.json")
assert result is None

@mock.patch("google.auth.aio.transport.mtls._check_config_path")
@mock.patch("google.auth.aio.transport.mtls.getenv")
def test_has_default_client_cert_source_env_var(self, mock_getenv, mock_check):
custom_path = "/custom/path.json"
mock_check.side_effect = lambda x: custom_path if x == custom_path else None
mock_getenv.return_value = custom_path

assert mtls.has_default_client_cert_source() is True

@mock.patch("google.auth.aio.transport.mtls._check_config_path")
@mock.patch("google.auth.aio.transport.mtls.getenv")
def test_has_default_client_cert_source_check_priority(
self, mock_getenv, mock_check
):
mock_check.return_value = "/default/path.json"

assert mtls.has_default_client_cert_source() is True
mock_getenv.assert_not_called()

@mock.patch(
"google.auth.aio.transport.mtls.has_default_client_cert_source",
return_value=False,
)
def test_default_client_cert_source_none(self, mock_has_default):
with pytest.raises(exceptions.MutualTLSChannelError):
mtls.default_client_cert_source()

@pytest.mark.asyncio
@mock.patch(
"google.auth.aio.transport.mtls.get_client_cert_and_key",
new_callable=mock.AsyncMock,
)
@mock.patch(
"google.auth.aio.transport.mtls.has_default_client_cert_source",
return_value=True,
)
async def test_default_client_cert_source_success(
self, mock_has_default, mock_get_cert_key
):
mock_get_cert_key.return_value = (True, CERT_DATA, KEY_DATA)

# Note: default_client_cert_source is NOT async, but it returns an async callback
callback = mtls.default_client_cert_source()
assert callable(callback)

cert, key = await callback()
assert cert == CERT_DATA
assert key == KEY_DATA

@pytest.mark.asyncio
@mock.patch(
"google.auth.aio.transport.mtls.has_default_client_cert_source",
return_value=False,
)
async def test_default_client_cert_source_not_found(self, mock_has_default):
with pytest.raises(exceptions.MutualTLSChannelError, match="doesn't exist"):
await mtls.default_client_cert_source()

@pytest.mark.asyncio
@mock.patch(
"google.auth.aio.transport.mtls.get_client_cert_and_key",
new_callable=mock.AsyncMock,
)
@mock.patch(
"google.auth.aio.transport.mtls.has_default_client_cert_source",
return_value=True,
)
async def test_default_client_cert_source_callback_wraps_exception(
self, mock_has, mock_get
):
mock_get.side_effect = ValueError("Format error")

callback = mtls.default_client_cert_source()

with pytest.raises(exceptions.MutualTLSChannelError) as excinfo:
await callback()
assert "Format error" in str(excinfo.value)

@pytest.mark.asyncio
@mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key")
async def test_get_client_ssl_credentials_success(self, mock_workload):
mock_workload.return_value = (CERT_DATA, KEY_DATA)

success, cert, key, passphrase = await mtls.get_client_ssl_credentials()

assert success is True
assert cert == CERT_DATA
assert key == KEY_DATA
assert passphrase is None

@pytest.mark.asyncio
@mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials")
async def test_get_client_cert_and_key_no_credentials_found(self, mock_get_ssl):
mock_get_ssl.return_value = (False, None, None, None)

success, cert, key = await mtls.get_client_cert_and_key(None)

assert success is False
assert cert is None
assert key is None

@pytest.mark.asyncio
async def test_get_client_cert_and_key_callback_async(self):
# Test with an actual coroutine/AsyncMock to satisfy the 'await' in your code
callback = mock.AsyncMock(return_value=(CERT_DATA, KEY_DATA))

success, cert, key = await mtls.get_client_cert_and_key(callback)

assert success is True
assert cert == CERT_DATA
assert key == KEY_DATA
callback.assert_called_once()

@pytest.mark.asyncio
async def test_get_client_cert_and_key_callback_sync(self):
# Test the fallback logic: if it's a sync function, the TypeError is caught
callback = mock.Mock(return_value=(CERT_DATA, KEY_DATA))

success, cert, key = await mtls.get_client_cert_and_key(callback)

assert success is True
assert cert == CERT_DATA
# In your current implementation, this might still show 2 calls if the
# first 'await' attempt triggers a call before failing.
# To strictly avoid 2 calls, the implementation would need to check inspect.iscoroutinefunction.
assert callback.call_count >= 1

@pytest.mark.asyncio
@mock.patch(
"google.auth.aio.transport.mtls.get_client_ssl_credentials",
new_callable=mock.AsyncMock,
)
async def test_get_client_cert_and_key_default(self, mock_get_credentials):
mock_get_credentials.return_value = (True, CERT_DATA, KEY_DATA, None)

success, cert, key = await mtls.get_client_cert_and_key(None)

assert success is True
assert cert == CERT_DATA
assert key == KEY_DATA
mock_get_credentials.assert_called_once()

@pytest.mark.asyncio
@mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key")
async def test_get_client_ssl_credentials_error(self, mock_workload):
mock_workload.side_effect = exceptions.ClientCertError(
"Failed to read metadata"
)

with pytest.raises(exceptions.ClientCertError, match="Failed to read metadata"):
await mtls.get_client_ssl_credentials()

@pytest.mark.asyncio
@mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials")
async def test_get_client_cert_and_key_exception_propagation(self, mock_get_ssl):
mock_get_ssl.side_effect = exceptions.ClientCertError(
"Underlying credentials failed"
)

with pytest.raises(
exceptions.ClientCertError, match="Underlying credentials failed"
):
await mtls.get_client_cert_and_key(client_cert_callback=None)
Loading