Skip to content

Commit 5fdf9e6

Browse files
authored
Merge pull request #117 from Serverless-Devs/kb-search-filter
feat: kb retrieve api supports filter option
2 parents bc76f73 + 10991b3 commit 5fdf9e6

7 files changed

Lines changed: 477 additions & 164 deletions

File tree

agentrun/integration/builtin/knowledgebase.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,19 @@ def __init__(
7979
" grouped by knowledge base name."
8080
),
8181
)
82-
def search_document(self, query: str) -> Dict[str, Any]:
82+
def search_document(
83+
self,
84+
query: str,
85+
metadata_filters: Optional[Any] = None,
86+
) -> Dict[str, Any]:
8387
"""检索文档 / Search documents
8488
8589
根据查询文本从配置的知识库中检索相关文档。
8690
Retrieves relevant documents from configured knowledge bases based on query text.
8791
8892
Args:
8993
query: 查询文本 / Query text
94+
metadata_filters: 元数据过滤条件 / Metadata filter conditions
9095
9196
Returns:
9297
Dict[str, Any]: 检索结果,包含各知识库的检索结果 /
@@ -96,6 +101,7 @@ def search_document(self, query: str) -> Dict[str, Any]:
96101
query=query,
97102
knowledge_base_names=self.knowledge_base_names,
98103
config=self.config,
104+
metadata_filters=metadata_filters,
99105
)
100106

101107

agentrun/knowledgebase/api/__data_async_template.py

Lines changed: 110 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ async def retrieve_async(
5959
self,
6060
query: str,
6161
config: Optional[Config] = None,
62+
metadata_filters: Optional[Any] = None,
6263
) -> Dict[str, Any]:
6364
"""检索知识库(异步)/ Retrieve from knowledge base (async)
6465
@@ -129,11 +130,14 @@ async def _get_api_key_async(self, config: Optional[Config] = None) -> str:
129130
)
130131
return credential.credential_secret
131132

132-
def _build_request_body(self, query: str) -> Dict[str, Any]:
133+
def _build_request_body(
134+
self, query: str, metadata_condition: Optional[Dict[str, Any]] = None
135+
) -> Dict[str, Any]:
133136
"""构建请求体 / Build request body
134137
135138
Args:
136139
query: 查询文本 / Query text
140+
metadata_condition: 元数据过滤条件 / Metadata condition filter
137141
138142
Returns:
139143
Dict[str, Any]: 请求体 / Request body
@@ -163,18 +167,24 @@ def _build_request_body(self, query: str) -> Dict[str, Any]:
163167
if self.retrieve_settings.cross_languages is not None:
164168
body["cross_languages"] = self.retrieve_settings.cross_languages
165169

170+
# 添加运行时元数据过滤条件 / Add runtime metadata condition filter
171+
if metadata_condition is not None:
172+
body["metadata_condition"] = metadata_condition
173+
166174
return body
167175

168176
async def retrieve_async(
169177
self,
170178
query: str,
171179
config: Optional[Config] = None,
180+
metadata_condition: Optional[Dict[str, Any]] = None,
172181
) -> Dict[str, Any]:
173182
"""RagFlow 检索(异步)/ RagFlow retrieval (async)
174183
175184
Args:
176185
query: 查询文本 / Query text
177186
config: 配置 / Configuration
187+
metadata_condition: 运行时元数据过滤条件 / Runtime metadata condition filter
178188
179189
Returns:
180190
Dict[str, Any]: 检索结果 / Retrieval results
@@ -195,7 +205,7 @@ async def retrieve_async(
195205
"Content-Type": "application/json",
196206
"Authorization": f"Bearer {api_key}",
197207
}
198-
body = self._build_request_body(query)
208+
body = self._build_request_body(query, metadata_condition=metadata_condition)
199209

200210
# 发送请求 / Send request
201211
async with httpx.AsyncClient(
@@ -261,12 +271,14 @@ async def retrieve_async(
261271
self,
262272
query: str,
263273
config: Optional[Config] = None,
274+
search_filters: Optional[List[Dict[str, str]]] = None,
264275
) -> Dict[str, Any]:
265276
"""百炼检索(异步)/ Bailian retrieval (async)
266277
267278
Args:
268279
query: 查询文本 / Query text
269280
config: 配置 / Configuration
281+
search_filters: 运行时元数据过滤条件 / Runtime metadata search filters
270282
271283
Returns:
272284
Dict[str, Any]: 检索结果 / Retrieval results
@@ -304,6 +316,11 @@ async def retrieve_async(
304316
self.retrieve_settings.rerank_top_n
305317
)
306318

319+
# 添加运行时元数据过滤条件 / Add runtime metadata search filters
320+
if search_filters is not None:
321+
request_params["search_filters"] = search_filters
322+
request_params["is_displayed_chunk_content"] = True
323+
307324
# 获取百炼客户端 / Get Bailian client
308325
client = self._get_bailian_client(config)
309326

@@ -381,13 +398,17 @@ def __init__(
381398
self.retrieve_settings = retrieve_settings
382399

383400
def _build_query_content_request(
384-
self, query: str, config: Optional[Config] = None
401+
self,
402+
query: str,
403+
config: Optional[Config] = None,
404+
filter: Optional[str] = None,
385405
) -> gpdb_models.QueryContentRequest:
386406
"""构建 QueryContent 请求 / Build QueryContent request
387407
388408
Args:
389409
query: 查询文本 / Query text
390410
config: 配置 / Configuration
411+
filter: 运行时 SQL WHERE 过滤条件 / Runtime SQL WHERE filter
391412
392413
Returns:
393414
QueryContentRequest: GPDB QueryContent 请求对象
@@ -444,8 +465,8 @@ def _build_query_content_request(
444465
request_params["hybrid_search_args"] = (
445466
self.retrieve_settings.hybrid_search_args
446467
)
447-
if self.retrieve_settings.filter is not None:
448-
request_params["filter"] = self.retrieve_settings.filter
468+
if filter is not None:
469+
request_params["filter"] = filter
449470

450471
return gpdb_models.QueryContentRequest(**request_params)
451472

@@ -513,6 +534,7 @@ async def retrieve_async(
513534
self,
514535
query: str,
515536
config: Optional[Config] = None,
537+
filter: Optional[str] = None,
516538
) -> Dict[str, Any]:
517539
"""ADB 检索(异步)/ ADB retrieval asynchronously
518540
@@ -522,6 +544,7 @@ async def retrieve_async(
522544
Args:
523545
query: 查询文本 / Query text
524546
config: 配置 / Configuration
547+
filter: 运行时 SQL WHERE 过滤条件 / Runtime SQL WHERE filter
525548
526549
Returns:
527550
Dict[str, Any]: 检索结果 / Retrieval results
@@ -536,7 +559,7 @@ async def retrieve_async(
536559
client = self._get_gpdb_client(config)
537560

538561
# 构建请求 / Build request
539-
request = self._build_query_content_request(query, config)
562+
request = self._build_query_content_request(query, config, filter=filter)
540563
logger.debug(f"ADB QueryContent request: {request}")
541564

542565
# 调用 QueryContent API / Call QueryContent API
@@ -611,85 +634,96 @@ def _build_agent_storage_client(
611634
ots_instance_name=self.provider_settings.ots_instance_name,
612635
)
613636

614-
def _build_retrieval_configuration(self) -> Optional[Dict[str, Any]]:
637+
def _build_retrieval_configuration(
638+
self, filter: Optional[Dict[str, Any]] = None
639+
) -> Optional[Dict[str, Any]]:
615640
"""将 OTSRetrieveSettings 转换为 tablestore-agent-storage 的 dict 格式
616641
Convert OTSRetrieveSettings to tablestore-agent-storage dict format
617642
643+
Args:
644+
filter: 运行时过滤条件 / Runtime filter
645+
618646
Returns:
619647
Optional[Dict[str, Any]]: 检索配置字典 / Retrieval configuration dict
620648
"""
621-
if self.retrieve_settings is None:
649+
if self.retrieve_settings is None and filter is None:
622650
return None
623651

624652
config: Dict[str, Any] = {}
625653

626-
if self.retrieve_settings.search_type is not None:
627-
config["searchType"] = self.retrieve_settings.search_type
628-
629-
if self.retrieve_settings.dense_vector_search_configuration is not None:
630-
dvsc = self.retrieve_settings.dense_vector_search_configuration
631-
dv_config: Dict[str, Any] = {}
632-
if dvsc.number_of_results is not None:
633-
dv_config["numberOfResults"] = dvsc.number_of_results
634-
config["denseVectorSearchConfiguration"] = dv_config
635-
636-
if self.retrieve_settings.full_text_search_configuration is not None:
637-
ftsc = self.retrieve_settings.full_text_search_configuration
638-
ft_config: Dict[str, Any] = {}
639-
if ftsc.number_of_results is not None:
640-
ft_config["numberOfResults"] = ftsc.number_of_results
641-
config["fullTextSearchConfiguration"] = ft_config
642-
643-
if self.retrieve_settings.reranking_configuration is not None:
644-
rc = self.retrieve_settings.reranking_configuration
645-
rr_config: Dict[str, Any] = {}
646-
647-
if rc.type is not None:
648-
rr_config["type"] = rc.type
649-
if rc.number_of_results is not None:
650-
rr_config["numberOfResults"] = rc.number_of_results
651-
652-
if rc.rrf_configuration is not None:
653-
rrf: Dict[str, Any] = {}
654-
if rc.rrf_configuration.dense_vector_search_weight is not None:
655-
rrf["denseVectorSearchWeight"] = (
656-
rc.rrf_configuration.dense_vector_search_weight
657-
)
658-
if rc.rrf_configuration.full_text_search_weight is not None:
659-
rrf["fullTextSearchWeight"] = (
660-
rc.rrf_configuration.full_text_search_weight
661-
)
662-
if rc.rrf_configuration.k is not None:
663-
rrf["k"] = rc.rrf_configuration.k
664-
rr_config["rrfConfiguration"] = rrf
665-
666-
if rc.weight_configuration is not None:
667-
wc: Dict[str, Any] = {}
668-
if (
669-
rc.weight_configuration.dense_vector_search_weight
670-
is not None
671-
):
672-
wc["denseVectorSearchWeight"] = (
654+
if self.retrieve_settings is not None:
655+
if self.retrieve_settings.search_type is not None:
656+
config["searchType"] = self.retrieve_settings.search_type
657+
658+
if self.retrieve_settings.dense_vector_search_configuration is not None:
659+
dvsc = self.retrieve_settings.dense_vector_search_configuration
660+
dv_config: Dict[str, Any] = {}
661+
if dvsc.number_of_results is not None:
662+
dv_config["numberOfResults"] = dvsc.number_of_results
663+
config["denseVectorSearchConfiguration"] = dv_config
664+
665+
if self.retrieve_settings.full_text_search_configuration is not None:
666+
ftsc = self.retrieve_settings.full_text_search_configuration
667+
ft_config: Dict[str, Any] = {}
668+
if ftsc.number_of_results is not None:
669+
ft_config["numberOfResults"] = ftsc.number_of_results
670+
config["fullTextSearchConfiguration"] = ft_config
671+
672+
if self.retrieve_settings.reranking_configuration is not None:
673+
rc = self.retrieve_settings.reranking_configuration
674+
rr_config: Dict[str, Any] = {}
675+
676+
if rc.type is not None:
677+
rr_config["type"] = rc.type
678+
if rc.number_of_results is not None:
679+
rr_config["numberOfResults"] = rc.number_of_results
680+
681+
if rc.rrf_configuration is not None:
682+
rrf: Dict[str, Any] = {}
683+
if rc.rrf_configuration.dense_vector_search_weight is not None:
684+
rrf["denseVectorSearchWeight"] = (
685+
rc.rrf_configuration.dense_vector_search_weight
686+
)
687+
if rc.rrf_configuration.full_text_search_weight is not None:
688+
rrf["fullTextSearchWeight"] = (
689+
rc.rrf_configuration.full_text_search_weight
690+
)
691+
if rc.rrf_configuration.k is not None:
692+
rrf["k"] = rc.rrf_configuration.k
693+
rr_config["rrfConfiguration"] = rrf
694+
695+
if rc.weight_configuration is not None:
696+
wc: Dict[str, Any] = {}
697+
if (
673698
rc.weight_configuration.dense_vector_search_weight
674-
)
675-
if rc.weight_configuration.full_text_search_weight is not None:
676-
wc["fullTextSearchWeight"] = (
677-
rc.weight_configuration.full_text_search_weight
678-
)
679-
rr_config["weightConfiguration"] = wc
680-
681-
if rc.model_configuration is not None:
682-
mc: Dict[str, Any] = {}
683-
if rc.model_configuration.provider is not None:
684-
mc["provider"] = rc.model_configuration.provider
685-
if rc.model_configuration.model is not None:
686-
mc["model"] = rc.model_configuration.model
687-
rr_config["modelConfiguration"] = mc
699+
is not None
700+
):
701+
wc["denseVectorSearchWeight"] = (
702+
rc.weight_configuration.dense_vector_search_weight
703+
)
704+
if rc.weight_configuration.full_text_search_weight is not None:
705+
wc["fullTextSearchWeight"] = (
706+
rc.weight_configuration.full_text_search_weight
707+
)
708+
rr_config["weightConfiguration"] = wc
709+
710+
if rc.model_configuration is not None:
711+
mc: Dict[str, Any] = {}
712+
if rc.model_configuration.provider is not None:
713+
mc["provider"] = rc.model_configuration.provider
714+
if rc.model_configuration.model is not None:
715+
mc["model"] = rc.model_configuration.model
716+
rr_config["modelConfiguration"] = mc
717+
718+
config["rerankingConfiguration"] = rr_config
688719

689-
config["rerankingConfiguration"] = rr_config
720+
if self.retrieve_settings.filter is not None:
721+
config["filter"] = self.retrieve_settings.filter
690722

691-
if self.retrieve_settings.filter is not None:
692-
config["filter"] = self.retrieve_settings.filter
723+
# 运行时 filter 优先级高于 retrieve_settings.filter
724+
# Runtime filter takes precedence over retrieve_settings.filter
725+
if filter is not None:
726+
config["filter"] = filter
693727

694728
return config if config else None
695729

@@ -731,6 +765,7 @@ async def retrieve_async(
731765
self,
732766
query: str,
733767
config: Optional[Config] = None,
768+
filter: Optional[Dict[str, Any]] = None,
734769
) -> Dict[str, Any]:
735770
"""OTS 检索(异步)/ OTS retrieval asynchronously
736771
@@ -740,6 +775,7 @@ async def retrieve_async(
740775
Args:
741776
query: 查询文本 / Query text
742777
config: 配置 / Configuration
778+
filter: 运行时过滤条件 / Runtime filter
743779
744780
Returns:
745781
Dict[str, Any]: 检索结果 / Retrieval results
@@ -752,7 +788,7 @@ async def retrieve_async(
752788

753789
client = self._build_agent_storage_client(config)
754790

755-
retrieval_config = self._build_retrieval_configuration()
791+
retrieval_config = self._build_retrieval_configuration(filter=filter)
756792

757793
request: Dict[str, Any] = {
758794
"knowledgeBaseName": self.knowledge_base_name,
@@ -785,6 +821,7 @@ def get_data_api(
785821
provider: KnowledgeBaseProvider,
786822
knowledge_base_name: str,
787823
config: Optional[Config] = None,
824+
788825
provider_settings: Optional[
789826
Union[
790827
RagFlowProviderSettings,

0 commit comments

Comments
 (0)