Skip to content

Commit 855d022

Browse files
committed
feat: support parallel web search tool
1 parent ccd59d1 commit 855d022

File tree

2 files changed

+123
-3
lines changed

2 files changed

+123
-3
lines changed

veadk/agents/supervise_agent.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,30 @@
1414

1515
from google.adk.models.llm_request import LlmRequest
1616
from jinja2 import Template
17+
from pydantic import BaseModel
1718

1819
from veadk import Agent, Runner
1920
from veadk.utils.logger import get_logger
2021

2122
logger = get_logger(__name__)
2223

24+
25+
class Advice(BaseModel):
26+
advice: str
27+
"""The advice to the worker agent. Should be empty if the history execution is correct."""
28+
29+
reason: str
30+
"""The reason for the advice"""
31+
32+
2333
instruction = Template("""You are a supervisor of an agent system. The system prompt of worker agent is:
2434
2535
```system prompt
2636
{{ system_prompt }}
2737
```
2838
29-
30-
3139
You should guide the agent to finish task and must output a JSON-format string with specific advice and reason:
32-
40+
3341
- If you think the history execution is not correct, you should give your advice to the worker agent: {"advice": "Your advice here", "reason": "Your reason here"}.
3442
- If you think the history execution is correct, you should output an empty string: {"advice": "", "reason": "Your reason here"}.
3543
""")
@@ -41,6 +49,7 @@ def build_supervisor(supervised_agent: Agent) -> Agent:
4149
name="supervisor",
4250
description="A supervisor for agent execution",
4351
instruction=custom_instruction,
52+
model_extra_config={"response_format": Advice},
4453
)
4554

4655
return agent
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
The document of this tool see: https://www.volcengine.com/docs/85508/1650263
17+
"""
18+
19+
import asyncio
20+
import os
21+
22+
from google.adk.tools import ToolContext
23+
24+
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
25+
from veadk.utils.logger import get_logger
26+
from veadk.utils.volcengine_sign import ve_request
27+
28+
logger = get_logger(__name__)
29+
30+
31+
def do_search(query: str, ak: str, sk: str, session_token: str) -> list[str]:
32+
response = ve_request(
33+
request_body={
34+
"Query": query,
35+
"SearchType": "web",
36+
"Count": 5,
37+
"NeedSummary": True,
38+
},
39+
action="WebSearch",
40+
ak=ak,
41+
sk=sk,
42+
service="volc_torchlight_api",
43+
version="2025-01-01",
44+
region="cn-beijing",
45+
host="mercury.volcengineapi.com",
46+
header={"X-Security-Token": session_token},
47+
)
48+
49+
try:
50+
results: list = response["Result"]["WebResults"]
51+
final_results = []
52+
for result in results:
53+
final_results.append(result["Summary"].strip())
54+
return final_results
55+
except Exception as e:
56+
logger.error(f"Web search failed {e}, response body: {response}")
57+
return [response]
58+
59+
60+
async def parallel_web_search(
61+
queries: list[str], tool_context: ToolContext | None = None
62+
) -> dict[str, list[str]]:
63+
"""Search queries from websites in parallel.
64+
65+
Args:
66+
queries: The queries to search. Each query will be searched in parallel.
67+
68+
Returns:
69+
A dict of query to result documents.
70+
"""
71+
ak = None
72+
sk = None
73+
# First try to get tool-specific AK/SK
74+
ak = os.getenv("TOOL_WEB_SEARCH_ACCESS_KEY")
75+
sk = os.getenv("TOOL_WEB_SEARCH_SECRET_KEY")
76+
if ak and sk:
77+
logger.debug("Successfully get tool-specific AK/SK.")
78+
elif tool_context:
79+
ak = tool_context.state.get("VOLCENGINE_ACCESS_KEY")
80+
sk = tool_context.state.get("VOLCENGINE_SECRET_KEY")
81+
session_token = ""
82+
83+
if not (ak and sk):
84+
logger.debug("Get AK/SK from tool context failed.")
85+
ak = os.getenv("VOLCENGINE_ACCESS_KEY")
86+
sk = os.getenv("VOLCENGINE_SECRET_KEY")
87+
if not (ak and sk):
88+
logger.debug("Get AK/SK from environment variables failed.")
89+
credential = get_credential_from_vefaas_iam()
90+
ak = credential.access_key_id
91+
sk = credential.secret_access_key
92+
session_token = credential.session_token
93+
else:
94+
logger.debug("Successfully get AK/SK from environment variables.")
95+
else:
96+
logger.debug("Successfully get AK/SK from tool context.")
97+
98+
results = {}
99+
100+
logger.info(f"Start to search {queries} in parallel.")
101+
results_list = await asyncio.gather(
102+
*(
103+
asyncio.to_thread(do_search, query, ak, sk, session_token)
104+
for query in queries
105+
)
106+
)
107+
logger.info(f"Finish to search {queries} in parallel.")
108+
109+
results = dict(zip(queries, results_list))
110+
logger.debug(f"Search results: {results}")
111+
return results

0 commit comments

Comments
 (0)