Skip to content

Commit 9523aac

Browse files
wanlin31copybara-github
authored andcommitted
chore: add some internal helpers
PiperOrigin-RevId: 912741979
1 parent 00d76d1 commit 9523aac

4 files changed

Lines changed: 387 additions & 6 deletions

File tree

google/genai/_interactions/_base_client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
APIResponseValidationError,
103103
)
104104
from ._utils._json import openapi_dumps
105+
from ._legacy_lyria import LEGACY_LYRIA_SHIM_CTX, maybe_remap_legacy_sse_event
105106

106107
log: logging.Logger = logging.getLogger(__name__)
107108

@@ -655,6 +656,17 @@ def _process_response_data(
655656
if cast_to is object:
656657
return cast(ResponseT, data)
657658

659+
# When the legacy-lyria shim is active for this request (set by the
660+
# `LegacyLyriaInteractionStream` subclass after dynamic detection of
661+
# a legacy event), rename legacy SSE event types and reshape
662+
# `content.start` payloads so the discriminated-union dispatch in
663+
# `construct_type` lands on the modern variants.
664+
# `Interaction._maybe_coerce_outputs` does its own data inspection
665+
# (model field) and doesn't depend on this contextvar for the
666+
# non-streaming paths.
667+
if LEGACY_LYRIA_SHIM_CTX.get() and isinstance(data, dict):
668+
data = maybe_remap_legacy_sse_event(cast("dict[str, object]", data))
669+
658670
try:
659671
if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol):
660672
return cast(ResponseT, cast_to.build(response=response, data=data))
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
"""Compatibility shim for the legacy vertex+lyria response/event shape.
17+
18+
The vertex `aiplatform.googleapis.com` endpoint returns a different schema for
19+
`lyria-3-pro-preview` than the public `generativelanguage.googleapis.com` API:
20+
- non-streaming responses use `outputs: List[Content]` instead of the modern
21+
`steps: List[Step]`,
22+
- streaming SSE events use `interaction.start`, `content.start/delta/stop`, and
23+
`interaction.complete` instead of the modern `interaction.created`,
24+
`step.start/delta/stop`, and `interaction.completed`.
25+
26+
Two cooperating mechanisms cover the surface:
27+
28+
1. **Data inspection — non-streaming.** `Interaction._maybe_coerce_outputs`
29+
checks whether the response body's `model` field is in `LEGACY_LYRIA_MODELS`
30+
and rewrites `outputs` to `steps` accordingly. The model field is present on
31+
every Interaction body produced by `create()`, `get()`, and any deferred
32+
parse via `with_raw_response.parse()`, including the nested `interaction`
33+
body inside `interaction.created` / `interaction.completed` SSE events.
34+
This helper does not consult any contextvar; data is the only signal.
35+
36+
2. **Stream subclass + contextvar — streaming SSE event renames.** Per-event
37+
`event_type` renames have to happen *before* the discriminated-union
38+
dispatch runs and most events don't carry a model field, so we use a
39+
per-iteration contextvar (`LEGACY_LYRIA_SHIM_CTX`) instead of data
40+
inspection. `_base_client._process_response_data` reads it to gate the
41+
rename helper. Two stream subclasses set the contextvar:
42+
43+
- `LegacyLyriaInteractionStream` / `LegacyLyriaInteractionAsyncStream`:
44+
activate the contextvar unconditionally on entry. Used by `create()`'s
45+
streaming path, where `is_legacy_lyria_request` lets the resource layer
46+
pre-detect the legacy case at request time.
47+
48+
- `LegacyLyriaInteractionDetectingStream` / `LegacyLyriaInteractionDetectingAsyncStream`:
49+
activate the contextvar lazily, only on observing the first legacy
50+
`event_type`. Used by `get()`'s streaming path, where the model is
51+
unknown until the first event arrives.
52+
53+
Both pairs reset the contextvar in `finally:` so activation is scoped to
54+
one iteration.
55+
"""
56+
57+
from __future__ import annotations
58+
59+
from typing import TYPE_CHECKING, Any, Dict, TypeVar, cast
60+
from contextvars import ContextVar
61+
from typing_extensions import override
62+
63+
from ._streaming import Stream, AsyncStream
64+
65+
if TYPE_CHECKING:
66+
from collections.abc import Iterator, AsyncIterator
67+
68+
__all__ = [
69+
"LEGACY_LYRIA_SHIM_CTX",
70+
"LEGACY_LYRIA_MODELS",
71+
"is_legacy_lyria_request",
72+
"is_legacy_lyria_response_body",
73+
"maybe_remap_legacy_sse_event",
74+
"LegacyLyriaInteractionStream",
75+
"LegacyLyriaInteractionAsyncStream",
76+
"LegacyLyriaInteractionDetectingStream",
77+
"LegacyLyriaInteractionDetectingAsyncStream",
78+
]
79+
80+
_T = TypeVar("_T")
81+
82+
# Set by the streaming subclasses below for the lifetime of one iteration. Read
83+
# by `_base_client._process_response_data` to gate the per-SSE-event
84+
# `event_type` rename (which must happen before discriminator-union dispatch).
85+
# Not consulted by `Interaction._maybe_coerce_outputs` — that helper is purely
86+
# data-gated, so a contextvar leak across yields cannot trigger spurious
87+
# Interaction rewrites.
88+
LEGACY_LYRIA_SHIM_CTX: ContextVar[bool] = ContextVar("legacy_lyria_shim", default=False)
89+
90+
# Models known to return the legacy vertex shape. Currently exactly one. Kept
91+
# as a frozenset so additional models can be added without touching call sites.
92+
LEGACY_LYRIA_MODELS = frozenset({"lyria-3-pro-preview"})
93+
94+
# Mapping of legacy SSE event_type values to their modern equivalents in the
95+
# `InteractionSSEEvent` discriminator union. Captured live from the vertex
96+
# endpoint for `lyria-3-pro-preview`.
97+
_LEGACY_EVENT_TYPE_RENAMES: Dict[str, str] = {
98+
"interaction.start": "interaction.created",
99+
"content.start": "step.start",
100+
"content.delta": "step.delta",
101+
"content.stop": "step.stop",
102+
"interaction.complete": "interaction.completed",
103+
}
104+
105+
106+
def is_legacy_lyria_request(*, is_vertex: bool, model: object) -> bool:
107+
"""Return True iff the (client, model) combination needs the shim active.
108+
109+
Used at request issue time (in the resource layer) to decide whether to
110+
pick the `LegacyLyriaInteractionStream` subclass for streaming requests.
111+
"""
112+
return bool(is_vertex) and isinstance(model, str) and model in LEGACY_LYRIA_MODELS
113+
114+
115+
def is_legacy_lyria_response_body(data: object) -> bool:
116+
"""Return True iff a parsed response body identifies itself as a legacy-lyria payload.
117+
118+
Used at parse time (inside `Interaction._maybe_coerce_outputs`) to gate
119+
the `outputs` -> `steps` rewrite. Works for any path that produces an
120+
Interaction body — including `get()` (where the model isn't known until
121+
the response arrives) and `with_raw_response.parse()` (where parsing
122+
happens after the resource-level detection has already returned).
123+
"""
124+
if not isinstance(data, dict):
125+
return False
126+
typed_data: Dict[str, Any] = cast("Dict[str, Any]", data)
127+
model = typed_data.get("model")
128+
return isinstance(model, str) and model in LEGACY_LYRIA_MODELS
129+
130+
131+
def maybe_remap_legacy_sse_event(data: Dict[str, Any]) -> Dict[str, Any]:
132+
"""Translate one legacy SSE event dict to the modern `InteractionSSEEvent` shape.
133+
134+
Returns the input unchanged if the `event_type` is not one of the legacy
135+
ones we know how to map. Only the `content.start` mapping is non-trivial:
136+
the legacy event carries a single `content: <Content>` block, while the
137+
modern `step.start` event expects `step: {type: "model_output", content:
138+
[<Content>]}`.
139+
"""
140+
event_type = data.get("event_type")
141+
if not isinstance(event_type, str) or event_type not in _LEGACY_EVENT_TYPE_RENAMES:
142+
return data
143+
144+
new_data: Dict[str, Any] = {**data, "event_type": _LEGACY_EVENT_TYPE_RENAMES[event_type]}
145+
146+
if event_type == "content.start":
147+
content = new_data.pop("content", None)
148+
new_data["step"] = {
149+
"type": "model_output",
150+
"content": [content] if content is not None else [],
151+
}
152+
153+
return new_data
154+
155+
156+
def _is_legacy_event_dict(data: Any) -> bool:
157+
if not isinstance(data, dict):
158+
return False
159+
typed_data: Dict[str, Any] = cast("Dict[str, Any]", data)
160+
event_type = typed_data.get("event_type")
161+
return isinstance(event_type, str) and event_type in _LEGACY_EVENT_TYPE_RENAMES
162+
163+
164+
class LegacyLyriaInteractionStream(Stream[_T]):
165+
"""Sync stream subclass that activates the legacy-lyria shim eagerly.
166+
167+
Used by `create(stream=True)` where the resource layer pre-detects the
168+
legacy case via `is_legacy_lyria_request(...)`. The contextvar is set on
169+
iteration start and reset in `finally`, so even an unrecognized first
170+
event won't disable the shim — every event runs through the rename helper
171+
(which is a no-op for unrecognized event_types).
172+
"""
173+
174+
@override
175+
def __stream__(self) -> "Iterator[_T]":
176+
token = LEGACY_LYRIA_SHIM_CTX.set(True)
177+
try:
178+
yield from super().__stream__()
179+
finally:
180+
LEGACY_LYRIA_SHIM_CTX.reset(token)
181+
182+
183+
class LegacyLyriaInteractionAsyncStream(AsyncStream[_T]):
184+
"""Async counterpart of `LegacyLyriaInteractionStream`."""
185+
186+
@override
187+
async def __stream__(self) -> "AsyncIterator[_T]":
188+
token = LEGACY_LYRIA_SHIM_CTX.set(True)
189+
try:
190+
async for item in super().__stream__():
191+
yield item
192+
finally:
193+
LEGACY_LYRIA_SHIM_CTX.reset(token)
194+
195+
196+
class LegacyLyriaInteractionDetectingStream(Stream[_T]):
197+
"""Sync stream subclass that activates the shim lazily on first legacy event.
198+
199+
Used by `get(stream=True)` where the model isn't known at request time, so
200+
we can't pre-detect. Replicates `Stream.__stream__` to peek at each raw
201+
event dict before parsing; the first event whose `event_type` matches a
202+
known legacy variant flips `LEGACY_LYRIA_SHIM_CTX` for the rest of the
203+
iteration. Reset in `finally`.
204+
205+
For non-legacy interactions the dynamic detection never activates and the
206+
subclass is a no-op vs. plain `Stream`.
207+
"""
208+
209+
@override
210+
def __stream__(self) -> "Iterator[_T]":
211+
cast_to = cast(Any, self._cast_to)
212+
response = self.response
213+
process_data = self._client._process_response_data
214+
iterator = self._iter_events()
215+
token = None
216+
try:
217+
for sse in iterator:
218+
if sse.data.startswith("[DONE]"):
219+
break
220+
data = sse.json()
221+
if token is None and _is_legacy_event_dict(data):
222+
token = LEGACY_LYRIA_SHIM_CTX.set(True)
223+
yield process_data(data=data, cast_to=cast_to, response=response)
224+
finally:
225+
if token is not None:
226+
LEGACY_LYRIA_SHIM_CTX.reset(token)
227+
response.close()
228+
229+
230+
class LegacyLyriaInteractionDetectingAsyncStream(AsyncStream[_T]):
231+
"""Async counterpart of `LegacyLyriaInteractionDetectingStream`."""
232+
233+
@override
234+
async def __stream__(self) -> "AsyncIterator[_T]":
235+
cast_to = cast(Any, self._cast_to)
236+
response = self.response
237+
process_data = self._client._process_response_data
238+
iterator = self._iter_events()
239+
token = None
240+
try:
241+
async for sse in iterator:
242+
if sse.data.startswith("[DONE]"):
243+
break
244+
data = sse.json()
245+
if token is None and _is_legacy_event_dict(data):
246+
token = LEGACY_LYRIA_SHIM_CTX.set(True)
247+
yield process_data(data=data, cast_to=cast_to, response=response)
248+
finally:
249+
if token is not None:
250+
LEGACY_LYRIA_SHIM_CTX.reset(token)
251+
await response.aclose()

google/genai/_interactions/resources/interactions.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@
3535
)
3636
from .._streaming import Stream, AsyncStream
3737
from .._base_client import make_request_options
38+
from .._legacy_lyria import (
39+
LegacyLyriaInteractionStream,
40+
LegacyLyriaInteractionAsyncStream,
41+
LegacyLyriaInteractionDetectingStream,
42+
LegacyLyriaInteractionDetectingAsyncStream,
43+
is_legacy_lyria_request,
44+
)
3845
from ..types.tool_param import ToolParam
3946
from ..types.interaction import Interaction
4047
from ..types.model_param import ModelParam
@@ -473,6 +480,17 @@ def create(
473480
raise ValueError("Invalid request: specified `model` and `agent_config`. If specifying `model`, use `generation_config`.")
474481
if agent is not omit and generation_config is not omit:
475482
raise ValueError("Invalid request: specified `agent` and `generation_config`. If specifying `agent`, use `agent_config`.")
483+
484+
# For streaming requests against vertex+legacy-lyria, swap in the
485+
# Stream subclass that activates the per-event SSE remap during
486+
# iteration. Non-streaming and `get()` paths don't need any resource-
487+
# layer signal here — `Interaction._maybe_coerce_outputs` looks at the
488+
# response body's `model` field directly.
489+
stream_cls = (
490+
LegacyLyriaInteractionStream[InteractionSSEEvent]
491+
if (stream and is_legacy_lyria_request(is_vertex=self._client._is_vertex, model=model))
492+
else Stream[InteractionSSEEvent]
493+
)
476494
return self._post(
477495
self._client._build_maybe_vertex_path(api_version=api_version, path='interactions'),
478496
body=maybe_transform(
@@ -503,7 +521,7 @@ def create(
503521
),
504522
cast_to=Interaction,
505523
stream=stream or False,
506-
stream_cls=Stream[InteractionSSEEvent],
524+
stream_cls=stream_cls,
507525
)
508526

509527
def delete(
@@ -719,6 +737,17 @@ def get(
719737
raise ValueError(f"Expected a non-empty value for `api_version` but received {api_version!r}")
720738
if not id:
721739
raise ValueError(f"Expected a non-empty value for `id` but received {id!r}")
740+
741+
# We don't know the model up front for `get`, so we can't apply the
742+
# same `is_legacy_lyria_request` gate that `create` uses. Instead, on
743+
# vertex we hand the stream off to the detecting subclass, which
744+
# activates the shim only after observing the first legacy event_type.
745+
# For non-legacy interactions the subclass is a no-op vs. plain Stream.
746+
stream_cls = (
747+
LegacyLyriaInteractionDetectingStream[InteractionSSEEvent]
748+
if (stream and self._client._is_vertex)
749+
else Stream[InteractionSSEEvent]
750+
)
722751
return self._get(
723752
self._client._build_maybe_vertex_path(api_version=api_version, path=f'interactions/{id}'),
724753
options=make_request_options(
@@ -737,7 +766,7 @@ def get(
737766
),
738767
cast_to=Interaction,
739768
stream=stream or False,
740-
stream_cls=Stream[InteractionSSEEvent],
769+
stream_cls=stream_cls,
741770
)
742771

743772

@@ -1169,6 +1198,13 @@ async def create(
11691198
raise ValueError("Invalid request: specified `model` and `agent_config`. If specifying `model`, use `generation_config`.")
11701199
if agent is not omit and generation_config is not omit:
11711200
raise ValueError("Invalid request: specified `agent` and `generation_config`. If specifying `agent`, use `agent_config`.")
1201+
1202+
# See sync `create` above for rationale.
1203+
stream_cls = (
1204+
LegacyLyriaInteractionAsyncStream[InteractionSSEEvent]
1205+
if (stream and is_legacy_lyria_request(is_vertex=self._client._is_vertex, model=model))
1206+
else AsyncStream[InteractionSSEEvent]
1207+
)
11721208
return await self._post(
11731209
self._client._build_maybe_vertex_path(api_version=api_version, path='interactions'),
11741210
body=await async_maybe_transform(
@@ -1199,7 +1235,7 @@ async def create(
11991235
),
12001236
cast_to=Interaction,
12011237
stream=stream or False,
1202-
stream_cls=AsyncStream[InteractionSSEEvent],
1238+
stream_cls=stream_cls,
12031239
)
12041240

12051241
async def delete(
@@ -1415,6 +1451,13 @@ async def get(
14151451
raise ValueError(f"Expected a non-empty value for `api_version` but received {api_version!r}")
14161452
if not id:
14171453
raise ValueError(f"Expected a non-empty value for `id` but received {id!r}")
1454+
1455+
# See sync `get` above for rationale.
1456+
stream_cls = (
1457+
LegacyLyriaInteractionDetectingAsyncStream[InteractionSSEEvent]
1458+
if (stream and self._client._is_vertex)
1459+
else AsyncStream[InteractionSSEEvent]
1460+
)
14181461
return await self._get(
14191462
self._client._build_maybe_vertex_path(api_version=api_version, path=f'interactions/{id}'),
14201463
options=make_request_options(
@@ -1433,7 +1476,7 @@ async def get(
14331476
),
14341477
cast_to=Interaction,
14351478
stream=stream or False,
1436-
stream_cls=AsyncStream[InteractionSSEEvent],
1479+
stream_cls=stream_cls,
14371480
)
14381481

14391482

0 commit comments

Comments
 (0)