Skip to content

Commit 4e17a9c

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: avoid caching stale token in async mTLS path
PiperOrigin-RevId: 908074886
1 parent ef26583 commit 4e17a9c

2 files changed

Lines changed: 86 additions & 3 deletions

File tree

google/genai/_api_client.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,7 @@ def _use_google_auth_async(self) -> bool:
865865
return bool(
866866
has_aiohttp
867867
and self.vertexai
868+
and hasattr(mtls, 'should_use_client_cert')
868869
and mtls.should_use_client_cert() # type: ignore[no-untyped-call]
869870
and mtls.has_default_client_cert_source() # type: ignore[no-untyped-call]
870871
and not self._http_options.httpx_async_client
@@ -877,11 +878,36 @@ async def _get_aiohttp_session(
877878

878879
if self._aiohttp_session is None and self._use_google_auth_async():
879880
try:
880-
from google.auth.aio.credentials import StaticCredentials
881+
from google.auth.aio.credentials import Credentials as AsyncCredentials
881882
from google.auth.aio.transport.sessions import AsyncAuthorizedSession
882883

883-
async_creds = StaticCredentials(token=self._access_token()) # type: ignore[no-untyped-call]
884-
self._aiohttp_session = AsyncAuthorizedSession(async_creds) # type: ignore[no-untyped-call,assignment]
884+
class _RefreshableAsyncCredentials(AsyncCredentials): # type: ignore[misc, valid-type]
885+
"""Adapter to use the client's sync credentials in an AsyncAuthorizedSession."""
886+
887+
def __init__(self, client: 'BaseApiClient'):
888+
super().__init__() # type: ignore[no-untyped-call]
889+
self._client = client
890+
891+
async def before_request(
892+
self, request: Any, method: str, url: str, headers: dict[str, str]
893+
) -> None:
894+
token = await self._client._async_access_token()
895+
headers['Authorization'] = f'Bearer {token}'
896+
if (
897+
self._client._credentials
898+
and self._client._credentials.quota_project_id
899+
):
900+
headers['x-goog-user-project'] = (
901+
self._client._credentials.quota_project_id
902+
)
903+
904+
@property
905+
def valid(self) -> bool:
906+
if not self._client._credentials:
907+
return False
908+
return not self._client._credentials.expired
909+
910+
self._aiohttp_session = AsyncAuthorizedSession(_RefreshableAsyncCredentials(self)) # type: ignore[no-untyped-call,assignment]
885911
return self._aiohttp_session # type: ignore[return-value]
886912
except ImportError:
887913
pass

google/genai/tests/client/test_client_initialization.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import logging
2222
import os
2323
import ssl
24+
import sys
2425
from unittest import mock
2526

2627
import certifi
@@ -1870,3 +1871,59 @@ async def test_get_aiohttp_session():
18701871
assert initial_session is not None
18711872
session = await client._api_client._get_aiohttp_session()
18721873
assert session is initial_session
1874+
1875+
1876+
@requires_aiohttp
1877+
@pytest.mark.asyncio
1878+
async def test_async_mtls_uses_refreshable_credentials(monkeypatch):
1879+
"""Tests that _RefreshableAsyncCredentials is used in async mTLS path."""
1880+
from google.genai import _api_client
1881+
1882+
# Ensure _use_google_auth_async returns True
1883+
monkeypatch.setattr(_api_client, "has_aiohttp", True)
1884+
monkeypatch.setattr(_api_client.mtls, "should_use_client_cert", lambda: True, raising=False)
1885+
monkeypatch.setattr(
1886+
_api_client.mtls, "has_default_client_cert_source", lambda: True
1887+
)
1888+
1889+
# Mock AsyncAuthorizedSession and google.auth.aio modules
1890+
mock_session = mock.MagicMock()
1891+
mock_auth_aio = mock.MagicMock()
1892+
monkeypatch.setitem(sys.modules, "google.auth.aio", mock_auth_aio)
1893+
monkeypatch.setitem(
1894+
sys.modules, "google.auth.aio.credentials", mock_auth_aio.credentials
1895+
)
1896+
monkeypatch.setitem(
1897+
sys.modules, "google.auth.aio.transport", mock_auth_aio.transport
1898+
)
1899+
monkeypatch.setitem(
1900+
sys.modules,
1901+
"google.auth.aio.transport.sessions",
1902+
mock_auth_aio.transport.sessions,
1903+
)
1904+
mock_auth_aio.transport.sessions.AsyncAuthorizedSession = mock_session
1905+
mock_auth_aio.credentials.Credentials = mock.MagicMock
1906+
1907+
# Mock credentials
1908+
mock_creds = mock.MagicMock()
1909+
mock_creds.expired = False
1910+
mock_creds.token = "initial_token"
1911+
monkeypatch.setattr(
1912+
google.auth, "default", lambda scopes=None: (mock_creds, "fake-project")
1913+
)
1914+
1915+
client = Client(vertexai=True, project="fake-project")
1916+
client._api_client._credentials = mock_creds
1917+
1918+
# Trigger session creation
1919+
await client._api_client._get_aiohttp_session()
1920+
1921+
# Verify AsyncAuthorizedSession was called with _RefreshableAsyncCredentials
1922+
assert mock_session.call_count == 1
1923+
passed_creds = mock_session.call_args[0][0]
1924+
assert type(passed_creds).__name__ == "_RefreshableAsyncCredentials"
1925+
1926+
# Verify valid property
1927+
assert passed_creds.valid == True
1928+
mock_creds.expired = True
1929+
assert passed_creds.valid == False

0 commit comments

Comments
 (0)