Skip to content

Commit d34a225

Browse files
committed
feat(runner): add metadata parameter to run(), run_live(), run_debug()
Add metadata support to all run methods for consistency: - run(): sync wrapper, passes metadata to run_async() - run_live(): live mode, passes metadata through invocation context - run_debug(): debug helper, passes metadata to run_async() Also update InvocationContext docstring to reflect all supported entry points.
1 parent e32f498 commit d34a225

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,14 @@ class InvocationContext(BaseModel):
215215
"""The cache of canonical tools for this invocation."""
216216

217217
metadata: Optional[dict[str, Any]] = None
218-
"""Per-request metadata passed from Runner.run_async().
218+
"""Per-request metadata passed from Runner entry points.
219219
220220
This field allows passing arbitrary metadata that can be accessed during
221221
the invocation lifecycle, particularly in callbacks like before_model_callback.
222222
Common use cases include passing user_id, trace_id, memory context keys, or
223223
other request-specific context that needs to be available during processing.
224+
225+
Supported entry points: run(), run_async(), run_live(), run_debug().
224226
"""
225227

226228
_invocation_cost_manager: _InvocationCostManager = PrivateAttr(

src/google/adk/runners.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ def run(
432432
session_id: str,
433433
new_message: types.Content,
434434
run_config: Optional[RunConfig] = None,
435+
metadata: Optional[dict[str, Any]] = None,
435436
) -> Generator[Event, None, None]:
436437
"""Runs the agent.
437438
@@ -449,6 +450,7 @@ def run(
449450
session_id: The session ID of the session.
450451
new_message: A new message to append to the session.
451452
run_config: The run config for the agent.
453+
metadata: Optional per-request metadata that will be passed to callbacks.
452454
453455
Yields:
454456
The events generated by the agent.
@@ -464,6 +466,7 @@ async def _invoke_run_async():
464466
session_id=session_id,
465467
new_message=new_message,
466468
run_config=run_config,
469+
metadata=metadata,
467470
)
468471
) as agen:
469472
async for event in agen:
@@ -1002,6 +1005,7 @@ async def run_live(
10021005
live_request_queue: LiveRequestQueue,
10031006
run_config: Optional[RunConfig] = None,
10041007
session: Optional[Session] = None,
1008+
metadata: Optional[dict[str, Any]] = None,
10051009
) -> AsyncGenerator[Event, None]:
10061010
"""Runs the agent in live mode (experimental feature).
10071011
@@ -1043,6 +1047,7 @@ async def run_live(
10431047
run_config: The run config for the agent.
10441048
session: The session to use. This parameter is deprecated, please use
10451049
`user_id` and `session_id` instead.
1050+
metadata: Optional per-request metadata that will be passed to callbacks.
10461051
10471052
Yields:
10481053
AsyncGenerator[Event, None]: An asynchronous generator that yields
@@ -1057,6 +1062,7 @@ async def run_live(
10571062
Either `session` or both `user_id` and `session_id` must be provided.
10581063
"""
10591064
run_config = run_config or RunConfig()
1065+
metadata = metadata.copy() if metadata is not None else None
10601066
# Some native audio models requires the modality to be set. So we set it to
10611067
# AUDIO by default.
10621068
if run_config.response_modalities is None:
@@ -1082,6 +1088,7 @@ async def run_live(
10821088
session,
10831089
live_request_queue=live_request_queue,
10841090
run_config=run_config,
1091+
metadata=metadata,
10851092
)
10861093

10871094
root_agent = self.agent
@@ -1188,6 +1195,7 @@ async def run_debug(
11881195
run_config: RunConfig | None = None,
11891196
quiet: bool = False,
11901197
verbose: bool = False,
1198+
metadata: dict[str, Any] | None = None,
11911199
) -> list[Event]:
11921200
"""Debug helper for quick agent experimentation and testing.
11931201
@@ -1211,6 +1219,7 @@ async def run_debug(
12111219
shown).
12121220
verbose: If True, shows detailed tool calls and responses. Defaults to
12131221
False for cleaner output showing only final agent responses.
1222+
metadata: Optional per-request metadata that will be passed to callbacks.
12141223
12151224
Returns:
12161225
list[Event]: All events from all messages.
@@ -1273,6 +1282,7 @@ async def run_debug(
12731282
session_id=session.id,
12741283
new_message=types.UserContent(parts=[types.Part(text=message)]),
12751284
run_config=run_config,
1285+
metadata=metadata,
12761286
):
12771287
if not quiet:
12781288
print_event(event, verbose=verbose)
@@ -1469,6 +1479,7 @@ def _new_invocation_context_for_live(
14691479
*,
14701480
live_request_queue: LiveRequestQueue,
14711481
run_config: Optional[RunConfig] = None,
1482+
metadata: Optional[dict[str, Any]] = None,
14721483
) -> InvocationContext:
14731484
"""Creates a new invocation context for live multi-agent."""
14741485
run_config = run_config or RunConfig()
@@ -1487,6 +1498,7 @@ def _new_invocation_context_for_live(
14871498
session,
14881499
live_request_queue=live_request_queue,
14891500
run_config=run_config,
1501+
metadata=metadata,
14901502
)
14911503

14921504
async def _handle_new_message(

tests/unittests/test_runners.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from google.adk.agents.base_agent import BaseAgent
2424
from google.adk.agents.context_cache_config import ContextCacheConfig
25+
from google.adk.agents.live_request_queue import LiveRequestQueue
2526
from google.adk.agents.invocation_context import InvocationContext
2627
from google.adk.agents.llm_agent import LlmAgent
2728
from google.adk.agents.run_config import RunConfig
@@ -36,6 +37,7 @@
3637
from google.adk.plugins.base_plugin import BasePlugin
3738
from google.adk.runners import Runner
3839
from google.adk.sessions.in_memory_session_service import InMemorySessionService
40+
from tests.unittests import testing_utils
3941
from google.adk.sessions.session import Session
4042
from google.genai import types
4143
import pytest
@@ -1484,6 +1486,125 @@ def before_model_callback(callback_context, llm_request):
14841486
# Nested object changes in callback WILL affect original (shallow copy behavior)
14851487
assert original_metadata["nested"]["inner_key"] == "modified_nested"
14861488

1489+
def test_new_invocation_context_for_live_with_metadata(self):
1490+
"""Test that _new_invocation_context_for_live correctly passes metadata."""
1491+
mock_session = Session(
1492+
id=TEST_SESSION_ID,
1493+
app_name=TEST_APP_ID,
1494+
user_id=TEST_USER_ID,
1495+
events=[],
1496+
)
1497+
1498+
test_metadata = {"user_id": "live_user", "trace_id": "live_trace"}
1499+
invocation_context = self.runner._new_invocation_context_for_live(
1500+
mock_session, metadata=test_metadata
1501+
)
1502+
1503+
assert invocation_context.metadata == test_metadata
1504+
assert invocation_context.metadata["user_id"] == "live_user"
1505+
1506+
@pytest.mark.asyncio
1507+
async def test_run_sync_passes_metadata(self):
1508+
"""Test that sync run() correctly passes metadata to run_async()."""
1509+
captured_metadata = None
1510+
1511+
def before_model_callback(callback_context, llm_request):
1512+
nonlocal captured_metadata
1513+
captured_metadata = llm_request.metadata
1514+
return LlmResponse(
1515+
content=types.Content(
1516+
role="model", parts=[types.Part(text="Test response")]
1517+
)
1518+
)
1519+
1520+
agent_with_callback = LlmAgent(
1521+
name="callback_agent",
1522+
model="gemini-2.0-flash",
1523+
before_model_callback=before_model_callback,
1524+
)
1525+
1526+
runner_with_callback = Runner(
1527+
app_name="test_app",
1528+
agent=agent_with_callback,
1529+
session_service=self.session_service,
1530+
artifact_service=self.artifact_service,
1531+
)
1532+
1533+
await self.session_service.create_session(
1534+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
1535+
)
1536+
1537+
test_metadata = {"sync_key": "sync_value"}
1538+
1539+
for event in runner_with_callback.run(
1540+
user_id=TEST_USER_ID,
1541+
session_id=TEST_SESSION_ID,
1542+
new_message=types.Content(
1543+
role="user", parts=[types.Part(text="Hello")]
1544+
),
1545+
metadata=test_metadata,
1546+
):
1547+
pass
1548+
1549+
assert captured_metadata is not None
1550+
assert captured_metadata["sync_key"] == "sync_value"
1551+
1552+
@pytest.mark.asyncio
1553+
async def test_run_live_passes_metadata_to_llm_request(self):
1554+
"""Test that run_live() passes metadata through live pipeline to LlmRequest."""
1555+
import asyncio
1556+
1557+
# Create MockModel to capture LlmRequest
1558+
mock_model = testing_utils.MockModel.create(
1559+
responses=[
1560+
LlmResponse(
1561+
content=types.Content(
1562+
role="model", parts=[types.Part(text="Live response")]
1563+
)
1564+
)
1565+
]
1566+
)
1567+
1568+
agent_with_mock = LlmAgent(
1569+
name="live_mock_agent",
1570+
model=mock_model,
1571+
)
1572+
1573+
runner_with_mock = Runner(
1574+
app_name="test_app",
1575+
agent=agent_with_mock,
1576+
session_service=self.session_service,
1577+
artifact_service=self.artifact_service,
1578+
)
1579+
1580+
await self.session_service.create_session(
1581+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
1582+
)
1583+
1584+
test_metadata = {"live_key": "live_value", "trace_id": "live_trace_123"}
1585+
live_queue = LiveRequestQueue()
1586+
live_queue.close() # Close immediately to end the live session
1587+
1588+
async def consume_events():
1589+
async for event in runner_with_mock.run_live(
1590+
user_id=TEST_USER_ID,
1591+
session_id=TEST_SESSION_ID,
1592+
live_request_queue=live_queue,
1593+
metadata=test_metadata,
1594+
):
1595+
pass
1596+
1597+
try:
1598+
await asyncio.wait_for(consume_events(), timeout=2)
1599+
except asyncio.TimeoutError:
1600+
pass # Expected - live session may not terminate cleanly
1601+
1602+
# Verify MockModel received LlmRequest with correct metadata
1603+
assert len(mock_model.requests) > 0
1604+
assert mock_model.requests[0].metadata is not None
1605+
assert mock_model.requests[0].metadata["live_key"] == "live_value"
1606+
assert mock_model.requests[0].metadata["trace_id"] == "live_trace_123"
1607+
14871608

14881609
if __name__ == "__main__":
14891610
pytest.main([__file__])

0 commit comments

Comments
 (0)