Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.

Commit 29b9625

Browse files
chore: fix the sequence flow and add docstrings and comment where necessary based on reviewer's comments
Signed-off-by: Radhika Agrawal <agrawalradhika@google.com>
1 parent 6f34fc2 commit 29b9625

File tree

1 file changed

+54
-39
lines changed

1 file changed

+54
-39
lines changed

google/auth/aio/transport/sessions.py

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
except (ImportError, AttributeError):
3737
ClientTimeout = None
3838

39+
# Tracks the internal aiohttp installation and usage
3940
try:
4041
from google.auth.aio.transport.aiohttp import Request as AiohttpRequest
4142

@@ -138,6 +139,7 @@ def __init__(
138139
if not _auth_request and AIOHTTP_INSTALLED:
139140
_auth_request = AiohttpRequest()
140141
self._is_mtls = False
142+
self._mtls_init_task = None
141143
self._cached_cert = None
142144
if _auth_request is None:
143145
raise exceptions.TransportError(
@@ -154,7 +156,10 @@ async def configure_mtls_channel(self, client_cert_callback=None):
154156
default SSL credentials), the underlying transport will be reconfigured
155157
to use mTLS.
156158
Note: This function does nothing if the `aiohttp` library is not
157-
installed. Plus, will close any ongoing API requests.
159+
installed.
160+
Important: Calling this method will close any ongoing API requests associated
161+
with the current session. To ensure a smooth transition, it is recommended
162+
to call this during session initialization.
158163
159164
Args:
160165
client_cert_callback (Optional[Callable[[], (bytes, bytes)]]):
@@ -167,42 +172,53 @@ async def configure_mtls_channel(self, client_cert_callback=None):
167172
google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
168173
creation failed for any reason.
169174
"""
170-
# Run the blocking check in an executor
171-
use_client_cert = await mtls._run_in_executor(
172-
google.auth.transport._mtls_helper.check_use_client_cert
173-
)
174-
if not use_client_cert:
175-
self._is_mtls = False
176-
return
175+
if self._mtls_init_task is None:
177176

178-
try:
179-
(
180-
self._is_mtls,
181-
cert,
182-
key,
183-
) = await mtls.get_client_cert_and_key(client_cert_callback)
184-
185-
if self._is_mtls:
186-
self._cached_cert = cert
187-
ssl_context = await mtls._run_in_executor(
188-
mtls.make_client_cert_ssl_context, cert, key
177+
async def _do_configure():
178+
# Run the blocking check in an executor
179+
use_client_cert = await mtls._run_in_executor(
180+
google.auth.transport._mtls_helper.check_use_client_cert
189181
)
190-
191-
# Re-create the auth request with the new SSL context
192-
if AIOHTTP_INSTALLED and isinstance(self._auth_request, AiohttpRequest):
193-
connector = aiohttp.TCPConnector(ssl=ssl_context)
194-
new_session = aiohttp.ClientSession(connector=connector)
195-
old_auth_request = self._auth_request
196-
self._auth_request = AiohttpRequest(session=new_session)
197-
await old_auth_request.close()
198-
199-
except (
200-
exceptions.ClientCertError,
201-
ImportError,
202-
OSError,
203-
) as caught_exc:
204-
new_exc = exceptions.MutualTLSChannelError(caught_exc)
205-
raise new_exc from caught_exc
182+
if not use_client_cert:
183+
self._is_mtls = False
184+
return
185+
186+
try:
187+
(
188+
self._is_mtls,
189+
cert,
190+
key,
191+
) = await mtls.get_client_cert_and_key(client_cert_callback)
192+
193+
if self._is_mtls:
194+
self._cached_cert = cert
195+
ssl_context = await mtls._run_in_executor(
196+
mtls.make_client_cert_ssl_context, cert, key
197+
)
198+
199+
# Re-create the auth request with the new SSL context
200+
if AIOHTTP_INSTALLED and isinstance(
201+
self._auth_request, AiohttpRequest
202+
):
203+
connector = aiohttp.TCPConnector(ssl=ssl_context)
204+
new_session = aiohttp.ClientSession(connector=connector)
205+
206+
old_auth_request = self._auth_request
207+
self._auth_request = AiohttpRequest(session=new_session)
208+
209+
await old_auth_request.close()
210+
211+
except (
212+
exceptions.ClientCertError,
213+
ImportError,
214+
OSError,
215+
) as caught_exc:
216+
new_exc = exceptions.MutualTLSChannelError(caught_exc)
217+
raise new_exc from caught_exc
218+
219+
self._mtls_init_task = asyncio.create_task(_do_configure())
220+
221+
return await self._mtls_init_task
206222

207223
async def request(
208224
self,
@@ -247,7 +263,8 @@ async def request(
247263
the configured `max_allowed_time` or the request exceeds the configured
248264
`timeout`.
249265
"""
250-
266+
if self._mtls_init_task:
267+
await self._mtls_init_task
251268
retries = _exponential_backoff.AsyncExponentialBackoff(
252269
total_attempts=total_attempts,
253270
)
@@ -261,12 +278,10 @@ async def request(
261278
)
262279
)
263280
actual_timeout: float = 0.0
264-
if AIOHTTP_INSTALLED and isinstance(timeout, aiohttp.ClientTimeout):
281+
if isinstance(timeout, aiohttp.ClientTimeout):
265282
actual_timeout = timeout.total if timeout.total is not None else 0.0
266283
elif isinstance(timeout, (int, float)):
267284
actual_timeout = float(timeout)
268-
else:
269-
actual_timeout = 0.0
270285
# Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch`
271286
# See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372
272287
async for _ in retries: # pragma: no branch

0 commit comments

Comments
 (0)