Skip to content

Commit d714724

Browse files
committed
feat: add QfRerankerCredential and QfBgeReranker classes for document reranking
1 parent 517e5f4 commit d714724

File tree

4 files changed

+138
-3
lines changed

4 files changed

+138
-3
lines changed

apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def is_cache_model():
2626

2727
@staticmethod
2828
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
29-
optional_params = {'params': {'voice': 'longxiaochun', 'speech_rate': 1.0}}
29+
optional_params = {'params': {}}
3030
for key, value in model_kwargs.items():
3131
if key not in ['model_id', 'use_local', 'streaming']:
3232
optional_params['params'][key] = value
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Dict
2+
3+
from langchain_core.documents import Document
4+
5+
from common import forms
6+
from common.exception.app_exception import AppApiException
7+
from common.forms import BaseForm
8+
from models_provider.base_model_provider import BaseModelCredential, ValidCode
9+
from django.utils.translation import gettext_lazy as _
10+
from common.utils.logger import maxkb_logger
11+
from models_provider.impl.wenxin_model_provider.model.reranker import QfBgeReranker
12+
13+
14+
class QfRerankerCredential(BaseForm, BaseModelCredential):
15+
api_url = forms.TextInputField('API URL', required=True)
16+
api_key = forms.PasswordInputField('API Key', required=True)
17+
18+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
19+
raise_exception=True):
20+
model_type_list = provider.get_model_type_list()
21+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
22+
raise AppApiException(ValidCode.valid_error.value,
23+
_('{model_type} Model type is not supported').format(model_type=model_type))
24+
25+
for key in ['api_url', 'api_key']:
26+
if key not in model_credential:
27+
if raise_exception:
28+
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
29+
else:
30+
return False
31+
try:
32+
model: QfBgeReranker = provider.get_model(model_type, model_name, model_credential)
33+
test_text = str(_('Hello'))
34+
model.compress_documents([Document(page_content=test_text)], test_text)
35+
except Exception as e:
36+
maxkb_logger.error(f'Exception: {e}', exc_info=True)
37+
if isinstance(e, AppApiException):
38+
raise e
39+
if raise_exception:
40+
raise AppApiException(
41+
ValidCode.valid_error.value,
42+
_('Verification failed, please check whether the parameters are correct: {error}').format(
43+
error=str(e))
44+
)
45+
return False
46+
47+
return True
48+
49+
def encryption_dict(self, model_info: Dict[str, object]):
50+
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import json
2+
from typing import Sequence, Optional, Dict, Any
3+
4+
import requests
5+
from langchain_core.callbacks import Callbacks
6+
from langchain_core.documents import BaseDocumentCompressor, Document
7+
8+
from models_provider.base_model_provider import MaxKBBaseModel
9+
10+
11+
class QfBgeReranker(MaxKBBaseModel, BaseDocumentCompressor):
12+
api_key: str
13+
api_url: str
14+
model: str
15+
params: dict
16+
top_n: int = 3
17+
18+
def __init__(self, **kwargs):
19+
super().__init__(**kwargs)
20+
self.api_key = kwargs.get('api_key')
21+
self.model = kwargs.get('model')
22+
self.params = kwargs.get('params', {})
23+
self.api_url = kwargs.get('api_url')
24+
self.top_n = self.params.get('top_n', 3)
25+
26+
@staticmethod
27+
def is_cache_model():
28+
return False
29+
30+
@staticmethod
31+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
32+
return QfBgeReranker(
33+
model=model_name,
34+
api_key=model_credential.get('api_key'),
35+
api_url=model_credential.get('api_url'),
36+
params=model_kwargs,
37+
)
38+
39+
def compress_documents(
40+
self,
41+
documents: Sequence[Document],
42+
query: str,
43+
callbacks: Optional[Callbacks] = None
44+
) -> Sequence[Document]:
45+
if not documents:
46+
return []
47+
48+
texts = [doc.page_content for doc in documents]
49+
50+
headers = {
51+
"Authorization": f"Bearer {self.api_key}",
52+
"Content-Type": "application/json"
53+
}
54+
top_n = min(self.top_n, len(texts))
55+
payload = {
56+
"model": self.model,
57+
"query": query,
58+
"documents": texts,
59+
"top_n": top_n
60+
}
61+
62+
response = requests.post(f"{self.api_url}/rerank", json=payload, headers=headers)
63+
64+
if response.status_code != 200:
65+
raise RuntimeError(f"千帆 API 请求失败:{response.text}")
66+
67+
res = response.json()
68+
69+
return [
70+
Document(
71+
page_content=item.get('document', ''),
72+
metadata={'relevance_score': item.get('relevance_score')}
73+
)
74+
for item in res.get('results', [])
75+
]

apps/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313
ModelInfoManage
1414
from models_provider.impl.wenxin_model_provider.credential.embedding import QianfanEmbeddingCredential
1515
from models_provider.impl.wenxin_model_provider.credential.llm import WenxinLLMModelCredential
16+
from models_provider.impl.wenxin_model_provider.credential.reranker import QfRerankerCredential
1617
from models_provider.impl.wenxin_model_provider.model.embedding import QianfanEmbeddings
1718
from models_provider.impl.wenxin_model_provider.model.llm import QianfanChatModel
1819
from maxkb.conf import PROJECT_DIR
1920
from django.utils.translation import gettext as _
2021

22+
from models_provider.impl.wenxin_model_provider.model.reranker import QfBgeReranker
23+
2124
win_xin_llm_model_credential = WenxinLLMModelCredential()
2225
qianfan_embedding_credential = QianfanEmbeddingCredential()
26+
qf_reranker_credential = QfRerankerCredential()
2327
model_info_list = [ModelInfo('ERNIE-Bot-4',
2428
_('ERNIE-Bot-4 is a large language model independently developed by Baidu. It covers massive Chinese data and has stronger capabilities in dialogue Q&A, content creation and generation.'),
2529
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
@@ -46,13 +50,19 @@
4650
ModelInfo('bge-large-zh', '', ModelTypeConst.EMBEDDING, qianfan_embedding_credential,
4751
QianfanEmbeddings)
4852
]
49-
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
53+
rerank_model_info_list = [ModelInfo('bce-reranker-base',
54+
_(''),
55+
ModelTypeConst.RERANKER, qf_reranker_credential, QfBgeReranker),
56+
]
57+
model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
5058
ModelInfo('ERNIE-Bot-4',
5159
_('ERNIE-Bot-4 is a large language model independently developed by Baidu. It covers massive Chinese data and has stronger capabilities in dialogue Q&A, content creation and generation.'),
5260
ModelTypeConst.LLM,
5361
win_xin_llm_model_credential,
5462
QianfanChatModel)).append_model_info_list(embedding_model_info_list).append_default_model_info(
55-
embedding_model_info_list[0]).build()
63+
embedding_model_info_list[0]).
64+
append_model_info_list(rerank_model_info_list).append_default_model_info(
65+
rerank_model_info_list[0]).build())
5666

5767

5868
class WenxinModelProvider(IModelProvider):

0 commit comments

Comments
 (0)