Skip to content

Commit 750648f

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: support passing the custom aiohttp.ClientSession through HttpOptions.aiohttp_client
fixes #1662 PiperOrigin-RevId: 856374597
1 parent b7b1c2e commit 750648f

4 files changed

Lines changed: 58 additions & 9 deletions

File tree

google/genai/_api_client.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -715,19 +715,23 @@ def __init__(
715715
self._async_httpx_client = self._http_options.httpx_async_client
716716
else:
717717
self._async_httpx_client = AsyncHttpxClient(**async_client_args)
718+
719+
# Initialize the aiohttp client session.
720+
self._aiohttp_session: Optional[aiohttp.ClientSession] = None
718721
if self._use_aiohttp():
719722
try:
720723
import aiohttp # pylint: disable=g-import-not-at-top
721-
# Do it once at the genai.Client level. Share among all requests.
722-
self._async_client_session_request_args = self._ensure_aiohttp_ssl_ctx(
723-
self._http_options
724-
)
724+
725+
if self._http_options.aiohttp_client:
726+
self._aiohttp_session = self._http_options.aiohttp_client
727+
else:
728+
# Do it once at the genai.Client level. Share among all requests.
729+
self._async_client_session_request_args = (
730+
self._ensure_aiohttp_ssl_ctx(self._http_options)
731+
)
725732
except ImportError:
726733
pass
727734

728-
# Initialize the aiohttp client session.
729-
self._aiohttp_session: Optional[aiohttp.ClientSession] = None
730-
731735
retry_kwargs = retry_args(self._http_options.retry_options)
732736
self._websocket_ssl_ctx = self._ensure_websocket_ssl_ctx(self._http_options)
733737
self._retry = tenacity.Retrying(**retry_kwargs)
@@ -1892,7 +1896,7 @@ async def aclose(self) -> None:
18921896
# close the client when the object is garbage collected.
18931897
if not self._http_options.httpx_async_client:
18941898
await self._async_httpx_client.aclose()
1895-
if self._aiohttp_session:
1899+
if self._aiohttp_session and not self._http_options.aiohttp_client:
18961900
await self._aiohttp_session.close()
18971901

18981902
def __del__(self) -> None:

google/genai/tests/client/test_custom_client.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,30 @@ def test_constructor_with_httpx_clients():
7575
assert not mldev_client.models._api_client._use_aiohttp()
7676

7777

78+
# Aiohttp
79+
@requires_aiohttp
80+
@pytest.mark.asyncio
81+
@pytest.mark.skipif(
82+
AIOHTTP_NOT_INSTALLED, reason='aiohttp is not installed, skipping test.'
83+
)
84+
async def test_constructor_with_aiohttp_clients():
85+
api_client.has_aiohttp = True
86+
mldev_http_options = {
87+
'aiohttp_client': aiohttp.ClientSession(trust_env=False),
88+
}
89+
vertexai_http_options = {
90+
'aiohttp_client': aiohttp.ClientSession(trust_env=False),
91+
}
92+
mldev_client = Client(
93+
api_key='google_api_key', http_options=mldev_http_options
94+
)
95+
assert not mldev_client.models._api_client._aiohttp_session.trust_env
96+
97+
vertexai_client = Client(
98+
vertexai=True,
99+
project='fake_project_id',
100+
location='fake-location',
101+
http_options=vertexai_http_options,
102+
)
103+
assert not vertexai_client.models._api_client._aiohttp_session.trust_env
104+

google/genai/tests/client/test_http_options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_patch_http_options_with_copies_all_fields():
4040

4141
for key in http_options_keys:
4242
assert hasattr(patched, key)
43-
if key not in ['httpx_client', 'httpx_async_client', 'aiohttp_client_session']:
43+
if key not in ['httpx_client', 'httpx_async_client', 'aiohttp_client']:
4444
assert getattr(patched, key) is not None
4545
assert patched.base_url == 'https://fake-url.com/'
4646
assert patched.api_version == 'v1'

google/genai/types.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,20 @@
108108
HttpxClient = None
109109
HttpxAsyncClient = None
110110

111+
_is_aiohttp_imported = False
112+
if typing.TYPE_CHECKING:
113+
from aiohttp import ClientSession
114+
115+
_is_aiohttp_imported = True
116+
else:
117+
ClientSession: typing.Type = Any
118+
try:
119+
from aiohttp import ClientSession
120+
121+
_is_aiohttp_imported = True
122+
except ImportError:
123+
ClientSession = None
124+
111125
logger = logging.getLogger('google_genai.types')
112126
_from_json_schema_warning_logged = False
113127
_json_schema_warning_logged = False
@@ -1936,6 +1950,10 @@ class HttpOptions(_common.BaseModel):
19361950
default=None,
19371951
description="""A custom httpx async client to be used for the request.""",
19381952
)
1953+
aiohttp_client: Optional['ClientSession'] = Field(
1954+
default=None,
1955+
description="""A custom aiohttp client session to be used for the request.""",
1956+
)
19391957

19401958

19411959
class HttpOptionsDict(TypedDict, total=False):

0 commit comments

Comments
 (0)