|
21 | 21 | import logging |
22 | 22 | import os |
23 | 23 | import ssl |
| 24 | +import sys |
24 | 25 | from unittest import mock |
25 | 26 |
|
26 | 27 | import certifi |
@@ -1870,3 +1871,59 @@ async def test_get_aiohttp_session(): |
1870 | 1871 | assert initial_session is not None |
1871 | 1872 | session = await client._api_client._get_aiohttp_session() |
1872 | 1873 | assert session is initial_session |
| 1874 | + |
| 1875 | + |
| 1876 | +@requires_aiohttp |
| 1877 | +@pytest.mark.asyncio |
| 1878 | +async def test_async_mtls_uses_refreshable_credentials(monkeypatch): |
| 1879 | + """Tests that _RefreshableAsyncCredentials is used in async mTLS path.""" |
| 1880 | + from google.genai import _api_client |
| 1881 | + |
| 1882 | + # Ensure _use_google_auth_async returns True |
| 1883 | + monkeypatch.setattr(_api_client, "has_aiohttp", True) |
| 1884 | + monkeypatch.setattr(_api_client.mtls, "should_use_client_cert", lambda: True, raising=False) |
| 1885 | + monkeypatch.setattr( |
| 1886 | + _api_client.mtls, "has_default_client_cert_source", lambda: True |
| 1887 | + ) |
| 1888 | + |
| 1889 | + # Mock AsyncAuthorizedSession and google.auth.aio modules |
| 1890 | + mock_session = mock.MagicMock() |
| 1891 | + mock_auth_aio = mock.MagicMock() |
| 1892 | + monkeypatch.setitem(sys.modules, "google.auth.aio", mock_auth_aio) |
| 1893 | + monkeypatch.setitem( |
| 1894 | + sys.modules, "google.auth.aio.credentials", mock_auth_aio.credentials |
| 1895 | + ) |
| 1896 | + monkeypatch.setitem( |
| 1897 | + sys.modules, "google.auth.aio.transport", mock_auth_aio.transport |
| 1898 | + ) |
| 1899 | + monkeypatch.setitem( |
| 1900 | + sys.modules, |
| 1901 | + "google.auth.aio.transport.sessions", |
| 1902 | + mock_auth_aio.transport.sessions, |
| 1903 | + ) |
| 1904 | + mock_auth_aio.transport.sessions.AsyncAuthorizedSession = mock_session |
| 1905 | + mock_auth_aio.credentials.Credentials = mock.MagicMock |
| 1906 | + |
| 1907 | + # Mock credentials |
| 1908 | + mock_creds = mock.MagicMock() |
| 1909 | + mock_creds.expired = False |
| 1910 | + mock_creds.token = "initial_token" |
| 1911 | + monkeypatch.setattr( |
| 1912 | + google.auth, "default", lambda scopes=None: (mock_creds, "fake-project") |
| 1913 | + ) |
| 1914 | + |
| 1915 | + client = Client(vertexai=True, project="fake-project") |
| 1916 | + client._api_client._credentials = mock_creds |
| 1917 | + |
| 1918 | + # Trigger session creation |
| 1919 | + await client._api_client._get_aiohttp_session() |
| 1920 | + |
| 1921 | + # Verify AsyncAuthorizedSession was called with _RefreshableAsyncCredentials |
| 1922 | + assert mock_session.call_count == 1 |
| 1923 | + passed_creds = mock_session.call_args[0][0] |
| 1924 | + assert type(passed_creds).__name__ == "_RefreshableAsyncCredentials" |
| 1925 | + |
| 1926 | + # Verify valid property |
| 1927 | + assert passed_creds.valid == True |
| 1928 | + mock_creds.expired = True |
| 1929 | + assert passed_creds.valid == False |
0 commit comments