Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.

Commit 9693326

Browse files
committed
support timeout as aiohttpClientTimeout and total_attempts (max retries)
1 parent 9d5c0d8 commit 9693326

2 files changed

Lines changed: 38 additions & 12 deletions

File tree

google/auth/aio/transport/aiohttp.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
import asyncio
1919
import logging
20-
from typing import AsyncGenerator, Mapping, Optional
20+
import typing
21+
from typing import Any, AsyncGenerator, Mapping, Optional, Union
2122

2223
try:
2324
import aiohttp # type: ignore
@@ -26,6 +27,15 @@
2627
"The aiohttp library is not installed from please install the aiohttp package to use the aiohttp transport."
2728
) from caught_exc
2829

30+
if typing.TYPE_CHECKING:
31+
from aiohttp import ClientTimeout
32+
else:
33+
ClientTimeout: typing.Type = Any
34+
try:
35+
from aiohttp import ClientTimeout
36+
except ImportError:
37+
ClientTimeout = None
38+
2939
from google.auth import _helpers
3040
from google.auth import exceptions
3141
from google.auth.aio import _helpers as _helpers_async
@@ -123,7 +133,7 @@ async def __call__(
123133
method: str = "GET",
124134
body: Optional[bytes] = None,
125135
headers: Optional[Mapping[str, str]] = None,
126-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
136+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
127137
**kwargs,
128138
) -> transport.Response:
129139
"""
@@ -158,7 +168,10 @@ async def __call__(
158168
if not self._session:
159169
self._session = aiohttp.ClientSession()
160170

161-
client_timeout = aiohttp.ClientTimeout(total=timeout)
171+
if isinstance(timeout, aiohttp.ClientTimeout):
172+
client_timeout = timeout
173+
else:
174+
client_timeout = aiohttp.ClientTimeout(total=timeout)
162175
_helpers.request_log(_LOGGER, method, url, body, headers)
163176
response = await self._session.request(
164177
method,

google/auth/aio/transport/sessions.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from contextlib import asynccontextmanager
1717
import functools
1818
import time
19-
from typing import Mapping, Optional
19+
import typing
20+
from typing import Any, Mapping, Optional, Union
2021

2122
from google.auth import _exponential_backoff, exceptions
2223
from google.auth.aio import transport
@@ -30,7 +31,16 @@
3031
except ImportError: # pragma: NO COVER
3132
AIOHTTP_INSTALLED = False
3233

33-
34+
if typing.TYPE_CHECKING:
35+
from aiohttp import ClientTimeout
36+
else:
37+
ClientTimeout: typing.Type = Any
38+
try:
39+
from aiohttp import ClientTimeout
40+
except ImportError:
41+
ClientTimeout = None
42+
43+
3444
@asynccontextmanager
3545
async def timeout_guard(timeout):
3646
"""
@@ -137,7 +147,8 @@ async def request(
137147
data: Optional[bytes] = None,
138148
headers: Optional[Mapping[str, str]] = None,
139149
max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
140-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
150+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
151+
total_attempts: Optional[int] = transport.DEFAULT_MAX_RETRY_ATTEMPTS,
141152
**kwargs,
142153
) -> transport.Response:
143154
"""
@@ -146,14 +157,16 @@ async def request(
146157
url (str): The URI to be requested.
147158
data (Optional[bytes]): The payload or body in HTTP request.
148159
headers (Optional[Mapping[str, str]]): Request headers.
149-
timeout (float):
160+
timeout (float, aiohttp.ClientTimeout):
150161
The amount of time in seconds to wait for the server response
151162
with each individual request.
152163
max_allowed_time (float):
153164
If the method runs longer than this, a ``Timeout`` exception is
154165
automatically raised. Unlike the ``timeout`` parameter, this
155166
value applies to the total method execution time, even if
156167
multiple requests are made under the hood.
168+
total_attempts (int):
169+
The total number of retry attempts.
157170
158171
Mind that it is not guaranteed that the timeout error is raised
159172
at ``max_allowed_time``. It might take longer, for example, if
@@ -172,7 +185,7 @@ async def request(
172185
"""
173186

174187
retries = _exponential_backoff.AsyncExponentialBackoff(
175-
total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS
188+
total_attempts=total_attempts,
176189
)
177190
async with timeout_guard(max_allowed_time) as with_timeout:
178191
await with_timeout(
@@ -198,7 +211,7 @@ async def get(
198211
data: Optional[bytes] = None,
199212
headers: Optional[Mapping[str, str]] = None,
200213
max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
201-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
214+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
202215
**kwargs,
203216
) -> transport.Response:
204217
return await self.request(
@@ -212,7 +225,7 @@ async def post(
212225
data: Optional[bytes] = None,
213226
headers: Optional[Mapping[str, str]] = None,
214227
max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
215-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
228+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
216229
**kwargs,
217230
) -> transport.Response:
218231
return await self.request(
@@ -226,7 +239,7 @@ async def put(
226239
data: Optional[bytes] = None,
227240
headers: Optional[Mapping[str, str]] = None,
228241
max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
229-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
242+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
230243
**kwargs,
231244
) -> transport.Response:
232245
return await self.request(
@@ -240,7 +253,7 @@ async def patch(
240253
data: Optional[bytes] = None,
241254
headers: Optional[Mapping[str, str]] = None,
242255
max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
243-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
256+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
244257
**kwargs,
245258
) -> transport.Response:
246259
return await self.request(

0 commit comments

Comments
 (0)