Skip to content

Commit 7e48a19

Browse files
committed
feat(rag): 实现路由层 RAG 前置管线与实时思考事件推送
- vector_store.py: 新增 compute_route_score 快速相关度评分方法 - chat.py: 实现路由判断逻辑(score > 0.5 走 RAG 管线) - 通过 asyncio.Queue 实时推送 RAG 思考事件到前端 - RAG 检索结果注入 Agent 的 rag_context 参数
1 parent 50ec5fe commit 7e48a19

2 files changed

Lines changed: 124 additions & 1 deletion

File tree

backend/app/rag/vector_store.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,26 @@ def _get_embed_model():
121121
"""获取嵌入模型(延迟加载包装器,模型在首次调用时解析)"""
122122
return _LazyEmbedding()
123123

124+
async def compute_route_score(self, query: str, user_id: str) -> float:
125+
"""快速计算查询与用户知识库的相关度(<10ms)
126+
127+
用 ChromaDB 做 Top-1 检索,返回 L2 距离转换后的相似度分数。
128+
分数越高表示与用户知识库越相关,用于路由层判断是否需要 RAG 前置管线。
129+
"""
130+
try:
131+
results = await asyncio.to_thread(
132+
self.vectors_store.similarity_search_with_score,
133+
query,
134+
k=1,
135+
filter={"user_id": user_id}
136+
)
137+
if not results:
138+
return 0.0
139+
distance = results[0][1]
140+
return 1 / (1 + distance)
141+
except Exception:
142+
return 0.0
143+
124144
async def get_bm25_retriever(self, user_id: str = None):
125145
return await self.hybrid_retriever.get_bm25_retriever(user_id)
126146

backend/app/router/chat.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import json
13
import uuid
24

35
from fastapi import Depends
@@ -23,8 +25,109 @@ async def query_stream(
2325
"""查询Agent流式响应"""
2426
session_id = request.session_id or str(uuid.uuid4())
2527

28+
from app.core.logger_handler import logger
29+
from app.rag.vector_store import VectorStoreService
30+
31+
vector_store = VectorStoreService()
32+
33+
# ---- 路由判断(快速,~50ms)----
34+
score = await vector_store.compute_route_score(
35+
request.query, user_id
36+
)
37+
38+
# 查询 Top-1 文档详情,用于日志输出
39+
top1_docs = await asyncio.to_thread(
40+
vector_store.vectors_store.similarity_search_with_score,
41+
request.query, k=1, filter={"user_id": user_id}
42+
)
43+
if top1_docs:
44+
top1_doc, top1_distance = top1_docs[0]
45+
source_type = "笔记库" if top1_doc.metadata.get("source_type") == "note" else "知识库"
46+
source_name = top1_doc.metadata.get("title") or top1_doc.metadata.get("original_filename", "未知")
47+
preview = top1_doc.page_content[:80].replace("\n", " ")
48+
logger.info(
49+
f"【路由决策】查询: 「{request.query}」 | "
50+
f"score: {score:.4f} (距离: {top1_distance:.4f}) | "
51+
f"Top-1来源: {source_type}{source_name}》 | "
52+
f"预览: {preview}... | "
53+
f"决策: {'→ RAG 前置管线' if score > 0.5 else '→ 跳过 RAG'}"
54+
)
55+
else:
56+
logger.info(
57+
f"【路由决策】查询: 「{request.query}」 | "
58+
f"score: {score:.4f} | "
59+
f"Top-1: 无文档 | "
60+
f"决策: → 跳过 RAG"
61+
)
62+
63+
async def stream_with_rag_thinking():
64+
"""包装生成器:RAG 管线在内部实时推送思考事件,再转发 Agent 流式响应"""
65+
rag_context = ""
66+
67+
if score > 0.5:
68+
from app.rag.rag_service import RagService
69+
70+
# RAG 管线与 SSE 推送共用的队列
71+
thinking_queue = asyncio.Queue()
72+
rag_done = asyncio.Event()
73+
74+
async def thinking_callback(data: dict):
75+
await thinking_queue.put(data)
76+
77+
async def run_rag_pipeline():
78+
"""在后台执行 RAG 管线,thinking 事件通过队列实时推送"""
79+
try:
80+
rag_service = RagService(user_id, thinking_callback=thinking_callback)
81+
documents = await rag_service.retrieve_document(request.query)
82+
83+
def _format_doc(doc):
84+
if doc.metadata.get("source_type") == "note":
85+
title = doc.metadata.get("title", "无标题")
86+
return f"[来源:笔记《{title}》]\n{doc.page_content}"
87+
else:
88+
filename = doc.metadata.get("original_filename", "知识库文档")
89+
return f"[来源:知识库《{filename}》]\n{doc.page_content}"
90+
91+
doc_contents = [_format_doc(doc) for doc in documents]
92+
reordered = await rag_service.reorder_documents(request.query, doc_contents)
93+
nonlocal rag_context
94+
rag_context = "\n\n".join(reordered[:3])
95+
logger.info(f"【RAG前置】检索到 {len(documents)} 个文档,重排序后取前 {min(3, len(reordered))} 个注入 Agent")
96+
except Exception as e:
97+
logger.error(f"【RAG前置】管线执行失败: {e}", exc_info=True)
98+
finally:
99+
rag_done.set()
100+
101+
# 启动 RAG 管线(后台任务)
102+
rag_task = asyncio.create_task(run_rag_pipeline())
103+
104+
# 实时推送 RAG 思考事件:边跑边推,不等管线结束
105+
while not rag_done.is_set() or not thinking_queue.empty():
106+
try:
107+
event = thinking_queue.get_nowait()
108+
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
109+
except asyncio.QueueEmpty:
110+
# 队列暂时为空,等 RAG 管线产出新事件
111+
try:
112+
event = await asyncio.wait_for(thinking_queue.get(), timeout=0.1)
113+
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
114+
except (asyncio.TimeoutError, asyncio.QueueEmpty):
115+
continue
116+
117+
# 确保 RAG 任务完成,再 drain 一次队列防止竞态丢失事件
118+
await rag_task
119+
while not thinking_queue.empty():
120+
event = thinking_queue.get_nowait()
121+
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
122+
123+
# 转发 Agent 流式响应
124+
async for chunk in get_agent_stream_response(
125+
request.query, session_id, user_id, rag_context=rag_context
126+
):
127+
yield chunk
128+
26129
return StreamingResponse(
27-
get_agent_stream_response(request.query, session_id, user_id),
130+
stream_with_rag_thinking(),
28131
media_type="text/event-stream",
29132
headers={
30133
"Cache-Control": "no-cache",

0 commit comments

Comments
 (0)