Skip to content

Commit aafd97f

Browse files
authored
fix: Support generalized history config injection for Gemini 3.1 Live on Vertex AI (#5999)
1 parent 15eb387 commit aafd97f

4 files changed

Lines changed: 96 additions & 24 deletions

File tree

src/google/adk/agents/run_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ class RunConfig(BaseModel):
247247
session_resumption: Optional[types.SessionResumptionConfig] = None
248248
"""Configures session resumption mechanism. Only support transparent session resumption mode now."""
249249

250+
history_config: Optional[types.HistoryConfig] = None
251+
"""Configures the exchange of history between the client and the server."""
252+
250253
context_window_compression: Optional[types.ContextWindowCompressionConfig] = (
251254
None
252255
)

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -529,18 +529,26 @@ async def run_live(
529529
if session_resumption.transparent is None:
530530
session_resumption.transparent = True
531531

532+
# When seeding a fresh connection with prior conversation history, set
533+
# initial_history_in_client_content to True. This tells the Live server
534+
# that the provided history already includes the model's past responses,
535+
# preventing the server from generating duplicate responses for those replayed turns.
532536
if (
533-
isinstance(llm, Gemini)
534-
and llm._api_backend == GoogleLLMVariant.GEMINI_API
535-
and model_name_utils.is_gemini_3_1_flash_live(llm_request.model)
536-
and llm_request.contents
537+
llm_request.contents
537538
and not invocation_context.live_session_resumption_handle
538539
):
539-
if llm_request.live_connect_config is None:
540+
if not llm_request.live_connect_config:
540541
llm_request.live_connect_config = types.LiveConnectConfig()
541-
if llm_request.live_connect_config.history_config is None:
542+
if not llm_request.live_connect_config.history_config:
542543
llm_request.live_connect_config.history_config = (
543-
types.HistoryConfig(initial_history_in_client_content=True)
544+
types.HistoryConfig()
545+
)
546+
if (
547+
llm_request.live_connect_config.history_config.initial_history_in_client_content
548+
is None
549+
):
550+
llm_request.live_connect_config.history_config.initial_history_in_client_content = (
551+
True
544552
)
545553

546554
logger.info(

src/google/adk/flows/llm_flows/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def _build_basic_request(
9595
llm_request.live_connect_config.session_resumption = (
9696
invocation_context.run_config.session_resumption
9797
)
98+
llm_request.live_connect_config.history_config = (
99+
invocation_context.run_config.history_config
100+
)
98101
llm_request.live_connect_config.context_window_compression = (
99102
invocation_context.run_config.context_window_compression
100103
)

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,16 +1272,14 @@ async def mock_receive_2():
12721272

12731273
@pytest.mark.asyncio
12741274
@pytest.mark.parametrize(
1275-
'api_backend,should_have_history_config',
1275+
'api_backend',
12761276
[
1277-
(GoogleLLMVariant.GEMINI_API, True),
1278-
(GoogleLLMVariant.VERTEX_AI, False),
1277+
GoogleLLMVariant.GEMINI_API,
1278+
GoogleLLMVariant.VERTEX_AI,
12791279
],
12801280
)
1281-
async def test_run_live_history_config_gated_by_backend(
1282-
api_backend, should_have_history_config
1283-
):
1284-
"""Test that run_live only sets history_config for Gemini API backend."""
1281+
async def test_run_live_history_config_set_for_all_backends(api_backend):
1282+
"""Test that run_live sets history_config for all backends."""
12851283

12861284
real_model = Gemini(model='gemini-3.1-flash-live-preview')
12871285
mock_connection = mock.AsyncMock()
@@ -1334,13 +1332,73 @@ async def mock_preprocess(ctx, req):
13341332

13351333
assert mock_connect.call_count == 1
13361334
called_req = mock_connect.call_args[0][0]
1337-
if should_have_history_config:
1338-
assert called_req.live_connect_config is not None
1339-
assert called_req.live_connect_config.history_config is not None
1340-
assert (
1341-
called_req.live_connect_config.history_config.initial_history_in_client_content
1342-
is True
1343-
)
1344-
else:
1345-
if called_req.live_connect_config:
1346-
assert called_req.live_connect_config.history_config is None
1335+
assert called_req.live_connect_config is not None
1336+
assert called_req.live_connect_config.history_config is not None
1337+
assert (
1338+
called_req.live_connect_config.history_config.initial_history_in_client_content
1339+
is True
1340+
)
1341+
1342+
1343+
@pytest.mark.asyncio
1344+
async def test_run_live_respects_explicit_initial_history_in_client_content_false():
1345+
"""Test that run_live respects explicit initial_history_in_client_content=False in RunConfig."""
1346+
1347+
real_model = Gemini()
1348+
mock_connection = mock.AsyncMock()
1349+
1350+
agent = Agent(name='test_agent', model=real_model)
1351+
invocation_context = await testing_utils.create_invocation_context(
1352+
agent=agent
1353+
)
1354+
invocation_context.live_request_queue = LiveRequestQueue()
1355+
run_config = RunConfig(
1356+
history_config=types.HistoryConfig(
1357+
initial_history_in_client_content=False
1358+
)
1359+
)
1360+
invocation_context.run_config = run_config
1361+
1362+
flow = BaseLlmFlowForTesting()
1363+
1364+
async def mock_preprocess(ctx, req):
1365+
req.contents = [types.Content(parts=[types.Part.from_text(text='history')])]
1366+
from google.adk.flows.llm_flows.basic import _build_basic_request
1367+
1368+
_build_basic_request(ctx, req)
1369+
yield Event(id=Event.new_id(), author='test')
1370+
1371+
with mock.patch.object(
1372+
flow, '_preprocess_async', side_effect=mock_preprocess
1373+
):
1374+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
1375+
1376+
class StopTestError(Exception):
1377+
pass
1378+
1379+
async def mock_receive():
1380+
yield LlmResponse(
1381+
content=types.Content(parts=[types.Part.from_text(text='hi')])
1382+
)
1383+
raise StopTestError('stop')
1384+
1385+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
1386+
1387+
with mock.patch(
1388+
'google.adk.models.google_llm.Gemini.connect'
1389+
) as mock_connect:
1390+
mock_connect.return_value.__aenter__.return_value = mock_connection
1391+
1392+
try:
1393+
async for _ in flow.run_live(invocation_context):
1394+
pass
1395+
except StopTestError:
1396+
pass
1397+
1398+
assert mock_connect.call_count == 1
1399+
call_req = mock_connect.call_args[0][0]
1400+
assert call_req.live_connect_config.history_config is not None
1401+
assert (
1402+
call_req.live_connect_config.history_config.initial_history_in_client_content
1403+
is False
1404+
)

0 commit comments

Comments
 (0)