diff --git a/agentrun/knowledgebase/api/__data_async_template.py b/agentrun/knowledgebase/api/__data_async_template.py index 29db508..796d64b 100644 --- a/agentrun/knowledgebase/api/__data_async_template.py +++ b/agentrun/knowledgebase/api/__data_async_template.py @@ -7,6 +7,7 @@ Dispatches to different implementations based on provider type (ragflow / bailian / adb). """ +import json from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union @@ -267,6 +268,42 @@ def __init__( self.provider_settings = provider_settings self.retrieve_settings = retrieve_settings + @staticmethod + def _normalize_search_filters( + search_filters: Optional[List[Dict[str, Any]]] = None, + ) -> Optional[List[Dict[str, str]]]: + """规范化百炼 SearchFilters 格式 / Normalize Bailian SearchFilters format + + 百炼 API 要求 search_filters 中每个 dict 的值必须是字符串类型。 + 对于 list 类型的值(如 tags 过滤),需转换为 JSON 序列化后的字符串。 + 例如: {"tags": ["0216"]} → {"tags": '["0216"]'} + + Bailian API requires each dict value in search_filters to be a string. + For list-typed values (e.g. tags filter), convert to JSON-serialized string. + e.g. {"tags": ["0216"]} → {"tags": '["0216"]'} + + Args: + search_filters: 原始 search_filters / Raw search_filters + + Returns: + 规范化后的 search_filters / Normalized search_filters + """ + if search_filters is None: + return None + + normalized: List[Dict[str, str]] = [] + for filter_item in search_filters: + normalized_item: Dict[str, str] = {} + for key, value in filter_item.items(): + if isinstance(value, (list, dict)): + normalized_item[key] = json.dumps( + value, ensure_ascii=False + ) + else: + normalized_item[key] = str(value) + normalized.append(normalized_item) + return normalized + async def retrieve_async( self, query: str, @@ -318,8 +355,9 @@ async def retrieve_async( # 添加运行时元数据过滤条件 / Add runtime metadata search filters if search_filters is not None: - request_params["search_filters"] = search_filters - request_params["is_displayed_chunk_content"] = True + request_params["search_filters"] = ( + self._normalize_search_filters(search_filters) + ) # 获取百炼客户端 / Get Bailian client client = self._get_bailian_client(config) diff --git a/agentrun/knowledgebase/api/data.py b/agentrun/knowledgebase/api/data.py index ac8ad8d..14a68cd 100644 --- a/agentrun/knowledgebase/api/data.py +++ b/agentrun/knowledgebase/api/data.py @@ -17,6 +17,7 @@ Dispatches to different implementations based on provider type (ragflow / bailian / adb). """ +import json from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union @@ -76,7 +77,6 @@ async def retrieve_async( Args: query: 查询文本 / Query text config: 配置 / Configuration - metadata_filters: 运行时元数据过滤条件 / Runtime metadata filters Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -96,7 +96,6 @@ def retrieve( Args: query: 查询文本 / Query text config: 配置 / Configuration - metadata_filters: 运行时元数据过滤条件 / Runtime metadata filters Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -393,6 +392,42 @@ def __init__( self.provider_settings = provider_settings self.retrieve_settings = retrieve_settings + @staticmethod + def _normalize_search_filters( + search_filters: Optional[List[Dict[str, Any]]] = None, + ) -> Optional[List[Dict[str, str]]]: + """规范化百炼 SearchFilters 格式 / Normalize Bailian SearchFilters format + + 百炼 API 要求 search_filters 中每个 dict 的值必须是字符串类型。 + 对于 list 类型的值(如 tags 过滤),需转换为 JSON 序列化后的字符串。 + 例如: {"tags": ["0216"]} → {"tags": '["0216"]'} + + Bailian API requires each dict value in search_filters to be a string. + For list-typed values (e.g. tags filter), convert to JSON-serialized string. + e.g. {"tags": ["0216"]} → {"tags": '["0216"]'} + + Args: + search_filters: 原始 search_filters / Raw search_filters + + Returns: + 规范化后的 search_filters / Normalized search_filters + """ + if search_filters is None: + return None + + normalized: List[Dict[str, str]] = [] + for filter_item in search_filters: + normalized_item: Dict[str, str] = {} + for key, value in filter_item.items(): + if isinstance(value, (list, dict)): + normalized_item[key] = json.dumps( + value, ensure_ascii=False + ) + else: + normalized_item[key] = str(value) + normalized.append(normalized_item) + return normalized + async def retrieve_async( self, query: str, @@ -444,8 +479,9 @@ async def retrieve_async( # 添加运行时元数据过滤条件 / Add runtime metadata search filters if search_filters is not None: - request_params["search_filters"] = search_filters - request_params["is_displayed_chunk_content"] = True + request_params["search_filters"] = ( + self._normalize_search_filters(search_filters) + ) # 获取百炼客户端 / Get Bailian client client = self._get_bailian_client(config) @@ -547,8 +583,9 @@ def retrieve( # 添加运行时元数据过滤条件 / Add runtime metadata search filters if search_filters is not None: - request_params["search_filters"] = search_filters - request_params["is_displayed_chunk_content"] = True + request_params["search_filters"] = ( + self._normalize_search_filters(search_filters) + ) # 获取百炼客户端 / Get Bailian client client = self._get_bailian_client(config) @@ -676,9 +713,11 @@ def _build_query_content_request( self.retrieve_settings.rerank_factor ) if self.retrieve_settings.rerank_model is not None: - request_params["rerank_model"] = gpdb_models.QueryContentRequestRerankModel( - name=self.retrieve_settings.rerank_model.name, - instruct=self.retrieve_settings.rerank_model.instruct, + request_params["rerank_model"] = ( + gpdb_models.QueryContentRequestRerankModel( + name=self.retrieve_settings.rerank_model.name, + instruct=self.retrieve_settings.rerank_model.instruct, + ) ) if self.retrieve_settings.recall_window is not None: request_params["recall_window"] = ( @@ -692,7 +731,6 @@ def _build_query_content_request( request_params["hybrid_search_args"] = ( self.retrieve_settings.hybrid_search_args ) - if filter is not None: request_params["filter"] = filter @@ -1157,6 +1195,7 @@ def get_data_api( provider: KnowledgeBaseProvider, knowledge_base_name: str, config: Optional[Config] = None, + provider_settings: Optional[ Union[ RagFlowProviderSettings,