Skip to content

Commit 997abdf

Browse files
authored
Merge pull request #118 from Serverless-Devs/kb-search-filter
feat:Support bailian filter by tags
2 parents 536e1e0 + 94bc7fa commit 997abdf

2 files changed

Lines changed: 89 additions & 12 deletions

File tree

agentrun/knowledgebase/api/__data_async_template.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Dispatches to different implementations based on provider type (ragflow / bailian / adb).
88
"""
99

10+
import json
1011
from abc import ABC, abstractmethod
1112
from typing import Any, Dict, List, Optional, Union
1213

@@ -267,6 +268,42 @@ def __init__(
267268
self.provider_settings = provider_settings
268269
self.retrieve_settings = retrieve_settings
269270

271+
@staticmethod
272+
def _normalize_search_filters(
273+
search_filters: Optional[List[Dict[str, Any]]] = None,
274+
) -> Optional[List[Dict[str, str]]]:
275+
"""规范化百炼 SearchFilters 格式 / Normalize Bailian SearchFilters format
276+
277+
百炼 API 要求 search_filters 中每个 dict 的值必须是字符串类型。
278+
对于 list 类型的值(如 tags 过滤),需转换为 JSON 序列化后的字符串。
279+
例如: {"tags": ["0216"]} → {"tags": '["0216"]'}
280+
281+
Bailian API requires each dict value in search_filters to be a string.
282+
For list-typed values (e.g. tags filter), convert to JSON-serialized string.
283+
e.g. {"tags": ["0216"]} → {"tags": '["0216"]'}
284+
285+
Args:
286+
search_filters: 原始 search_filters / Raw search_filters
287+
288+
Returns:
289+
规范化后的 search_filters / Normalized search_filters
290+
"""
291+
if search_filters is None:
292+
return None
293+
294+
normalized: List[Dict[str, str]] = []
295+
for filter_item in search_filters:
296+
normalized_item: Dict[str, str] = {}
297+
for key, value in filter_item.items():
298+
if isinstance(value, (list, dict)):
299+
normalized_item[key] = json.dumps(
300+
value, ensure_ascii=False
301+
)
302+
else:
303+
normalized_item[key] = str(value)
304+
normalized.append(normalized_item)
305+
return normalized
306+
270307
async def retrieve_async(
271308
self,
272309
query: str,
@@ -318,8 +355,9 @@ async def retrieve_async(
318355

319356
# 添加运行时元数据过滤条件 / Add runtime metadata search filters
320357
if search_filters is not None:
321-
request_params["search_filters"] = search_filters
322-
request_params["is_displayed_chunk_content"] = True
358+
request_params["search_filters"] = (
359+
self._normalize_search_filters(search_filters)
360+
)
323361

324362
# 获取百炼客户端 / Get Bailian client
325363
client = self._get_bailian_client(config)

agentrun/knowledgebase/api/data.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Dispatches to different implementations based on provider type (ragflow / bailian / adb).
1818
"""
1919

20+
import json
2021
from abc import ABC, abstractmethod
2122
from typing import Any, Dict, List, Optional, Union
2223

@@ -76,7 +77,6 @@ async def retrieve_async(
7677
Args:
7778
query: 查询文本 / Query text
7879
config: 配置 / Configuration
79-
metadata_filters: 运行时元数据过滤条件 / Runtime metadata filters
8080
8181
Returns:
8282
Dict[str, Any]: 检索结果 / Retrieval results
@@ -96,7 +96,6 @@ def retrieve(
9696
Args:
9797
query: 查询文本 / Query text
9898
config: 配置 / Configuration
99-
metadata_filters: 运行时元数据过滤条件 / Runtime metadata filters
10099
101100
Returns:
102101
Dict[str, Any]: 检索结果 / Retrieval results
@@ -393,6 +392,42 @@ def __init__(
393392
self.provider_settings = provider_settings
394393
self.retrieve_settings = retrieve_settings
395394

395+
@staticmethod
396+
def _normalize_search_filters(
397+
search_filters: Optional[List[Dict[str, Any]]] = None,
398+
) -> Optional[List[Dict[str, str]]]:
399+
"""规范化百炼 SearchFilters 格式 / Normalize Bailian SearchFilters format
400+
401+
百炼 API 要求 search_filters 中每个 dict 的值必须是字符串类型。
402+
对于 list 类型的值(如 tags 过滤),需转换为 JSON 序列化后的字符串。
403+
例如: {"tags": ["0216"]} → {"tags": '["0216"]'}
404+
405+
Bailian API requires each dict value in search_filters to be a string.
406+
For list-typed values (e.g. tags filter), convert to JSON-serialized string.
407+
e.g. {"tags": ["0216"]} → {"tags": '["0216"]'}
408+
409+
Args:
410+
search_filters: 原始 search_filters / Raw search_filters
411+
412+
Returns:
413+
规范化后的 search_filters / Normalized search_filters
414+
"""
415+
if search_filters is None:
416+
return None
417+
418+
normalized: List[Dict[str, str]] = []
419+
for filter_item in search_filters:
420+
normalized_item: Dict[str, str] = {}
421+
for key, value in filter_item.items():
422+
if isinstance(value, (list, dict)):
423+
normalized_item[key] = json.dumps(
424+
value, ensure_ascii=False
425+
)
426+
else:
427+
normalized_item[key] = str(value)
428+
normalized.append(normalized_item)
429+
return normalized
430+
396431
async def retrieve_async(
397432
self,
398433
query: str,
@@ -444,8 +479,9 @@ async def retrieve_async(
444479

445480
# 添加运行时元数据过滤条件 / Add runtime metadata search filters
446481
if search_filters is not None:
447-
request_params["search_filters"] = search_filters
448-
request_params["is_displayed_chunk_content"] = True
482+
request_params["search_filters"] = (
483+
self._normalize_search_filters(search_filters)
484+
)
449485

450486
# 获取百炼客户端 / Get Bailian client
451487
client = self._get_bailian_client(config)
@@ -547,8 +583,9 @@ def retrieve(
547583

548584
# 添加运行时元数据过滤条件 / Add runtime metadata search filters
549585
if search_filters is not None:
550-
request_params["search_filters"] = search_filters
551-
request_params["is_displayed_chunk_content"] = True
586+
request_params["search_filters"] = (
587+
self._normalize_search_filters(search_filters)
588+
)
552589

553590
# 获取百炼客户端 / Get Bailian client
554591
client = self._get_bailian_client(config)
@@ -676,9 +713,11 @@ def _build_query_content_request(
676713
self.retrieve_settings.rerank_factor
677714
)
678715
if self.retrieve_settings.rerank_model is not None:
679-
request_params["rerank_model"] = gpdb_models.QueryContentRequestRerankModel(
680-
name=self.retrieve_settings.rerank_model.name,
681-
instruct=self.retrieve_settings.rerank_model.instruct,
716+
request_params["rerank_model"] = (
717+
gpdb_models.QueryContentRequestRerankModel(
718+
name=self.retrieve_settings.rerank_model.name,
719+
instruct=self.retrieve_settings.rerank_model.instruct,
720+
)
682721
)
683722
if self.retrieve_settings.recall_window is not None:
684723
request_params["recall_window"] = (
@@ -692,7 +731,6 @@ def _build_query_content_request(
692731
request_params["hybrid_search_args"] = (
693732
self.retrieve_settings.hybrid_search_args
694733
)
695-
696734
if filter is not None:
697735
request_params["filter"] = filter
698736

@@ -1157,6 +1195,7 @@ def get_data_api(
11571195
provider: KnowledgeBaseProvider,
11581196
knowledge_base_name: str,
11591197
config: Optional[Config] = None,
1198+
11601199
provider_settings: Optional[
11611200
Union[
11621201
RagFlowProviderSettings,

0 commit comments

Comments
 (0)