|
19 | 19 | from google.adk.agents.llm_agent import Agent |
20 | 20 | from google.adk.events.event import Event |
21 | 21 | from google.adk.flows.llm_flows.functions import handle_function_calls_async |
| 22 | +from google.adk.flows.llm_flows.functions import handle_function_calls_live |
22 | 23 | from google.adk.plugins.base_plugin import BasePlugin |
23 | 24 | from google.adk.tools.base_tool import BaseTool |
24 | 25 | from google.adk.tools.function_tool import FunctionTool |
@@ -185,5 +186,159 @@ async def test_async_on_tool_error_fallback_to_runner( |
185 | 186 | assert e == mock_error |
186 | 187 |
|
187 | 188 |
|
| 189 | +async def invoke_tool_with_plugin_live( |
| 190 | + mock_tool, mock_plugin |
| 191 | +) -> Optional[Event]: |
| 192 | + """Invokes a tool with a plugin using the live path.""" |
| 193 | + model = testing_utils.MockModel.create(responses=[]) |
| 194 | + agent = Agent( |
| 195 | + name="agent", |
| 196 | + model=model, |
| 197 | + tools=[mock_tool], |
| 198 | + ) |
| 199 | + invocation_context = await testing_utils.create_invocation_context( |
| 200 | + agent=agent, user_content="", plugins=[mock_plugin] |
| 201 | + ) |
| 202 | + # Build function call event |
| 203 | + function_call = types.FunctionCall(name=mock_tool.name, args={}) |
| 204 | + content = types.Content(parts=[types.Part(function_call=function_call)]) |
| 205 | + event = Event( |
| 206 | + invocation_id=invocation_context.invocation_id, |
| 207 | + author=agent.name, |
| 208 | + content=content, |
| 209 | + ) |
| 210 | + tools_dict = {mock_tool.name: mock_tool} |
| 211 | + return await handle_function_calls_live( |
| 212 | + invocation_context, |
| 213 | + event, |
| 214 | + tools_dict, |
| 215 | + ) |
| 216 | + |
| 217 | + |
| 218 | +@pytest.mark.asyncio |
| 219 | +async def test_live_before_tool_callback(mock_tool, mock_plugin): |
| 220 | + mock_plugin.enable_before_tool_callback = True |
| 221 | + |
| 222 | + result_event = await invoke_tool_with_plugin_live(mock_tool, mock_plugin) |
| 223 | + |
| 224 | + assert result_event is not None |
| 225 | + part = result_event.content.parts[0] |
| 226 | + assert part.function_response.response == mock_plugin.before_tool_response |
| 227 | + |
| 228 | + |
| 229 | +@pytest.mark.asyncio |
| 230 | +async def test_live_after_tool_callback(mock_tool, mock_plugin): |
| 231 | + mock_plugin.enable_after_tool_callback = True |
| 232 | + |
| 233 | + result_event = await invoke_tool_with_plugin_live(mock_tool, mock_plugin) |
| 234 | + |
| 235 | + assert result_event is not None |
| 236 | + part = result_event.content.parts[0] |
| 237 | + assert part.function_response.response == mock_plugin.after_tool_response |
| 238 | + |
| 239 | + |
| 240 | +@pytest.mark.asyncio |
| 241 | +async def test_live_on_tool_error_use_plugin_response( |
| 242 | + mock_error_tool, mock_plugin |
| 243 | +): |
| 244 | + mock_plugin.enable_on_tool_error_callback = True |
| 245 | + |
| 246 | + result_event = await invoke_tool_with_plugin_live( |
| 247 | + mock_error_tool, mock_plugin |
| 248 | + ) |
| 249 | + |
| 250 | + assert result_event is not None |
| 251 | + part = result_event.content.parts[0] |
| 252 | + assert part.function_response.response == mock_plugin.on_tool_error_response |
| 253 | + |
| 254 | + |
| 255 | +@pytest.mark.asyncio |
| 256 | +async def test_live_on_tool_error_fallback_to_runner( |
| 257 | + mock_error_tool, mock_plugin |
| 258 | +): |
| 259 | + mock_plugin.enable_on_tool_error_callback = False |
| 260 | + |
| 261 | + try: |
| 262 | + await invoke_tool_with_plugin_live(mock_error_tool, mock_plugin) |
| 263 | + except Exception as e: |
| 264 | + assert e == mock_error |
| 265 | + |
| 266 | + |
| 267 | +@pytest.mark.asyncio |
| 268 | +async def test_live_plugin_before_tool_callback_takes_priority( |
| 269 | + mock_tool, mock_plugin |
| 270 | +): |
| 271 | + """Plugin before_tool_callback should run before agent canonical callbacks.""" |
| 272 | + mock_plugin.enable_before_tool_callback = True |
| 273 | + |
| 274 | + def agent_before_cb(tool, args, tool_context): |
| 275 | + return {"agent": "should_not_be_called"} |
| 276 | + |
| 277 | + model = testing_utils.MockModel.create(responses=[]) |
| 278 | + agent = Agent( |
| 279 | + name="agent", |
| 280 | + model=model, |
| 281 | + tools=[mock_tool], |
| 282 | + before_tool_callback=agent_before_cb, |
| 283 | + ) |
| 284 | + invocation_context = await testing_utils.create_invocation_context( |
| 285 | + agent=agent, user_content="", plugins=[mock_plugin] |
| 286 | + ) |
| 287 | + function_call = types.FunctionCall(name=mock_tool.name, args={}) |
| 288 | + content = types.Content(parts=[types.Part(function_call=function_call)]) |
| 289 | + event = Event( |
| 290 | + invocation_id=invocation_context.invocation_id, |
| 291 | + author=agent.name, |
| 292 | + content=content, |
| 293 | + ) |
| 294 | + tools_dict = {mock_tool.name: mock_tool} |
| 295 | + result_event = await handle_function_calls_live( |
| 296 | + invocation_context, event, tools_dict |
| 297 | + ) |
| 298 | + |
| 299 | + assert result_event is not None |
| 300 | + part = result_event.content.parts[0] |
| 301 | + # Plugin response should win, not the agent callback |
| 302 | + assert part.function_response.response == mock_plugin.before_tool_response |
| 303 | + |
| 304 | + |
| 305 | +@pytest.mark.asyncio |
| 306 | +async def test_live_plugin_after_tool_callback_takes_priority( |
| 307 | + mock_tool, mock_plugin |
| 308 | +): |
| 309 | + """Plugin after_tool_callback should run before agent canonical callbacks.""" |
| 310 | + mock_plugin.enable_after_tool_callback = True |
| 311 | + |
| 312 | + def agent_after_cb(tool, args, tool_context, tool_response): |
| 313 | + return {"agent": "should_not_be_called"} |
| 314 | + |
| 315 | + model = testing_utils.MockModel.create(responses=[]) |
| 316 | + agent = Agent( |
| 317 | + name="agent", |
| 318 | + model=model, |
| 319 | + tools=[mock_tool], |
| 320 | + after_tool_callback=agent_after_cb, |
| 321 | + ) |
| 322 | + invocation_context = await testing_utils.create_invocation_context( |
| 323 | + agent=agent, user_content="", plugins=[mock_plugin] |
| 324 | + ) |
| 325 | + function_call = types.FunctionCall(name=mock_tool.name, args={}) |
| 326 | + content = types.Content(parts=[types.Part(function_call=function_call)]) |
| 327 | + event = Event( |
| 328 | + invocation_id=invocation_context.invocation_id, |
| 329 | + author=agent.name, |
| 330 | + content=content, |
| 331 | + ) |
| 332 | + tools_dict = {mock_tool.name: mock_tool} |
| 333 | + result_event = await handle_function_calls_live( |
| 334 | + invocation_context, event, tools_dict |
| 335 | + ) |
| 336 | + |
| 337 | + assert result_event is not None |
| 338 | + part = result_event.content.parts[0] |
| 339 | + # Plugin response should win, not the agent callback |
| 340 | + assert part.function_response.response == mock_plugin.after_tool_response |
| 341 | + |
| 342 | + |
188 | 343 | if __name__ == "__main__": |
189 | 344 | pytest.main([__file__]) |
0 commit comments