Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion agentrun/integration/builtin/knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,19 @@ def __init__(
" grouped by knowledge base name."
),
)
def search_document(self, query: str) -> Dict[str, Any]:
def search_document(
self,
query: str,
metadata_filters: Optional[Any] = None,
) -> Dict[str, Any]:
"""检索文档 / Search documents

根据查询文本从配置的知识库中检索相关文档。
Retrieves relevant documents from configured knowledge bases based on query text.

Args:
query: 查询文本 / Query text
metadata_filters: 元数据过滤条件 / Metadata filter conditions

Returns:
Dict[str, Any]: 检索结果,包含各知识库的检索结果 /
Expand All @@ -96,6 +101,7 @@ def search_document(self, query: str) -> Dict[str, Any]:
query=query,
knowledge_base_names=self.knowledge_base_names,
config=self.config,
metadata_filters=metadata_filters,
)


Expand Down
183 changes: 110 additions & 73 deletions agentrun/knowledgebase/api/__data_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ async def retrieve_async(
self,
query: str,
config: Optional[Config] = None,
metadata_filters: Optional[Any] = None,
) -> Dict[str, Any]:
"""检索知识库(异步)/ Retrieve from knowledge base (async)

Expand Down Expand Up @@ -129,11 +130,14 @@ async def _get_api_key_async(self, config: Optional[Config] = None) -> str:
)
return credential.credential_secret

def _build_request_body(self, query: str) -> Dict[str, Any]:
def _build_request_body(
self, query: str, metadata_condition: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""构建请求体 / Build request body

Args:
query: 查询文本 / Query text
metadata_condition: 元数据过滤条件 / Metadata condition filter

Returns:
Dict[str, Any]: 请求体 / Request body
Expand Down Expand Up @@ -163,18 +167,24 @@ def _build_request_body(self, query: str) -> Dict[str, Any]:
if self.retrieve_settings.cross_languages is not None:
body["cross_languages"] = self.retrieve_settings.cross_languages

# 添加运行时元数据过滤条件 / Add runtime metadata condition filter
if metadata_condition is not None:
body["metadata_condition"] = metadata_condition

return body

async def retrieve_async(
self,
query: str,
config: Optional[Config] = None,
metadata_condition: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""RagFlow 检索(异步)/ RagFlow retrieval (async)

Args:
query: 查询文本 / Query text
config: 配置 / Configuration
metadata_condition: 运行时元数据过滤条件 / Runtime metadata condition filter

Returns:
Dict[str, Any]: 检索结果 / Retrieval results
Expand All @@ -195,7 +205,7 @@ async def retrieve_async(
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
body = self._build_request_body(query)
body = self._build_request_body(query, metadata_condition=metadata_condition)

# 发送请求 / Send request
async with httpx.AsyncClient(
Expand Down Expand Up @@ -261,12 +271,14 @@ async def retrieve_async(
self,
query: str,
config: Optional[Config] = None,
search_filters: Optional[List[Dict[str, str]]] = None,
) -> Dict[str, Any]:
"""百炼检索(异步)/ Bailian retrieval (async)

Args:
query: 查询文本 / Query text
config: 配置 / Configuration
search_filters: 运行时元数据过滤条件 / Runtime metadata search filters

Returns:
Dict[str, Any]: 检索结果 / Retrieval results
Expand Down Expand Up @@ -304,6 +316,11 @@ async def retrieve_async(
self.retrieve_settings.rerank_top_n
)

# 添加运行时元数据过滤条件 / Add runtime metadata search filters
if search_filters is not None:
request_params["search_filters"] = search_filters
request_params["is_displayed_chunk_content"] = True

# 获取百炼客户端 / Get Bailian client
client = self._get_bailian_client(config)

Expand Down Expand Up @@ -381,13 +398,17 @@ def __init__(
self.retrieve_settings = retrieve_settings

def _build_query_content_request(
self, query: str, config: Optional[Config] = None
self,
query: str,
config: Optional[Config] = None,
filter: Optional[str] = None,
) -> gpdb_models.QueryContentRequest:
"""构建 QueryContent 请求 / Build QueryContent request

Args:
query: 查询文本 / Query text
config: 配置 / Configuration
filter: 运行时 SQL WHERE 过滤条件 / Runtime SQL WHERE filter

Returns:
QueryContentRequest: GPDB QueryContent 请求对象
Expand Down Expand Up @@ -444,8 +465,8 @@ def _build_query_content_request(
request_params["hybrid_search_args"] = (
self.retrieve_settings.hybrid_search_args
)
if self.retrieve_settings.filter is not None:
request_params["filter"] = self.retrieve_settings.filter
if filter is not None:
request_params["filter"] = filter

return gpdb_models.QueryContentRequest(**request_params)

Expand Down Expand Up @@ -513,6 +534,7 @@ async def retrieve_async(
self,
query: str,
config: Optional[Config] = None,
filter: Optional[str] = None,
) -> Dict[str, Any]:
"""ADB 检索(异步)/ ADB retrieval asynchronously

Expand All @@ -522,6 +544,7 @@ async def retrieve_async(
Args:
query: 查询文本 / Query text
config: 配置 / Configuration
filter: 运行时 SQL WHERE 过滤条件 / Runtime SQL WHERE filter

Returns:
Dict[str, Any]: 检索结果 / Retrieval results
Expand All @@ -536,7 +559,7 @@ async def retrieve_async(
client = self._get_gpdb_client(config)

# 构建请求 / Build request
request = self._build_query_content_request(query, config)
request = self._build_query_content_request(query, config, filter=filter)
logger.debug(f"ADB QueryContent request: {request}")

# 调用 QueryContent API / Call QueryContent API
Expand Down Expand Up @@ -611,85 +634,96 @@ def _build_agent_storage_client(
ots_instance_name=self.provider_settings.ots_instance_name,
)

def _build_retrieval_configuration(self) -> Optional[Dict[str, Any]]:
def _build_retrieval_configuration(
self, filter: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
"""将 OTSRetrieveSettings 转换为 tablestore-agent-storage 的 dict 格式
Convert OTSRetrieveSettings to tablestore-agent-storage dict format

Args:
filter: 运行时过滤条件 / Runtime filter

Returns:
Optional[Dict[str, Any]]: 检索配置字典 / Retrieval configuration dict
"""
if self.retrieve_settings is None:
if self.retrieve_settings is None and filter is None:
return None

config: Dict[str, Any] = {}

if self.retrieve_settings.search_type is not None:
config["searchType"] = self.retrieve_settings.search_type

if self.retrieve_settings.dense_vector_search_configuration is not None:
dvsc = self.retrieve_settings.dense_vector_search_configuration
dv_config: Dict[str, Any] = {}
if dvsc.number_of_results is not None:
dv_config["numberOfResults"] = dvsc.number_of_results
config["denseVectorSearchConfiguration"] = dv_config

if self.retrieve_settings.full_text_search_configuration is not None:
ftsc = self.retrieve_settings.full_text_search_configuration
ft_config: Dict[str, Any] = {}
if ftsc.number_of_results is not None:
ft_config["numberOfResults"] = ftsc.number_of_results
config["fullTextSearchConfiguration"] = ft_config

if self.retrieve_settings.reranking_configuration is not None:
rc = self.retrieve_settings.reranking_configuration
rr_config: Dict[str, Any] = {}

if rc.type is not None:
rr_config["type"] = rc.type
if rc.number_of_results is not None:
rr_config["numberOfResults"] = rc.number_of_results

if rc.rrf_configuration is not None:
rrf: Dict[str, Any] = {}
if rc.rrf_configuration.dense_vector_search_weight is not None:
rrf["denseVectorSearchWeight"] = (
rc.rrf_configuration.dense_vector_search_weight
)
if rc.rrf_configuration.full_text_search_weight is not None:
rrf["fullTextSearchWeight"] = (
rc.rrf_configuration.full_text_search_weight
)
if rc.rrf_configuration.k is not None:
rrf["k"] = rc.rrf_configuration.k
rr_config["rrfConfiguration"] = rrf

if rc.weight_configuration is not None:
wc: Dict[str, Any] = {}
if (
rc.weight_configuration.dense_vector_search_weight
is not None
):
wc["denseVectorSearchWeight"] = (
if self.retrieve_settings is not None:
if self.retrieve_settings.search_type is not None:
config["searchType"] = self.retrieve_settings.search_type

if self.retrieve_settings.dense_vector_search_configuration is not None:
dvsc = self.retrieve_settings.dense_vector_search_configuration
dv_config: Dict[str, Any] = {}
if dvsc.number_of_results is not None:
dv_config["numberOfResults"] = dvsc.number_of_results
config["denseVectorSearchConfiguration"] = dv_config

if self.retrieve_settings.full_text_search_configuration is not None:
ftsc = self.retrieve_settings.full_text_search_configuration
ft_config: Dict[str, Any] = {}
if ftsc.number_of_results is not None:
ft_config["numberOfResults"] = ftsc.number_of_results
config["fullTextSearchConfiguration"] = ft_config

if self.retrieve_settings.reranking_configuration is not None:
rc = self.retrieve_settings.reranking_configuration
rr_config: Dict[str, Any] = {}

if rc.type is not None:
rr_config["type"] = rc.type
if rc.number_of_results is not None:
rr_config["numberOfResults"] = rc.number_of_results

if rc.rrf_configuration is not None:
rrf: Dict[str, Any] = {}
if rc.rrf_configuration.dense_vector_search_weight is not None:
rrf["denseVectorSearchWeight"] = (
rc.rrf_configuration.dense_vector_search_weight
)
if rc.rrf_configuration.full_text_search_weight is not None:
rrf["fullTextSearchWeight"] = (
rc.rrf_configuration.full_text_search_weight
)
if rc.rrf_configuration.k is not None:
rrf["k"] = rc.rrf_configuration.k
rr_config["rrfConfiguration"] = rrf

if rc.weight_configuration is not None:
wc: Dict[str, Any] = {}
if (
rc.weight_configuration.dense_vector_search_weight
)
if rc.weight_configuration.full_text_search_weight is not None:
wc["fullTextSearchWeight"] = (
rc.weight_configuration.full_text_search_weight
)
rr_config["weightConfiguration"] = wc

if rc.model_configuration is not None:
mc: Dict[str, Any] = {}
if rc.model_configuration.provider is not None:
mc["provider"] = rc.model_configuration.provider
if rc.model_configuration.model is not None:
mc["model"] = rc.model_configuration.model
rr_config["modelConfiguration"] = mc
is not None
):
wc["denseVectorSearchWeight"] = (
rc.weight_configuration.dense_vector_search_weight
)
if rc.weight_configuration.full_text_search_weight is not None:
wc["fullTextSearchWeight"] = (
rc.weight_configuration.full_text_search_weight
)
rr_config["weightConfiguration"] = wc

if rc.model_configuration is not None:
mc: Dict[str, Any] = {}
if rc.model_configuration.provider is not None:
mc["provider"] = rc.model_configuration.provider
if rc.model_configuration.model is not None:
mc["model"] = rc.model_configuration.model
rr_config["modelConfiguration"] = mc

config["rerankingConfiguration"] = rr_config

config["rerankingConfiguration"] = rr_config
if self.retrieve_settings.filter is not None:
config["filter"] = self.retrieve_settings.filter

if self.retrieve_settings.filter is not None:
config["filter"] = self.retrieve_settings.filter
# 运行时 filter 优先级高于 retrieve_settings.filter
# Runtime filter takes precedence over retrieve_settings.filter
if filter is not None:
config["filter"] = filter

return config if config else None

Expand Down Expand Up @@ -731,6 +765,7 @@ async def retrieve_async(
self,
query: str,
config: Optional[Config] = None,
filter: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""OTS 检索(异步)/ OTS retrieval asynchronously

Expand All @@ -740,6 +775,7 @@ async def retrieve_async(
Args:
query: 查询文本 / Query text
config: 配置 / Configuration
filter: 运行时过滤条件 / Runtime filter

Returns:
Dict[str, Any]: 检索结果 / Retrieval results
Expand All @@ -752,7 +788,7 @@ async def retrieve_async(

client = self._build_agent_storage_client(config)

retrieval_config = self._build_retrieval_configuration()
retrieval_config = self._build_retrieval_configuration(filter=filter)

request: Dict[str, Any] = {
"knowledgeBaseName": self.knowledge_base_name,
Expand Down Expand Up @@ -785,6 +821,7 @@ def get_data_api(
provider: KnowledgeBaseProvider,
knowledge_base_name: str,
config: Optional[Config] = None,

provider_settings: Optional[
Union[
RagFlowProviderSettings,
Expand Down
Loading
Loading