Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions agentrun/knowledgebase/api/__data_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Dispatches to different implementations based on provider type (ragflow / bailian / adb).
"""

import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -267,6 +268,42 @@ def __init__(
self.provider_settings = provider_settings
self.retrieve_settings = retrieve_settings

@staticmethod
def _normalize_search_filters(
search_filters: Optional[List[Dict[str, Any]]] = None,
) -> Optional[List[Dict[str, str]]]:
"""规范化百炼 SearchFilters 格式 / Normalize Bailian SearchFilters format

百炼 API 要求 search_filters 中每个 dict 的值必须是字符串类型。
对于 list 类型的值(如 tags 过滤),需转换为 JSON 序列化后的字符串。
例如: {"tags": ["0216"]} → {"tags": '["0216"]'}

Bailian API requires each dict value in search_filters to be a string.
For list-typed values (e.g. tags filter), convert to JSON-serialized string.
e.g. {"tags": ["0216"]} → {"tags": '["0216"]'}

Args:
search_filters: 原始 search_filters / Raw search_filters

Returns:
规范化后的 search_filters / Normalized search_filters
"""
if search_filters is None:
return None

normalized: List[Dict[str, str]] = []
for filter_item in search_filters:
normalized_item: Dict[str, str] = {}
for key, value in filter_item.items():
if isinstance(value, (list, dict)):
normalized_item[key] = json.dumps(
value, ensure_ascii=False
)
else:
normalized_item[key] = str(value)
normalized.append(normalized_item)
return normalized

async def retrieve_async(
self,
query: str,
Expand Down Expand Up @@ -318,8 +355,9 @@ async def retrieve_async(

# 添加运行时元数据过滤条件 / Add runtime metadata search filters
if search_filters is not None:
request_params["search_filters"] = search_filters
request_params["is_displayed_chunk_content"] = True
request_params["search_filters"] = (
self._normalize_search_filters(search_filters)
)

# 获取百炼客户端 / Get Bailian client
client = self._get_bailian_client(config)
Expand Down
59 changes: 49 additions & 10 deletions agentrun/knowledgebase/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Dispatches to different implementations based on provider type (ragflow / bailian / adb).
"""

import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -76,7 +77,6 @@ async def retrieve_async(
Args:
query: 查询文本 / Query text
config: 配置 / Configuration
metadata_filters: 运行时元数据过滤条件 / Runtime metadata filters

Returns:
Dict[str, Any]: 检索结果 / Retrieval results
Expand All @@ -96,7 +96,6 @@ def retrieve(
Args:
query: 查询文本 / Query text
config: 配置 / Configuration
metadata_filters: 运行时元数据过滤条件 / Runtime metadata filters

Returns:
Dict[str, Any]: 检索结果 / Retrieval results
Expand Down Expand Up @@ -393,6 +392,42 @@ def __init__(
self.provider_settings = provider_settings
self.retrieve_settings = retrieve_settings

@staticmethod
def _normalize_search_filters(
search_filters: Optional[List[Dict[str, Any]]] = None,
) -> Optional[List[Dict[str, str]]]:
"""规范化百炼 SearchFilters 格式 / Normalize Bailian SearchFilters format

百炼 API 要求 search_filters 中每个 dict 的值必须是字符串类型。
对于 list 类型的值(如 tags 过滤),需转换为 JSON 序列化后的字符串。
例如: {"tags": ["0216"]} → {"tags": '["0216"]'}

Bailian API requires each dict value in search_filters to be a string.
For list-typed values (e.g. tags filter), convert to JSON-serialized string.
e.g. {"tags": ["0216"]} → {"tags": '["0216"]'}

Args:
search_filters: 原始 search_filters / Raw search_filters

Returns:
规范化后的 search_filters / Normalized search_filters
"""
if search_filters is None:
return None

normalized: List[Dict[str, str]] = []
for filter_item in search_filters:
normalized_item: Dict[str, str] = {}
for key, value in filter_item.items():
if isinstance(value, (list, dict)):
normalized_item[key] = json.dumps(
value, ensure_ascii=False
)
else:
normalized_item[key] = str(value)
normalized.append(normalized_item)
return normalized

async def retrieve_async(
self,
query: str,
Expand Down Expand Up @@ -444,8 +479,9 @@ async def retrieve_async(

# 添加运行时元数据过滤条件 / Add runtime metadata search filters
if search_filters is not None:
request_params["search_filters"] = search_filters
request_params["is_displayed_chunk_content"] = True
request_params["search_filters"] = (
self._normalize_search_filters(search_filters)
)

# 获取百炼客户端 / Get Bailian client
client = self._get_bailian_client(config)
Expand Down Expand Up @@ -547,8 +583,9 @@ def retrieve(

# 添加运行时元数据过滤条件 / Add runtime metadata search filters
if search_filters is not None:
request_params["search_filters"] = search_filters
request_params["is_displayed_chunk_content"] = True
request_params["search_filters"] = (
self._normalize_search_filters(search_filters)
)

# 获取百炼客户端 / Get Bailian client
client = self._get_bailian_client(config)
Expand Down Expand Up @@ -676,9 +713,11 @@ def _build_query_content_request(
self.retrieve_settings.rerank_factor
)
if self.retrieve_settings.rerank_model is not None:
request_params["rerank_model"] = gpdb_models.QueryContentRequestRerankModel(
name=self.retrieve_settings.rerank_model.name,
instruct=self.retrieve_settings.rerank_model.instruct,
request_params["rerank_model"] = (
gpdb_models.QueryContentRequestRerankModel(
name=self.retrieve_settings.rerank_model.name,
instruct=self.retrieve_settings.rerank_model.instruct,
)
)
if self.retrieve_settings.recall_window is not None:
request_params["recall_window"] = (
Expand All @@ -692,7 +731,6 @@ def _build_query_content_request(
request_params["hybrid_search_args"] = (
self.retrieve_settings.hybrid_search_args
)

if filter is not None:
request_params["filter"] = filter

Expand Down Expand Up @@ -1157,6 +1195,7 @@ def get_data_api(
provider: KnowledgeBaseProvider,
knowledge_base_name: str,
config: Optional[Config] = None,

provider_settings: Optional[
Union[
RagFlowProviderSettings,
Expand Down
Loading