Skip to content

Commit db1d7ee

Browse files
MarkDaoustcopybara-github
authored andcommitted
feat: Support ephemeral auth tokens as API keys for live connections in Python.
PiperOrigin-RevId: 765372194
1 parent 6719faf commit db1d7ee

File tree

4 files changed

+60
-104
lines changed

4 files changed

+60
-104
lines changed

google/genai/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
from . import version
1919
from .client import Client
20-
from .live import live_ephemeral_connect
2120

2221

2322
__version__ = version.__version__
2423

25-
__all__ = ['Client', 'live_ephemeral_connect']
24+
__all__ = ['Client']

google/genai/_api_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@
6363
INITIAL_RETRY_DELAY = 1 # second
6464
DELAY_MULTIPLIER = 2
6565

66+
67+
class EphemeralTokenAPIKeyError(ValueError):
68+
"""Error raised when the API key is invalid."""
69+
70+
6671
def _append_library_version_headers(headers: dict[str, str]) -> None:
6772
"""Appends the telemetry header to the headers dict."""
6873
library_label = f'google-genai-sdk/{version.__version__}'
@@ -625,6 +630,11 @@ def _build_request(
625630
versioned_path,
626631
)
627632

633+
if self.api_key and self.api_key.startswith('auth_tokens/'):
634+
raise EphemeralTokenAPIKeyError(
635+
'Ephemeral tokens can only be used with the live API.'
636+
)
637+
628638
timeout_in_seconds = _get_timeout_in_seconds(patched_http_options.timeout)
629639

630640
if patched_http_options.headers is None:

google/genai/live.py

Lines changed: 18 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from . import _mcp_utils
3535
from . import _transformers as t
3636
from . import client
37+
from . import errors
3738
from . import types
3839
from ._api_client import BaseApiClient
3940
from ._common import get_value_by_path as getv
@@ -78,10 +79,6 @@
7879
' response of a ToolCall.FunctionalCalls in Google AI.'
7980
)
8081

81-
82-
_DUMMY_KEY = 'dummy_key'
83-
84-
8582
class AsyncSession:
8683
"""[Preview] AsyncSession."""
8784

@@ -912,25 +909,10 @@ async def connect(
912909
Yields:
913910
An AsyncSession object.
914911
"""
915-
async with self._connect(
916-
model=model,
917-
config=config,
918-
) as session:
919-
yield session
920-
921-
@contextlib.asynccontextmanager
922-
async def _connect(
923-
self,
924-
*,
925-
model: Optional[str] = None,
926-
config: Optional[types.LiveConnectConfigOrDict] = None,
927-
uri: Optional[str] = None,
928-
) -> AsyncIterator[AsyncSession]:
929-
930912
# TODO(b/404946570): Support per request http options.
931913
if isinstance(config, dict):
932914
config = types.LiveConnectConfig(**config)
933-
if config and config.http_options and uri is None:
915+
if config and config.http_options:
934916
raise ValueError(
935917
'google.genai.client.aio.live.connect() does not support'
936918
' http_options at request-level in LiveConnectConfig yet. Please use'
@@ -945,10 +927,22 @@ async def _connect(
945927
parameter_model = await _t_live_connect_config(self._api_client, config)
946928

947929
if self._api_client.api_key and not self._api_client.vertexai:
948-
api_key = self._api_client.api_key
949930
version = self._api_client._http_options.api_version
950-
if uri is None:
951-
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
931+
api_key = self._api_client.api_key
932+
method = 'BidiGenerateContent'
933+
key_name = 'key'
934+
if api_key.startswith('auth_tokens/'):
935+
warnings.warn(
936+
message=(
937+
"The SDK's ephemeral token support is experimental, and may"
938+
' change in future versions.'
939+
),
940+
category=errors.ExperimentalWarning,
941+
)
942+
method = 'BidiGenerateContentConstrained'
943+
key_name = 'access_token'
944+
945+
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}?{key_name}={api_key}'
952946
headers = self._api_client._http_options.headers
953947

954948
request_dict = _common.convert_to_dict(
@@ -969,8 +963,7 @@ async def _connect(
969963
# Headers already contains api key for express mode.
970964
api_key = self._api_client.api_key
971965
version = self._api_client._http_options.api_version
972-
if uri is None:
973-
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
966+
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
974967
headers = self._api_client._http_options.headers
975968

976969
request_dict = _common.convert_to_dict(
@@ -1060,83 +1053,6 @@ async def _connect(
10601053
yield AsyncSession(api_client=self._api_client, websocket=ws)
10611054

10621055

1063-
@_common.experimental_warning(
1064-
"The SDK's Live API connection with ephemeral token implementation is"
1065-
' experimental, and may change in future versions.',
1066-
)
1067-
@contextlib.asynccontextmanager
1068-
async def live_ephemeral_connect(
1069-
access_token: str,
1070-
model: Optional[str] = None,
1071-
config: Optional[types.LiveConnectConfigOrDict] = None,
1072-
) -> AsyncIterator[AsyncSession]:
1073-
"""[Experimental] Connect to the live server using ephermeral token (Gemini Developer API only).
1074-
1075-
Note: the live API is currently in experimental.
1076-
1077-
Usage:
1078-
1079-
.. code-block:: python
1080-
from google import genai
1081-
1082-
config = {}
1083-
async with genai.live_ephemeral_connect(
1084-
access_token='auth_tokens/12345',
1085-
model='...',
1086-
config=config,
1087-
http_options=types.HttpOptions(api_version='v1beta'),
1088-
) as session:
1089-
await session.send_client_content(
1090-
turns=types.Content(
1091-
role='user',
1092-
parts=[types.Part(text='hello!')]
1093-
),
1094-
turn_complete=True
1095-
)
1096-
1097-
async for message in session.receive():
1098-
print(message)
1099-
1100-
Args:
1101-
access_token: The access token to use for the Live session. It can be
1102-
generated by the `client.tokens.create` method.
1103-
model: The model to use for the Live session.
1104-
config: The configuration for the Live session.
1105-
1106-
Yields:
1107-
An AsyncSession object.
1108-
"""
1109-
if isinstance(config, dict):
1110-
config = types.LiveConnectConfig(**config)
1111-
1112-
http_options = config.http_options if config else None
1113-
1114-
base_url = (
1115-
http_options.base_url
1116-
if http_options and http_options.base_url
1117-
else 'https://generativelanguage.googleapis.com/'
1118-
)
1119-
api_version = (
1120-
http_options.api_version
1121-
if http_options and http_options.api_version
1122-
else 'v1beta'
1123-
)
1124-
internal_client = client.Client(
1125-
api_key=_DUMMY_KEY, # Can't be None during initialization
1126-
http_options=types.HttpOptions(
1127-
base_url=base_url,
1128-
api_version=api_version,
1129-
),
1130-
)
1131-
websocket_base_url = internal_client._api_client._websocket_base_url()
1132-
uri = f'{websocket_base_url}/ws/google.ai.generativelanguage.{api_version}.GenerativeService.BidiGenerateContentConstrained?access_token={access_token}'
1133-
1134-
async with internal_client.aio.live._connect(
1135-
model=model, config=config, uri=uri
1136-
) as session:
1137-
yield session
1138-
1139-
11401056
async def _t_live_connect_config(
11411057
api_client: BaseApiClient,
11421058
config: Optional[types.LiveConnectConfigOrDict],

google/genai/tests/live/test_live.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,3 +1867,34 @@ async def _test_connect():
18671867
)
18681868

18691869
await _test_connect()
1870+
1871+
1872+
@pytest.mark.parametrize('vertexai', [False])
1873+
@pytest.mark.asyncio
1874+
async def test_bidi_setup_to_api_with_auth_tokens(mock_websocket, vertexai):
1875+
api_client_mock = mock_api_client(vertexai=vertexai)
1876+
api_client_mock.api_key = 'auth_tokens/TEST_AUTH_TOKEN'
1877+
result = await get_connect_message(
1878+
api_client_mock,
1879+
model='test_model'
1880+
)
1881+
1882+
mock_ws = AsyncMock()
1883+
mock_ws.send = AsyncMock()
1884+
mock_ws.recv = AsyncMock(return_value=b'some response')
1885+
uri_capture = {} # Capture the uri here
1886+
1887+
@contextlib.asynccontextmanager
1888+
async def mock_connect(uri, additional_headers=None):
1889+
uri_capture['uri'] = uri # Capture the uri
1890+
yield mock_ws
1891+
1892+
with patch.object(live, 'ws_connect', new=mock_connect):
1893+
live_module = live.AsyncLive(api_client_mock)
1894+
async with live_module.connect(
1895+
model='test_model',
1896+
):
1897+
pass
1898+
1899+
assert 'access_token=auth_tokens/TEST_AUTH_TOKEN' in uri_capture['uri']
1900+
assert 'BidiGenerateContentConstrained' in uri_capture['uri']

0 commit comments

Comments
 (0)