Skip to content

Commit 0c9e634

Browse files
author
灵轮
committed
feat: ADB-PG KB Supports RerankModel config
Change-Id: Iea8bb11027d0ebc59035a3ce4c19f64bdb65eab6 Co-developed-by: Qoder <noreply@qoder.com>
1 parent 72bf7cb commit 0c9e634

10 files changed

Lines changed: 84 additions & 0 deletions

File tree

agentrun/knowledgebase/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .knowledgebase import KnowledgeBase
1414
from .model import (
1515
ADBProviderSettings,
16+
ADBRerankModel,
1617
ADBRetrieveSettings,
1718
BailianProviderSettings,
1819
BailianRetrieveSettings,
@@ -64,6 +65,7 @@
6465
"RetrieveSettings",
6566
"RagFlowRetrieveSettings",
6667
"BailianRetrieveSettings",
68+
"ADBRerankModel",
6769
"ADBRetrieveSettings",
6870
"OTSRetrieveSettings",
6971
"OTSDenseVectorSearchConfig",

agentrun/knowledgebase/__knowledgebase_async_template.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .api.data import get_data_api
1616
from .model import (
1717
ADBProviderSettings,
18+
ADBRerankModel,
1819
ADBRetrieveSettings,
1920
BailianProviderSettings,
2021
BailianRetrieveSettings,
@@ -344,6 +345,14 @@ def _get_data_api(self, config: Optional[Config] = None):
344345
rerank_factor=self.retrieve_settings.get(
345346
"RerankFactor"
346347
),
348+
rerank_model=(
349+
ADBRerankModel(
350+
name=self.retrieve_settings.get("RerankModel", {}).get("Name", ""),
351+
instruct=self.retrieve_settings.get("RerankModel", {}).get("Instruct"),
352+
)
353+
if self.retrieve_settings.get("RerankModel")
354+
else None
355+
),
347356
recall_window=self.retrieve_settings.get(
348357
"RecallWindow"
349358
),

agentrun/knowledgebase/api/__data_async_template.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,8 @@ def _build_query_content_request(
405405
"namespace_password": self.provider_settings.namespace_password,
406406
"collection": self.knowledge_base_name,
407407
"region_id": cfg.get_region_id(),
408+
# 固定设置 URL 过期时间为 356 天 / Fixed URL expiration to 356 days
409+
"url_expiration": "356d",
408410
}
409411

410412
# 添加可选的提供商设置 / Add optional provider settings
@@ -423,6 +425,15 @@ def _build_query_content_request(
423425
request_params["rerank_factor"] = (
424426
self.retrieve_settings.rerank_factor
425427
)
428+
if self.retrieve_settings.rerank_model is not None:
429+
rerank_model_params: Dict[str, Any] = {
430+
"Name": self.retrieve_settings.rerank_model.name,
431+
}
432+
if self.retrieve_settings.rerank_model.instruct is not None:
433+
rerank_model_params["Instruct"] = (
434+
self.retrieve_settings.rerank_model.instruct
435+
)
436+
request_params["rerank_model"] = rerank_model_params
426437
if self.retrieve_settings.recall_window is not None:
427438
request_params["recall_window"] = (
428439
self.retrieve_settings.recall_window

agentrun/knowledgebase/api/data.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,8 @@ def _build_query_content_request(
615615
"namespace_password": self.provider_settings.namespace_password,
616616
"collection": self.knowledge_base_name,
617617
"region_id": cfg.get_region_id(),
618+
# 固定设置 URL 过期时间为 356 天 / Fixed URL expiration to 356 days
619+
"url_expiration": "356d",
618620
}
619621

620622
# 添加可选的提供商设置 / Add optional provider settings
@@ -633,6 +635,15 @@ def _build_query_content_request(
633635
request_params["rerank_factor"] = (
634636
self.retrieve_settings.rerank_factor
635637
)
638+
if self.retrieve_settings.rerank_model is not None:
639+
rerank_model_params: Dict[str, Any] = {
640+
"Name": self.retrieve_settings.rerank_model.name,
641+
}
642+
if self.retrieve_settings.rerank_model.instruct is not None:
643+
rerank_model_params["Instruct"] = (
644+
self.retrieve_settings.rerank_model.instruct
645+
)
646+
request_params["rerank_model"] = rerank_model_params
636647
if self.retrieve_settings.recall_window is not None:
637648
request_params["recall_window"] = (
638649
self.retrieve_settings.recall_window

agentrun/knowledgebase/knowledgebase.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .api.data import get_data_api
2626
from .model import (
2727
ADBProviderSettings,
28+
ADBRerankModel,
2829
ADBRetrieveSettings,
2930
BailianProviderSettings,
3031
BailianRetrieveSettings,
@@ -526,6 +527,14 @@ def _get_data_api(self, config: Optional[Config] = None):
526527
rerank_factor=self.retrieve_settings.get(
527528
"RerankFactor"
528529
),
530+
rerank_model=(
531+
ADBRerankModel(
532+
name=self.retrieve_settings.get("RerankModel", {}).get("Name", ""),
533+
instruct=self.retrieve_settings.get("RerankModel", {}).get("Instruct"),
534+
)
535+
if self.retrieve_settings.get("RerankModel")
536+
else None
537+
),
529538
recall_window=self.retrieve_settings.get(
530539
"RecallWindow"
531540
),

agentrun/knowledgebase/model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,22 @@ class ADBProviderSettings(BaseModel):
106106
"""元数据配置,JSON 字符串格式 / Metadata configuration in JSON string format"""
107107

108108

109+
class ADBRerankModel(BaseModel):
110+
"""ADB 重排模型配置 / ADB Rerank Model Configuration
111+
112+
配置重排模型的名称和排序任务类型说明。
113+
Configure the rerank model name and instruct for sorting task type.
114+
"""
115+
116+
name: str
117+
"""重排模型名称,可选值:qwen3-rerank、gte-rerank-v2
118+
Rerank model name, options: qwen3-rerank, gte-rerank-v2"""
119+
instruct: Optional[str] = None
120+
"""排序任务类型说明,仅当 name 为 qwen3-rerank 时可设置,指导模型采用不同的排序策略
121+
Instruct for sorting task type, only available when name is qwen3-rerank,
122+
guides the model to adopt different sorting strategies"""
123+
124+
109125
class ADBRetrieveSettings(BaseModel):
110126
"""ADB 检索设置 / ADB Retrieve Settings
111127
@@ -122,6 +138,9 @@ class ADBRetrieveSettings(BaseModel):
122138
rerank_factor: Optional[float] = None
123139
"""重排序因子,取值范围 1 < RerankFactor <= 5
124140
Re-ranking factor, value range: 1 < RerankFactor <= 5"""
141+
rerank_model: Optional[ADBRerankModel] = None
142+
"""重排模型配置,当启用重排因子时可设置
143+
Rerank model configuration, available when rerank factor is enabled"""
125144
recall_window: Optional[List[int]] = None
126145
"""召回窗口,格式为 [A, B],其中 -10 <= A <= 0,0 <= B <= 10
127146
Recall window, format [A, B] where -10 <= A <= 0, 0 <= B <= 10"""

examples/knowledgebase.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
from agentrun.knowledgebase import (
4747
ADBProviderSettings,
48+
ADBRerankModel,
4849
ADBRetrieveSettings,
4950
BailianProviderSettings,
5051
BailianRetrieveSettings,
@@ -474,6 +475,10 @@ def create_or_get_adb_kb() -> KnowledgeBase:
474475
top_k=10,
475476
use_full_text_retrieval=False, # 仅使用向量检索 / Vector only
476477
rerank_factor=2.0, # 重排序因子 / Rerank factor
478+
rerank_model=ADBRerankModel(
479+
name="qwen3-rerank", # 重排模型名称 / Rerank model name
480+
instruct="按相关性排序", # 排序任务类型说明(仅 qwen3-rerank 支持)/ Instruct (only for qwen3-rerank)
481+
),
477482
),
478483
)
479484
)

tests/unittests/knowledgebase/api/test_data.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from agentrun.knowledgebase.model import (
1616
ADBProviderSettings,
17+
ADBRerankModel,
1718
ADBRetrieveSettings,
1819
BailianProviderSettings,
1920
BailianRetrieveSettings,
@@ -995,6 +996,7 @@ def test_build_query_content_request(self):
995996
assert request.dbinstance_id == "gp-123456"
996997
assert request.namespace == "public"
997998
assert request.collection == "test-kb"
999+
assert request.url_expiration == "356d"
9981000

9991001
@patch.dict(
10001002
os.environ,
@@ -1017,17 +1019,23 @@ def test_build_query_content_request_with_settings(self):
10171019
top_k=10,
10181020
use_full_text_retrieval=True,
10191021
rerank_factor=1.5,
1022+
rerank_model=ADBRerankModel(
1023+
name="qwen3-rerank",
1024+
instruct="按相关性排序",
1025+
),
10201026
recall_window=[-5, 5],
10211027
hybrid_search="RRF",
10221028
hybrid_search_args={"RRF": {"k": 60}},
10231029
),
10241030
)
10251031

10261032
request = api._build_query_content_request("test query")
1033+
assert request.url_expiration == "356d"
10271034
assert request.metrics == "cosine"
10281035
assert request.top_k == 10
10291036
assert request.use_full_text_retrieval is True
10301037
assert request.rerank_factor == 1.5
1038+
assert request.rerank_model is not None
10311039

10321040
@patch.dict(
10331041
os.environ,

tests/unittests/knowledgebase/test_knowledgebase.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ def test_get_data_api_adb_with_raw_dict_settings(self):
702702
"TopK": 10,
703703
"UseFullTextRetrieval": True,
704704
"RerankFactor": 1.5,
705+
"RerankModel": {"Name": "qwen3-rerank", "Instruct": "按相关性排序"},
705706
"RecallWindow": [-5, 5],
706707
"HybridSearch": "RRF",
707708
"HybridSearchArgs": {"RRF": {"k": 60}},

tests/unittests/knowledgebase/test_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from agentrun.knowledgebase.model import (
8+
ADBRerankModel,
89
ADBProviderSettings,
910
ADBRetrieveSettings,
1011
BailianProviderSettings,
@@ -207,13 +208,20 @@ def test_create_adb_retrieve_settings(self):
207208
top_k=10,
208209
use_full_text_retrieval=True,
209210
rerank_factor=1.5,
211+
rerank_model=ADBRerankModel(
212+
name="qwen3-rerank",
213+
instruct="按相关性排序",
214+
),
210215
recall_window=[-5, 5],
211216
hybrid_search="RRF",
212217
hybrid_search_args={"RRF": {"k": 60}},
213218
)
214219
assert settings.top_k == 10
215220
assert settings.use_full_text_retrieval is True
216221
assert settings.rerank_factor == 1.5
222+
assert settings.rerank_model is not None
223+
assert settings.rerank_model.name == "qwen3-rerank"
224+
assert settings.rerank_model.instruct == "按相关性排序"
217225
assert settings.recall_window == [-5, 5]
218226
assert settings.hybrid_search == "RRF"
219227
assert settings.hybrid_search_args == {"RRF": {"k": 60}}
@@ -224,6 +232,7 @@ def test_adb_retrieve_settings_optional(self):
224232
assert settings.top_k is None
225233
assert settings.use_full_text_retrieval is None
226234
assert settings.rerank_factor is None
235+
assert settings.rerank_model is None
227236
assert settings.recall_window is None
228237
assert settings.hybrid_search is None
229238
assert settings.hybrid_search_args is None

0 commit comments

Comments
 (0)