|
15 | 15 | from typing import Any |
16 | 16 | from typing import Optional |
17 | 17 |
|
| 18 | +from google.adk.agents.base_agent import BaseAgent |
18 | 19 | from google.adk.agents.callback_context import CallbackContext |
19 | 20 | from google.adk.agents.invocation_context import InvocationContext |
20 | 21 | from google.adk.agents.llm_agent import Agent |
21 | 22 | from google.adk.agents.run_config import RunConfig |
22 | 23 | from google.adk.agents.sequential_agent import SequentialAgent |
23 | 24 | from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService |
| 25 | +from google.adk.events.event import Event |
24 | 26 | from google.adk.features import FeatureName |
25 | 27 | from google.adk.features._feature_registry import temporary_feature_override |
26 | 28 | from google.adk.memory.in_memory_memory_service import InMemoryMemoryService |
@@ -985,6 +987,109 @@ async def test_run_async_handles_none_parts_in_response(): |
985 | 987 | assert tool_result == '' |
986 | 988 |
|
987 | 989 |
|
| 990 | +async def _run_agent_tool_with_parts(parts: list[types.Part]) -> Any: |
| 991 | + """Drives AgentTool with an inner agent whose final event content is `parts`.""" |
| 992 | + |
| 993 | + class _StaticAgent(BaseAgent): |
| 994 | + |
| 995 | + async def _run_async_impl(self, ctx): |
| 996 | + yield Event( |
| 997 | + invocation_id=ctx.invocation_id, |
| 998 | + author=self.name, |
| 999 | + content=types.Content(role='model', parts=parts), |
| 1000 | + ) |
| 1001 | + |
| 1002 | + inner = _StaticAgent(name='inner_agent', description='static') |
| 1003 | + agent_tool = AgentTool(agent=inner) |
| 1004 | + |
| 1005 | + session_service = InMemorySessionService() |
| 1006 | + session = await session_service.create_session( |
| 1007 | + app_name='test_app', user_id='test_user' |
| 1008 | + ) |
| 1009 | + invocation_context = InvocationContext( |
| 1010 | + invocation_id='invocation_id', |
| 1011 | + agent=inner, |
| 1012 | + session=session, |
| 1013 | + session_service=session_service, |
| 1014 | + ) |
| 1015 | + tool_context = ToolContext(invocation_context=invocation_context) |
| 1016 | + |
| 1017 | + return await agent_tool.run_async( |
| 1018 | + args={'request': 'test request'}, tool_context=tool_context |
| 1019 | + ) |
| 1020 | + |
| 1021 | + |
| 1022 | +@mark.asyncio |
| 1023 | +async def test_run_async_extracts_text_only(): |
| 1024 | + """Plain text parts pass through unchanged.""" |
| 1025 | + result = await _run_agent_tool_with_parts([types.Part(text='hello world')]) |
| 1026 | + assert result == 'hello world' |
| 1027 | + |
| 1028 | + |
| 1029 | +@mark.asyncio |
| 1030 | +async def test_run_async_extracts_code_execution_result_only(): |
| 1031 | + """code_execution_result.output and executable_code.code are returned.""" |
| 1032 | + result = await _run_agent_tool_with_parts([ |
| 1033 | + types.Part( |
| 1034 | + executable_code=types.ExecutableCode( |
| 1035 | + language=types.Language.PYTHON, code='print(2 ** 10)' |
| 1036 | + ) |
| 1037 | + ), |
| 1038 | + types.Part( |
| 1039 | + code_execution_result=types.CodeExecutionResult( |
| 1040 | + outcome=types.Outcome.OUTCOME_OK, output='1024\n' |
| 1041 | + ) |
| 1042 | + ), |
| 1043 | + ]) |
| 1044 | + assert result == 'print(2 ** 10)\n1024' |
| 1045 | + |
| 1046 | + |
| 1047 | +@mark.asyncio |
| 1048 | +async def test_run_async_extracts_text_and_code_execution_result(): |
| 1049 | + """Mixed text + code parts are concatenated in order.""" |
| 1050 | + result = await _run_agent_tool_with_parts([ |
| 1051 | + types.Part(text='Here is the answer:'), |
| 1052 | + types.Part( |
| 1053 | + executable_code=types.ExecutableCode( |
| 1054 | + language=types.Language.PYTHON, code='print(2 ** 10)' |
| 1055 | + ) |
| 1056 | + ), |
| 1057 | + types.Part( |
| 1058 | + code_execution_result=types.CodeExecutionResult( |
| 1059 | + outcome=types.Outcome.OUTCOME_OK, output='1024\n' |
| 1060 | + ) |
| 1061 | + ), |
| 1062 | + ]) |
| 1063 | + assert result == 'Here is the answer:\nprint(2 ** 10)\n1024' |
| 1064 | + |
| 1065 | + |
| 1066 | +@mark.asyncio |
| 1067 | +async def test_run_async_extracts_executable_code_only(): |
| 1068 | + """executable_code.code alone is returned when no result part follows.""" |
| 1069 | + result = await _run_agent_tool_with_parts([ |
| 1070 | + types.Part( |
| 1071 | + executable_code=types.ExecutableCode( |
| 1072 | + language=types.Language.PYTHON, code='print("hi")' |
| 1073 | + ) |
| 1074 | + ), |
| 1075 | + ]) |
| 1076 | + assert result == 'print("hi")' |
| 1077 | + |
| 1078 | + |
| 1079 | +@mark.asyncio |
| 1080 | +async def test_run_async_skips_thought_parts(): |
| 1081 | + """Parts marked thought=True are dropped regardless of kind.""" |
| 1082 | + result = await _run_agent_tool_with_parts([ |
| 1083 | + types.Part(text='thinking out loud', thought=True), |
| 1084 | + types.Part( |
| 1085 | + code_execution_result=types.CodeExecutionResult( |
| 1086 | + outcome=types.Outcome.OUTCOME_OK, output='42\n' |
| 1087 | + ) |
| 1088 | + ), |
| 1089 | + ]) |
| 1090 | + assert result == '42' |
| 1091 | + |
| 1092 | + |
988 | 1093 | class TestAgentToolWithCompositeAgents: |
989 | 1094 | """Tests for AgentTool wrapping composite agents (SequentialAgent, etc.).""" |
990 | 1095 |
|
|
0 commit comments