|
| 1 | +""" |
| 2 | +AstrBot 测试配置 |
| 3 | +
|
| 4 | +提供共享的 pytest fixtures 和测试工具。 |
| 5 | +""" |
| 6 | + |
| 7 | +import json |
| 8 | +import os |
| 9 | +import sys |
| 10 | +from asyncio import Queue |
| 11 | +from pathlib import Path |
| 12 | +from typing import Any |
| 13 | +from unittest.mock import AsyncMock, MagicMock |
| 14 | + |
| 15 | +import pytest |
| 16 | +import pytest_asyncio |
| 17 | + |
| 18 | +# 使用 tests/fixtures/helpers.py 中的共享工具函数,避免重复定义 |
| 19 | +from tests.fixtures.helpers import create_mock_llm_response, create_mock_message_component |
| 20 | + |
| 21 | +# 将项目根目录添加到 sys.path |
| 22 | +PROJECT_ROOT = Path(__file__).parent.parent |
| 23 | +if str(PROJECT_ROOT) not in sys.path: |
| 24 | + sys.path.insert(0, str(PROJECT_ROOT)) |
| 25 | + |
| 26 | +# 设置测试环境变量 |
| 27 | +os.environ.setdefault("TESTING", "true") |
| 28 | +os.environ.setdefault("ASTRBOT_TEST_MODE", "true") |
| 29 | + |
| 30 | + |
| 31 | +# ============================================================ |
| 32 | +# 测试收集和排序 |
| 33 | +# ============================================================ |
| 34 | + |
| 35 | + |
| 36 | +def pytest_collection_modifyitems(session, config, items): # noqa: ARG001 |
| 37 | + """重新排序测试:单元测试优先,集成测试在后。""" |
| 38 | + unit_tests = [] |
| 39 | + integration_tests = [] |
| 40 | + deselected = [] |
| 41 | + profile = config.getoption("--test-profile") or os.environ.get( |
| 42 | + "ASTRBOT_TEST_PROFILE", "all" |
| 43 | + ) |
| 44 | + |
| 45 | + for item in items: |
| 46 | + item_path = Path(str(item.path)) |
| 47 | + is_integration = "integration" in item_path.parts |
| 48 | + |
| 49 | + if is_integration: |
| 50 | + if item.get_closest_marker("integration") is None: |
| 51 | + item.add_marker(pytest.mark.integration) |
| 52 | + item.add_marker(pytest.mark.tier_d) |
| 53 | + integration_tests.append(item) |
| 54 | + else: |
| 55 | + if item.get_closest_marker("unit") is None: |
| 56 | + item.add_marker(pytest.mark.unit) |
| 57 | + if any( |
| 58 | + item.get_closest_marker(marker) is not None |
| 59 | + for marker in ("platform", "provider", "slow") |
| 60 | + ): |
| 61 | + item.add_marker(pytest.mark.tier_c) |
| 62 | + unit_tests.append(item) |
| 63 | + |
| 64 | + # 单元测试 -> 集成测试 |
| 65 | + ordered_items = unit_tests + integration_tests |
| 66 | + if profile == "blocking": |
| 67 | + selected_items = [] |
| 68 | + for item in ordered_items: |
| 69 | + if item.get_closest_marker("tier_c") or item.get_closest_marker("tier_d"): |
| 70 | + deselected.append(item) |
| 71 | + else: |
| 72 | + selected_items.append(item) |
| 73 | + if deselected: |
| 74 | + config.hook.pytest_deselected(items=deselected) |
| 75 | + items[:] = selected_items |
| 76 | + return |
| 77 | + |
| 78 | + items[:] = ordered_items |
| 79 | + |
| 80 | + |
| 81 | +def pytest_addoption(parser): |
| 82 | + """增加测试执行档位选择。""" |
| 83 | + parser.addoption( |
| 84 | + "--test-profile", |
| 85 | + action="store", |
| 86 | + default=None, |
| 87 | + choices=["all", "blocking"], |
| 88 | + help="Select test profile. 'blocking' excludes auto-classified tier_c/tier_d tests.", |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +def pytest_configure(config): |
| 93 | + """注册自定义标记。""" |
| 94 | + config.addinivalue_line("markers", "unit: 单元测试") |
| 95 | + config.addinivalue_line("markers", "integration: 集成测试") |
| 96 | + config.addinivalue_line("markers", "slow: 慢速测试") |
| 97 | + config.addinivalue_line("markers", "platform: 平台适配器测试") |
| 98 | + config.addinivalue_line("markers", "provider: LLM Provider 测试") |
| 99 | + config.addinivalue_line("markers", "db: 数据库相关测试") |
| 100 | + config.addinivalue_line("markers", "tier_c: C-tier tests (optional / non-blocking)") |
| 101 | + config.addinivalue_line("markers", "tier_d: D-tier tests (extended / integration)") |
| 102 | + |
| 103 | + |
| 104 | +# ============================================================ |
| 105 | +# 临时目录和文件 Fixtures |
| 106 | +# ============================================================ |
| 107 | + |
| 108 | + |
| 109 | +@pytest.fixture |
| 110 | +def temp_dir(tmp_path: Path) -> Path: |
| 111 | + """创建临时目录用于测试。""" |
| 112 | + return tmp_path |
| 113 | + |
| 114 | + |
| 115 | +@pytest.fixture |
| 116 | +def event_queue() -> Queue: |
| 117 | + """Create a shared asyncio queue fixture for tests.""" |
| 118 | + return Queue() |
| 119 | + |
| 120 | + |
| 121 | +@pytest.fixture |
| 122 | +def platform_settings() -> dict: |
| 123 | + """Create a shared empty platform settings fixture for adapter tests.""" |
| 124 | + return {} |
| 125 | + |
| 126 | + |
| 127 | +@pytest.fixture |
| 128 | +def temp_data_dir(temp_dir: Path) -> Path: |
| 129 | + """创建模拟的 data 目录结构。""" |
| 130 | + data_dir = temp_dir / "data" |
| 131 | + data_dir.mkdir() |
| 132 | + |
| 133 | + # 创建必要的子目录 |
| 134 | + (data_dir / "config").mkdir() |
| 135 | + (data_dir / "plugins").mkdir() |
| 136 | + (data_dir / "temp").mkdir() |
| 137 | + (data_dir / "attachments").mkdir() |
| 138 | + |
| 139 | + return data_dir |
| 140 | + |
| 141 | + |
| 142 | +@pytest.fixture |
| 143 | +def temp_config_file(temp_data_dir: Path) -> Path: |
| 144 | + """创建临时配置文件。""" |
| 145 | + config_path = temp_data_dir / "config" / "cmd_config.json" |
| 146 | + default_config = { |
| 147 | + "provider": [], |
| 148 | + "platform": [], |
| 149 | + "provider_settings": {}, |
| 150 | + "default_personality": None, |
| 151 | + "timezone": "Asia/Shanghai", |
| 152 | + } |
| 153 | + config_path.write_text(json.dumps(default_config, indent=2), encoding="utf-8") |
| 154 | + return config_path |
| 155 | + |
| 156 | + |
| 157 | +@pytest.fixture |
| 158 | +def temp_db_file(temp_data_dir: Path) -> Path: |
| 159 | + """创建临时数据库文件路径。""" |
| 160 | + return temp_data_dir / "test.db" |
| 161 | + |
| 162 | + |
| 163 | +# ============================================================ |
| 164 | +# Mock Fixtures |
| 165 | +# ============================================================ |
| 166 | + |
| 167 | + |
| 168 | +@pytest.fixture |
| 169 | +def mock_provider(): |
| 170 | + """创建模拟的 Provider。""" |
| 171 | + provider = MagicMock() |
| 172 | + provider.provider_config = { |
| 173 | + "id": "test-provider", |
| 174 | + "type": "openai_chat_completion", |
| 175 | + "model": "gpt-4o-mini", |
| 176 | + } |
| 177 | + provider.get_model = MagicMock(return_value="gpt-4o-mini") |
| 178 | + provider.text_chat = AsyncMock() |
| 179 | + provider.text_chat_stream = AsyncMock() |
| 180 | + provider.terminate = AsyncMock() |
| 181 | + return provider |
| 182 | + |
| 183 | + |
| 184 | +@pytest.fixture |
| 185 | +def mock_platform(): |
| 186 | + """创建模拟的 Platform。""" |
| 187 | + platform = MagicMock() |
| 188 | + platform.platform_name = "test_platform" |
| 189 | + platform.platform_meta = MagicMock() |
| 190 | + platform.platform_meta.support_proactive_message = False |
| 191 | + platform.send_message = AsyncMock() |
| 192 | + platform.terminate = AsyncMock() |
| 193 | + return platform |
| 194 | + |
| 195 | + |
| 196 | +@pytest.fixture |
| 197 | +def mock_conversation(): |
| 198 | + """创建模拟的 Conversation。""" |
| 199 | + from astrbot.core.db.po import ConversationV2 |
| 200 | + |
| 201 | + return ConversationV2( |
| 202 | + conversation_id="test-conv-id", |
| 203 | + platform_id="test_platform", |
| 204 | + user_id="test_user", |
| 205 | + content=[], |
| 206 | + persona_id=None, |
| 207 | + ) |
| 208 | + |
| 209 | + |
| 210 | +@pytest.fixture |
| 211 | +def mock_event(): |
| 212 | + """创建模拟的 AstrMessageEvent。""" |
| 213 | + event = MagicMock() |
| 214 | + event.unified_msg_origin = "test_umo" |
| 215 | + event.session_id = "test_session" |
| 216 | + event.message_str = "Hello, world!" |
| 217 | + event.message_obj = MagicMock() |
| 218 | + event.message_obj.message = [] |
| 219 | + event.message_obj.sender = MagicMock() |
| 220 | + event.message_obj.sender.user_id = "test_user" |
| 221 | + event.message_obj.sender.nickname = "Test User" |
| 222 | + event.message_obj.group_id = None |
| 223 | + event.message_obj.group = None |
| 224 | + event.get_platform_name = MagicMock(return_value="test_platform") |
| 225 | + event.get_platform_id = MagicMock(return_value="test_platform") |
| 226 | + event.get_group_id = MagicMock(return_value=None) |
| 227 | + event.get_extra = MagicMock(return_value=None) |
| 228 | + event.set_extra = MagicMock() |
| 229 | + event.trace = MagicMock() |
| 230 | + event.platform_meta = MagicMock() |
| 231 | + event.platform_meta.support_proactive_message = False |
| 232 | + return event |
| 233 | + |
| 234 | + |
| 235 | +# ============================================================ |
| 236 | +# 配置 Fixtures |
| 237 | +# ============================================================ |
| 238 | + |
| 239 | + |
| 240 | +@pytest.fixture |
| 241 | +def astrbot_config(temp_config_file: Path): |
| 242 | + """创建 AstrBotConfig 实例。""" |
| 243 | + from astrbot.core.config.astrbot_config import AstrBotConfig |
| 244 | + |
| 245 | + config = AstrBotConfig() |
| 246 | + config._config_path = str(temp_config_file) # noqa: SLF001 |
| 247 | + return config |
| 248 | + |
| 249 | + |
| 250 | +@pytest.fixture |
| 251 | +def main_agent_build_config(): |
| 252 | + """创建 MainAgentBuildConfig 实例。""" |
| 253 | + from astrbot.core.astr_main_agent import MainAgentBuildConfig |
| 254 | + |
| 255 | + return MainAgentBuildConfig( |
| 256 | + tool_call_timeout=60, |
| 257 | + tool_schema_mode="full", |
| 258 | + provider_wake_prefix="", |
| 259 | + streaming_response=True, |
| 260 | + sanitize_context_by_modalities=False, |
| 261 | + kb_agentic_mode=False, |
| 262 | + file_extract_enabled=False, |
| 263 | + context_limit_reached_strategy="truncate_by_turns", |
| 264 | + llm_safety_mode=True, |
| 265 | + computer_use_runtime="local", |
| 266 | + add_cron_tools=True, |
| 267 | + ) |
| 268 | + |
| 269 | + |
| 270 | +# ============================================================ |
| 271 | +# 数据库 Fixtures |
| 272 | +# ============================================================ |
| 273 | + |
| 274 | + |
| 275 | +@pytest_asyncio.fixture |
| 276 | +async def temp_db(temp_db_file: Path): |
| 277 | + """创建临时数据库实例。""" |
| 278 | + from astrbot.core.db.sqlite import SQLiteDatabase |
| 279 | + |
| 280 | + db = SQLiteDatabase(str(temp_db_file)) |
| 281 | + try: |
| 282 | + yield db |
| 283 | + finally: |
| 284 | + await db.engine.dispose() |
| 285 | + if temp_db_file.exists(): |
| 286 | + temp_db_file.unlink() |
| 287 | + |
| 288 | + |
| 289 | +# ============================================================ |
| 290 | +# Context Fixtures |
| 291 | +# ============================================================ |
| 292 | + |
| 293 | + |
| 294 | +@pytest_asyncio.fixture |
| 295 | +async def mock_context( |
| 296 | + astrbot_config, |
| 297 | + temp_db, |
| 298 | + mock_provider, |
| 299 | + mock_platform, |
| 300 | +): |
| 301 | + """创建模拟的插件上下文。""" |
| 302 | + from asyncio import Queue |
| 303 | + |
| 304 | + from astrbot.core.star.context import Context |
| 305 | + |
| 306 | + event_queue = Queue() |
| 307 | + |
| 308 | + provider_manager = MagicMock() |
| 309 | + provider_manager.get_using_provider = MagicMock(return_value=mock_provider) |
| 310 | + provider_manager.get_provider_by_id = MagicMock(return_value=mock_provider) |
| 311 | + |
| 312 | + platform_manager = MagicMock() |
| 313 | + conversation_manager = MagicMock() |
| 314 | + message_history_manager = MagicMock() |
| 315 | + persona_manager = MagicMock() |
| 316 | + persona_manager.personas_v3 = [] |
| 317 | + astrbot_config_mgr = MagicMock() |
| 318 | + knowledge_base_manager = MagicMock() |
| 319 | + cron_manager = MagicMock() |
| 320 | + subagent_orchestrator = None |
| 321 | + |
| 322 | + context = Context( |
| 323 | + event_queue, |
| 324 | + astrbot_config, |
| 325 | + temp_db, |
| 326 | + provider_manager, |
| 327 | + platform_manager, |
| 328 | + conversation_manager, |
| 329 | + message_history_manager, |
| 330 | + persona_manager, |
| 331 | + astrbot_config_mgr, |
| 332 | + knowledge_base_manager, |
| 333 | + cron_manager, |
| 334 | + subagent_orchestrator, |
| 335 | + ) |
| 336 | + |
| 337 | + return context |
| 338 | + |
| 339 | + |
| 340 | +# ============================================================ |
| 341 | +# Provider Request Fixtures |
| 342 | +# ============================================================ |
| 343 | + |
| 344 | + |
| 345 | +@pytest.fixture |
| 346 | +def provider_request(): |
| 347 | + """创建 ProviderRequest 实例。""" |
| 348 | + from astrbot.core.provider.entities import ProviderRequest |
| 349 | + |
| 350 | + return ProviderRequest( |
| 351 | + prompt="Hello", |
| 352 | + session_id="test_session", |
| 353 | + image_urls=[], |
| 354 | + contexts=[], |
| 355 | + system_prompt="You are a helpful assistant.", |
| 356 | + ) |
| 357 | + |
| 358 | + |
| 359 | +# ============================================================ |
| 360 | +# 跳过条件 |
| 361 | +# ============================================================ |
| 362 | + |
| 363 | + |
| 364 | +def pytest_runtest_setup(item): |
| 365 | + """在测试运行前检查跳过条件。""" |
| 366 | + # 跳过需要 API Key 但未设置的 Provider 测试 |
| 367 | + if item.get_closest_marker("provider"): |
| 368 | + if not os.environ.get("TEST_PROVIDER_API_KEY"): |
| 369 | + pytest.skip("TEST_PROVIDER_API_KEY not set") |
| 370 | + |
| 371 | + # 跳过需要特定平台的测试 |
| 372 | + if item.get_closest_marker("platform"): |
| 373 | + required_platform = None |
| 374 | + marker = item.get_closest_marker("platform") |
| 375 | + if marker and marker.args: |
| 376 | + required_platform = marker.args[0] |
| 377 | + |
| 378 | + if required_platform and not os.environ.get( |
| 379 | + f"TEST_{required_platform.upper()}_ENABLED" |
| 380 | + ): |
| 381 | + pytest.skip(f"TEST_{required_platform.upper()}_ENABLED not set") |
0 commit comments