1- import httpx
21from typing import List , Dict , Any
2+ import torch
3+ import os
4+ from dotenv import load_dotenv
5+ from sentence_transformers import CrossEncoder
36from app .core .logger_handler import logger
47
8+ # 加载环境变量
9+ load_dotenv ()
10+
11+
12+ def check_and_download_reranker_model () -> None :
13+ """检查并重排序模型,在FastAPI启动时执行"""
14+ LOCAL_MODEL_PATH = os .getenv ("RERANKER_MODEL_PATH" , r"D:\Hugging_Face\models\Qwen3-Reranker-0.6B" )
15+ HF_MODEL_NAME = "Qwen/Qwen3-Reranker-0.6B"
16+
17+ try :
18+ # 检查本地模型是否存在
19+ if os .path .exists (LOCAL_MODEL_PATH ) and os .path .isdir (LOCAL_MODEL_PATH ):
20+ logger .info (f"✅ 检测到本地重排序模型:{ LOCAL_MODEL_PATH } " )
21+ else :
22+ logger .warning (f"⚠️ 本地模型未找到:{ LOCAL_MODEL_PATH } " )
23+ logger .info (f"🔄 开始自动下载模型:{ HF_MODEL_NAME } " )
24+
25+ # 创建模型目录
26+ os .makedirs (LOCAL_MODEL_PATH , exist_ok = True )
27+
28+ # 自动下载模型
29+ device = "cuda" if torch .cuda .is_available () else "cpu"
30+ model = CrossEncoder (
31+ HF_MODEL_NAME ,
32+ max_length = 512 ,
33+ device = device ,
34+ cache_folder = LOCAL_MODEL_PATH
35+ )
36+ logger .info (f"✅ 模型下载完成,使用设备:{ device } " )
37+
38+ except Exception as e :
39+ logger .error (f"❌ 模型检查失败: { str (e )} " )
40+ raise RuntimeError (f"重排序模型检查失败: { str (e )} " )
41+
542
643class ReorderService :
744 """文档重排序服务"""
845
9- @staticmethod
10- async def reorder_documents (query : str , documents : List [str ]) -> Dict [str , Any ]:
46+ def __init__ (self ):
47+ # 从环境变量读取重排序模型路径
48+ self .LOCAL_MODEL_PATH = os .getenv ("RERANKER_MODEL_PATH" , r"D:\Hugging_Face\models\Qwen3-Reranker-0.6B" )
49+ # Hugging Face模型名称
50+ self .HF_MODEL_NAME = "Qwen/Qwen3-Reranker-0.6B"
51+ # 自动选择设备(优先使用GPU)
52+ self .device = "cuda" if torch .cuda .is_available () else "cpu"
53+ # 模型实例(懒加载)
54+ self ._model = None
55+
56+ def _get_model (self ):
57+ """懒加载模型实例"""
58+ if self ._model is None :
59+ logger .info (f"✅ 加载重排序模型:{ self .LOCAL_MODEL_PATH } " )
60+ self ._model = CrossEncoder (
61+ self .LOCAL_MODEL_PATH ,
62+ max_length = 512 ,
63+ device = self .device ,
64+ local_files_only = True
65+ )
66+ logger .info (f"✅ 模型加载成功,使用设备:{ self .device } " )
67+ return self ._model
68+
69+ @property
70+ def model (self ):
71+ """获取模型实例(懒加载)"""
72+ return self ._get_model ()
73+
74+ async def reorder_documents (self , query : str , documents : List [str ]) -> Dict [str , Any ]:
1175 """
1276 对文档进行重排序
1377 :param query: 查询语句
@@ -16,37 +80,40 @@ async def reorder_documents(query: str, documents: List[str]) -> Dict[str, Any]:
1680 {"success": bool, "documents": List[Dict], "error": str}
1781 """
1882 try :
19- async with httpx .AsyncClient () as client :
20- response = await client .post (
21- "http://localhost:8000/api/reorder" ,
22- json = {
23- "query" : query ,
24- "documents" : documents
25- },
26- timeout = 30.0
27- )
28- response .raise_for_status () # 检查响应状态
29- result = response .json ()
30-
31- if result .get ("code" ) == 200 :
32- sorted_docs = result .get ("data" , {}).get ("documents" , [])
33- logger .info (f"【重排序服务】文档重排序成功,返回 { len (sorted_docs )} 个文档" )
34- return {
35- "success" : True ,
36- "documents" : sorted_docs ,
37- "error" : ""
38- }
39- else :
40- error_msg = result .get ("message" , "未知错误" )
41- logger .warning (f"【重排序服务】重排序失败: { error_msg } " )
42- return {
43- "success" : False ,
44- "documents" : [],
45- "error" : error_msg
46- }
83+ if not documents :
84+ return {
85+ "success" : True ,
86+ "documents" : [],
87+ "error" : ""
88+ }
89+
90+ # 构造查询+文档对
91+ pairs = [(query , doc ) for doc in documents ]
92+
93+ # 使用模型进行批量预测(batch_size=1避免padding令牌报错)
94+ scores = self .model .predict (pairs , batch_size = 1 )
95+
96+ # 构建结果列表
97+ scored_documents = []
98+ for doc , score in zip (documents , scores ):
99+ scored_documents .append ({
100+ "document" : doc ,
101+ "similarity" : float (score )
102+ })
103+ logger .info (f"【重排序服务】文档相似度分数: { score :.4f} " )
104+
105+ # 按相似度分数降序排序
106+ sorted_docs = sorted (scored_documents , key = lambda x : x ["similarity" ], reverse = True )
107+ logger .info (f"【重排序服务】文档重排序成功,返回 { len (sorted_docs )} 个文档" )
108+
109+ return {
110+ "success" : True ,
111+ "documents" : sorted_docs ,
112+ "error" : ""
113+ }
47114 except Exception as e :
48115 error_msg = str (e )
49- logger .error (f"【重排序服务】重排序请求失败 : { error_msg } " )
116+ logger .error (f"【重排序服务】重排序失败 : { error_msg } " )
50117 return {
51118 "success" : False ,
52119 "documents" : [],
0 commit comments