Skip to content

Commit e0ffe85

Browse files
committed
test: add tests for request_state and header propagation
Add comprehensive tests for request_state merging in ReadonlyContext, header provider creation, state_header_mapping config, credential_key shorthand, RFC 7230 header validation, CRLF injection prevention, and strict mode. Update existing mocks to include request_state.
1 parent 097a99c commit e0ffe85

6 files changed

Lines changed: 914 additions & 4 deletions

File tree

src/google/adk/cli/adk_web_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,9 @@ async def version() -> dict[str, str]:
990990
return {
991991
"version": __version__,
992992
"language": "python",
993-
"language_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
993+
"language_version": (
994+
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
995+
),
994996
}
995997

996998
@app.get("/list-apps")

tests/unittests/agents/test_mcp_instruction_provider.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ async def test_call_success_no_args(self):
6868

6969
mock_invocation_context = MagicMock()
7070
mock_invocation_context.session.state = {}
71+
mock_invocation_context.request_state = {}
7172
context = ReadonlyContext(mock_invocation_context)
7273

7374
# Call
@@ -98,6 +99,7 @@ async def test_call_success_with_args(self):
9899

99100
mock_invocation_context = MagicMock()
100101
mock_invocation_context.session.state = {"arg1": "value1", "arg2": "value2"}
102+
mock_invocation_context.request_state = {}
101103
context = ReadonlyContext(mock_invocation_context)
102104

103105
instruction = await self.provider(context)
@@ -119,6 +121,7 @@ async def test_call_prompt_not_found_in_list_prompts(self):
119121

120122
mock_invocation_context = MagicMock()
121123
mock_invocation_context.session.state = {"arg1": "value1"}
124+
mock_invocation_context.request_state = {}
122125
context = ReadonlyContext(mock_invocation_context)
123126

124127
instruction = await self.provider(context)
@@ -137,6 +140,7 @@ async def test_call_get_prompt_returns_no_messages(self):
137140

138141
mock_invocation_context = MagicMock()
139142
mock_invocation_context.session.state = {}
143+
mock_invocation_context.request_state = {}
140144
context = ReadonlyContext(mock_invocation_context)
141145

142146
# Call and assert
@@ -179,6 +183,7 @@ async def test_call_ignore_non_text_messages(self):
179183

180184
mock_invocation_context = MagicMock()
181185
mock_invocation_context.session.state = {}
186+
mock_invocation_context.request_state = {}
182187
context = ReadonlyContext(mock_invocation_context)
183188

184189
# Call

tests/unittests/agents/test_readonly_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def mock_invocation_context():
2525
mock_context.invocation_id = "test-invocation-id"
2626
mock_context.agent.name = "test-agent-name"
2727
mock_context.session.state = {"key1": "value1", "key2": "value2"}
28+
mock_context.request_state = {}
2829
mock_context.user_id = "test-user-id"
2930
return mock_context
3031

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from collections import ChainMap
2+
import unittest
3+
from unittest.mock import MagicMock
4+
5+
from google.adk.agents.invocation_context import InvocationContext
6+
from google.adk.agents.readonly_context import ReadonlyContext
7+
from google.adk.sessions.session import Session
8+
9+
10+
class TestReadonlyContextState(unittest.TestCase):
11+
12+
def test_state_merging_precedence(self):
13+
# Setup
14+
mock_session = MagicMock(spec=Session)
15+
mock_session.state = {
16+
"persistent_key": "persistent_value",
17+
"conflict_key": "persistent_value",
18+
}
19+
20+
mock_invocation_context = MagicMock(spec=InvocationContext)
21+
mock_invocation_context.session = mock_session
22+
mock_invocation_context.request_state = {
23+
"ephemeral_key": "ephemeral_value",
24+
"conflict_key": "ephemeral_value",
25+
}
26+
27+
readonly_context = ReadonlyContext(mock_invocation_context)
28+
29+
# Verify
30+
state = readonly_context.state
31+
32+
# Check that ephemeral keys are present
33+
self.assertEqual(state["ephemeral_key"], "ephemeral_value")
34+
35+
# Check that persistent keys are present
36+
self.assertEqual(state["persistent_key"], "persistent_value")
37+
38+
# Check that ephemeral keys override persistent keys
39+
self.assertEqual(state["conflict_key"], "ephemeral_value")
40+
41+
# Verify it behaves like a mapping
42+
self.assertIn("ephemeral_key", state)
43+
self.assertIn("persistent_key", state)
44+
self.assertEqual(state.get("ephemeral_key"), "ephemeral_value")
45+
46+
def test_state_merging_empty_request_state(self):
47+
# Setup
48+
mock_session = MagicMock(spec=Session)
49+
mock_session.state = {"persistent_key": "persistent_value"}
50+
51+
mock_invocation_context = MagicMock(spec=InvocationContext)
52+
mock_invocation_context.session = mock_session
53+
mock_invocation_context.request_state = {}
54+
55+
readonly_context = ReadonlyContext(mock_invocation_context)
56+
57+
# Verify
58+
state = readonly_context.state
59+
self.assertEqual(state["persistent_key"], "persistent_value")
60+
self.assertNotIn("ephemeral_key", state)
61+
62+
def test_state_immutability(self):
63+
# Setup
64+
mock_session = MagicMock(spec=Session)
65+
mock_session.state = {"key": "value"}
66+
67+
mock_invocation_context = MagicMock(spec=InvocationContext)
68+
mock_invocation_context.session = mock_session
69+
mock_invocation_context.request_state = {}
70+
71+
readonly_context = ReadonlyContext(mock_invocation_context)
72+
state = readonly_context.state
73+
74+
# Verify it raises TypeError on assignment
75+
with self.assertRaises(TypeError):
76+
state["key"] = "new_value"
77+
78+
with self.assertRaises(TypeError):
79+
state["new_key"] = "value"
80+
81+
82+
if __name__ == "__main__":
83+
unittest.main()

tests/unittests/cli/test_fast_api.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ async def dummy_run_async(
128128
session_id,
129129
new_message,
130130
state_delta=None,
131+
request_state=None,
131132
run_config: Optional[RunConfig] = None,
132133
invocation_id: Optional[str] = None,
133134
):
@@ -1356,9 +1357,18 @@ async def run_async_capture(
13561357
invocation_id: Optional[str] = None,
13571358
new_message: Optional[types.Content] = None,
13581359
state_delta: Optional[dict[str, Any]] = None,
1360+
request_state: Optional[dict[str, Any]] = None,
13591361
run_config: Optional[RunConfig] = None,
13601362
):
1361-
del self, user_id, session_id, new_message, state_delta, run_config
1363+
del (
1364+
self,
1365+
user_id,
1366+
session_id,
1367+
new_message,
1368+
state_delta,
1369+
request_state,
1370+
run_config,
1371+
)
13621372
captured_invocation_id["invocation_id"] = invocation_id
13631373
yield _event_1()
13641374

@@ -1393,9 +1403,18 @@ async def run_async_with_artifact_delta(
13931403
invocation_id: Optional[str] = None,
13941404
new_message: Optional[types.Content] = None,
13951405
state_delta: Optional[dict[str, Any]] = None,
1406+
request_state: Optional[dict[str, Any]] = None,
13961407
run_config: Optional[RunConfig] = None,
13971408
):
1398-
del user_id, session_id, invocation_id, new_message, state_delta, run_config
1409+
del (
1410+
user_id,
1411+
session_id,
1412+
invocation_id,
1413+
new_message,
1414+
state_delta,
1415+
request_state,
1416+
run_config,
1417+
)
13991418
yield Event(
14001419
author="dummy agent",
14011420
invocation_id="invocation_id",
@@ -1449,9 +1468,18 @@ async def run_async_with_artifact_delta(
14491468
invocation_id: Optional[str] = None,
14501469
new_message: Optional[types.Content] = None,
14511470
state_delta: Optional[dict[str, Any]] = None,
1471+
request_state: Optional[dict[str, Any]] = None,
14521472
run_config: Optional[RunConfig] = None,
14531473
):
1454-
del user_id, session_id, invocation_id, new_message, state_delta, run_config
1474+
del (
1475+
user_id,
1476+
session_id,
1477+
invocation_id,
1478+
new_message,
1479+
state_delta,
1480+
request_state,
1481+
run_config,
1482+
)
14551483
yield Event(
14561484
author="dummy agent",
14571485
invocation_id="invocation_id",

0 commit comments

Comments
 (0)