Skip to content

Commit c40af85

Browse files
committed
feat(runner): add metadata parameter to Runner.run_async()
Add support for passing per-request metadata through the agent execution pipeline. This enables use cases like: - Passing user_id, trace_id, or session context to callbacks - Enabling memory injection in before_model_callback - Supporting request-specific context without using ContextVar workarounds Changes: - Add `metadata` field to LlmRequest model - Add `metadata` field to InvocationContext model - Add `metadata` parameter to Runner.run_async() and related methods - Propagate metadata from InvocationContext to LlmRequest in base_llm_flow - Add unit tests for metadata functionality Closes #2978
1 parent 42eeaef commit c40af85

File tree

5 files changed

+181
-3
lines changed

5 files changed

+181
-3
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ class InvocationContext(BaseModel):
206206
canonical_tools_cache: Optional[list[BaseTool]] = None
207207
"""The cache of canonical tools for this invocation."""
208208

209+
metadata: Optional[dict[str, Any]] = None
210+
"""Per-request metadata passed from Runner.run_async().
211+
212+
This field allows passing arbitrary metadata that can be accessed during
213+
the invocation lifecycle, particularly in callbacks like before_model_callback.
214+
Common use cases include passing user_id, trace_id, memory context keys, or
215+
other request-specific context that needs to be available during processing.
216+
"""
217+
209218
_invocation_cost_manager: _InvocationCostManager = PrivateAttr(
210219
default_factory=_InvocationCostManager
211220
)

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ async def run_live(
131131
invocation_context: InvocationContext,
132132
) -> AsyncGenerator[Event, None]:
133133
"""Runs the flow using live api."""
134-
llm_request = LlmRequest()
134+
llm_request = LlmRequest(metadata=invocation_context.metadata)
135135
event_id = Event.new_id()
136136

137137
# Preprocess before calling the LLM.
@@ -437,7 +437,7 @@ async def _run_one_step_async(
437437
invocation_context: InvocationContext,
438438
) -> AsyncGenerator[Event, None]:
439439
"""One step means one LLM call."""
440-
llm_request = LlmRequest()
440+
llm_request = LlmRequest(metadata=invocation_context.metadata)
441441

442442
# Preprocess before calling the LLM.
443443
async with Aclosing(

src/google/adk/models/llm_request.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import logging
18+
from typing import Any
1819
from typing import Optional
1920
from typing import Union
2021

@@ -99,6 +100,15 @@ class LlmRequest(BaseModel):
99100
the full history.
100101
"""
101102

103+
metadata: Optional[dict[str, Any]] = None
104+
"""Per-request metadata for callbacks and custom processing.
105+
106+
This field allows passing arbitrary metadata from the Runner.run_async()
107+
call to callbacks like before_model_callback. This is useful for passing
108+
request-specific context such as user_id, trace_id, or memory context keys
109+
that need to be available during model invocation.
110+
"""
111+
102112
def append_instructions(
103113
self, instructions: Union[list[str], types.Content]
104114
) -> list[types.Content]:

src/google/adk/runners.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ async def run_async(
457457
new_message: Optional[types.Content] = None,
458458
state_delta: Optional[dict[str, Any]] = None,
459459
run_config: Optional[RunConfig] = None,
460+
metadata: Optional[dict[str, Any]] = None,
460461
) -> AsyncGenerator[Event, None]:
461462
"""Main entry method to run the agent in this runner.
462463
@@ -474,6 +475,9 @@ async def run_async(
474475
new_message: A new message to append to the session.
475476
state_delta: Optional state changes to apply to the session.
476477
run_config: The run config for the agent.
478+
metadata: Optional per-request metadata that will be passed to callbacks.
479+
This allows passing request-specific context such as user_id, trace_id,
480+
or memory context keys to before_model_callback and other callbacks.
477481
478482
Yields:
479483
The events generated by the agent.
@@ -483,13 +487,16 @@ async def run_async(
483487
new_message are None.
484488
"""
485489
run_config = run_config or RunConfig()
490+
# Create a shallow copy to isolate from caller's modifications
491+
metadata = metadata.copy() if metadata else None
486492

487493
if new_message and not new_message.role:
488494
new_message.role = 'user'
489495

490496
async def _run_with_trace(
491497
new_message: Optional[types.Content] = None,
492498
invocation_id: Optional[str] = None,
499+
metadata: Optional[dict[str, Any]] = None,
493500
) -> AsyncGenerator[Event, None]:
494501
with tracer.start_as_current_span('invocation'):
495502
session = await self._get_or_create_session(
@@ -517,6 +524,7 @@ async def _run_with_trace(
517524
invocation_id=invocation_id,
518525
run_config=run_config,
519526
state_delta=state_delta,
527+
metadata=metadata,
520528
)
521529
if invocation_context.end_of_agents.get(
522530
invocation_context.agent.name
@@ -530,6 +538,7 @@ async def _run_with_trace(
530538
new_message=new_message, # new_message is not None.
531539
run_config=run_config,
532540
state_delta=state_delta,
541+
metadata=metadata,
533542
)
534543

535544
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
@@ -556,7 +565,9 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
556565
self.app, session, self.session_service
557566
)
558567

559-
async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen:
568+
async with Aclosing(
569+
_run_with_trace(new_message, invocation_id, metadata)
570+
) as agen:
560571
async for event in agen:
561572
yield event
562573

@@ -1212,6 +1223,7 @@ async def _setup_context_for_new_invocation(
12121223
new_message: types.Content,
12131224
run_config: RunConfig,
12141225
state_delta: Optional[dict[str, Any]],
1226+
metadata: Optional[dict[str, Any]] = None,
12151227
) -> InvocationContext:
12161228
"""Sets up the context for a new invocation.
12171229
@@ -1220,6 +1232,7 @@ async def _setup_context_for_new_invocation(
12201232
new_message: The new message to process and append to the session.
12211233
run_config: The run config of the agent.
12221234
state_delta: Optional state changes to apply to the session.
1235+
metadata: Optional per-request metadata to pass to callbacks.
12231236
12241237
Returns:
12251238
The invocation context for the new invocation.
@@ -1229,6 +1242,7 @@ async def _setup_context_for_new_invocation(
12291242
session,
12301243
new_message=new_message,
12311244
run_config=run_config,
1245+
metadata=metadata,
12321246
)
12331247
# Step 2: Handle new message, by running callbacks and appending to
12341248
# session.
@@ -1251,6 +1265,7 @@ async def _setup_context_for_resumed_invocation(
12511265
invocation_id: Optional[str],
12521266
run_config: RunConfig,
12531267
state_delta: Optional[dict[str, Any]],
1268+
metadata: Optional[dict[str, Any]] = None,
12541269
) -> InvocationContext:
12551270
"""Sets up the context for a resumed invocation.
12561271
@@ -1260,6 +1275,7 @@ async def _setup_context_for_resumed_invocation(
12601275
invocation_id: The invocation id to resume.
12611276
run_config: The run config of the agent.
12621277
state_delta: Optional state changes to apply to the session.
1278+
metadata: Optional per-request metadata to pass to callbacks.
12631279
12641280
Returns:
12651281
The invocation context for the resumed invocation.
@@ -1285,6 +1301,7 @@ async def _setup_context_for_resumed_invocation(
12851301
new_message=user_message,
12861302
run_config=run_config,
12871303
invocation_id=invocation_id,
1304+
metadata=metadata,
12881305
)
12891306
# Step 3: Maybe handle new message.
12901307
if new_message:
@@ -1329,6 +1346,7 @@ def _new_invocation_context(
13291346
new_message: Optional[types.Content] = None,
13301347
live_request_queue: Optional[LiveRequestQueue] = None,
13311348
run_config: Optional[RunConfig] = None,
1349+
metadata: Optional[dict[str, Any]] = None,
13321350
) -> InvocationContext:
13331351
"""Creates a new invocation context.
13341352
@@ -1338,6 +1356,7 @@ def _new_invocation_context(
13381356
new_message: The new message for the context.
13391357
live_request_queue: The live request queue for the context.
13401358
run_config: The run config for the context.
1359+
metadata: Optional per-request metadata for the context.
13411360
13421361
Returns:
13431362
The new invocation context.
@@ -1369,6 +1388,7 @@ def _new_invocation_context(
13691388
live_request_queue=live_request_queue,
13701389
run_config=run_config,
13711390
resumability_config=self.resumability_config,
1391+
metadata=metadata,
13721392
)
13731393

13741394
def _new_invocation_context_for_live(

tests/unittests/test_runners.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
3131
from google.adk.cli.utils.agent_loader import AgentLoader
3232
from google.adk.events.event import Event
33+
from google.adk.models.llm_request import LlmRequest
34+
from google.adk.models.llm_response import LlmResponse
3335
from google.adk.plugins.base_plugin import BasePlugin
3436
from google.adk.runners import Runner
3537
from google.adk.sessions.in_memory_session_service import InMemorySessionService
@@ -1237,5 +1239,142 @@ def test_infer_agent_origin_detects_mismatch_for_user_agent(
12371239
assert "actual_name" in runner._app_name_alignment_hint
12381240

12391241

1242+
class TestRunnerMetadata:
1243+
"""Tests for Runner metadata parameter functionality."""
1244+
1245+
def setup_method(self):
1246+
"""Set up test fixtures."""
1247+
self.session_service = InMemorySessionService()
1248+
self.artifact_service = InMemoryArtifactService()
1249+
self.root_agent = MockLlmAgent("root_agent")
1250+
self.runner = Runner(
1251+
app_name="test_app",
1252+
agent=self.root_agent,
1253+
session_service=self.session_service,
1254+
artifact_service=self.artifact_service,
1255+
)
1256+
1257+
def test_new_invocation_context_with_metadata(self):
1258+
"""Test that _new_invocation_context correctly passes metadata."""
1259+
mock_session = Session(
1260+
id=TEST_SESSION_ID,
1261+
app_name=TEST_APP_ID,
1262+
user_id=TEST_USER_ID,
1263+
events=[],
1264+
)
1265+
1266+
test_metadata = {"user_id": "test123", "trace_id": "trace456"}
1267+
invocation_context = self.runner._new_invocation_context(
1268+
mock_session, metadata=test_metadata
1269+
)
1270+
1271+
assert invocation_context.metadata == test_metadata
1272+
assert invocation_context.metadata["user_id"] == "test123"
1273+
assert invocation_context.metadata["trace_id"] == "trace456"
1274+
1275+
def test_new_invocation_context_without_metadata(self):
1276+
"""Test that _new_invocation_context works without metadata."""
1277+
mock_session = Session(
1278+
id=TEST_SESSION_ID,
1279+
app_name=TEST_APP_ID,
1280+
user_id=TEST_USER_ID,
1281+
events=[],
1282+
)
1283+
1284+
invocation_context = self.runner._new_invocation_context(mock_session)
1285+
1286+
assert invocation_context.metadata is None
1287+
1288+
@pytest.mark.asyncio
1289+
async def test_run_async_passes_metadata_to_invocation_context(self):
1290+
"""Test that run_async correctly passes metadata to before_model_callback."""
1291+
# Capture metadata received in callback
1292+
captured_metadata = None
1293+
1294+
def before_model_callback(callback_context, llm_request):
1295+
nonlocal captured_metadata
1296+
captured_metadata = llm_request.metadata
1297+
# Return a response to skip actual LLM call
1298+
return LlmResponse(
1299+
content=types.Content(
1300+
role="model", parts=[types.Part(text="Test response")]
1301+
)
1302+
)
1303+
1304+
# Create agent with before_model_callback
1305+
agent_with_callback = LlmAgent(
1306+
name="callback_agent",
1307+
model="gemini-2.0-flash",
1308+
before_model_callback=before_model_callback,
1309+
)
1310+
1311+
runner_with_callback = Runner(
1312+
app_name="test_app",
1313+
agent=agent_with_callback,
1314+
session_service=self.session_service,
1315+
artifact_service=self.artifact_service,
1316+
)
1317+
1318+
session = await self.session_service.create_session(
1319+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
1320+
)
1321+
1322+
test_metadata = {"experiment_id": "exp-001", "variant": "B"}
1323+
1324+
async for event in runner_with_callback.run_async(
1325+
user_id=TEST_USER_ID,
1326+
session_id=TEST_SESSION_ID,
1327+
new_message=types.Content(
1328+
role="user", parts=[types.Part(text="Hello")]
1329+
),
1330+
metadata=test_metadata,
1331+
):
1332+
pass
1333+
1334+
# Verify metadata was passed to before_model_callback
1335+
assert captured_metadata is not None
1336+
assert captured_metadata == test_metadata
1337+
assert captured_metadata["experiment_id"] == "exp-001"
1338+
assert captured_metadata["variant"] == "B"
1339+
1340+
def test_metadata_field_in_invocation_context(self):
1341+
"""Test that InvocationContext model accepts metadata field."""
1342+
mock_session = Session(
1343+
id=TEST_SESSION_ID,
1344+
app_name=TEST_APP_ID,
1345+
user_id=TEST_USER_ID,
1346+
events=[],
1347+
)
1348+
1349+
test_metadata = {"key1": "value1", "key2": 123}
1350+
1351+
# This should not raise a validation error
1352+
invocation_context = InvocationContext(
1353+
session_service=self.session_service,
1354+
invocation_id="test_inv_id",
1355+
agent=self.root_agent,
1356+
session=mock_session,
1357+
metadata=test_metadata,
1358+
)
1359+
1360+
assert invocation_context.metadata == test_metadata
1361+
1362+
def test_metadata_field_in_llm_request(self):
1363+
"""Test that LlmRequest model accepts metadata field."""
1364+
test_metadata = {"context_key": "ctx123", "user_info": {"name": "test"}}
1365+
1366+
llm_request = LlmRequest(metadata=test_metadata)
1367+
1368+
assert llm_request.metadata == test_metadata
1369+
assert llm_request.metadata["context_key"] == "ctx123"
1370+
assert llm_request.metadata["user_info"]["name"] == "test"
1371+
1372+
def test_llm_request_without_metadata(self):
1373+
"""Test that LlmRequest works without metadata."""
1374+
llm_request = LlmRequest()
1375+
1376+
assert llm_request.metadata is None
1377+
1378+
12401379
if __name__ == "__main__":
12411380
pytest.main([__file__])

0 commit comments

Comments
 (0)