Skip to content

Commit e7eb5fe

Browse files
wyf7107copybara-github
authored andcommitted
fix: Support generalized history config injection for Gemini 3.1 Live on Vertex AI
Port of GitHub PR: 61a3933 Expose history_config in RunConfig and map it to LLM live connect request configuration. Generalize history connection logic to automatically inject initial_history_in_client_content = True when seeding history on a fresh connection for both Gemini API and Vertex AI backends. Co-authored-by: Yifan Wang <wanyif@google.com> PiperOrigin-RevId: 927465431
1 parent 4eb337e commit e7eb5fe

4 files changed

Lines changed: 167 additions & 0 deletions

File tree

src/google/adk/agents/run_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ class RunConfig(BaseModel):
247247
session_resumption: Optional[types.SessionResumptionConfig] = None
248248
"""Configures session resumption mechanism. Only support transparent session resumption mode now."""
249249

250+
history_config: Optional[types.HistoryConfig] = None
251+
"""Configures the exchange of history between the client and the server."""
252+
250253
context_window_compression: Optional[types.ContextWindowCompressionConfig] = (
251254
None
252255
)

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,28 @@ async def run_live(
552552
if session_resumption.transparent is None:
553553
session_resumption.transparent = True
554554

555+
# When seeding a fresh connection with prior conversation history, set
556+
# initial_history_in_client_content to True. This tells the Live server
557+
# that the provided history already includes the model's past responses,
558+
# preventing the server from generating duplicate responses for those replayed turns.
559+
if (
560+
llm_request.contents
561+
and not invocation_context.live_session_resumption_handle
562+
):
563+
if not llm_request.live_connect_config:
564+
llm_request.live_connect_config = types.LiveConnectConfig()
565+
if not llm_request.live_connect_config.history_config:
566+
llm_request.live_connect_config.history_config = (
567+
types.HistoryConfig()
568+
)
569+
if (
570+
llm_request.live_connect_config.history_config.initial_history_in_client_content
571+
is None
572+
):
573+
llm_request.live_connect_config.history_config.initial_history_in_client_content = (
574+
True
575+
)
576+
555577
logger.info(
556578
'Establishing live connection for agent: %s',
557579
invocation_context.agent.name,

src/google/adk/flows/llm_flows/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def _build_basic_request(
9191
llm_request.live_connect_config.session_resumption = (
9292
invocation_context.run_config.session_resumption
9393
)
94+
llm_request.live_connect_config.history_config = (
95+
invocation_context.run_config.history_config
96+
)
9497
llm_request.live_connect_config.context_window_compression = (
9598
invocation_context.run_config.context_window_compression
9699
)

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from google.adk.plugins.base_plugin import BasePlugin
3131
from google.adk.tools.base_toolset import BaseToolset
3232
from google.adk.tools.google_search_tool import GoogleSearchTool
33+
from google.adk.utils.variant_utils import GoogleLLMVariant
3334
from google.genai import types
3435
import pytest
3536
from websockets.exceptions import ConnectionClosed
@@ -1386,3 +1387,141 @@ async def mock_receive_2():
13861387
second_call_req = mock_connect.call_args_list[1][0][0]
13871388
session_resump = second_call_req.live_connect_config.session_resumption
13881389
assert session_resump.transparent
1390+
1391+
1392+
@pytest.mark.asyncio
1393+
@pytest.mark.parametrize(
1394+
'api_backend',
1395+
[
1396+
GoogleLLMVariant.GEMINI_API,
1397+
GoogleLLMVariant.VERTEX_AI,
1398+
],
1399+
)
1400+
async def test_run_live_history_config_set_for_all_backends(api_backend):
1401+
"""Test that run_live sets history_config for all backends."""
1402+
1403+
real_model = Gemini(model='gemini-3.1-flash-live-preview')
1404+
mock_connection = mock.AsyncMock()
1405+
1406+
agent = Agent(name='test_agent', model=real_model)
1407+
invocation_context = await testing_utils.create_invocation_context(
1408+
agent=agent
1409+
)
1410+
invocation_context.live_request_queue = LiveRequestQueue()
1411+
invocation_context.run_config = RunConfig()
1412+
1413+
flow = BaseLlmFlowForTesting()
1414+
1415+
async def mock_preprocess(ctx, req):
1416+
req.contents = [types.Content(parts=[types.Part.from_text(text='history')])]
1417+
from google.adk.flows.llm_flows.basic import _build_basic_request
1418+
1419+
_build_basic_request(ctx, req)
1420+
yield Event(id=Event.new_id(), author='test')
1421+
1422+
with mock.patch.object(
1423+
flow, '_preprocess_async', side_effect=mock_preprocess
1424+
):
1425+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
1426+
1427+
class StopTestError(Exception):
1428+
pass
1429+
1430+
async def mock_receive():
1431+
yield LlmResponse(
1432+
content=types.Content(parts=[types.Part.from_text(text='hi')])
1433+
)
1434+
raise StopTestError('stop')
1435+
1436+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
1437+
1438+
with mock.patch(
1439+
'google.adk.models.google_llm.Gemini.connect'
1440+
) as mock_connect:
1441+
mock_connect.return_value.__aenter__.return_value = mock_connection
1442+
1443+
# Mock the api_backend property
1444+
with mock.patch.object(
1445+
Gemini,
1446+
'_api_backend',
1447+
new_callable=mock.PropertyMock,
1448+
return_value=api_backend,
1449+
):
1450+
try:
1451+
async for _ in flow.run_live(invocation_context):
1452+
pass
1453+
except StopTestError:
1454+
pass
1455+
1456+
assert mock_connect.call_count == 1
1457+
called_req = mock_connect.call_args[0][0]
1458+
assert called_req.live_connect_config is not None
1459+
assert called_req.live_connect_config.history_config is not None
1460+
assert (
1461+
called_req.live_connect_config.history_config.initial_history_in_client_content
1462+
is True
1463+
)
1464+
1465+
1466+
@pytest.mark.asyncio
1467+
async def test_run_live_respects_explicit_initial_history_in_client_content_false():
1468+
"""Test that run_live respects explicit initial_history_in_client_content=False in RunConfig."""
1469+
1470+
real_model = Gemini()
1471+
mock_connection = mock.AsyncMock()
1472+
1473+
agent = Agent(name='test_agent', model=real_model)
1474+
invocation_context = await testing_utils.create_invocation_context(
1475+
agent=agent
1476+
)
1477+
invocation_context.live_request_queue = LiveRequestQueue()
1478+
run_config = RunConfig(
1479+
history_config=types.HistoryConfig(
1480+
initial_history_in_client_content=False
1481+
)
1482+
)
1483+
invocation_context.run_config = run_config
1484+
1485+
flow = BaseLlmFlowForTesting()
1486+
1487+
async def mock_preprocess(ctx, req):
1488+
req.contents = [types.Content(parts=[types.Part.from_text(text='history')])]
1489+
from google.adk.flows.llm_flows.basic import _build_basic_request
1490+
1491+
_build_basic_request(ctx, req)
1492+
yield Event(id=Event.new_id(), author='test')
1493+
1494+
with mock.patch.object(
1495+
flow, '_preprocess_async', side_effect=mock_preprocess
1496+
):
1497+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
1498+
1499+
class StopTestError(Exception):
1500+
pass
1501+
1502+
async def mock_receive():
1503+
yield LlmResponse(
1504+
content=types.Content(parts=[types.Part.from_text(text='hi')])
1505+
)
1506+
raise StopTestError('stop')
1507+
1508+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
1509+
1510+
with mock.patch(
1511+
'google.adk.models.google_llm.Gemini.connect'
1512+
) as mock_connect:
1513+
mock_connect.return_value.__aenter__.return_value = mock_connection
1514+
1515+
try:
1516+
async for _ in flow.run_live(invocation_context):
1517+
pass
1518+
except StopTestError:
1519+
pass
1520+
1521+
assert mock_connect.call_count == 1
1522+
call_req = mock_connect.call_args[0][0]
1523+
assert call_req.live_connect_config.history_config is not None
1524+
assert (
1525+
call_req.live_connect_config.history_config.initial_history_in_client_content
1526+
is False
1527+
)

0 commit comments

Comments
 (0)