Skip to content

Commit 8c3bfbc

Browse files
committed
feat(tests): add unit tests for knowledgebase module and API
This commit introduces a comprehensive suite of unit tests for the knowledgebase module, including tests for the KnowledgeBaseClient, KnowledgeBase, and various provider settings. The tests cover creation, deletion, and update functionalities, ensuring robust validation of the knowledgebase operations. Additionally, new test files for the API and model components have been added to enhance test coverage and reliability. Co-developed-by: Aone Copilot <noreply@alibaba-inc.com> Signed-off-by: Sodawyx <sodawyx@126.com>
1 parent 84a901c commit 8c3bfbc

File tree

16 files changed

+5111
-13
lines changed

16 files changed

+5111
-13
lines changed

agentrun/knowledgebase/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""KnowledgeBase 模块 / KnowledgeBase Module"""
22

33
from .api import (
4+
ADBDataAPI,
45
BailianDataAPI,
56
get_data_api,
67
KnowledgeBaseControlAPI,
@@ -10,6 +11,8 @@
1011
from .client import KnowledgeBaseClient
1112
from .knowledgebase import KnowledgeBase
1213
from .model import (
14+
ADBProviderSettings,
15+
ADBRetrieveSettings,
1316
BailianProviderSettings,
1417
BailianRetrieveSettings,
1518
KnowledgeBaseCreateInput,
@@ -33,17 +36,20 @@
3336
"KnowledgeBaseDataAPI",
3437
"RagFlowDataAPI",
3538
"BailianDataAPI",
39+
"ADBDataAPI",
3640
"get_data_api",
3741
# enums
3842
"KnowledgeBaseProvider",
3943
# provider settings
4044
"ProviderSettings",
4145
"RagFlowProviderSettings",
4246
"BailianProviderSettings",
47+
"ADBProviderSettings",
4348
# retrieve settings
4449
"RetrieveSettings",
4550
"RagFlowRetrieveSettings",
4651
"BailianRetrieveSettings",
52+
"ADBRetrieveSettings",
4753
# api model
4854
"KnowledgeBaseCreateInput",
4955
"KnowledgeBaseUpdateInput",

agentrun/knowledgebase/__knowledgebase_async_template.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from .api.data import get_data_api
1616
from .model import (
17+
ADBProviderSettings,
18+
ADBRetrieveSettings,
1719
BailianProviderSettings,
1820
BailianRetrieveSettings,
1921
KnowledgeBaseCreateInput,
@@ -294,6 +296,54 @@ def _get_data_api(self, config: Optional[Config] = None):
294296
**self.retrieve_settings
295297
)
296298

299+
elif provider == KnowledgeBaseProvider.ADB:
300+
# ADB 设置 / ADB settings
301+
if self.provider_settings:
302+
if isinstance(self.provider_settings, ADBProviderSettings):
303+
converted_provider_settings = self.provider_settings
304+
elif isinstance(self.provider_settings, dict):
305+
# ADB provider_settings 使用 PascalCase 键名,需要转换为 snake_case
306+
# ADB provider_settings uses PascalCase keys, need to convert to snake_case
307+
converted_provider_settings = ADBProviderSettings(
308+
db_instance_id=self.provider_settings.get(
309+
"DBInstanceId", ""
310+
),
311+
namespace=self.provider_settings.get("Namespace", ""),
312+
namespace_password=self.provider_settings.get(
313+
"NamespacePassword", ""
314+
),
315+
embedding_model=self.provider_settings.get(
316+
"EmbeddingModel"
317+
),
318+
metrics=self.provider_settings.get("Metrics"),
319+
metadata=self.provider_settings.get("Metadata"),
320+
)
321+
322+
if self.retrieve_settings:
323+
if isinstance(self.retrieve_settings, ADBRetrieveSettings):
324+
converted_retrieve_settings = self.retrieve_settings
325+
elif isinstance(self.retrieve_settings, dict):
326+
# ADB retrieve_settings 使用 PascalCase 键名,需要转换为 snake_case
327+
# ADB retrieve_settings uses PascalCase keys, need to convert to snake_case
328+
converted_retrieve_settings = ADBRetrieveSettings(
329+
top_k=self.retrieve_settings.get("TopK"),
330+
use_full_text_retrieval=self.retrieve_settings.get(
331+
"UseFullTextRetrieval"
332+
),
333+
rerank_factor=self.retrieve_settings.get(
334+
"RerankFactor"
335+
),
336+
recall_window=self.retrieve_settings.get(
337+
"RecallWindow"
338+
),
339+
hybrid_search=self.retrieve_settings.get(
340+
"HybridSearch"
341+
),
342+
hybrid_search_args=self.retrieve_settings.get(
343+
"HybridSearchArgs"
344+
),
345+
)
346+
297347
return get_data_api(
298348
provider=provider,
299349
knowledge_base_name=self.knowledge_base_name or "",

agentrun/knowledgebase/api/__data_async_template.py

Lines changed: 219 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
提供知识库检索功能的数据链路 API。
44
Provides data API for knowledge base retrieval operations.
55
6-
根据不同的 provider 类型(ragflow / bailian)分发到不同的实现。
7-
Dispatches to different implementations based on provider type (ragflow / bailian).
6+
根据不同的 provider 类型(ragflow / bailian / adb)分发到不同的实现。
7+
Dispatches to different implementations based on provider type (ragflow / bailian / adb).
88
"""
99

1010
from abc import ABC, abstractmethod
1111
from typing import Any, Dict, List, Optional, Union
1212

1313
from alibabacloud_bailian20231229 import models as bailian_models
14+
from alibabacloud_gpdb20160503 import models as gpdb_models
1415
import httpx
1516

1617
from agentrun.utils.config import Config
@@ -19,6 +20,8 @@
1920
from agentrun.utils.log import logger
2021

2122
from ..model import (
23+
ADBProviderSettings,
24+
ADBRetrieveSettings,
2225
BailianProviderSettings,
2326
BailianRetrieveSettings,
2427
KnowledgeBaseProvider,
@@ -347,15 +350,210 @@ async def retrieve_async(
347350
}
348351

349352

353+
class ADBDataAPI(KnowledgeBaseDataAPI, ControlAPI):
354+
"""ADB (AnalyticDB for PostgreSQL) 知识库数据链路 API / ADB KnowledgeBase Data API
355+
356+
实现 ADB 知识库的检索逻辑,通过 GPDB SDK 调用 QueryContent 接口。
357+
Implements retrieval logic for ADB knowledge base via GPDB SDK QueryContent API.
358+
"""
359+
360+
def __init__(
361+
self,
362+
knowledge_base_name: str,
363+
config: Optional[Config] = None,
364+
provider_settings: Optional[ADBProviderSettings] = None,
365+
retrieve_settings: Optional[ADBRetrieveSettings] = None,
366+
):
367+
"""初始化 ADB 知识库数据链路 API / Initialize ADB KnowledgeBase Data API
368+
369+
Args:
370+
knowledge_base_name: 知识库名称 / Knowledge base name
371+
config: 配置 / Configuration
372+
provider_settings: ADB 提供商设置 / ADB provider settings
373+
retrieve_settings: ADB 检索设置 / ADB retrieve settings
374+
"""
375+
KnowledgeBaseDataAPI.__init__(self, knowledge_base_name, config)
376+
ControlAPI.__init__(self, config)
377+
self.provider_settings = provider_settings
378+
self.retrieve_settings = retrieve_settings
379+
380+
def _build_query_content_request(
381+
self, query: str, config: Optional[Config] = None
382+
) -> gpdb_models.QueryContentRequest:
383+
"""构建 QueryContent 请求 / Build QueryContent request
384+
385+
Args:
386+
query: 查询文本 / Query text
387+
config: 配置 / Configuration
388+
389+
Returns:
390+
QueryContentRequest: GPDB QueryContent 请求对象
391+
"""
392+
if self.provider_settings is None:
393+
raise ValueError("provider_settings is required for ADB retrieval")
394+
395+
cfg = Config.with_configs(self.config, config)
396+
397+
# 构建基础请求参数 / Build base request parameters
398+
request_params: Dict[str, Any] = {
399+
"content": query,
400+
"dbinstance_id": self.provider_settings.db_instance_id,
401+
"namespace": self.provider_settings.namespace,
402+
"namespace_password": self.provider_settings.namespace_password,
403+
"collection": self.knowledge_base_name,
404+
"region_id": cfg.get_region_id(),
405+
}
406+
407+
# 添加可选的提供商设置 / Add optional provider settings
408+
if self.provider_settings.metrics is not None:
409+
request_params["metrics"] = self.provider_settings.metrics
410+
411+
# 添加检索设置 / Add retrieve settings
412+
if self.retrieve_settings:
413+
if self.retrieve_settings.top_k is not None:
414+
request_params["top_k"] = self.retrieve_settings.top_k
415+
if self.retrieve_settings.use_full_text_retrieval is not None:
416+
request_params["use_full_text_retrieval"] = (
417+
self.retrieve_settings.use_full_text_retrieval
418+
)
419+
if self.retrieve_settings.rerank_factor is not None:
420+
request_params["rerank_factor"] = (
421+
self.retrieve_settings.rerank_factor
422+
)
423+
if self.retrieve_settings.recall_window is not None:
424+
request_params["recall_window"] = (
425+
self.retrieve_settings.recall_window
426+
)
427+
if self.retrieve_settings.hybrid_search is not None:
428+
request_params["hybrid_search"] = (
429+
self.retrieve_settings.hybrid_search
430+
)
431+
if self.retrieve_settings.hybrid_search_args is not None:
432+
request_params["hybrid_search_args"] = (
433+
self.retrieve_settings.hybrid_search_args
434+
)
435+
436+
return gpdb_models.QueryContentRequest(**request_params)
437+
438+
def _parse_query_content_response(
439+
self, response: gpdb_models.QueryContentResponse, query: str
440+
) -> Dict[str, Any]:
441+
"""解析 QueryContent 响应 / Parse QueryContent response
442+
443+
Args:
444+
response: GPDB QueryContent 响应对象
445+
query: 原始查询文本 / Original query text
446+
447+
Returns:
448+
Dict[str, Any]: 格式化的检索结果 / Formatted retrieval results
449+
"""
450+
all_matches: List[Dict[str, Any]] = []
451+
452+
if response.body and response.body.matches:
453+
match_list = response.body.matches.match_list or []
454+
for match in match_list:
455+
all_matches.append({
456+
"content": (
457+
match.content if hasattr(match, "content") else None
458+
),
459+
"score": match.score if hasattr(match, "score") else None,
460+
"id": match.id if hasattr(match, "id") else None,
461+
"file_name": (
462+
match.file_name if hasattr(match, "file_name") else None
463+
),
464+
"metadata": (
465+
match.metadata if hasattr(match, "metadata") else None
466+
),
467+
"rerank_score": (
468+
match.rerank_score
469+
if hasattr(match, "rerank_score")
470+
else None
471+
),
472+
"retrieval_source": (
473+
match.retrieval_source
474+
if hasattr(match, "retrieval_source")
475+
else None
476+
),
477+
})
478+
479+
return {
480+
"data": all_matches,
481+
"query": query,
482+
"knowledge_base_name": self.knowledge_base_name,
483+
"request_id": (
484+
response.body.request_id
485+
if response.body and hasattr(response.body, "request_id")
486+
else None
487+
),
488+
}
489+
490+
async def retrieve_async(
491+
self,
492+
query: str,
493+
config: Optional[Config] = None,
494+
) -> Dict[str, Any]:
495+
"""ADB 检索(异步)/ ADB retrieval asynchronously
496+
497+
通过 GPDB SDK 调用 QueryContent 接口进行知识库检索。
498+
Retrieves from ADB knowledge base via GPDB SDK QueryContent API.
499+
500+
Args:
501+
query: 查询文本 / Query text
502+
config: 配置 / Configuration
503+
504+
Returns:
505+
Dict[str, Any]: 检索结果 / Retrieval results
506+
"""
507+
try:
508+
if self.provider_settings is None:
509+
raise ValueError(
510+
"provider_settings is required for ADB retrieval"
511+
)
512+
513+
# 获取 GPDB 客户端 / Get GPDB client
514+
client = self._get_gpdb_client(config)
515+
516+
# 构建请求 / Build request
517+
request = self._build_query_content_request(query, config)
518+
logger.debug(f"ADB QueryContent request: {request}")
519+
520+
# 调用 QueryContent API / Call QueryContent API
521+
response = await client.query_content_async(request)
522+
logger.debug(f"ADB QueryContent response: {response}")
523+
524+
# 解析并返回结果 / Parse and return results
525+
return self._parse_query_content_response(response, query)
526+
527+
except Exception as e:
528+
logger.warning(
529+
"Failed to retrieve from ADB knowledge base "
530+
f"'{self.knowledge_base_name}': {e}"
531+
)
532+
return {
533+
"data": f"Failed to retrieve: {e}",
534+
"query": query,
535+
"knowledge_base_name": self.knowledge_base_name,
536+
"error": True,
537+
}
538+
539+
350540
def get_data_api(
351541
provider: KnowledgeBaseProvider,
352542
knowledge_base_name: str,
353543
config: Optional[Config] = None,
354544
provider_settings: Optional[
355-
Union[RagFlowProviderSettings, BailianProviderSettings]
545+
Union[
546+
RagFlowProviderSettings,
547+
BailianProviderSettings,
548+
ADBProviderSettings,
549+
]
356550
] = None,
357551
retrieve_settings: Optional[
358-
Union[RagFlowRetrieveSettings, BailianRetrieveSettings]
552+
Union[
553+
RagFlowRetrieveSettings,
554+
BailianRetrieveSettings,
555+
ADBRetrieveSettings,
556+
]
359557
] = None,
360558
credential_name: Optional[str] = None,
361559
) -> KnowledgeBaseDataAPI:
@@ -410,5 +608,22 @@ def get_data_api(
410608
provider_settings=bailian_provider_settings,
411609
retrieve_settings=bailian_retrieve_settings,
412610
)
611+
elif provider == KnowledgeBaseProvider.ADB or provider == "adb":
612+
adb_provider_settings = (
613+
provider_settings
614+
if isinstance(provider_settings, ADBProviderSettings)
615+
else None
616+
)
617+
adb_retrieve_settings = (
618+
retrieve_settings
619+
if isinstance(retrieve_settings, ADBRetrieveSettings)
620+
else None
621+
)
622+
return ADBDataAPI(
623+
knowledge_base_name,
624+
config,
625+
provider_settings=adb_provider_settings,
626+
retrieve_settings=adb_retrieve_settings,
627+
)
413628
else:
414629
raise ValueError(f"Unsupported provider type: {provider}")

agentrun/knowledgebase/api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .control import KnowledgeBaseControlAPI
44
from .data import (
5+
ADBDataAPI,
56
BailianDataAPI,
67
get_data_api,
78
KnowledgeBaseDataAPI,
@@ -15,5 +16,6 @@
1516
"KnowledgeBaseDataAPI",
1617
"RagFlowDataAPI",
1718
"BailianDataAPI",
19+
"ADBDataAPI",
1820
"get_data_api",
1921
]

0 commit comments

Comments
 (0)