Skip to content

Commit f374963

Browse files
committed
feat:使用 Cross Encoder 进行RAG检索后的重排序
1 parent 38361ae commit f374963

1 file changed

Lines changed: 99 additions & 32 deletions

File tree

backend/app/rag/reorder_service.py

Lines changed: 99 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,77 @@
1-
import httpx
21
from typing import List, Dict, Any
2+
import torch
3+
import os
4+
from dotenv import load_dotenv
5+
from sentence_transformers import CrossEncoder
36
from app.core.logger_handler import logger
47

8+
# 加载环境变量
9+
load_dotenv()
10+
11+
12+
def check_and_download_reranker_model() -> None:
13+
"""检查并重排序模型,在FastAPI启动时执行"""
14+
LOCAL_MODEL_PATH = os.getenv("RERANKER_MODEL_PATH", r"D:\Hugging_Face\models\Qwen3-Reranker-0.6B")
15+
HF_MODEL_NAME = "Qwen/Qwen3-Reranker-0.6B"
16+
17+
try:
18+
# 检查本地模型是否存在
19+
if os.path.exists(LOCAL_MODEL_PATH) and os.path.isdir(LOCAL_MODEL_PATH):
20+
logger.info(f"✅ 检测到本地重排序模型:{LOCAL_MODEL_PATH}")
21+
else:
22+
logger.warning(f"⚠️ 本地模型未找到:{LOCAL_MODEL_PATH}")
23+
logger.info(f"🔄 开始自动下载模型:{HF_MODEL_NAME}")
24+
25+
# 创建模型目录
26+
os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
27+
28+
# 自动下载模型
29+
device = "cuda" if torch.cuda.is_available() else "cpu"
30+
model = CrossEncoder(
31+
HF_MODEL_NAME,
32+
max_length=512,
33+
device=device,
34+
cache_folder=LOCAL_MODEL_PATH
35+
)
36+
logger.info(f"✅ 模型下载完成,使用设备:{device}")
37+
38+
except Exception as e:
39+
logger.error(f"❌ 模型检查失败: {str(e)}")
40+
raise RuntimeError(f"重排序模型检查失败: {str(e)}")
41+
542

643
class ReorderService:
744
"""文档重排序服务"""
845

9-
@staticmethod
10-
async def reorder_documents(query: str, documents: List[str]) -> Dict[str, Any]:
46+
def __init__(self):
47+
# 从环境变量读取重排序模型路径
48+
self.LOCAL_MODEL_PATH = os.getenv("RERANKER_MODEL_PATH", r"D:\Hugging_Face\models\Qwen3-Reranker-0.6B")
49+
# Hugging Face模型名称
50+
self.HF_MODEL_NAME = "Qwen/Qwen3-Reranker-0.6B"
51+
# 自动选择设备(优先使用GPU)
52+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
53+
# 模型实例(懒加载)
54+
self._model = None
55+
56+
def _get_model(self):
57+
"""懒加载模型实例"""
58+
if self._model is None:
59+
logger.info(f"✅ 加载重排序模型:{self.LOCAL_MODEL_PATH}")
60+
self._model = CrossEncoder(
61+
self.LOCAL_MODEL_PATH,
62+
max_length=512,
63+
device=self.device,
64+
local_files_only=True
65+
)
66+
logger.info(f"✅ 模型加载成功,使用设备:{self.device}")
67+
return self._model
68+
69+
@property
70+
def model(self):
71+
"""获取模型实例(懒加载)"""
72+
return self._get_model()
73+
74+
async def reorder_documents(self, query: str, documents: List[str]) -> Dict[str, Any]:
1175
"""
1276
对文档进行重排序
1377
:param query: 查询语句
@@ -16,37 +80,40 @@ async def reorder_documents(query: str, documents: List[str]) -> Dict[str, Any]:
1680
{"success": bool, "documents": List[Dict], "error": str}
1781
"""
1882
try:
19-
async with httpx.AsyncClient() as client:
20-
response = await client.post(
21-
"http://localhost:8000/api/reorder",
22-
json={
23-
"query": query,
24-
"documents": documents
25-
},
26-
timeout=30.0
27-
)
28-
response.raise_for_status() # 检查响应状态
29-
result = response.json()
30-
31-
if result.get("code") == 200:
32-
sorted_docs = result.get("data", {}).get("documents", [])
33-
logger.info(f"【重排序服务】文档重排序成功,返回 {len(sorted_docs)} 个文档")
34-
return {
35-
"success": True,
36-
"documents": sorted_docs,
37-
"error": ""
38-
}
39-
else:
40-
error_msg = result.get("message", "未知错误")
41-
logger.warning(f"【重排序服务】重排序失败: {error_msg}")
42-
return {
43-
"success": False,
44-
"documents": [],
45-
"error": error_msg
46-
}
83+
if not documents:
84+
return {
85+
"success": True,
86+
"documents": [],
87+
"error": ""
88+
}
89+
90+
# 构造查询+文档对
91+
pairs = [(query, doc) for doc in documents]
92+
93+
# 使用模型进行批量预测(batch_size=1避免padding令牌报错)
94+
scores = self.model.predict(pairs, batch_size=1)
95+
96+
# 构建结果列表
97+
scored_documents = []
98+
for doc, score in zip(documents, scores):
99+
scored_documents.append({
100+
"document": doc,
101+
"similarity": float(score)
102+
})
103+
logger.info(f"【重排序服务】文档相似度分数: {score:.4f}")
104+
105+
# 按相似度分数降序排序
106+
sorted_docs = sorted(scored_documents, key=lambda x: x["similarity"], reverse=True)
107+
logger.info(f"【重排序服务】文档重排序成功,返回 {len(sorted_docs)} 个文档")
108+
109+
return {
110+
"success": True,
111+
"documents": sorted_docs,
112+
"error": ""
113+
}
47114
except Exception as e:
48115
error_msg = str(e)
49-
logger.error(f"【重排序服务】重排序请求失败: {error_msg}")
116+
logger.error(f"【重排序服务】重排序失败: {error_msg}")
50117
return {
51118
"success": False,
52119
"documents": [],

0 commit comments

Comments
 (0)