Skip to content

Commit 22e2c8b

Browse files
committed
fix(websearch): 统一前后端网页搜索引用提取逻辑,增加前端 refs 降级获取
- 重构 web_search_utils.py 为分层结构,新增 build_web_search_refs() 和 _extract_ref_indices() 支持从 <ref> 标签提取引用索引 - 简化 chat.py/live_chat.py 中 ref 提取为调用 build_web_search_refs() - MessageList.vue 新增 getMessageRefs() 在后端未返回 refs 时前端自行降级提取 - 修复 chat.py 中消息保存条件判断逻辑
1 parent 479c58e commit 22e2c8b

5 files changed

Lines changed: 282 additions & 103 deletions

File tree

astrbot/core/utils/web_search_utils.py

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23
from typing import Any
34
from urllib.parse import urlparse
45

@@ -29,9 +30,9 @@ def normalize_web_search_base_url(
2930
return normalized
3031

3132

32-
def collect_web_search_results(accumulated_parts: list[dict[str, Any]]) -> dict:
33-
web_search_results = {}
34-
33+
def _iter_web_search_result_items(
34+
accumulated_parts: list[dict[str, Any]],
35+
):
3536
for part in accumulated_parts:
3637
if part.get("type") != "tool_call" or not part.get("tool_calls"):
3738
continue
@@ -52,13 +53,78 @@ def collect_web_search_results(accumulated_parts: list[dict[str, Any]]) -> dict:
5253
continue
5354

5455
for item in result_data.get("results", []):
55-
if not isinstance(item, dict):
56-
continue
57-
if idx := item.get("index"):
58-
web_search_results[idx] = {
59-
"url": item.get("url"),
60-
"title": item.get("title"),
61-
"snippet": item.get("snippet"),
62-
}
56+
if isinstance(item, dict):
57+
yield item
58+
59+
60+
def _extract_ref_indices(accumulated_text: str) -> list[str]:
61+
ref_indices: list[str] = []
62+
seen_indices: set[str] = set()
63+
64+
for match in re.finditer(r"<ref>(.*?)</ref>", accumulated_text):
65+
ref_index = match.group(1).strip()
66+
if not ref_index or ref_index in seen_indices:
67+
continue
68+
ref_indices.append(ref_index)
69+
seen_indices.add(ref_index)
70+
71+
return ref_indices
72+
73+
74+
def collect_web_search_ref_items(
75+
accumulated_parts: list[dict[str, Any]],
76+
favicon_cache: dict[str, str] | None = None,
77+
) -> list[dict[str, Any]]:
78+
web_search_refs: list[dict[str, Any]] = []
79+
seen_indices: set[str] = set()
80+
81+
for item in _iter_web_search_result_items(accumulated_parts):
82+
ref_index = item.get("index")
83+
if not ref_index or ref_index in seen_indices:
84+
continue
85+
86+
payload = {
87+
"index": ref_index,
88+
"url": item.get("url"),
89+
"title": item.get("title"),
90+
"snippet": item.get("snippet"),
91+
}
92+
if favicon_cache and payload["url"] in favicon_cache:
93+
payload["favicon"] = favicon_cache[payload["url"]]
94+
95+
web_search_refs.append(payload)
96+
seen_indices.add(ref_index)
97+
98+
return web_search_refs
99+
100+
101+
def build_web_search_refs(
102+
accumulated_text: str,
103+
accumulated_parts: list[dict[str, Any]],
104+
favicon_cache: dict[str, str] | None = None,
105+
) -> dict:
106+
ordered_refs = collect_web_search_ref_items(accumulated_parts, favicon_cache)
107+
if not ordered_refs:
108+
return {}
109+
110+
refs_by_index = {ref["index"]: ref for ref in ordered_refs}
111+
ref_indices = _extract_ref_indices(accumulated_text)
112+
used_refs = [refs_by_index[idx] for idx in ref_indices if idx in refs_by_index]
113+
114+
if not used_refs:
115+
used_refs = ordered_refs
116+
117+
return {"used": used_refs}
118+
119+
120+
def collect_web_search_results(accumulated_parts: list[dict[str, Any]]) -> dict:
121+
web_search_results = {}
122+
123+
for ref in collect_web_search_ref_items(accumulated_parts):
124+
web_search_results[ref["index"]] = {
125+
"url": ref.get("url"),
126+
"title": ref.get("title"),
127+
"snippet": ref.get("snippet"),
128+
}
63129

64130
return web_search_results

astrbot/dashboard/routes/chat.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import json
33
import os
4-
import re
54
import uuid
65
from contextlib import asynccontextmanager
76
from typing import cast
@@ -23,7 +22,7 @@
2322
from astrbot.core.utils.active_event_registry import active_event_registry
2423
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
2524
from astrbot.core.utils.datetime_utils import to_utc_isoformat
26-
from astrbot.core.utils.web_search_utils import collect_web_search_results
25+
from astrbot.core.utils.web_search_utils import build_web_search_refs
2726

2827
from .route import Response, Route, RouteContext
2928

@@ -216,35 +215,13 @@ async def _create_attachment_from_file(
216215
def _extract_web_search_refs(
217216
self, accumulated_text: str, accumulated_parts: list
218217
) -> dict:
219-
"""从消息中提取网页搜索引用。
220-
221-
Args:
222-
accumulated_text: 累积的文本内容
223-
accumulated_parts: 累积的消息部分列表
224-
225-
Returns:
226-
包含 used 列表的字典,记录被引用的搜索结果
227-
"""
228-
web_search_results = collect_web_search_results(accumulated_parts)
229-
if not web_search_results:
230-
return {}
231-
232-
# 从文本中提取所有 <ref>xxx</ref> 标签并去重
233-
ref_indices = {
234-
m.strip() for m in re.findall(r"<ref>(.*?)</ref>", accumulated_text)
235-
}
236-
237-
# 构建被引用的结果列表
238-
used_refs = []
239-
for ref_index in ref_indices:
240-
if ref_index not in web_search_results:
241-
continue
242-
payload = {"index": ref_index, **web_search_results[ref_index]}
243-
if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
244-
payload["favicon"] = favicon
245-
used_refs.append(payload)
246-
247-
return {"used": used_refs} if used_refs else {}
218+
"""从消息中提取网页搜索引用。"""
219+
favicon_cache = sp.temporary_cache.get("_ws_favicon", {})
220+
return build_web_search_refs(
221+
accumulated_text,
222+
accumulated_parts,
223+
favicon_cache,
224+
)
248225

249226
async def _save_bot_message(
250227
self,
@@ -446,19 +423,27 @@ async def stream():
446423
accumulated_parts.append(part)
447424

448425
# 消息结束处理
426+
should_save = False
449427
if msg_type == "end":
450-
break
428+
should_save = bool(
429+
accumulated_parts
430+
or accumulated_text
431+
or accumulated_reasoning
432+
or refs
433+
or agent_stats
434+
)
451435
elif (
452436
(streaming and msg_type == "complete") or not streaming
453437
# or msg_type == "break"
454438
):
455-
if (
456-
chain_type == "tool_call"
457-
or chain_type == "tool_call_result"
439+
if chain_type not in (
440+
"tool_call",
441+
"tool_call_result",
442+
"agent_stats",
458443
):
459-
continue
444+
should_save = True
460445

461-
# 提取 web_search_tavily 引用
446+
if should_save:
462447
try:
463448
refs = self._extract_web_search_refs(
464449
accumulated_text,
@@ -499,6 +484,9 @@ async def stream():
499484
# tool_calls = {}
500485
agent_stats = {}
501486
refs = {}
487+
488+
if msg_type == "end":
489+
break
502490
except BaseException as e:
503491
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
504492
finally:

astrbot/dashboard/routes/live_chat.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import json
33
import os
4-
import re
54
import time
65
import uuid
76
import wave
@@ -22,7 +21,7 @@
2221
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
2322
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
2423
from astrbot.core.utils.datetime_utils import to_utc_isoformat
25-
from astrbot.core.utils.web_search_utils import collect_web_search_results
24+
from astrbot.core.utils.web_search_utils import build_web_search_refs
2625

2726
from .route import Route, RouteContext
2827

@@ -199,24 +198,12 @@ def _extract_web_search_refs(
199198
self, accumulated_text: str, accumulated_parts: list
200199
) -> dict:
201200
"""从消息中提取 web_search 引用。"""
202-
web_search_results = collect_web_search_results(accumulated_parts)
203-
if not web_search_results:
204-
return {}
205-
206-
ref_indices = {
207-
m.strip() for m in re.findall(r"<ref>(.*?)</ref>", accumulated_text)
208-
}
209-
210-
used_refs = []
211-
for ref_index in ref_indices:
212-
if ref_index not in web_search_results:
213-
continue
214-
payload = {"index": ref_index, **web_search_results[ref_index]}
215-
if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
216-
payload["favicon"] = favicon
217-
used_refs.append(payload)
218-
219-
return {"used": used_refs} if used_refs else {}
201+
favicon_cache = sp.temporary_cache.get("_ws_favicon", {})
202+
return build_web_search_refs(
203+
accumulated_text,
204+
accumulated_parts,
205+
favicon_cache,
206+
)
220207

221208
async def _save_bot_message(
222209
self,

dashboard/src/components/chat/MessageList.vue

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@
149149
@click="$emit('replyMessage', msg, index)" :title="tm('actions.reply')" />
150150

151151
<!-- Refs Visualization -->
152-
<ActionRef :refs="msg.content.refs" @open-refs="openRefsSidebar" />
152+
<ActionRef :refs="getMessageRefs(msg.content)" @open-refs="openRefsSidebar" />
153153
</div>
154154
</div>
155155
</div>
@@ -294,47 +294,95 @@ export default {
294294
this.extractWebSearchResults();
295295
},
296296
methods: {
297-
// 从消息中提取 web_search_tavily 的搜索结果
297+
extractRefsFromToolCall(toolCall) {
298+
if (!WEB_SEARCH_REFERENCE_TOOLS.includes(toolCall?.name) || !toolCall.result) {
299+
return [];
300+
}
301+
302+
try {
303+
const resultData = typeof toolCall.result === 'string'
304+
? JSON.parse(toolCall.result)
305+
: toolCall.result;
306+
307+
if (!resultData?.results || !Array.isArray(resultData.results)) {
308+
return [];
309+
}
310+
311+
const refs = [];
312+
const seenIndices = new Set();
313+
314+
resultData.results.forEach(item => {
315+
if (!item?.index || seenIndices.has(item.index)) {
316+
return;
317+
}
318+
319+
refs.push({
320+
index: item.index,
321+
url: item.url,
322+
title: item.title,
323+
snippet: item.snippet
324+
});
325+
seenIndices.add(item.index);
326+
});
327+
328+
return refs;
329+
} catch (e) {
330+
console.error('Failed to parse web search result:', e);
331+
return [];
332+
}
333+
},
334+
335+
collectMessageWebSearchRefs(messageParts) {
336+
if (!Array.isArray(messageParts)) {
337+
return [];
338+
}
339+
340+
const refs = [];
341+
const seenIndices = new Set();
342+
343+
messageParts.forEach(part => {
344+
if (part.type !== 'tool_call' || !Array.isArray(part.tool_calls)) {
345+
return;
346+
}
347+
348+
part.tool_calls.forEach(toolCall => {
349+
this.extractRefsFromToolCall(toolCall).forEach(ref => {
350+
if (seenIndices.has(ref.index)) {
351+
return;
352+
}
353+
refs.push(ref);
354+
seenIndices.add(ref.index);
355+
});
356+
});
357+
});
358+
359+
return refs;
360+
},
361+
362+
getMessageRefs(content) {
363+
if (content?.refs?.used?.length) {
364+
return content.refs;
365+
}
366+
367+
const fallbackRefs = this.collectMessageWebSearchRefs(content?.message);
368+
return fallbackRefs.length ? { used: fallbackRefs } : null;
369+
},
370+
371+
// 从消息中提取网页搜索结果映射
298372
extractWebSearchResults() {
299373
const results = {};
300374
301375
this.messages.forEach(msg => {
302376
if (msg.content.type !== 'bot' || !Array.isArray(msg.content.message)) {
303377
return;
304378
}
305-
306-
msg.content.message.forEach(part => {
307-
if (part.type !== 'tool_call' || !Array.isArray(part.tool_calls)) {
308-
return;
309-
}
310-
311-
part.tool_calls.forEach(toolCall => {
312-
// 检查是否是网页搜索工具调用
313-
if (!WEB_SEARCH_REFERENCE_TOOLS.includes(toolCall.name) || !toolCall.result) {
314-
return;
315-
}
316-
317-
try {
318-
// 解析工具调用结果
319-
const resultData = typeof toolCall.result === 'string'
320-
? JSON.parse(toolCall.result)
321-
: toolCall.result;
322-
323-
if (resultData.results && Array.isArray(resultData.results)) {
324-
resultData.results.forEach(item => {
325-
if (item.index) {
326-
results[item.index] = {
327-
url: item.url,
328-
title: item.title,
329-
snippet: item.snippet
330-
};
331-
}
332-
});
333-
}
334-
} catch (e) {
335-
console.error('Failed to parse web search result:', e);
336-
}
337-
});
379+
380+
this.collectMessageWebSearchRefs(msg.content.message).forEach(ref => {
381+
results[ref.index] = {
382+
url: ref.url,
383+
title: ref.title,
384+
snippet: ref.snippet
385+
};
338386
});
339387
});
340388

0 commit comments

Comments
 (0)