Skip to content

Commit 84c6885

Browse files
committed
feat(runner): add metadata parameter to run(), run_live(), run_debug()
Add metadata support to all run methods for consistency: - run(): sync wrapper, passes metadata to run_async() - run_live(): live mode, passes metadata through invocation context - run_debug(): debug helper, passes metadata to run_async() Also update InvocationContext docstring to reflect all supported entry points.
1 parent 652ce46 commit 84c6885

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,14 @@ class InvocationContext(BaseModel):
207207
"""The cache of canonical tools for this invocation."""
208208

209209
metadata: Optional[dict[str, Any]] = None
210-
"""Per-request metadata passed from Runner.run_async().
210+
"""Per-request metadata passed from Runner entry points.
211211
212212
This field allows passing arbitrary metadata that can be accessed during
213213
the invocation lifecycle, particularly in callbacks like before_model_callback.
214214
Common use cases include passing user_id, trace_id, memory context keys, or
215215
other request-specific context that needs to be available during processing.
216+
217+
Supported entry points: run(), run_async(), run_live(), run_debug().
216218
"""
217219

218220
_invocation_cost_manager: _InvocationCostManager = PrivateAttr(

src/google/adk/runners.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ def run(
390390
session_id: str,
391391
new_message: types.Content,
392392
run_config: Optional[RunConfig] = None,
393+
metadata: Optional[dict[str, Any]] = None,
393394
) -> Generator[Event, None, None]:
394395
"""Runs the agent.
395396
@@ -407,6 +408,7 @@ def run(
407408
session_id: The session ID of the session.
408409
new_message: A new message to append to the session.
409410
run_config: The run config for the agent.
411+
metadata: Optional per-request metadata that will be passed to callbacks.
410412
411413
Yields:
412414
The events generated by the agent.
@@ -422,6 +424,7 @@ async def _invoke_run_async():
422424
session_id=session_id,
423425
new_message=new_message,
424426
run_config=run_config,
427+
metadata=metadata,
425428
)
426429
) as agen:
427430
async for event in agen:
@@ -941,6 +944,7 @@ async def run_live(
941944
live_request_queue: LiveRequestQueue,
942945
run_config: Optional[RunConfig] = None,
943946
session: Optional[Session] = None,
947+
metadata: Optional[dict[str, Any]] = None,
944948
) -> AsyncGenerator[Event, None]:
945949
"""Runs the agent in live mode (experimental feature).
946950
@@ -982,6 +986,7 @@ async def run_live(
982986
run_config: The run config for the agent.
983987
session: The session to use. This parameter is deprecated, please use
984988
`user_id` and `session_id` instead.
989+
metadata: Optional per-request metadata that will be passed to callbacks.
985990
986991
Yields:
987992
AsyncGenerator[Event, None]: An asynchronous generator that yields
@@ -996,6 +1001,7 @@ async def run_live(
9961001
Either `session` or both `user_id` and `session_id` must be provided.
9971002
"""
9981003
run_config = run_config or RunConfig()
1004+
metadata = metadata.copy() if metadata is not None else None
9991005
# Some native audio models requires the modality to be set. So we set it to
10001006
# AUDIO by default.
10011007
if run_config.response_modalities is None:
@@ -1021,6 +1027,7 @@ async def run_live(
10211027
session,
10221028
live_request_queue=live_request_queue,
10231029
run_config=run_config,
1030+
metadata=metadata,
10241031
)
10251032

10261033
root_agent = self.agent
@@ -1127,6 +1134,7 @@ async def run_debug(
11271134
run_config: RunConfig | None = None,
11281135
quiet: bool = False,
11291136
verbose: bool = False,
1137+
metadata: dict[str, Any] | None = None,
11301138
) -> list[Event]:
11311139
"""Debug helper for quick agent experimentation and testing.
11321140
@@ -1150,6 +1158,7 @@ async def run_debug(
11501158
shown).
11511159
verbose: If True, shows detailed tool calls and responses. Defaults to
11521160
False for cleaner output showing only final agent responses.
1161+
metadata: Optional per-request metadata that will be passed to callbacks.
11531162
11541163
Returns:
11551164
list[Event]: All events from all messages.
@@ -1212,6 +1221,7 @@ async def run_debug(
12121221
session_id=session.id,
12131222
new_message=types.UserContent(parts=[types.Part(text=message)]),
12141223
run_config=run_config,
1224+
metadata=metadata,
12151225
):
12161226
if not quiet:
12171227
print_event(event, verbose=verbose)
@@ -1401,6 +1411,7 @@ def _new_invocation_context_for_live(
14011411
*,
14021412
live_request_queue: LiveRequestQueue,
14031413
run_config: Optional[RunConfig] = None,
1414+
metadata: Optional[dict[str, Any]] = None,
14041415
) -> InvocationContext:
14051416
"""Creates a new invocation context for live multi-agent."""
14061417
run_config = run_config or RunConfig()
@@ -1419,6 +1430,7 @@ def _new_invocation_context_for_live(
14191430
session,
14201431
live_request_queue=live_request_queue,
14211432
run_config=run_config,
1433+
metadata=metadata,
14221434
)
14231435

14241436
async def _handle_new_message(

tests/unittests/test_runners.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from google.adk.agents.base_agent import BaseAgent
2424
from google.adk.agents.context_cache_config import ContextCacheConfig
25+
from google.adk.agents.live_request_queue import LiveRequestQueue
2526
from google.adk.agents.invocation_context import InvocationContext
2627
from google.adk.agents.llm_agent import LlmAgent
2728
from google.adk.agents.run_config import RunConfig
@@ -35,6 +36,7 @@
3536
from google.adk.plugins.base_plugin import BasePlugin
3637
from google.adk.runners import Runner
3738
from google.adk.sessions.in_memory_session_service import InMemorySessionService
39+
from tests.unittests import testing_utils
3840
from google.adk.sessions.session import Session
3941
from google.genai import types
4042
import pytest
@@ -1483,6 +1485,125 @@ def before_model_callback(callback_context, llm_request):
14831485
# Nested object changes in callback WILL affect original (shallow copy behavior)
14841486
assert original_metadata["nested"]["inner_key"] == "modified_nested"
14851487

1488+
def test_new_invocation_context_for_live_with_metadata(self):
1489+
"""Test that _new_invocation_context_for_live correctly passes metadata."""
1490+
mock_session = Session(
1491+
id=TEST_SESSION_ID,
1492+
app_name=TEST_APP_ID,
1493+
user_id=TEST_USER_ID,
1494+
events=[],
1495+
)
1496+
1497+
test_metadata = {"user_id": "live_user", "trace_id": "live_trace"}
1498+
invocation_context = self.runner._new_invocation_context_for_live(
1499+
mock_session, metadata=test_metadata
1500+
)
1501+
1502+
assert invocation_context.metadata == test_metadata
1503+
assert invocation_context.metadata["user_id"] == "live_user"
1504+
1505+
@pytest.mark.asyncio
1506+
async def test_run_sync_passes_metadata(self):
1507+
"""Test that sync run() correctly passes metadata to run_async()."""
1508+
captured_metadata = None
1509+
1510+
def before_model_callback(callback_context, llm_request):
1511+
nonlocal captured_metadata
1512+
captured_metadata = llm_request.metadata
1513+
return LlmResponse(
1514+
content=types.Content(
1515+
role="model", parts=[types.Part(text="Test response")]
1516+
)
1517+
)
1518+
1519+
agent_with_callback = LlmAgent(
1520+
name="callback_agent",
1521+
model="gemini-2.0-flash",
1522+
before_model_callback=before_model_callback,
1523+
)
1524+
1525+
runner_with_callback = Runner(
1526+
app_name="test_app",
1527+
agent=agent_with_callback,
1528+
session_service=self.session_service,
1529+
artifact_service=self.artifact_service,
1530+
)
1531+
1532+
await self.session_service.create_session(
1533+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
1534+
)
1535+
1536+
test_metadata = {"sync_key": "sync_value"}
1537+
1538+
for event in runner_with_callback.run(
1539+
user_id=TEST_USER_ID,
1540+
session_id=TEST_SESSION_ID,
1541+
new_message=types.Content(
1542+
role="user", parts=[types.Part(text="Hello")]
1543+
),
1544+
metadata=test_metadata,
1545+
):
1546+
pass
1547+
1548+
assert captured_metadata is not None
1549+
assert captured_metadata["sync_key"] == "sync_value"
1550+
1551+
@pytest.mark.asyncio
1552+
async def test_run_live_passes_metadata_to_llm_request(self):
1553+
"""Test that run_live() passes metadata through live pipeline to LlmRequest."""
1554+
import asyncio
1555+
1556+
# Create MockModel to capture LlmRequest
1557+
mock_model = testing_utils.MockModel.create(
1558+
responses=[
1559+
LlmResponse(
1560+
content=types.Content(
1561+
role="model", parts=[types.Part(text="Live response")]
1562+
)
1563+
)
1564+
]
1565+
)
1566+
1567+
agent_with_mock = LlmAgent(
1568+
name="live_mock_agent",
1569+
model=mock_model,
1570+
)
1571+
1572+
runner_with_mock = Runner(
1573+
app_name="test_app",
1574+
agent=agent_with_mock,
1575+
session_service=self.session_service,
1576+
artifact_service=self.artifact_service,
1577+
)
1578+
1579+
await self.session_service.create_session(
1580+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
1581+
)
1582+
1583+
test_metadata = {"live_key": "live_value", "trace_id": "live_trace_123"}
1584+
live_queue = LiveRequestQueue()
1585+
live_queue.close() # Close immediately to end the live session
1586+
1587+
async def consume_events():
1588+
async for event in runner_with_mock.run_live(
1589+
user_id=TEST_USER_ID,
1590+
session_id=TEST_SESSION_ID,
1591+
live_request_queue=live_queue,
1592+
metadata=test_metadata,
1593+
):
1594+
pass
1595+
1596+
try:
1597+
await asyncio.wait_for(consume_events(), timeout=2)
1598+
except asyncio.TimeoutError:
1599+
pass # Expected - live session may not terminate cleanly
1600+
1601+
# Verify MockModel received LlmRequest with correct metadata
1602+
assert len(mock_model.requests) > 0
1603+
assert mock_model.requests[0].metadata is not None
1604+
assert mock_model.requests[0].metadata["live_key"] == "live_value"
1605+
assert mock_model.requests[0].metadata["trace_id"] == "live_trace_123"
1606+
14861607

14871608
if __name__ == "__main__":
14881609
pytest.main([__file__])

0 commit comments

Comments
 (0)