Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.

Commit fce3c71

Browse files
chore: Update based on reviewer comments, updated helpers so that the sync helpers can be reused
Signed-off-by: Radhika Agrawal <agrawalradhika@google.com>
1 parent 8110a6f commit fce3c71

5 files changed

Lines changed: 55 additions & 142 deletions

File tree

google/auth/aio/transport/mtls.py

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,14 @@
1818

1919
import asyncio
2020
import logging
21-
from os import getenv, path
2221

2322
from google.auth import exceptions
2423
import google.auth.transport._mtls_helper
24+
import google.auth.transport.mtls
2525

26-
CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json"
2726
_LOGGER = logging.getLogger(__name__)
2827

2928

30-
def _check_config_path(config_path):
31-
"""Checks for config file path. If it exists, returns the absolute path with user expansion;
32-
otherwise returns None.
33-
34-
Args:
35-
config_path (str): The config file path for certificate_config.json for example
36-
37-
Returns:
38-
str: absolute path if exists and None otherwise.
39-
"""
40-
config_path = path.expanduser(config_path)
41-
if not path.exists(config_path):
42-
_LOGGER.debug("%s is not found.", config_path)
43-
return None
44-
return config_path
45-
46-
4729
async def _run_in_executor(func, *args):
4830
"""Run a blocking function in an executor to avoid blocking the event loop.
4931
@@ -58,20 +40,6 @@ async def _run_in_executor(func, *args):
5840
return await loop.run_in_executor(None, func, *args)
5941

6042

61-
def has_default_client_cert_source():
62-
"""Check if default client SSL credentials exists on the device.
63-
64-
Returns:
65-
bool: indicating if the default client cert source exists.
66-
"""
67-
if _check_config_path(CERTIFICATE_CONFIGURATION_DEFAULT_PATH) is not None:
68-
return True
69-
cert_config_path = getenv("GOOGLE_API_CERTIFICATE_CONFIG")
70-
if cert_config_path and _check_config_path(cert_config_path) is not None:
71-
return True
72-
return False
73-
74-
7543
def default_client_cert_source():
7644
"""Get a callback which returns the default client SSL credentials.
7745
@@ -83,7 +51,9 @@ def default_client_cert_source():
8351
google.auth.exceptions.DefaultClientCertSourceError: If the default
8452
client SSL credentials don't exist or are malformed.
8553
"""
86-
if not has_default_client_cert_source():
54+
if not google.auth.transport.mtls.has_default_client_cert_source(
55+
include_context_aware=False
56+
):
8757
raise exceptions.MutualTLSChannelError(
8858
"Default client cert source doesn't exist"
8959
)
@@ -126,6 +96,7 @@ async def get_client_ssl_credentials(
12696
cert, key = await _run_in_executor(
12797
google.auth.transport._mtls_helper._get_workload_cert_and_key,
12898
certificate_config_path,
99+
False,
129100
)
130101

131102
if cert and key:
@@ -155,13 +126,11 @@ async def get_client_cert_and_key(client_cert_callback=None):
155126
the cert and key.
156127
"""
157128
if client_cert_callback:
129+
result = client_cert_callback()
158130
try:
159-
# If it's awaitable, this works.
160-
cert, key = await client_cert_callback()
131+
cert, key = await result
161132
except TypeError:
162-
# If it's not awaitable (e.g., a tuple), result is already the data.
163-
cert, key = client_cert_callback()
164-
133+
cert, key = result
165134
return True, cert, key
166135

167136
has_cert, cert, key, _ = await get_client_ssl_credentials()

google/auth/transport/_mtls_helper.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from google.auth import exceptions
2626

2727
CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json"
28+
29+
# Default gcloud config path, to be used with path.expanduser for cross-platform compatibility.
2830
CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json"
2931
_CERT_PROVIDER_COMMAND = "cert_provider_command"
3032
_CERT_REGEX = re.compile(
@@ -103,14 +105,18 @@ def _load_json_file(path):
103105
return json_data
104106

105107

106-
def _get_workload_cert_and_key(certificate_config_path=None):
108+
def _get_workload_cert_and_key(
109+
certificate_config_path=None, include_context_aware=True
110+
):
107111
"""Read the workload identity cert and key files specified in the certificate config provided.
108112
If no config path is provided, check the environment variable: "GOOGLE_API_CERTIFICATE_CONFIG"
109113
first, then the well known gcloud location: "~/.config/gcloud/certificate_config.json".
110114
111115
Args:
112116
certificate_config_path (string): The certificate config path. If no path is provided,
113117
the environment variable will be checked first, then the well known gcloud location.
118+
include_context_aware (bool): If context aware metadata path should be checked for the
119+
SecureConnect mTLS configuration.
114120
115121
Returns:
116122
Tuple[Optional[bytes], Optional[bytes]]: client certificate bytes in PEM format and key
@@ -121,15 +127,17 @@ def _get_workload_cert_and_key(certificate_config_path=None):
121127
the certificate or key information.
122128
"""
123129

124-
cert_path, key_path = _get_workload_cert_and_key_paths(certificate_config_path)
130+
cert_path, key_path = _get_workload_cert_and_key_paths(
131+
certificate_config_path, include_context_aware
132+
)
125133

126134
if cert_path is None and key_path is None:
127135
return None, None
128136

129137
return _read_cert_and_key_files(cert_path, key_path)
130138

131139

132-
def _get_cert_config_path(certificate_config_path=None):
140+
def _get_cert_config_path(certificate_config_path=None, include_context_aware=True):
133141
"""Get the certificate configuration path based on the following order:
134142
135143
1: Explicit override, if set
@@ -141,6 +149,8 @@ def _get_cert_config_path(certificate_config_path=None):
141149
Args:
142150
certificate_config_path (string): The certificate config path. If provided, the well known
143151
location and environment variable will be ignored.
152+
include_context_aware (bool): If context aware metadata path should be checked for the
153+
SecureConnect mTLS configuration.
144154
145155
Returns:
146156
The absolute path of the certificate config file, and None if the file does not exist.
@@ -155,7 +165,7 @@ def _get_cert_config_path(certificate_config_path=None):
155165
environment_vars.CLOUDSDK_CONTEXT_AWARE_CERTIFICATE_CONFIG_FILE_PATH,
156166
None,
157167
)
158-
if env_path is not None and env_path != "":
168+
if include_context_aware and env_path is not None and env_path != "":
159169
certificate_config_path = env_path
160170
else:
161171
certificate_config_path = CERTIFICATE_CONFIGURATION_DEFAULT_PATH
@@ -166,8 +176,8 @@ def _get_cert_config_path(certificate_config_path=None):
166176
return certificate_config_path
167177

168178

169-
def _get_workload_cert_and_key_paths(config_path):
170-
absolute_path = _get_cert_config_path(config_path)
179+
def _get_workload_cert_and_key_paths(config_path, include_context_aware=True):
180+
absolute_path = _get_cert_config_path(config_path, include_context_aware)
171181
if absolute_path is None:
172182
return None, None
173183

google/auth/transport/mtls.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,19 @@
2020
from google.auth.transport import _mtls_helper
2121

2222

23-
def has_default_client_cert_source():
23+
def has_default_client_cert_source(include_context_aware):
2424
"""Check if default client SSL credentials exists on the device.
2525
26+
Args:
27+
include_context_aware (bool): include_context_aware indicates if context_aware
28+
path location will be checked or should it be skipped.
29+
2630
Returns:
2731
bool: indicating if the default client cert source exists.
2832
"""
2933
if (
30-
_mtls_helper._check_config_path(_mtls_helper.CONTEXT_AWARE_METADATA_PATH)
34+
include_context_aware
35+
and _mtls_helper._check_config_path(_mtls_helper.CONTEXT_AWARE_METADATA_PATH)
3136
is not None
3237
):
3338
return True
@@ -58,7 +63,7 @@ def default_client_cert_source():
5863
google.auth.exceptions.DefaultClientCertSourceError: If the default
5964
client SSL credentials don't exist or are malformed.
6065
"""
61-
if not has_default_client_cert_source():
66+
if not has_default_client_cert_source(include_context_aware=True):
6267
raise exceptions.MutualTLSChannelError(
6368
"Default client cert source doesn't exist"
6469
)
@@ -94,7 +99,7 @@ def default_client_encrypted_cert_source(cert_path, key_path):
9499
google.auth.exceptions.DefaultClientCertSourceError: If any problem
95100
occurs when loading or saving the client certificate and key.
96101
"""
97-
if not has_default_client_cert_source():
102+
if not has_default_client_cert_source(include_context_aware=True):
98103
raise exceptions.MutualTLSChannelError(
99104
"Default client encrypted cert source doesn't exist"
100105
)

tests/transport/test_aio_mtls_helper.py

Lines changed: 18 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -24,49 +24,13 @@
2424

2525

2626
class TestMTLS:
27-
@mock.patch("google.auth.aio.transport.mtls.path.expanduser")
28-
@mock.patch("google.auth.aio.transport.mtls.path.exists")
29-
def test__check_config_path_exists(self, mock_exists, mock_expand):
30-
mock_expand.side_effect = lambda x: x.replace("~", "/home/user")
31-
mock_exists.return_value = True
32-
33-
input_path = "~/config.json"
34-
expected_path = "/home/user/config.json"
35-
result = mtls._check_config_path(input_path)
36-
37-
assert result == expected_path
38-
mock_exists.assert_called_with(expected_path)
39-
40-
@mock.patch("google.auth.aio.transport.mtls.path.exists", return_value=False)
41-
def test__check_config_path_not_found(self, mock_exists):
42-
result = mtls._check_config_path("nonexistent.json")
43-
assert result is None
44-
45-
@mock.patch("google.auth.aio.transport.mtls._check_config_path")
46-
@mock.patch("google.auth.aio.transport.mtls.getenv")
47-
def test_has_default_client_cert_source_env_var(self, mock_getenv, mock_check):
48-
custom_path = "/custom/path.json"
49-
mock_check.side_effect = lambda x: custom_path if x == custom_path else None
50-
mock_getenv.return_value = custom_path
51-
52-
assert mtls.has_default_client_cert_source() is True
53-
54-
@mock.patch("google.auth.aio.transport.mtls._check_config_path")
55-
@mock.patch("google.auth.aio.transport.mtls.getenv")
56-
def test_has_default_client_cert_source_check_priority(
57-
self, mock_getenv, mock_check
58-
):
59-
mock_check.return_value = "/default/path.json"
60-
61-
assert mtls.has_default_client_cert_source() is True
62-
mock_getenv.assert_not_called()
63-
27+
@pytest.mark.asyncio
6428
@mock.patch(
65-
"google.auth.aio.transport.mtls.has_default_client_cert_source",
66-
return_value=False,
29+
"google.auth.transport.mtls.has_default_client_cert_source", return_value=False
6730
)
68-
def test_default_client_cert_source_none(self, mock_has_default):
69-
with pytest.raises(exceptions.MutualTLSChannelError):
31+
async def test_default_client_cert_source_not_found(self, mock_has_default):
32+
"""Tests that a MutualTLSChannelError is raised if no cert source exists."""
33+
with pytest.raises(exceptions.MutualTLSChannelError, match="doesn't exist"):
7034
mtls.default_client_cert_source()
7135

7236
@pytest.mark.asyncio
@@ -75,45 +39,35 @@ def test_default_client_cert_source_none(self, mock_has_default):
7539
new_callable=mock.AsyncMock,
7640
)
7741
@mock.patch(
78-
"google.auth.aio.transport.mtls.has_default_client_cert_source",
79-
return_value=True,
42+
"google.auth.transport.mtls.has_default_client_cert_source", return_value=True
8043
)
8144
async def test_default_client_cert_source_success(
8245
self, mock_has_default, mock_get_cert_key
8346
):
47+
"""Tests the async callback returned by default_client_cert_source."""
8448
mock_get_cert_key.return_value = (True, CERT_DATA, KEY_DATA)
8549

86-
# Note: default_client_cert_source is NOT async, but it returns an async callback
50+
# default_client_cert_source is a factory that returns an async callback
8751
callback = mtls.default_client_cert_source()
8852
assert callable(callback)
8953

9054
cert, key = await callback()
9155
assert cert == CERT_DATA
9256
assert key == KEY_DATA
9357

94-
@pytest.mark.asyncio
95-
@mock.patch(
96-
"google.auth.aio.transport.mtls.has_default_client_cert_source",
97-
return_value=False,
98-
)
99-
async def test_default_client_cert_source_not_found(self, mock_has_default):
100-
with pytest.raises(exceptions.MutualTLSChannelError, match="doesn't exist"):
101-
await mtls.default_client_cert_source()
102-
10358
@pytest.mark.asyncio
10459
@mock.patch(
10560
"google.auth.aio.transport.mtls.get_client_cert_and_key",
10661
new_callable=mock.AsyncMock,
10762
)
10863
@mock.patch(
109-
"google.auth.aio.transport.mtls.has_default_client_cert_source",
110-
return_value=True,
64+
"google.auth.transport.mtls.has_default_client_cert_source", return_value=True
11165
)
11266
async def test_default_client_cert_source_callback_wraps_exception(
11367
self, mock_has, mock_get
11468
):
69+
"""Tests that the callback wraps underlying errors into MutualTLSChannelError."""
11570
mock_get.side_effect = ValueError("Format error")
116-
11771
callback = mtls.default_client_cert_source()
11872

11973
with pytest.raises(exceptions.MutualTLSChannelError) as excinfo:
@@ -123,6 +77,7 @@ async def test_default_client_cert_source_callback_wraps_exception(
12377
@pytest.mark.asyncio
12478
@mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key")
12579
async def test_get_client_ssl_credentials_success(self, mock_workload):
80+
"""Tests successful retrieval of workload credentials via the executor."""
12681
mock_workload.return_value = (CERT_DATA, KEY_DATA)
12782

12883
success, cert, key, passphrase = await mtls.get_client_ssl_credentials()
@@ -135,6 +90,7 @@ async def test_get_client_ssl_credentials_success(self, mock_workload):
13590
@pytest.mark.asyncio
13691
@mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials")
13792
async def test_get_client_cert_and_key_no_credentials_found(self, mock_get_ssl):
93+
"""Tests behavior when no credentials are found at the default location."""
13894
mock_get_ssl.return_value = (False, None, None, None)
13995

14096
success, cert, key = await mtls.get_client_cert_and_key(None)
@@ -145,7 +101,7 @@ async def test_get_client_cert_and_key_no_credentials_found(self, mock_get_ssl):
145101

146102
@pytest.mark.asyncio
147103
async def test_get_client_cert_and_key_callback_async(self):
148-
# Test with an actual coroutine/AsyncMock to satisfy the 'await' in your code
104+
"""Tests that an async callback is correctly awaited."""
149105
callback = mock.AsyncMock(return_value=(CERT_DATA, KEY_DATA))
150106

151107
success, cert, key = await mtls.get_client_cert_and_key(callback)
@@ -157,51 +113,24 @@ async def test_get_client_cert_and_key_callback_async(self):
157113

158114
@pytest.mark.asyncio
159115
async def test_get_client_cert_and_key_callback_sync(self):
160-
# Test the fallback logic: if it's a sync function, the TypeError is caught
116+
"""Tests that a sync callback is handled via the TypeError fallback."""
161117
callback = mock.Mock(return_value=(CERT_DATA, KEY_DATA))
162118

163119
success, cert, key = await mtls.get_client_cert_and_key(callback)
164120

165121
assert success is True
166122
assert cert == CERT_DATA
167-
# In your current implementation, this might still show 2 calls if the
168-
# first 'await' attempt triggers a call before failing.
169-
# To strictly avoid 2 calls, the implementation would need to check inspect.iscoroutinefunction.
170-
assert callback.call_count >= 1
171-
172-
@pytest.mark.asyncio
173-
@mock.patch(
174-
"google.auth.aio.transport.mtls.get_client_ssl_credentials",
175-
new_callable=mock.AsyncMock,
176-
)
177-
async def test_get_client_cert_and_key_default(self, mock_get_credentials):
178-
mock_get_credentials.return_value = (True, CERT_DATA, KEY_DATA, None)
179-
180-
success, cert, key = await mtls.get_client_cert_and_key(None)
181-
182-
assert success is True
183-
assert cert == CERT_DATA
184-
assert key == KEY_DATA
185-
mock_get_credentials.assert_called_once()
123+
# Note: In the source, the first 'await' will call the function.
124+
# When it fails to await, the exception handler uses the result already obtained.
125+
assert callback.call_count == 1
186126

187127
@pytest.mark.asyncio
188128
@mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key")
189129
async def test_get_client_ssl_credentials_error(self, mock_workload):
130+
"""Tests exception propagation from the workload helper."""
190131
mock_workload.side_effect = exceptions.ClientCertError(
191132
"Failed to read metadata"
192133
)
193134

194135
with pytest.raises(exceptions.ClientCertError, match="Failed to read metadata"):
195136
await mtls.get_client_ssl_credentials()
196-
197-
@pytest.mark.asyncio
198-
@mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials")
199-
async def test_get_client_cert_and_key_exception_propagation(self, mock_get_ssl):
200-
mock_get_ssl.side_effect = exceptions.ClientCertError(
201-
"Underlying credentials failed"
202-
)
203-
204-
with pytest.raises(
205-
exceptions.ClientCertError, match="Underlying credentials failed"
206-
):
207-
await mtls.get_client_cert_and_key(client_cert_callback=None)

0 commit comments

Comments
 (0)