Skip to content

Commit 8bebbee

Browse files
committed
feat(runner): add metadata parameter to run(), run_live(), run_debug()
Add metadata support to all run methods for consistency: - run(): sync wrapper, passes metadata to run_async() - run_live(): live mode, passes metadata through invocation context - run_debug(): debug helper, passes metadata to run_async() Also update InvocationContext docstring to reflect all supported entry points.
1 parent 02ff34f commit 8bebbee

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,14 @@ class InvocationContext(BaseModel):
207207
"""The cache of canonical tools for this invocation."""
208208

209209
metadata: Optional[dict[str, Any]] = None
210-
"""Per-request metadata passed from Runner.run_async().
210+
"""Per-request metadata passed from Runner entry points.
211211
212212
This field allows passing arbitrary metadata that can be accessed during
213213
the invocation lifecycle, particularly in callbacks like before_model_callback.
214214
Common use cases include passing user_id, trace_id, memory context keys, or
215215
other request-specific context that needs to be available during processing.
216+
217+
Supported entry points: run(), run_async(), run_live(), run_debug().
216218
"""
217219

218220
_invocation_cost_manager: _InvocationCostManager = PrivateAttr(

src/google/adk/runners.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def run(
392392
session_id: str,
393393
new_message: types.Content,
394394
run_config: Optional[RunConfig] = None,
395+
metadata: Optional[dict[str, Any]] = None,
395396
) -> Generator[Event, None, None]:
396397
"""Runs the agent.
397398
@@ -409,6 +410,7 @@ def run(
409410
session_id: The session ID of the session.
410411
new_message: A new message to append to the session.
411412
run_config: The run config for the agent.
413+
metadata: Optional per-request metadata that will be passed to callbacks.
412414
413415
Yields:
414416
The events generated by the agent.
@@ -424,6 +426,7 @@ async def _invoke_run_async():
424426
session_id=session_id,
425427
new_message=new_message,
426428
run_config=run_config,
429+
metadata=metadata,
427430
)
428431
) as agen:
429432
async for event in agen:
@@ -943,6 +946,7 @@ async def run_live(
943946
live_request_queue: LiveRequestQueue,
944947
run_config: Optional[RunConfig] = None,
945948
session: Optional[Session] = None,
949+
metadata: Optional[dict[str, Any]] = None,
946950
) -> AsyncGenerator[Event, None]:
947951
"""Runs the agent in live mode (experimental feature).
948952
@@ -984,6 +988,7 @@ async def run_live(
984988
run_config: The run config for the agent.
985989
session: The session to use. This parameter is deprecated, please use
986990
`user_id` and `session_id` instead.
991+
metadata: Optional per-request metadata that will be passed to callbacks.
987992
988993
Yields:
989994
AsyncGenerator[Event, None]: An asynchronous generator that yields
@@ -998,6 +1003,7 @@ async def run_live(
9981003
Either `session` or both `user_id` and `session_id` must be provided.
9991004
"""
10001005
run_config = run_config or RunConfig()
1006+
metadata = metadata.copy() if metadata is not None else None
10011007
# Some native audio models requires the modality to be set. So we set it to
10021008
# AUDIO by default.
10031009
if run_config.response_modalities is None:
@@ -1023,6 +1029,7 @@ async def run_live(
10231029
session,
10241030
live_request_queue=live_request_queue,
10251031
run_config=run_config,
1032+
metadata=metadata,
10261033
)
10271034

10281035
root_agent = self.agent
@@ -1175,6 +1182,7 @@ async def run_debug(
11751182
run_config: RunConfig | None = None,
11761183
quiet: bool = False,
11771184
verbose: bool = False,
1185+
metadata: dict[str, Any] | None = None,
11781186
) -> list[Event]:
11791187
"""Debug helper for quick agent experimentation and testing.
11801188
@@ -1198,6 +1206,7 @@ async def run_debug(
11981206
shown).
11991207
verbose: If True, shows detailed tool calls and responses. Defaults to
12001208
False for cleaner output showing only final agent responses.
1209+
metadata: Optional per-request metadata that will be passed to callbacks.
12011210
12021211
Returns:
12031212
list[Event]: All events from all messages.
@@ -1260,6 +1269,7 @@ async def run_debug(
12601269
session_id=session.id,
12611270
new_message=types.UserContent(parts=[types.Part(text=message)]),
12621271
run_config=run_config,
1272+
metadata=metadata,
12631273
):
12641274
if not quiet:
12651275
print_event(event, verbose=verbose)
@@ -1449,6 +1459,7 @@ def _new_invocation_context_for_live(
14491459
*,
14501460
live_request_queue: LiveRequestQueue,
14511461
run_config: Optional[RunConfig] = None,
1462+
metadata: Optional[dict[str, Any]] = None,
14521463
) -> InvocationContext:
14531464
"""Creates a new invocation context for live multi-agent."""
14541465
run_config = run_config or RunConfig()
@@ -1467,6 +1478,7 @@ def _new_invocation_context_for_live(
14671478
session,
14681479
live_request_queue=live_request_queue,
14691480
run_config=run_config,
1481+
metadata=metadata,
14701482
)
14711483

14721484
async def _handle_new_message(

tests/unittests/test_runners.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from google.adk.agents.base_agent import BaseAgent
2424
from google.adk.agents.context_cache_config import ContextCacheConfig
25+
from google.adk.agents.live_request_queue import LiveRequestQueue
2526
from google.adk.agents.invocation_context import InvocationContext
2627
from google.adk.agents.live_request_queue import LiveRequestQueue
2728
from google.adk.agents.llm_agent import LlmAgent
@@ -36,6 +37,7 @@
3637
from google.adk.plugins.base_plugin import BasePlugin
3738
from google.adk.runners import Runner
3839
from google.adk.sessions.in_memory_session_service import InMemorySessionService
40+
from tests.unittests import testing_utils
3941
from google.adk.sessions.session import Session
4042
from google.adk.tools.function_tool import FunctionTool
4143
from google.genai import types
@@ -1567,6 +1569,125 @@ def before_model_callback(callback_context, llm_request):
15671569
# Nested object changes in callback WILL affect original (shallow copy behavior)
15681570
assert original_metadata["nested"]["inner_key"] == "modified_nested"
15691571

1572+
def test_new_invocation_context_for_live_with_metadata(self):
1573+
"""Test that _new_invocation_context_for_live correctly passes metadata."""
1574+
mock_session = Session(
1575+
id=TEST_SESSION_ID,
1576+
app_name=TEST_APP_ID,
1577+
user_id=TEST_USER_ID,
1578+
events=[],
1579+
)
1580+
1581+
test_metadata = {"user_id": "live_user", "trace_id": "live_trace"}
1582+
invocation_context = self.runner._new_invocation_context_for_live(
1583+
mock_session, metadata=test_metadata
1584+
)
1585+
1586+
assert invocation_context.metadata == test_metadata
1587+
assert invocation_context.metadata["user_id"] == "live_user"
1588+
1589+
@pytest.mark.asyncio
1590+
async def test_run_sync_passes_metadata(self):
1591+
"""Test that sync run() correctly passes metadata to run_async()."""
1592+
captured_metadata = None
1593+
1594+
def before_model_callback(callback_context, llm_request):
1595+
nonlocal captured_metadata
1596+
captured_metadata = llm_request.metadata
1597+
return LlmResponse(
1598+
content=types.Content(
1599+
role="model", parts=[types.Part(text="Test response")]
1600+
)
1601+
)
1602+
1603+
agent_with_callback = LlmAgent(
1604+
name="callback_agent",
1605+
model="gemini-2.0-flash",
1606+
before_model_callback=before_model_callback,
1607+
)
1608+
1609+
runner_with_callback = Runner(
1610+
app_name="test_app",
1611+
agent=agent_with_callback,
1612+
session_service=self.session_service,
1613+
artifact_service=self.artifact_service,
1614+
)
1615+
1616+
await self.session_service.create_session(
1617+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
1618+
)
1619+
1620+
test_metadata = {"sync_key": "sync_value"}
1621+
1622+
for event in runner_with_callback.run(
1623+
user_id=TEST_USER_ID,
1624+
session_id=TEST_SESSION_ID,
1625+
new_message=types.Content(
1626+
role="user", parts=[types.Part(text="Hello")]
1627+
),
1628+
metadata=test_metadata,
1629+
):
1630+
pass
1631+
1632+
assert captured_metadata is not None
1633+
assert captured_metadata["sync_key"] == "sync_value"
1634+
1635+
@pytest.mark.asyncio
1636+
async def test_run_live_passes_metadata_to_llm_request(self):
1637+
"""Test that run_live() passes metadata through live pipeline to LlmRequest."""
1638+
import asyncio
1639+
1640+
# Create MockModel to capture LlmRequest
1641+
mock_model = testing_utils.MockModel.create(
1642+
responses=[
1643+
LlmResponse(
1644+
content=types.Content(
1645+
role="model", parts=[types.Part(text="Live response")]
1646+
)
1647+
)
1648+
]
1649+
)
1650+
1651+
agent_with_mock = LlmAgent(
1652+
name="live_mock_agent",
1653+
model=mock_model,
1654+
)
1655+
1656+
runner_with_mock = Runner(
1657+
app_name="test_app",
1658+
agent=agent_with_mock,
1659+
session_service=self.session_service,
1660+
artifact_service=self.artifact_service,
1661+
)
1662+
1663+
await self.session_service.create_session(
1664+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
1665+
)
1666+
1667+
test_metadata = {"live_key": "live_value", "trace_id": "live_trace_123"}
1668+
live_queue = LiveRequestQueue()
1669+
live_queue.close() # Close immediately to end the live session
1670+
1671+
async def consume_events():
1672+
async for event in runner_with_mock.run_live(
1673+
user_id=TEST_USER_ID,
1674+
session_id=TEST_SESSION_ID,
1675+
live_request_queue=live_queue,
1676+
metadata=test_metadata,
1677+
):
1678+
pass
1679+
1680+
try:
1681+
await asyncio.wait_for(consume_events(), timeout=2)
1682+
except asyncio.TimeoutError:
1683+
pass # Expected - live session may not terminate cleanly
1684+
1685+
# Verify MockModel received LlmRequest with correct metadata
1686+
assert len(mock_model.requests) > 0
1687+
assert mock_model.requests[0].metadata is not None
1688+
assert mock_model.requests[0].metadata["live_key"] == "live_value"
1689+
assert mock_model.requests[0].metadata["trace_id"] == "live_trace_123"
1690+
15701691

15711692
if __name__ == "__main__":
15721693
pytest.main([__file__])

0 commit comments

Comments
 (0)