Skip to content

Commit bab3b7c

Browse files
committed
fix(web-search): validate base URLs against endpoint paths and expand Exa search types
- Reject specific API endpoint paths (e.g., /search, /extract) in base URL normalization via new disallowed_path_suffixes parameter to prevent misconfiguration errors - Add deep-lite and deep-reasoning to valid Exa search types and normalize search_type input before validation - Add missing config parameter to BochaWebSearchTool builtin_tool decorator so provider status checks are properly registered
1 parent 8e0a723 commit bab3b7c

5 files changed

Lines changed: 178 additions & 5 deletions

File tree

astrbot/core/knowledge_base/parsers/url_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
tavily_base_url,
2929
default="https://api.tavily.com",
3030
provider_name="Tavily",
31+
disallowed_path_suffixes=("search", "extract"),
3132
)
3233

3334
async def _get_tavily_key(self) -> str:

astrbot/core/tools/web_search_tools.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@
4646
"provider_settings.web_search": True,
4747
"provider_settings.websearch_provider": "exa",
4848
}
49+
_EXA_SEARCH_TYPES = (
50+
"auto",
51+
"fast",
52+
"deep",
53+
"deep-lite",
54+
"deep-reasoning",
55+
"instant",
56+
"neural",
57+
)
4958

5059

5160
@std_dataclass
@@ -149,6 +158,7 @@ def _get_tavily_base_url(provider_settings: dict) -> str:
149158
provider_settings.get("websearch_tavily_base_url"),
150159
default="https://api.tavily.com",
151160
provider_name="Tavily",
161+
disallowed_path_suffixes=("search", "extract"),
152162
)
153163

154164

@@ -157,6 +167,7 @@ def _get_exa_base_url(provider_settings: dict) -> str:
157167
provider_settings.get("websearch_exa_base_url"),
158168
default="https://api.exa.ai",
159169
provider_name="Exa",
170+
disallowed_path_suffixes=("search", "contents", "findSimilar"),
160171
)
161172

162173

@@ -645,7 +656,7 @@ class ExaWebSearchTool(FunctionTool[AstrAgentContext]):
645656
},
646657
"search_type": {
647658
"type": "string",
648-
"description": 'Optional. Search type. Must be one of "auto", "neural", "fast", "instant", "deep". Default is "auto".',
659+
"description": 'Optional. Search type. Must be one of "auto", "fast", "deep", "deep-lite", "deep-reasoning", "instant", "neural". Default is "auto".',
649660
},
650661
"category": {
651662
"type": "string",
@@ -665,8 +676,8 @@ async def call(self, context, **kwargs) -> ToolExecResult:
665676
if not provider_settings.get("websearch_exa_key", []):
666677
return "Error: Exa API key is not configured in AstrBot."
667678

668-
search_type = kwargs.get("search_type", "auto")
669-
if search_type not in ("auto", "neural", "fast", "instant", "deep"):
679+
search_type = str(kwargs.get("search_type", "auto")).strip().lower()
680+
if search_type not in _EXA_SEARCH_TYPES:
670681
search_type = "auto"
671682

672683
max_results = max(1, min(int(kwargs.get("max_results", 10)), 100))
@@ -794,7 +805,7 @@ async def call(self, context, **kwargs) -> ToolExecResult:
794805
return _search_result_payload(results)
795806

796807

797-
@builtin_tool
808+
@builtin_tool(config=_BOCHA_WEB_SEARCH_TOOL_CONFIG)
798809
@pydantic_dataclass
799810
class BochaWebSearchTool(FunctionTool[AstrAgentContext]):
800811
name: str = "web_search_bocha"

astrbot/core/utils/web_search_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def normalize_web_search_base_url(
1818
*,
1919
default: str,
2020
provider_name: str,
21+
disallowed_path_suffixes: tuple[str, ...] = (),
2122
) -> str:
2223
normalized = (base_url or "").strip()
2324
if not normalized:
@@ -30,6 +31,18 @@ def normalize_web_search_base_url(
3031
f"Error: {provider_name} API Base URL must start with http:// or "
3132
f"https://. Proxy base paths are allowed. Received: {normalized!r}.",
3233
)
34+
35+
last_path_segment = parsed.path.rstrip("/").rsplit("/", 1)[-1].lower()
36+
invalid_suffixes = {
37+
suffix.strip("/").lower()
38+
for suffix in disallowed_path_suffixes
39+
if suffix and suffix.strip("/")
40+
}
41+
if last_path_segment and last_path_segment in invalid_suffixes:
42+
raise ValueError(
43+
f"Error: {provider_name} API Base URL must be a base URL or proxy "
44+
f"prefix, not a specific endpoint path. Received: {normalized!r}.",
45+
)
3346
return normalized
3447

3548

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from types import SimpleNamespace
2+
3+
import pytest
4+
5+
import astrbot.core.tools.registry as tool_registry
6+
import astrbot.core.tools.web_search_tools as web_search_tools
7+
from astrbot.core.knowledge_base.parsers.url_parser import URLExtractor
8+
from astrbot.core.tools.web_search_tools import ExaWebSearchTool
9+
10+
11+
def _make_tool_context(provider_settings: dict) -> SimpleNamespace:
12+
cfg = {"provider_settings": provider_settings}
13+
return SimpleNamespace(
14+
context=SimpleNamespace(
15+
context=SimpleNamespace(get_config=lambda umo=None: cfg),
16+
event=SimpleNamespace(unified_msg_origin="test:private:session"),
17+
)
18+
)
19+
20+
21+
@pytest.mark.asyncio
22+
@pytest.mark.parametrize(
23+
("search_type", "expected"),
24+
[
25+
("deep-lite", "deep-lite"),
26+
("deep-reasoning", "deep-reasoning"),
27+
("instant", "instant"),
28+
("unsupported", "auto"),
29+
],
30+
)
31+
async def test_exa_web_search_tool_normalizes_search_type(
32+
monkeypatch: pytest.MonkeyPatch,
33+
search_type: str,
34+
expected: str,
35+
):
36+
captured: dict[str, object] = {}
37+
38+
async def fake_exa_search(provider_settings: dict, payload: dict, timeout: int):
39+
captured["provider_settings"] = provider_settings
40+
captured["payload"] = payload
41+
captured["timeout"] = timeout
42+
return []
43+
44+
monkeypatch.setattr(web_search_tools, "_exa_search", fake_exa_search)
45+
46+
tool = ExaWebSearchTool()
47+
result = await tool.call(
48+
_make_tool_context({"websearch_exa_key": ["test-key"]}),
49+
query="AstrBot",
50+
search_type=search_type,
51+
)
52+
53+
assert result == "Error: Exa web searcher does not return any results."
54+
assert captured["payload"]["type"] == expected
55+
56+
57+
def test_get_exa_base_url_rejects_endpoint_path():
58+
with pytest.raises(ValueError) as exc_info:
59+
web_search_tools._get_exa_base_url(
60+
{"websearch_exa_base_url": "https://api.exa.ai/search"}
61+
)
62+
63+
assert str(exc_info.value) == (
64+
"Error: Exa API Base URL must be a base URL or proxy prefix, "
65+
"not a specific endpoint path. Received: 'https://api.exa.ai/search'."
66+
)
67+
68+
69+
def test_url_extractor_rejects_endpoint_base_url():
70+
with pytest.raises(ValueError) as exc_info:
71+
URLExtractor(
72+
["test-key"],
73+
tavily_base_url="https://api.tavily.com/extract",
74+
)
75+
76+
assert str(exc_info.value) == (
77+
"Error: Tavily API Base URL must be a base URL or proxy prefix, "
78+
"not a specific endpoint path. Received: 'https://api.tavily.com/extract'."
79+
)
80+
81+
82+
def test_bocha_builtin_config_statuses_are_registered():
83+
rule = tool_registry._BUILTIN_TOOL_CONFIG_RULES.get("web_search_bocha")
84+
85+
assert rule is not None
86+
statuses = rule.evaluate(
87+
{
88+
"provider_settings": {
89+
"web_search": True,
90+
"websearch_provider": "bocha",
91+
}
92+
}
93+
)
94+
95+
assert statuses == [
96+
{
97+
"key": "provider_settings.web_search",
98+
"operator": "equals",
99+
"expected": True,
100+
"actual": True,
101+
"matched": True,
102+
"message": None,
103+
},
104+
{
105+
"key": "provider_settings.websearch_provider",
106+
"operator": "equals",
107+
"expected": "bocha",
108+
"actual": "bocha",
109+
"matched": True,
110+
"message": None,
111+
}
112+
]

tests/unit/test_web_search_utils.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def test_normalize_web_search_base_url_reports_invalid_value(
121121
[
122122
(" https://api.exa.ai/ ", "https://api.exa.ai"),
123123
("https://proxy.example.com/exa/", "https://proxy.example.com/exa"),
124-
("https://api.exa.ai/search", "https://api.exa.ai/search"),
125124
],
126125
)
127126
def test_normalize_web_search_base_url_accepts_proxy_paths(
@@ -134,3 +133,40 @@ def test_normalize_web_search_base_url_accepts_proxy_paths(
134133
)
135134

136135
assert normalized == expected
136+
137+
138+
@pytest.mark.parametrize(
139+
("base_url", "provider_name", "disallowed_path_suffixes", "expected_message"),
140+
[
141+
(
142+
"https://api.exa.ai/search",
143+
"Exa",
144+
("search", "contents", "findSimilar"),
145+
"Error: Exa API Base URL must be a base URL or proxy prefix, "
146+
"not a specific endpoint path. Received: 'https://api.exa.ai/search'.",
147+
),
148+
(
149+
"https://api.tavily.com/extract",
150+
"Tavily",
151+
("search", "extract"),
152+
"Error: Tavily API Base URL must be a base URL or proxy prefix, "
153+
"not a specific endpoint path. Received: "
154+
"'https://api.tavily.com/extract'.",
155+
),
156+
],
157+
)
158+
def test_normalize_web_search_base_url_rejects_endpoint_paths(
159+
base_url: str,
160+
provider_name: str,
161+
disallowed_path_suffixes: tuple[str, ...],
162+
expected_message: str,
163+
):
164+
with pytest.raises(ValueError) as exc_info:
165+
normalize_web_search_base_url(
166+
base_url,
167+
default="https://api.exa.ai",
168+
provider_name=provider_name,
169+
disallowed_path_suffixes=disallowed_path_suffixes,
170+
)
171+
172+
assert str(exc_info.value) == expected_message

0 commit comments

Comments
 (0)