Skip to content

Commit f4f17e6

Browse files
author
灵轮
committed
feat:adb retrieve settings supports filter
Change-Id: I665fa00b9362e58b43295c7167b70ff015ea3ea1 Co-developed-by: Qoder <noreply@qoder.com>
1 parent 99d750e commit f4f17e6

8 files changed

Lines changed: 84 additions & 27 deletions

File tree

agentrun/knowledgebase/__knowledgebase_async_template.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,18 @@ def _get_data_api(self, config: Optional[Config] = None):
276276
converted_provider_settings = None
277277
converted_retrieve_settings = None
278278

279+
# 当 retrieve_settings 被 pydantic Union 匹配到错误的类型时(由于 extra="allow"),
280+
# 从 __pydantic_extra__ 提取原始数据作为 dict 使用
281+
# When retrieve_settings is matched to wrong Union type by pydantic (due to extra="allow"),
282+
# extract raw data from __pydantic_extra__ as dict
283+
if (
284+
self.retrieve_settings is not None
285+
and not isinstance(self.retrieve_settings, dict)
286+
and hasattr(self.retrieve_settings, "__pydantic_extra__")
287+
and self.retrieve_settings.__pydantic_extra__
288+
):
289+
self.retrieve_settings = self.retrieve_settings.__pydantic_extra__
290+
279291
if provider == KnowledgeBaseProvider.BAILIAN:
280292
# 百炼设置 / Bailian settings
281293
if self.provider_settings:
@@ -347,8 +359,12 @@ def _get_data_api(self, config: Optional[Config] = None):
347359
),
348360
rerank_model=(
349361
ADBRerankModel(
350-
name=self.retrieve_settings.get("RerankModel", {}).get("Name", ""),
351-
instruct=self.retrieve_settings.get("RerankModel", {}).get("Instruct"),
362+
name=self.retrieve_settings.get(
363+
"RerankModel", {}
364+
).get("Name", ""),
365+
instruct=self.retrieve_settings.get(
366+
"RerankModel", {}
367+
).get("Instruct"),
352368
)
353369
if self.retrieve_settings.get("RerankModel")
354370
else None
@@ -362,6 +378,7 @@ def _get_data_api(self, config: Optional[Config] = None):
362378
hybrid_search_args=self.retrieve_settings.get(
363379
"HybridSearchArgs"
364380
),
381+
filter=self.retrieve_settings.get("Filter"),
365382
)
366383

367384
elif provider == KnowledgeBaseProvider.OTS:

agentrun/knowledgebase/api/__data_async_template.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -426,14 +426,12 @@ def _build_query_content_request(
426426
self.retrieve_settings.rerank_factor
427427
)
428428
if self.retrieve_settings.rerank_model is not None:
429-
rerank_model_params: Dict[str, Any] = {
430-
"Name": self.retrieve_settings.rerank_model.name,
431-
}
432-
if self.retrieve_settings.rerank_model.instruct is not None:
433-
rerank_model_params["Instruct"] = (
434-
self.retrieve_settings.rerank_model.instruct
429+
request_params["rerank_model"] = (
430+
gpdb_models.QueryContentRequestRerankModel(
431+
name=self.retrieve_settings.rerank_model.name,
432+
instruct=self.retrieve_settings.rerank_model.instruct,
435433
)
436-
request_params["rerank_model"] = rerank_model_params
434+
)
437435
if self.retrieve_settings.recall_window is not None:
438436
request_params["recall_window"] = (
439437
self.retrieve_settings.recall_window
@@ -446,6 +444,8 @@ def _build_query_content_request(
446444
request_params["hybrid_search_args"] = (
447445
self.retrieve_settings.hybrid_search_args
448446
)
447+
if self.retrieve_settings.filter is not None:
448+
request_params["filter"] = self.retrieve_settings.filter
449449

450450
return gpdb_models.QueryContentRequest(**request_params)
451451

agentrun/knowledgebase/api/data.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ async def retrieve_async(
8181
"""
8282
raise NotImplementedError("Subclasses must implement retrieve_async")
8383

84+
8485
@abstractmethod
8586
def retrieve(
8687
self,
@@ -175,7 +176,9 @@ def _get_api_key(self, config: Optional[Config] = None) -> str:
175176

176177
from agentrun.credential import Credential
177178

178-
credential = Credential.get_by_name(self.credential_name, config=config)
179+
credential = Credential.get_by_name(
180+
self.credential_name, config=config
181+
)
179182
if not credential.credential_secret:
180183
raise ValueError(
181184
f"Credential '{self.credential_name}' has no secret configured"
@@ -282,6 +285,7 @@ async def retrieve_async(
282285
"error": True,
283286
}
284287

288+
285289
def retrieve(
286290
self,
287291
query: str,
@@ -315,7 +319,9 @@ def retrieve(
315319
body = self._build_request_body(query)
316320

317321
# 发送请求 / Send request
318-
with httpx.Client(timeout=self.config.get_timeout()) as client:
322+
with httpx.Client(
323+
timeout=self.config.get_timeout()
324+
) as client:
319325
response = client.post(url, json=body, headers=headers)
320326
response.raise_for_status()
321327
result = response.json()
@@ -467,6 +473,7 @@ async def retrieve_async(
467473
"error": True,
468474
}
469475

476+
470477
def retrieve(
471478
self,
472479
query: str,
@@ -636,14 +643,10 @@ def _build_query_content_request(
636643
self.retrieve_settings.rerank_factor
637644
)
638645
if self.retrieve_settings.rerank_model is not None:
639-
rerank_model_params: Dict[str, Any] = {
640-
"Name": self.retrieve_settings.rerank_model.name,
641-
}
642-
if self.retrieve_settings.rerank_model.instruct is not None:
643-
rerank_model_params["Instruct"] = (
644-
self.retrieve_settings.rerank_model.instruct
645-
)
646-
request_params["rerank_model"] = rerank_model_params
646+
request_params["rerank_model"] = gpdb_models.QueryContentRequestRerankModel(
647+
name=self.retrieve_settings.rerank_model.name,
648+
instruct=self.retrieve_settings.rerank_model.instruct,
649+
)
647650
if self.retrieve_settings.recall_window is not None:
648651
request_params["recall_window"] = (
649652
self.retrieve_settings.recall_window
@@ -656,6 +659,8 @@ def _build_query_content_request(
656659
request_params["hybrid_search_args"] = (
657660
self.retrieve_settings.hybrid_search_args
658661
)
662+
if self.retrieve_settings.filter is not None:
663+
request_params["filter"] = self.retrieve_settings.filter
659664

660665
return gpdb_models.QueryContentRequest(**request_params)
661666

@@ -768,6 +773,7 @@ async def retrieve_async(
768773
"error": True,
769774
}
770775

776+
771777
def retrieve(
772778
self,
773779
query: str,
@@ -1039,6 +1045,7 @@ async def retrieve_async(
10391045
"error": True,
10401046
}
10411047

1048+
10421049
def retrieve(
10431050
self,
10441051
query: str,

agentrun/knowledgebase/knowledgebase.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def create(
109109
Returns:
110110
KnowledgeBase: 创建的知识库对象 / Created knowledge base object
111111
"""
112-
return cls.__get_client(config=config).create(input, config=config)
112+
return cls.__get_client(config=config).create(
113+
input, config=config
114+
)
113115

114116
@classmethod
115117
async def delete_by_name_async(
@@ -357,7 +359,9 @@ def delete(self, config: Optional[Config] = None):
357359
"knowledge_base_name is required to delete a KnowledgeBase"
358360
)
359361

360-
return self.delete_by_name(self.knowledge_base_name, config=config)
362+
return self.delete_by_name(
363+
self.knowledge_base_name, config=config
364+
)
361365

362366
async def get_async(self, config: Optional[Config] = None):
363367
"""刷新知识库信息(异步)/ Refresh knowledge base info asynchronously
@@ -394,7 +398,9 @@ def get(self, config: Optional[Config] = None):
394398
"knowledge_base_name is required to refresh a KnowledgeBase"
395399
)
396400

397-
result = self.get_by_name(self.knowledge_base_name, config=config)
401+
result = self.get_by_name(
402+
self.knowledge_base_name, config=config
403+
)
398404
self.update_self(result)
399405

400406
return self
@@ -458,6 +464,18 @@ def _get_data_api(self, config: Optional[Config] = None):
458464
converted_provider_settings = None
459465
converted_retrieve_settings = None
460466

467+
# 当 retrieve_settings 被 pydantic Union 匹配到错误的类型时(由于 extra="allow"),
468+
# 从 __pydantic_extra__ 提取原始数据作为 dict 使用
469+
# When retrieve_settings is matched to wrong Union type by pydantic (due to extra="allow"),
470+
# extract raw data from __pydantic_extra__ as dict
471+
if (
472+
self.retrieve_settings is not None
473+
and not isinstance(self.retrieve_settings, dict)
474+
and hasattr(self.retrieve_settings, "__pydantic_extra__")
475+
and self.retrieve_settings.__pydantic_extra__
476+
):
477+
self.retrieve_settings = self.retrieve_settings.__pydantic_extra__
478+
461479
if provider == KnowledgeBaseProvider.BAILIAN:
462480
# 百炼设置 / Bailian settings
463481
if self.provider_settings:
@@ -544,6 +562,9 @@ def _get_data_api(self, config: Optional[Config] = None):
544562
hybrid_search_args=self.retrieve_settings.get(
545563
"HybridSearchArgs"
546564
),
565+
filter=self.retrieve_settings.get(
566+
"Filter"
567+
),
547568
)
548569

549570
elif provider == KnowledgeBaseProvider.OTS:
@@ -905,19 +926,21 @@ def multi_retrieve(
905926
"""
906927
# 1. 根据 knowledge_base_names 并发获取各知识库配置(安全方式)
907928
# Fetch all knowledge bases concurrently by name (safely)
908-
knowledge_base_results = [
929+
knowledge_base_results = ([
909930
cls._safe_get_kb(name, config=config)
910931
for name in knowledge_base_names
911-
]
932+
])
912933

913934
# 2. 并发执行各知识库的检索(安全方式)
914935
# Execute retrieval for each knowledge base concurrently (safely)
915-
retrieve_results = [
916-
cls._safe_retrieve_kb(kb_name, kb_or_error, query, config=config)
936+
retrieve_results = ([
937+
cls._safe_retrieve_kb(
938+
kb_name, kb_or_error, query, config=config
939+
)
917940
for kb_name, kb_or_error in zip(
918941
knowledge_base_names, knowledge_base_results
919942
)
920-
]
943+
])
921944

922945
# 3. 合并返回结果,按知识库名称分组
923946
# Merge results, grouped by knowledge base name

agentrun/knowledgebase/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ class ADBRetrieveSettings(BaseModel):
150150
hybrid_search_args: Optional[Dict[str, Any]] = None
151151
"""混合检索算法参数,如 {"RRF": {"k": 60}} 或 {"Weight": {"alpha": 0.5}}
152152
Hybrid search algorithm parameters"""
153+
filter: Optional[str] = None
154+
"""过滤条件,SQL WHERE 格式,如 "category = 'tech' AND score > 0.5"
155+
Filter condition in SQL WHERE format"""
153156

154157

155158
# =============================================================================

tests/unittests/knowledgebase/api/test_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,7 @@ def test_build_query_content_request_with_settings(self):
10261026
recall_window=[-5, 5],
10271027
hybrid_search="RRF",
10281028
hybrid_search_args={"RRF": {"k": 60}},
1029+
filter="category = 'tech' AND score > 0.5",
10291030
),
10301031
)
10311032

@@ -1036,6 +1037,7 @@ def test_build_query_content_request_with_settings(self):
10361037
assert request.use_full_text_retrieval is True
10371038
assert request.rerank_factor == 1.5
10381039
assert request.rerank_model is not None
1040+
assert request.filter == "category = 'tech' AND score > 0.5"
10391041

10401042
@patch.dict(
10411043
os.environ,

tests/unittests/knowledgebase/test_knowledgebase.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,13 +706,15 @@ def test_get_data_api_adb_with_raw_dict_settings(self):
706706
"RecallWindow": [-5, 5],
707707
"HybridSearch": "RRF",
708708
"HybridSearchArgs": {"RRF": {"k": 60}},
709+
"Filter": "category = 'tech'",
709710
},
710711
)
711712

712713
from agentrun.knowledgebase.api.data import ADBDataAPI
713714

714715
data_api = kb._get_data_api()
715716
assert isinstance(data_api, ADBDataAPI)
717+
assert data_api.retrieve_settings.filter == "category = 'tech'"
716718

717719
def test_get_data_api_without_provider(self):
718720
"""测试获取数据链路 API(无提供商)"""

tests/unittests/knowledgebase/test_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def test_create_adb_retrieve_settings(self):
215215
recall_window=[-5, 5],
216216
hybrid_search="RRF",
217217
hybrid_search_args={"RRF": {"k": 60}},
218+
filter="category = 'tech'",
218219
)
219220
assert settings.top_k == 10
220221
assert settings.use_full_text_retrieval is True
@@ -225,6 +226,7 @@ def test_create_adb_retrieve_settings(self):
225226
assert settings.recall_window == [-5, 5]
226227
assert settings.hybrid_search == "RRF"
227228
assert settings.hybrid_search_args == {"RRF": {"k": 60}}
229+
assert settings.filter == "category = 'tech'"
228230

229231
def test_adb_retrieve_settings_optional(self):
230232
"""测试 ADB 检索设置可选字段"""
@@ -236,6 +238,7 @@ def test_adb_retrieve_settings_optional(self):
236238
assert settings.recall_window is None
237239
assert settings.hybrid_search is None
238240
assert settings.hybrid_search_args is None
241+
assert settings.filter is None
239242

240243
def test_adb_retrieve_settings_weight_hybrid(self):
241244
"""测试 ADB 检索设置加权混合检索"""

0 commit comments

Comments
 (0)