diff --git a/agentrun/integration/builtin/knowledgebase.py b/agentrun/integration/builtin/knowledgebase.py index 4fa5fa8..bdd6b62 100644 --- a/agentrun/integration/builtin/knowledgebase.py +++ b/agentrun/integration/builtin/knowledgebase.py @@ -79,7 +79,11 @@ def __init__( " grouped by knowledge base name." ), ) - def search_document(self, query: str) -> Dict[str, Any]: + def search_document( + self, + query: str, + metadata_filters: Optional[Any] = None, + ) -> Dict[str, Any]: """检索文档 / Search documents 根据查询文本从配置的知识库中检索相关文档。 @@ -87,6 +91,7 @@ def search_document(self, query: str) -> Dict[str, Any]: Args: query: 查询文本 / Query text + metadata_filters: 元数据过滤条件 / Metadata filter conditions Returns: Dict[str, Any]: 检索结果,包含各知识库的检索结果 / @@ -96,6 +101,7 @@ def search_document(self, query: str) -> Dict[str, Any]: query=query, knowledge_base_names=self.knowledge_base_names, config=self.config, + metadata_filters=metadata_filters, ) diff --git a/agentrun/knowledgebase/api/__data_async_template.py b/agentrun/knowledgebase/api/__data_async_template.py index 5277c3f..29db508 100644 --- a/agentrun/knowledgebase/api/__data_async_template.py +++ b/agentrun/knowledgebase/api/__data_async_template.py @@ -59,6 +59,7 @@ async def retrieve_async( self, query: str, config: Optional[Config] = None, + metadata_filters: Optional[Any] = None, ) -> Dict[str, Any]: """检索知识库(异步)/ Retrieve from knowledge base (async) @@ -129,11 +130,14 @@ async def _get_api_key_async(self, config: Optional[Config] = None) -> str: ) return credential.credential_secret - def _build_request_body(self, query: str) -> Dict[str, Any]: + def _build_request_body( + self, query: str, metadata_condition: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """构建请求体 / Build request body Args: query: 查询文本 / Query text + metadata_condition: 元数据过滤条件 / Metadata condition filter Returns: Dict[str, Any]: 请求体 / Request body @@ -163,18 +167,24 @@ def _build_request_body(self, query: str) -> Dict[str, Any]: if self.retrieve_settings.cross_languages is not None: body["cross_languages"] = self.retrieve_settings.cross_languages + # 添加运行时元数据过滤条件 / Add runtime metadata condition filter + if metadata_condition is not None: + body["metadata_condition"] = metadata_condition + return body async def retrieve_async( self, query: str, config: Optional[Config] = None, + metadata_condition: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """RagFlow 检索(异步)/ RagFlow retrieval (async) Args: query: 查询文本 / Query text config: 配置 / Configuration + metadata_condition: 运行时元数据过滤条件 / Runtime metadata condition filter Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -195,7 +205,7 @@ async def retrieve_async( "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } - body = self._build_request_body(query) + body = self._build_request_body(query, metadata_condition=metadata_condition) # 发送请求 / Send request async with httpx.AsyncClient( @@ -261,12 +271,14 @@ async def retrieve_async( self, query: str, config: Optional[Config] = None, + search_filters: Optional[List[Dict[str, str]]] = None, ) -> Dict[str, Any]: """百炼检索(异步)/ Bailian retrieval (async) Args: query: 查询文本 / Query text config: 配置 / Configuration + search_filters: 运行时元数据过滤条件 / Runtime metadata search filters Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -304,6 +316,11 @@ async def retrieve_async( self.retrieve_settings.rerank_top_n ) + # 添加运行时元数据过滤条件 / Add runtime metadata search filters + if search_filters is not None: + request_params["search_filters"] = search_filters + request_params["is_displayed_chunk_content"] = True + # 获取百炼客户端 / Get Bailian client client = self._get_bailian_client(config) @@ -381,13 +398,17 @@ def __init__( self.retrieve_settings = retrieve_settings def _build_query_content_request( - self, query: str, config: Optional[Config] = None + self, + query: str, + config: Optional[Config] = None, + filter: Optional[str] = None, ) -> gpdb_models.QueryContentRequest: """构建 QueryContent 请求 / Build QueryContent request Args: query: 查询文本 / Query text config: 配置 / Configuration + filter: 运行时 SQL WHERE 过滤条件 / Runtime SQL WHERE filter Returns: QueryContentRequest: GPDB QueryContent 请求对象 @@ -444,8 +465,8 @@ def _build_query_content_request( request_params["hybrid_search_args"] = ( self.retrieve_settings.hybrid_search_args ) - if self.retrieve_settings.filter is not None: - request_params["filter"] = self.retrieve_settings.filter + if filter is not None: + request_params["filter"] = filter return gpdb_models.QueryContentRequest(**request_params) @@ -513,6 +534,7 @@ async def retrieve_async( self, query: str, config: Optional[Config] = None, + filter: Optional[str] = None, ) -> Dict[str, Any]: """ADB 检索(异步)/ ADB retrieval asynchronously @@ -522,6 +544,7 @@ async def retrieve_async( Args: query: 查询文本 / Query text config: 配置 / Configuration + filter: 运行时 SQL WHERE 过滤条件 / Runtime SQL WHERE filter Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -536,7 +559,7 @@ async def retrieve_async( client = self._get_gpdb_client(config) # 构建请求 / Build request - request = self._build_query_content_request(query, config) + request = self._build_query_content_request(query, config, filter=filter) logger.debug(f"ADB QueryContent request: {request}") # 调用 QueryContent API / Call QueryContent API @@ -611,85 +634,96 @@ def _build_agent_storage_client( ots_instance_name=self.provider_settings.ots_instance_name, ) - def _build_retrieval_configuration(self) -> Optional[Dict[str, Any]]: + def _build_retrieval_configuration( + self, filter: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: """将 OTSRetrieveSettings 转换为 tablestore-agent-storage 的 dict 格式 Convert OTSRetrieveSettings to tablestore-agent-storage dict format + Args: + filter: 运行时过滤条件 / Runtime filter + Returns: Optional[Dict[str, Any]]: 检索配置字典 / Retrieval configuration dict """ - if self.retrieve_settings is None: + if self.retrieve_settings is None and filter is None: return None config: Dict[str, Any] = {} - if self.retrieve_settings.search_type is not None: - config["searchType"] = self.retrieve_settings.search_type - - if self.retrieve_settings.dense_vector_search_configuration is not None: - dvsc = self.retrieve_settings.dense_vector_search_configuration - dv_config: Dict[str, Any] = {} - if dvsc.number_of_results is not None: - dv_config["numberOfResults"] = dvsc.number_of_results - config["denseVectorSearchConfiguration"] = dv_config - - if self.retrieve_settings.full_text_search_configuration is not None: - ftsc = self.retrieve_settings.full_text_search_configuration - ft_config: Dict[str, Any] = {} - if ftsc.number_of_results is not None: - ft_config["numberOfResults"] = ftsc.number_of_results - config["fullTextSearchConfiguration"] = ft_config - - if self.retrieve_settings.reranking_configuration is not None: - rc = self.retrieve_settings.reranking_configuration - rr_config: Dict[str, Any] = {} - - if rc.type is not None: - rr_config["type"] = rc.type - if rc.number_of_results is not None: - rr_config["numberOfResults"] = rc.number_of_results - - if rc.rrf_configuration is not None: - rrf: Dict[str, Any] = {} - if rc.rrf_configuration.dense_vector_search_weight is not None: - rrf["denseVectorSearchWeight"] = ( - rc.rrf_configuration.dense_vector_search_weight - ) - if rc.rrf_configuration.full_text_search_weight is not None: - rrf["fullTextSearchWeight"] = ( - rc.rrf_configuration.full_text_search_weight - ) - if rc.rrf_configuration.k is not None: - rrf["k"] = rc.rrf_configuration.k - rr_config["rrfConfiguration"] = rrf - - if rc.weight_configuration is not None: - wc: Dict[str, Any] = {} - if ( - rc.weight_configuration.dense_vector_search_weight - is not None - ): - wc["denseVectorSearchWeight"] = ( + if self.retrieve_settings is not None: + if self.retrieve_settings.search_type is not None: + config["searchType"] = self.retrieve_settings.search_type + + if self.retrieve_settings.dense_vector_search_configuration is not None: + dvsc = self.retrieve_settings.dense_vector_search_configuration + dv_config: Dict[str, Any] = {} + if dvsc.number_of_results is not None: + dv_config["numberOfResults"] = dvsc.number_of_results + config["denseVectorSearchConfiguration"] = dv_config + + if self.retrieve_settings.full_text_search_configuration is not None: + ftsc = self.retrieve_settings.full_text_search_configuration + ft_config: Dict[str, Any] = {} + if ftsc.number_of_results is not None: + ft_config["numberOfResults"] = ftsc.number_of_results + config["fullTextSearchConfiguration"] = ft_config + + if self.retrieve_settings.reranking_configuration is not None: + rc = self.retrieve_settings.reranking_configuration + rr_config: Dict[str, Any] = {} + + if rc.type is not None: + rr_config["type"] = rc.type + if rc.number_of_results is not None: + rr_config["numberOfResults"] = rc.number_of_results + + if rc.rrf_configuration is not None: + rrf: Dict[str, Any] = {} + if rc.rrf_configuration.dense_vector_search_weight is not None: + rrf["denseVectorSearchWeight"] = ( + rc.rrf_configuration.dense_vector_search_weight + ) + if rc.rrf_configuration.full_text_search_weight is not None: + rrf["fullTextSearchWeight"] = ( + rc.rrf_configuration.full_text_search_weight + ) + if rc.rrf_configuration.k is not None: + rrf["k"] = rc.rrf_configuration.k + rr_config["rrfConfiguration"] = rrf + + if rc.weight_configuration is not None: + wc: Dict[str, Any] = {} + if ( rc.weight_configuration.dense_vector_search_weight - ) - if rc.weight_configuration.full_text_search_weight is not None: - wc["fullTextSearchWeight"] = ( - rc.weight_configuration.full_text_search_weight - ) - rr_config["weightConfiguration"] = wc - - if rc.model_configuration is not None: - mc: Dict[str, Any] = {} - if rc.model_configuration.provider is not None: - mc["provider"] = rc.model_configuration.provider - if rc.model_configuration.model is not None: - mc["model"] = rc.model_configuration.model - rr_config["modelConfiguration"] = mc + is not None + ): + wc["denseVectorSearchWeight"] = ( + rc.weight_configuration.dense_vector_search_weight + ) + if rc.weight_configuration.full_text_search_weight is not None: + wc["fullTextSearchWeight"] = ( + rc.weight_configuration.full_text_search_weight + ) + rr_config["weightConfiguration"] = wc + + if rc.model_configuration is not None: + mc: Dict[str, Any] = {} + if rc.model_configuration.provider is not None: + mc["provider"] = rc.model_configuration.provider + if rc.model_configuration.model is not None: + mc["model"] = rc.model_configuration.model + rr_config["modelConfiguration"] = mc + + config["rerankingConfiguration"] = rr_config - config["rerankingConfiguration"] = rr_config + if self.retrieve_settings.filter is not None: + config["filter"] = self.retrieve_settings.filter - if self.retrieve_settings.filter is not None: - config["filter"] = self.retrieve_settings.filter + # 运行时 filter 优先级高于 retrieve_settings.filter + # Runtime filter takes precedence over retrieve_settings.filter + if filter is not None: + config["filter"] = filter return config if config else None @@ -731,6 +765,7 @@ async def retrieve_async( self, query: str, config: Optional[Config] = None, + filter: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """OTS 检索(异步)/ OTS retrieval asynchronously @@ -740,6 +775,7 @@ async def retrieve_async( Args: query: 查询文本 / Query text config: 配置 / Configuration + filter: 运行时过滤条件 / Runtime filter Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -752,7 +788,7 @@ async def retrieve_async( client = self._build_agent_storage_client(config) - retrieval_config = self._build_retrieval_configuration() + retrieval_config = self._build_retrieval_configuration(filter=filter) request: Dict[str, Any] = { "knowledgeBaseName": self.knowledge_base_name, @@ -785,6 +821,7 @@ def get_data_api( provider: KnowledgeBaseProvider, knowledge_base_name: str, config: Optional[Config] = None, + provider_settings: Optional[ Union[ RagFlowProviderSettings, diff --git a/agentrun/knowledgebase/api/data.py b/agentrun/knowledgebase/api/data.py index 74517b5..ac8ad8d 100644 --- a/agentrun/knowledgebase/api/data.py +++ b/agentrun/knowledgebase/api/data.py @@ -69,12 +69,14 @@ async def retrieve_async( self, query: str, config: Optional[Config] = None, + metadata_filters: Optional[Any] = None, ) -> Dict[str, Any]: """检索知识库(异步)/ Retrieve from knowledge base (async) Args: query: 查询文本 / Query text config: 配置 / Configuration + metadata_filters: 运行时元数据过滤条件 / Runtime metadata filters Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -87,12 +89,14 @@ def retrieve( self, query: str, config: Optional[Config] = None, + metadata_filters: Optional[Any] = None, ) -> Dict[str, Any]: """检索知识库(同步)/ Retrieve from knowledge base (async) Args: query: 查询文本 / Query text config: 配置 / Configuration + metadata_filters: 运行时元数据过滤条件 / Runtime metadata filters Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -185,11 +189,14 @@ def _get_api_key(self, config: Optional[Config] = None) -> str: ) return credential.credential_secret - def _build_request_body(self, query: str) -> Dict[str, Any]: + def _build_request_body( + self, query: str, metadata_condition: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """构建请求体 / Build request body Args: query: 查询文本 / Query text + metadata_condition: 元数据过滤条件 / Metadata condition filter Returns: Dict[str, Any]: 请求体 / Request body @@ -219,18 +226,24 @@ def _build_request_body(self, query: str) -> Dict[str, Any]: if self.retrieve_settings.cross_languages is not None: body["cross_languages"] = self.retrieve_settings.cross_languages + # 添加运行时元数据过滤条件 / Add runtime metadata condition filter + if metadata_condition is not None: + body["metadata_condition"] = metadata_condition + return body async def retrieve_async( self, query: str, config: Optional[Config] = None, + metadata_condition: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """RagFlow 检索(异步)/ RagFlow retrieval (async) Args: query: 查询文本 / Query text config: 配置 / Configuration + metadata_condition: 运行时元数据过滤条件 / Runtime metadata condition filter Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -251,7 +264,7 @@ async def retrieve_async( "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } - body = self._build_request_body(query) + body = self._build_request_body(query, metadata_condition=metadata_condition) # 发送请求 / Send request async with httpx.AsyncClient( @@ -290,12 +303,14 @@ def retrieve( self, query: str, config: Optional[Config] = None, + metadata_condition: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """RagFlow 检索(同步)/ RagFlow retrieval (async) Args: query: 查询文本 / Query text config: 配置 / Configuration + metadata_condition: 运行时元数据过滤条件 / Runtime metadata condition filter Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -316,7 +331,7 @@ def retrieve( "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } - body = self._build_request_body(query) + body = self._build_request_body(query, metadata_condition=metadata_condition) # 发送请求 / Send request with httpx.Client( @@ -382,12 +397,14 @@ async def retrieve_async( self, query: str, config: Optional[Config] = None, + search_filters: Optional[List[Dict[str, str]]] = None, ) -> Dict[str, Any]: """百炼检索(异步)/ Bailian retrieval (async) Args: query: 查询文本 / Query text config: 配置 / Configuration + search_filters: 运行时元数据过滤条件 / Runtime metadata search filters Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -425,6 +442,11 @@ async def retrieve_async( self.retrieve_settings.rerank_top_n ) + # 添加运行时元数据过滤条件 / Add runtime metadata search filters + if search_filters is not None: + request_params["search_filters"] = search_filters + request_params["is_displayed_chunk_content"] = True + # 获取百炼客户端 / Get Bailian client client = self._get_bailian_client(config) @@ -478,12 +500,14 @@ def retrieve( self, query: str, config: Optional[Config] = None, + search_filters: Optional[List[Dict[str, str]]] = None, ) -> Dict[str, Any]: """百炼检索(同步)/ Bailian retrieval (async) Args: query: 查询文本 / Query text config: 配置 / Configuration + search_filters: 运行时元数据过滤条件 / Runtime metadata search filters Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -521,6 +545,11 @@ def retrieve( self.retrieve_settings.rerank_top_n ) + # 添加运行时元数据过滤条件 / Add runtime metadata search filters + if search_filters is not None: + request_params["search_filters"] = search_filters + request_params["is_displayed_chunk_content"] = True + # 获取百炼客户端 / Get Bailian client client = self._get_bailian_client(config) @@ -598,13 +627,17 @@ def __init__( self.retrieve_settings = retrieve_settings def _build_query_content_request( - self, query: str, config: Optional[Config] = None + self, + query: str, + config: Optional[Config] = None, + filter: Optional[str] = None, ) -> gpdb_models.QueryContentRequest: """构建 QueryContent 请求 / Build QueryContent request Args: query: 查询文本 / Query text config: 配置 / Configuration + filter: 运行时 SQL WHERE 过滤条件 / Runtime SQL WHERE filter Returns: QueryContentRequest: GPDB QueryContent 请求对象 @@ -659,8 +692,9 @@ def _build_query_content_request( request_params["hybrid_search_args"] = ( self.retrieve_settings.hybrid_search_args ) - if self.retrieve_settings.filter is not None: - request_params["filter"] = self.retrieve_settings.filter + + if filter is not None: + request_params["filter"] = filter return gpdb_models.QueryContentRequest(**request_params) @@ -728,6 +762,7 @@ async def retrieve_async( self, query: str, config: Optional[Config] = None, + filter: Optional[str] = None, ) -> Dict[str, Any]: """ADB 检索(异步)/ ADB retrieval asynchronously @@ -737,6 +772,7 @@ async def retrieve_async( Args: query: 查询文本 / Query text config: 配置 / Configuration + filter: 运行时 SQL WHERE 过滤条件 / Runtime SQL WHERE filter Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -751,7 +787,7 @@ async def retrieve_async( client = self._get_gpdb_client(config) # 构建请求 / Build request - request = self._build_query_content_request(query, config) + request = self._build_query_content_request(query, config, filter=filter) logger.debug(f"ADB QueryContent request: {request}") # 调用 QueryContent API / Call QueryContent API @@ -778,6 +814,7 @@ def retrieve( self, query: str, config: Optional[Config] = None, + filter: Optional[str] = None, ) -> Dict[str, Any]: """ADB 检索(同步)/ ADB retrieval synchronously @@ -787,6 +824,7 @@ def retrieve( Args: query: 查询文本 / Query text config: 配置 / Configuration + filter: 运行时 SQL WHERE 过滤条件 / Runtime SQL WHERE filter Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -801,7 +839,7 @@ def retrieve( client = self._get_gpdb_client(config) # 构建请求 / Build request - request = self._build_query_content_request(query, config) + request = self._build_query_content_request(query, config, filter=filter) logger.debug(f"ADB QueryContent request: {request}") # 调用 QueryContent API / Call QueryContent API @@ -876,85 +914,96 @@ def _build_agent_storage_client( ots_instance_name=self.provider_settings.ots_instance_name, ) - def _build_retrieval_configuration(self) -> Optional[Dict[str, Any]]: + def _build_retrieval_configuration( + self, filter: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: """将 OTSRetrieveSettings 转换为 tablestore-agent-storage 的 dict 格式 Convert OTSRetrieveSettings to tablestore-agent-storage dict format + Args: + filter: 运行时过滤条件 / Runtime filter + Returns: Optional[Dict[str, Any]]: 检索配置字典 / Retrieval configuration dict """ - if self.retrieve_settings is None: + if self.retrieve_settings is None and filter is None: return None config: Dict[str, Any] = {} - if self.retrieve_settings.search_type is not None: - config["searchType"] = self.retrieve_settings.search_type - - if self.retrieve_settings.dense_vector_search_configuration is not None: - dvsc = self.retrieve_settings.dense_vector_search_configuration - dv_config: Dict[str, Any] = {} - if dvsc.number_of_results is not None: - dv_config["numberOfResults"] = dvsc.number_of_results - config["denseVectorSearchConfiguration"] = dv_config - - if self.retrieve_settings.full_text_search_configuration is not None: - ftsc = self.retrieve_settings.full_text_search_configuration - ft_config: Dict[str, Any] = {} - if ftsc.number_of_results is not None: - ft_config["numberOfResults"] = ftsc.number_of_results - config["fullTextSearchConfiguration"] = ft_config - - if self.retrieve_settings.reranking_configuration is not None: - rc = self.retrieve_settings.reranking_configuration - rr_config: Dict[str, Any] = {} - - if rc.type is not None: - rr_config["type"] = rc.type - if rc.number_of_results is not None: - rr_config["numberOfResults"] = rc.number_of_results - - if rc.rrf_configuration is not None: - rrf: Dict[str, Any] = {} - if rc.rrf_configuration.dense_vector_search_weight is not None: - rrf["denseVectorSearchWeight"] = ( - rc.rrf_configuration.dense_vector_search_weight - ) - if rc.rrf_configuration.full_text_search_weight is not None: - rrf["fullTextSearchWeight"] = ( - rc.rrf_configuration.full_text_search_weight - ) - if rc.rrf_configuration.k is not None: - rrf["k"] = rc.rrf_configuration.k - rr_config["rrfConfiguration"] = rrf - - if rc.weight_configuration is not None: - wc: Dict[str, Any] = {} - if ( - rc.weight_configuration.dense_vector_search_weight - is not None - ): - wc["denseVectorSearchWeight"] = ( + if self.retrieve_settings is not None: + if self.retrieve_settings.search_type is not None: + config["searchType"] = self.retrieve_settings.search_type + + if self.retrieve_settings.dense_vector_search_configuration is not None: + dvsc = self.retrieve_settings.dense_vector_search_configuration + dv_config: Dict[str, Any] = {} + if dvsc.number_of_results is not None: + dv_config["numberOfResults"] = dvsc.number_of_results + config["denseVectorSearchConfiguration"] = dv_config + + if self.retrieve_settings.full_text_search_configuration is not None: + ftsc = self.retrieve_settings.full_text_search_configuration + ft_config: Dict[str, Any] = {} + if ftsc.number_of_results is not None: + ft_config["numberOfResults"] = ftsc.number_of_results + config["fullTextSearchConfiguration"] = ft_config + + if self.retrieve_settings.reranking_configuration is not None: + rc = self.retrieve_settings.reranking_configuration + rr_config: Dict[str, Any] = {} + + if rc.type is not None: + rr_config["type"] = rc.type + if rc.number_of_results is not None: + rr_config["numberOfResults"] = rc.number_of_results + + if rc.rrf_configuration is not None: + rrf: Dict[str, Any] = {} + if rc.rrf_configuration.dense_vector_search_weight is not None: + rrf["denseVectorSearchWeight"] = ( + rc.rrf_configuration.dense_vector_search_weight + ) + if rc.rrf_configuration.full_text_search_weight is not None: + rrf["fullTextSearchWeight"] = ( + rc.rrf_configuration.full_text_search_weight + ) + if rc.rrf_configuration.k is not None: + rrf["k"] = rc.rrf_configuration.k + rr_config["rrfConfiguration"] = rrf + + if rc.weight_configuration is not None: + wc: Dict[str, Any] = {} + if ( rc.weight_configuration.dense_vector_search_weight - ) - if rc.weight_configuration.full_text_search_weight is not None: - wc["fullTextSearchWeight"] = ( - rc.weight_configuration.full_text_search_weight - ) - rr_config["weightConfiguration"] = wc + is not None + ): + wc["denseVectorSearchWeight"] = ( + rc.weight_configuration.dense_vector_search_weight + ) + if rc.weight_configuration.full_text_search_weight is not None: + wc["fullTextSearchWeight"] = ( + rc.weight_configuration.full_text_search_weight + ) + rr_config["weightConfiguration"] = wc + + if rc.model_configuration is not None: + mc: Dict[str, Any] = {} + if rc.model_configuration.provider is not None: + mc["provider"] = rc.model_configuration.provider + if rc.model_configuration.model is not None: + mc["model"] = rc.model_configuration.model + rr_config["modelConfiguration"] = mc + + config["rerankingConfiguration"] = rr_config - if rc.model_configuration is not None: - mc: Dict[str, Any] = {} - if rc.model_configuration.provider is not None: - mc["provider"] = rc.model_configuration.provider - if rc.model_configuration.model is not None: - mc["model"] = rc.model_configuration.model - rr_config["modelConfiguration"] = mc - - config["rerankingConfiguration"] = rr_config + if self.retrieve_settings.filter is not None: + config["filter"] = self.retrieve_settings.filter - if self.retrieve_settings.filter is not None: - config["filter"] = self.retrieve_settings.filter + # 运行时 filter 优先级高于 retrieve_settings.filter + # Runtime filter takes precedence over retrieve_settings.filter + if filter is not None: + config["filter"] = filter return config if config else None @@ -996,6 +1045,7 @@ async def retrieve_async( self, query: str, config: Optional[Config] = None, + filter: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """OTS 检索(异步)/ OTS retrieval asynchronously @@ -1005,6 +1055,7 @@ async def retrieve_async( Args: query: 查询文本 / Query text config: 配置 / Configuration + filter: 运行时过滤条件 / Runtime filter Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -1017,7 +1068,7 @@ async def retrieve_async( client = self._build_agent_storage_client(config) - retrieval_config = self._build_retrieval_configuration() + retrieval_config = self._build_retrieval_configuration(filter=filter) request: Dict[str, Any] = { "knowledgeBaseName": self.knowledge_base_name, @@ -1050,6 +1101,7 @@ def retrieve( self, query: str, config: Optional[Config] = None, + filter: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """OTS 检索(同步)/ OTS retrieval synchronously @@ -1059,6 +1111,7 @@ def retrieve( Args: query: 查询文本 / Query text config: 配置 / Configuration + filter: 运行时过滤条件 / Runtime filter Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -1071,7 +1124,7 @@ def retrieve( client = self._build_agent_storage_client(config) - retrieval_config = self._build_retrieval_configuration() + retrieval_config = self._build_retrieval_configuration(filter=filter) request: Dict[str, Any] = { "knowledgeBaseName": self.knowledge_base_name, diff --git a/agentrun/knowledgebase/knowledgebase.py b/agentrun/knowledgebase/knowledgebase.py index ad3f6ab..4796d2e 100644 --- a/agentrun/knowledgebase/knowledgebase.py +++ b/agentrun/knowledgebase/knowledgebase.py @@ -562,9 +562,6 @@ def _get_data_api(self, config: Optional[Config] = None): hybrid_search_args=self.retrieve_settings.get( "HybridSearchArgs" ), - filter=self.retrieve_settings.get( - "Filter" - ), ) elif provider == KnowledgeBaseProvider.OTS: @@ -689,6 +686,7 @@ async def retrieve_async( self, query: str, config: Optional[Config] = None, + metadata_filters: Optional[Any] = None, ) -> Dict[str, Any]: """检索知识库(异步)/ Retrieve from knowledge base asynchronously @@ -698,17 +696,41 @@ async def retrieve_async( Args: query: 查询文本 / Query text config: 配置 / Configuration + metadata_filters: 元数据过滤条件 / Metadata filter conditions Returns: Dict[str, Any]: 检索结果 / Retrieval results """ data_api = self._get_data_api(config) - return await data_api.retrieve_async(query, config=config) + provider = ( + self.provider + if isinstance(self.provider, KnowledgeBaseProvider) + else KnowledgeBaseProvider(self.provider) + ) + if provider == KnowledgeBaseProvider.BAILIAN: + return await data_api.retrieve_async( + query, config=config, search_filters=metadata_filters + ) + elif provider == KnowledgeBaseProvider.RAGFLOW: + return await data_api.retrieve_async( + query, config=config, metadata_condition=metadata_filters + ) + elif provider == KnowledgeBaseProvider.ADB: + return await data_api.retrieve_async( + query, config=config, filter=metadata_filters + ) + elif provider == KnowledgeBaseProvider.OTS: + return await data_api.retrieve_async( + query, config=config, filter=metadata_filters + ) + else: + return await data_api.retrieve_async(query, config=config) def retrieve( self, query: str, config: Optional[Config] = None, + metadata_filters: Optional[Any] = None, ) -> Dict[str, Any]: """检索知识库(同步)/ Retrieve from knowledge base synchronously @@ -718,12 +740,35 @@ def retrieve( Args: query: 查询文本 / Query text config: 配置 / Configuration + metadata_filters: 元数据过滤条件 / Metadata filter conditions Returns: Dict[str, Any]: 检索结果 / Retrieval results """ data_api = self._get_data_api(config) - return data_api.retrieve(query, config=config) + provider = ( + self.provider + if isinstance(self.provider, KnowledgeBaseProvider) + else KnowledgeBaseProvider(self.provider) + ) + if provider == KnowledgeBaseProvider.BAILIAN: + return data_api.retrieve( + query, config=config, search_filters=metadata_filters + ) + elif provider == KnowledgeBaseProvider.RAGFLOW: + return data_api.retrieve( + query, config=config, metadata_condition=metadata_filters + ) + elif provider == KnowledgeBaseProvider.ADB: + return data_api.retrieve( + query, config=config, filter=metadata_filters + ) + elif provider == KnowledgeBaseProvider.OTS: + return data_api.retrieve( + query, config=config, filter=metadata_filters + ) + else: + return data_api.retrieve(query, config=config) @classmethod async def _safe_get_kb_async( @@ -772,6 +817,7 @@ async def _safe_retrieve_kb_async( kb_or_error: Any, query: str, config: Optional[Config] = None, + metadata_filters: Optional[Any] = None, ) -> Dict[str, Any]: """安全执行知识库检索(异步)/ Safely retrieve from knowledge base asynchronously @@ -780,6 +826,7 @@ async def _safe_retrieve_kb_async( kb_or_error: 知识库对象或异常 / Knowledge base object or exception query: 查询文本 / Query text config: 配置 / Configuration + metadata_filters: 元数据过滤条件 / Metadata filter conditions Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -795,7 +842,9 @@ async def _safe_retrieve_kb_async( "error": True, } try: - return await kb_or_error.retrieve_async(query, config=config) + return await kb_or_error.retrieve_async( + query, config=config, metadata_filters=metadata_filters + ) except Exception as e: logger.warning( f"Failed to retrieve from knowledge base '{kb_name}': {e}" @@ -814,6 +863,7 @@ def _safe_retrieve_kb( kb_or_error: Any, query: str, config: Optional[Config] = None, + metadata_filters: Optional[Any] = None, ) -> Dict[str, Any]: """安全执行知识库检索(同步)/ Safely retrieve from knowledge base synchronously @@ -822,6 +872,7 @@ def _safe_retrieve_kb( kb_or_error: 知识库对象或异常 / Knowledge base object or exception query: 查询文本 / Query text config: 配置 / Configuration + metadata_filters: 元数据过滤条件 / Metadata filter conditions Returns: Dict[str, Any]: 检索结果 / Retrieval results @@ -837,7 +888,9 @@ def _safe_retrieve_kb( "error": True, } try: - return kb_or_error.retrieve(query, config=config) + return kb_or_error.retrieve( + query, config=config, metadata_filters=metadata_filters + ) except Exception as e: logger.warning( f"Failed to retrieve from knowledge base '{kb_name}': {e}" @@ -855,6 +908,7 @@ async def multi_retrieve_async( query: str, knowledge_base_names: List[str], config: Optional[Config] = None, + metadata_filters: Optional[Any] = None, ) -> Dict[str, Any]: """多知识库检索(异步)/ Multi knowledge base retrieval asynchronously @@ -883,7 +937,8 @@ async def multi_retrieve_async( # Execute retrieval for each knowledge base concurrently (safely) retrieve_results = await asyncio.gather(*[ cls._safe_retrieve_kb_async( - kb_name, kb_or_error, query, config=config + kb_name, kb_or_error, query, config=config, + metadata_filters=metadata_filters, ) for kb_name, kb_or_error in zip( knowledge_base_names, knowledge_base_results @@ -907,6 +962,7 @@ def multi_retrieve( query: str, knowledge_base_names: List[str], config: Optional[Config] = None, + metadata_filters: Optional[Any] = None, ) -> Dict[str, Any]: """多知识库检索(同步)/ Multi knowledge base retrieval synchronously @@ -935,7 +991,8 @@ def multi_retrieve( # Execute retrieval for each knowledge base concurrently (safely) retrieve_results = ([ cls._safe_retrieve_kb( - kb_name, kb_or_error, query, config=config + kb_name, kb_or_error, query, config=config, + metadata_filters=metadata_filters, ) for kb_name, kb_or_error in zip( knowledge_base_names, knowledge_base_results diff --git a/agentrun/knowledgebase/model.py b/agentrun/knowledgebase/model.py index ff0f7a6..295acea 100644 --- a/agentrun/knowledgebase/model.py +++ b/agentrun/knowledgebase/model.py @@ -151,8 +151,8 @@ class ADBRetrieveSettings(BaseModel): """混合检索算法参数,如 {"RRF": {"k": 60}} 或 {"Weight": {"alpha": 0.5}} Hybrid search algorithm parameters""" filter: Optional[str] = None - """过滤条件,SQL WHERE 格式,如 "category = 'tech' AND score > 0.5" - Filter condition in SQL WHERE format""" + """过滤条件(已弃用,请通过 retrieve 方法的 metadata_filters 参数传入) + Filter condition (deprecated, use metadata_filters parameter in retrieve method instead)""" # ============================================================================= diff --git a/tests/unittests/knowledgebase/api/test_data.py b/tests/unittests/knowledgebase/api/test_data.py index 4104a0a..6716cdc 100644 --- a/tests/unittests/knowledgebase/api/test_data.py +++ b/tests/unittests/knowledgebase/api/test_data.py @@ -1030,6 +1030,7 @@ def test_build_query_content_request_with_settings(self): ), ) + # filter 字段已弃用,仅验证构造不报错 request = api._build_query_content_request("test query") assert request.url_expiration == "356d" assert request.metrics == "cosine" @@ -1037,7 +1038,6 @@ def test_build_query_content_request_with_settings(self): assert request.use_full_text_retrieval is True assert request.rerank_factor == 1.5 assert request.rerank_model is not None - assert request.filter == "category = 'tech' AND score > 0.5" @patch.dict( os.environ, diff --git a/tests/unittests/knowledgebase/test_knowledgebase.py b/tests/unittests/knowledgebase/test_knowledgebase.py index ffbf883..b267d9a 100644 --- a/tests/unittests/knowledgebase/test_knowledgebase.py +++ b/tests/unittests/knowledgebase/test_knowledgebase.py @@ -13,6 +13,8 @@ KnowledgeBaseCreateInput, KnowledgeBaseProvider, KnowledgeBaseUpdateInput, + OTSProviderSettings, + OTSRetrieveSettings, RagFlowProviderSettings, RagFlowRetrieveSettings, ) @@ -706,7 +708,6 @@ def test_get_data_api_adb_with_raw_dict_settings(self): "RecallWindow": [-5, 5], "HybridSearch": "RRF", "HybridSearchArgs": {"RRF": {"k": 60}}, - "Filter": "category = 'tech'", }, ) @@ -714,7 +715,7 @@ def test_get_data_api_adb_with_raw_dict_settings(self): data_api = kb._get_data_api() assert isinstance(data_api, ADBDataAPI) - assert data_api.retrieve_settings.filter == "category = 'tech'" + assert data_api.retrieve_settings.top_k == 10 def test_get_data_api_without_provider(self): """测试获取数据链路 API(无提供商)""" @@ -1307,3 +1308,162 @@ async def mock_get_async(*args, **kwargs): ) assert "results" in result + +class TestKnowledgeBaseMetadataFilters: + """测试 metadata_filters 运行时参数 / Test metadata_filters runtime parameter""" + + @patch("agentrun.knowledgebase.api.data.BailianDataAPI.retrieve") + def test_retrieve_sync_bailian_with_metadata_filters(self, mock_retrieve): + """测试百炼同步检索传递 metadata_filters -> search_filters""" + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + } + + kb = KnowledgeBase( + knowledge_base_name="test-bailian-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + credential_name="test-credential", + ) + + filters = [{"key": "vehicle_type", "value": "sedan"}] + kb.retrieve("test query", metadata_filters=filters) + + mock_retrieve.assert_called_once() + _, kwargs = mock_retrieve.call_args + assert kwargs.get("search_filters") == filters + + @patch("agentrun.knowledgebase.api.data.RagFlowDataAPI.retrieve") + def test_retrieve_sync_ragflow_with_metadata_filters(self, mock_retrieve): + """测试 RagFlow 同步检索传递 metadata_filters -> metadata_condition""" + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + } + + kb = KnowledgeBase( + knowledge_base_name="test-ragflow-kb", + provider=KnowledgeBaseProvider.RAGFLOW, + provider_settings=RagFlowProviderSettings( + base_url="https://ragflow.example.com", + dataset_ids=["ds-1"], + ), + credential_name="test-credential", + ) + + filters = {"logic": "and", "conditions": []} + kb.retrieve("test query", metadata_filters=filters) + + mock_retrieve.assert_called_once() + _, kwargs = mock_retrieve.call_args + assert kwargs.get("metadata_condition") == filters + + @patch("agentrun.knowledgebase.api.data.ADBDataAPI.retrieve") + def test_retrieve_sync_adb_with_metadata_filters(self, mock_retrieve): + """测试 ADB 同步检索传递 metadata_filters -> filter""" + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + } + + kb = KnowledgeBase( + knowledge_base_name="test-adb-kb", + provider=KnowledgeBaseProvider.ADB, + provider_settings=ADBProviderSettings( + db_instance_id="adb-123", + namespace="public", + namespace_password="test-pwd", + ), + credential_name="test-credential", + ) + + filters = "vehicle_type = 'sedan'" + kb.retrieve("test query", metadata_filters=filters) + + mock_retrieve.assert_called_once() + _, kwargs = mock_retrieve.call_args + assert kwargs.get("filter") == filters + + @patch("agentrun.knowledgebase.api.data.OTSDataAPI.retrieve") + def test_retrieve_sync_ots_with_metadata_filters(self, mock_retrieve): + """测试 OTS 同步检索传递 metadata_filters -> filter""" + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + } + + kb = KnowledgeBase( + knowledge_base_name="test-ots-kb", + provider=KnowledgeBaseProvider.OTS, + provider_settings=OTSProviderSettings( + ots_instance_name="ots-123", + ), + credential_name="test-credential", + ) + + filters = {"vehicle_type": "sedan"} + kb.retrieve("test query", metadata_filters=filters) + + mock_retrieve.assert_called_once() + _, kwargs = mock_retrieve.call_args + assert kwargs.get("filter") == filters + + @patch("agentrun.knowledgebase.api.data.BailianDataAPI.retrieve") + def test_safe_retrieve_kb_with_metadata_filters(self, mock_retrieve): + """测试 _safe_retrieve_kb 透传 metadata_filters""" + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + } + + kb = KnowledgeBase( + knowledge_base_name="test-bailian-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id="ws-123", + index_ids=["idx-1"], + ), + credential_name="test-credential", + ) + + filters = [{"key": "vehicle_type", "value": "sedan"}] + result = KnowledgeBase._safe_retrieve_kb( + "test-bailian-kb", kb, "test query", metadata_filters=filters + ) + + assert "data" in result + mock_retrieve.assert_called_once() + _, kwargs = mock_retrieve.call_args + assert kwargs.get("search_filters") == filters + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @patch("agentrun.knowledgebase.api.data.BailianDataAPI.retrieve") + def test_multi_retrieve_with_metadata_filters( + self, mock_retrieve, mock_control_api_class + ): + """测试 multi_retrieve 透传 metadata_filters""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = MockBailianKnowledgeBaseData() + mock_control_api_class.return_value = mock_control_api + + mock_retrieve.return_value = { + "data": [{"content": "test"}], + "query": "test query", + } + + filters = [{"key": "vehicle_type", "value": "sedan"}] + result = KnowledgeBase.multi_retrieve( + query="test query", + knowledge_base_names=["kb-1"], + metadata_filters=filters, + ) + + assert "results" in result + mock_retrieve.assert_called_once() + _, kwargs = mock_retrieve.call_args + assert kwargs.get("search_filters") == filters +