Skip to content

Commit 268d019

Browse files
committed
feat: add ephemeral request_state for secure token handling
Add a non-persisted request_state dict to InvocationContext that is threaded through Runner.run_async and the AdkWebServer. ReadonlyContext.state now returns a ChainMap merging request_state over session.state, so ephemeral keys (e.g. tokens) take precedence without being persisted.
1 parent 218ea76 commit 268d019

4 files changed

Lines changed: 59 additions & 11 deletions

File tree

src/google/adk/agents/invocation_context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,14 @@ class InvocationContext(BaseModel):
171171
agent_states: dict[str, dict[str, Any]] = Field(default_factory=dict)
172172
"""The state of the agent for this invocation."""
173173

174+
request_state: dict[str, Any] = Field(default_factory=dict)
175+
"""The ephemeral state of the request.
176+
177+
This state is not persisted to the session and is only available for the
178+
current invocation. It is used to pass sensitive information like tokens
179+
that should not be stored in the session state.
180+
"""
181+
174182
end_of_agents: dict[str, bool] = Field(default_factory=dict)
175183
"""The end of agent status for each agent in this invocation."""
176184

src/google/adk/agents/readonly_context.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from collections import ChainMap
1718
from types import MappingProxyType
1819
from typing import Any
1920
from typing import Optional
@@ -53,8 +54,19 @@ def agent_name(self) -> str:
5354

5455
@property
5556
def state(self) -> MappingProxyType[str, Any]:
56-
"""The state of the current session. READONLY field."""
57-
return MappingProxyType(self._invocation_context.session.state)
57+
"""The state of the current session. READONLY field.
58+
59+
Note: This property returns a merged view of ephemeral request_state and
60+
persistent session.state using ChainMap. Changes to the underlying
61+
request_state or session.state dictionaries will be reflected through
62+
this view, but direct writes through this property are prevented.
63+
"""
64+
return MappingProxyType(
65+
ChainMap(
66+
self._invocation_context.request_state,
67+
self._invocation_context.session.state,
68+
)
69+
)
5870

5971
@property
6072
def session(self) -> Session:

src/google/adk/cli/adk_web_server.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ class RunAgentRequest(common.BaseModel):
369369
state_delta: Optional[dict[str, Any]] = None
370370
# for long-running function resume requests (e.g., OAuth callback)
371371
function_call_event_id: Optional[str] = None
372+
request_state: Optional[dict[str, Any]] = None
372373
# for resume long-running functions
373374
invocation_id: Optional[str] = None
374375

@@ -989,9 +990,7 @@ async def version() -> dict[str, str]:
989990
return {
990991
"version": __version__,
991992
"language": "python",
992-
"language_version": (
993-
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
994-
),
993+
"language_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
995994
}
996995

997996
@app.get("/list-apps")
@@ -1900,6 +1899,7 @@ async def run_agent(req: RunAgentRequest) -> list[Event]:
19001899
new_message=req.new_message,
19011900
state_delta=req.state_delta,
19021901
invocation_id=req.invocation_id,
1902+
request_state=req.request_state,
19031903
)
19041904
) as agen:
19051905
events = [event async for event in agen]
@@ -1944,6 +1944,7 @@ async def event_generator():
19441944
state_delta=req.state_delta,
19451945
run_config=RunConfig(streaming_mode=stream_mode),
19461946
invocation_id=req.invocation_id,
1947+
request_state=req.request_state,
19471948
)
19481949
) as agen:
19491950
try:

src/google/adk/runners.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ async def run_async(
507507
invocation_id: Optional[str] = None,
508508
new_message: Optional[types.Content] = None,
509509
state_delta: Optional[dict[str, Any]] = None,
510+
request_state: Optional[dict[str, Any]] = None,
510511
run_config: Optional[RunConfig] = None,
511512
) -> AsyncGenerator[Event, None]:
512513
"""Main entry method to run the agent in this runner.
@@ -524,6 +525,7 @@ async def run_async(
524525
interrupted invocation.
525526
new_message: A new message to append to the session.
526527
state_delta: Optional state changes to apply to the session.
528+
request_state: Optional ephemeral state for the request.
527529
run_config: The run config for the agent.
528530
529531
Yields:
@@ -559,18 +561,32 @@ async def _run_with_trace(
559561
is_resumable = (
560562
self.resumability_config and self.resumability_config.is_resumable
561563
)
562-
if not is_resumable and not new_message:
563-
raise ValueError(
564-
'Running an agent requires a new_message or a resumable app. '
565-
f'Session: {session_id}, User: {user_id}'
564+
if invocation_id:
565+
if not is_resumable:
566+
raise ValueError(
567+
f'invocation_id: {invocation_id} is provided but the app is not'
568+
' resumable.'
569+
)
570+
invocation_context = await self._setup_context_for_resumed_invocation(
571+
session=session,
572+
new_message=new_message,
573+
invocation_id=invocation_id,
574+
run_config=run_config,
575+
state_delta=state_delta,
576+
request_state=request_state,
566577
)
567-
568-
if not is_resumable:
578+
elif not is_resumable:
579+
if not new_message:
580+
raise ValueError(
581+
'Running an agent requires a new_message or a resumable app. '
582+
f'Session: {session_id}, User: {user_id}'
583+
)
569584
invocation_context = await self._setup_context_for_new_invocation(
570585
session=session,
571586
new_message=new_message,
572587
run_config=run_config,
573588
state_delta=state_delta,
589+
request_state=request_state,
574590
)
575591
else:
576592
invocation_id = self._resolve_invocation_id(
@@ -582,6 +598,7 @@ async def _run_with_trace(
582598
new_message=new_message,
583599
run_config=run_config,
584600
state_delta=state_delta,
601+
request_state=request_state,
585602
)
586603
else:
587604
invocation_context = (
@@ -591,6 +608,7 @@ async def _run_with_trace(
591608
invocation_id=invocation_id,
592609
run_config=run_config,
593610
state_delta=state_delta,
611+
request_state=request_state,
594612
)
595613
)
596614
if invocation_context.end_of_agents.get(
@@ -1334,6 +1352,7 @@ async def _setup_context_for_new_invocation(
13341352
new_message: types.Content,
13351353
run_config: RunConfig,
13361354
state_delta: Optional[dict[str, Any]],
1355+
request_state: Optional[dict[str, Any]] = None,
13371356
) -> InvocationContext:
13381357
"""Sets up the context for a new invocation.
13391358
@@ -1342,6 +1361,7 @@ async def _setup_context_for_new_invocation(
13421361
new_message: The new message to process and append to the session.
13431362
run_config: The run config of the agent.
13441363
state_delta: Optional state changes to apply to the session.
1364+
request_state: Optional ephemeral state for the request.
13451365
13461366
Returns:
13471367
The invocation context for the new invocation.
@@ -1351,6 +1371,7 @@ async def _setup_context_for_new_invocation(
13511371
session,
13521372
new_message=new_message,
13531373
run_config=run_config,
1374+
request_state=request_state,
13541375
)
13551376
# Step 2: Handle new message, by running callbacks and appending to
13561377
# session.
@@ -1373,6 +1394,7 @@ async def _setup_context_for_resumed_invocation(
13731394
invocation_id: Optional[str],
13741395
run_config: RunConfig,
13751396
state_delta: Optional[dict[str, Any]],
1397+
request_state: Optional[dict[str, Any]] = None,
13761398
) -> InvocationContext:
13771399
"""Sets up the context for a resumed invocation.
13781400
@@ -1382,6 +1404,7 @@ async def _setup_context_for_resumed_invocation(
13821404
invocation_id: The invocation id to resume.
13831405
run_config: The run config of the agent.
13841406
state_delta: Optional state changes to apply to the session.
1407+
request_state: Optional ephemeral state for the request.
13851408
13861409
Returns:
13871410
The invocation context for the resumed invocation.
@@ -1407,6 +1430,7 @@ async def _setup_context_for_resumed_invocation(
14071430
new_message=user_message,
14081431
run_config=run_config,
14091432
invocation_id=invocation_id,
1433+
request_state=request_state,
14101434
)
14111435
# Step 3: Maybe handle new message.
14121436
if new_message:
@@ -1455,6 +1479,7 @@ def _new_invocation_context(
14551479
new_message: Optional[types.Content] = None,
14561480
live_request_queue: Optional[LiveRequestQueue] = None,
14571481
run_config: Optional[RunConfig] = None,
1482+
request_state: Optional[dict[str, Any]] = None,
14581483
) -> InvocationContext:
14591484
"""Creates a new invocation context.
14601485
@@ -1464,6 +1489,7 @@ def _new_invocation_context(
14641489
new_message: The new message for the context.
14651490
live_request_queue: The live request queue for the context.
14661491
run_config: The run config for the context.
1492+
request_state: The ephemeral state for the request.
14671493
14681494
Returns:
14691495
The new invocation context.
@@ -1498,6 +1524,7 @@ def _new_invocation_context(
14981524
live_request_queue=live_request_queue,
14991525
run_config=run_config,
15001526
resumability_config=self.resumability_config,
1527+
request_state=request_state if request_state is not None else {},
15011528
)
15021529

15031530
def _new_invocation_context_for_live(

0 commit comments

Comments
 (0)