|
| 1 | +""" |
| 2 | +前端 Demo Tool:ADK 侧注册为 FunctionTool,实际执行由前端完成,返回值在同轮 run 内回传。 |
| 3 | +
|
| 4 | +- 本 tool 挂在 demo_frontend_agent 的 tools 里;flow 在需要验证的节点调用该 agent 即可。 |
| 5 | +- 前端收到 function_call 后展示卡片,用户操作后通过 API 调用 set_frontend_tool_result |
| 6 | + 传入 (session_id, invocation_id, function_call_id, result),本侧 await 即返回并继续。 |
| 7 | +""" |
| 8 | + |
| 9 | +import asyncio |
| 10 | +import logging |
| 11 | +from typing import Dict |
| 12 | + |
| 13 | +from google.adk.tools import ToolContext |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | +# (session_id, invocation_id, function_call_id) -> asyncio.Future[dict] |
| 18 | +_pending: Dict[tuple, asyncio.Future] = {} |
| 19 | +_DEFAULT_TIMEOUT = 60.0 |
| 20 | + |
| 21 | + |
| 22 | +def _key(session_id: str, invocation_id: str, function_call_id: str) -> tuple: |
| 23 | + return (session_id, invocation_id, function_call_id) |
| 24 | + |
| 25 | + |
| 26 | +async def wait_for_frontend_tool_result( |
| 27 | + session_id: str, |
| 28 | + invocation_id: str, |
| 29 | + function_call_id: str, |
| 30 | + timeout: float = _DEFAULT_TIMEOUT, |
| 31 | +) -> dict: |
| 32 | + """ |
| 33 | + 在同一轮 run 内阻塞等待前端回传的 tool 结果。 |
| 34 | + 前端通过 API 调用 set_frontend_tool_result(session_id, invocation_id, function_call_id, result) 后,本函数返回 result。 |
| 35 | + """ |
| 36 | + k = _key(session_id, invocation_id, function_call_id) |
| 37 | + if k not in _pending: |
| 38 | + fut: asyncio.Future = asyncio.get_event_loop().create_future() |
| 39 | + _pending[k] = fut |
| 40 | + try: |
| 41 | + return await asyncio.wait_for(_pending[k], timeout=timeout) |
| 42 | + finally: |
| 43 | + _pending.pop(k, None) |
| 44 | + |
| 45 | + |
| 46 | +def set_frontend_tool_result( |
| 47 | + session_id: str, |
| 48 | + invocation_id: str, |
| 49 | + function_call_id: str, |
| 50 | + result: dict, |
| 51 | +) -> None: |
| 52 | + """ |
| 53 | + 由运行时 API 在收到前端提交的 tool 结果时调用,用于解除 wait_for_frontend_tool_result 的阻塞。 |
| 54 | + """ |
| 55 | + k = _key(session_id, invocation_id, function_call_id) |
| 56 | + fut = _pending.get(k) |
| 57 | + if fut is not None and not fut.done(): |
| 58 | + fut.set_result(result) |
| 59 | + logger.info( |
| 60 | + 'demo_frontend_tool result set for session_id=%s invocation_id=%s function_call_id=%s', |
| 61 | + session_id, |
| 62 | + invocation_id, |
| 63 | + function_call_id, |
| 64 | + ) |
| 65 | + else: |
| 66 | + logger.warning( |
| 67 | + 'demo_frontend_tool no pending future for session_id=%s invocation_id=%s function_call_id=%s', |
| 68 | + session_id, |
| 69 | + invocation_id, |
| 70 | + function_call_id, |
| 71 | + ) |
| 72 | + |
| 73 | + |
| 74 | +async def demo_frontend_tool( |
| 75 | + message: str, |
| 76 | + title: str, |
| 77 | + tool_context: ToolContext, |
| 78 | +) -> dict: |
| 79 | + """ |
| 80 | + Demo frontend tool for verifying ADK can call a client-side tool. |
| 81 | + 实际执行由前端完成;本函数在此轮 run 内等待前端回传结果后返回。 |
| 82 | +
|
| 83 | + Args: |
| 84 | + message: 展示给用户的文案,与前端约定一致。 |
| 85 | + title: 卡片标题,与前端约定一致。 |
| 86 | +
|
| 87 | + Returns: |
| 88 | + 前端回传的 dict,例如 {"confirmed": True, "value": "..."}。 |
| 89 | + """ |
| 90 | + session_id = tool_context.session.id |
| 91 | + invocation_id = getattr(tool_context, 'invocation_id', None) or '' |
| 92 | + function_call_id = getattr(tool_context, 'function_call_id', None) or '' |
| 93 | + return await wait_for_frontend_tool_result( |
| 94 | + session_id, invocation_id, function_call_id |
| 95 | + ) |
| 96 | + |
| 97 | + |
| 98 | +# 供 ADK 自动从函数签名生成 schema 的 docstring(参数与前端一致) |
| 99 | +demo_frontend_tool.__doc__ = ( |
| 100 | + 'Demo frontend tool for verifying ADK can call a client-side tool. ' |
| 101 | + 'Renders a card on the frontend (message, title); execution and return value are provided by the frontend.' |
| 102 | +) |
0 commit comments