Skip to content

Commit 7b731eb

Browse files
whatevertogoclaude
andauthored
test: enhance test framework with comprehensive fixtures and mocks (#5354)
* test: enhance test framework with comprehensive fixtures and mocks - Add shared mock builders for aiocqhttp, discord, telegram - Add test helpers for platform configs and mock objects - Expand conftest.py with test profile support - Update coverage test workflow configuration Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor(tests): 移动并重构模拟 LLM 响应和消息组件函数 * fix(tests): 优化 pytest_runtest_setup 中的标记检查逻辑 --------- Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 28bfb3b commit 7b731eb

File tree

12 files changed

+1259
-1
lines changed

12 files changed

+1259
-1
lines changed

.github/workflows/coverage_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
mkdir -p data/temp
3838
export TESTING=true
3939
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
40-
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
40+
pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG
4141
4242
- name: Upload results to Codecov
4343
uses: codecov/codecov-action@v5

tests/conftest.py

Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
"""
2+
AstrBot 测试配置
3+
4+
提供共享的 pytest fixtures 和测试工具。
5+
"""
6+
7+
import json
8+
import os
9+
import sys
10+
from asyncio import Queue
11+
from pathlib import Path
12+
from typing import Any
13+
from unittest.mock import AsyncMock, MagicMock
14+
15+
import pytest
16+
import pytest_asyncio
17+
18+
# 使用 tests/fixtures/helpers.py 中的共享工具函数,避免重复定义
19+
from tests.fixtures.helpers import create_mock_llm_response, create_mock_message_component
20+
21+
# 将项目根目录添加到 sys.path
22+
PROJECT_ROOT = Path(__file__).parent.parent
23+
if str(PROJECT_ROOT) not in sys.path:
24+
sys.path.insert(0, str(PROJECT_ROOT))
25+
26+
# 设置测试环境变量
27+
os.environ.setdefault("TESTING", "true")
28+
os.environ.setdefault("ASTRBOT_TEST_MODE", "true")
29+
30+
31+
# ============================================================
32+
# 测试收集和排序
33+
# ============================================================
34+
35+
36+
def pytest_collection_modifyitems(session, config, items): # noqa: ARG001
37+
"""重新排序测试:单元测试优先,集成测试在后。"""
38+
unit_tests = []
39+
integration_tests = []
40+
deselected = []
41+
profile = config.getoption("--test-profile") or os.environ.get(
42+
"ASTRBOT_TEST_PROFILE", "all"
43+
)
44+
45+
for item in items:
46+
item_path = Path(str(item.path))
47+
is_integration = "integration" in item_path.parts
48+
49+
if is_integration:
50+
if item.get_closest_marker("integration") is None:
51+
item.add_marker(pytest.mark.integration)
52+
item.add_marker(pytest.mark.tier_d)
53+
integration_tests.append(item)
54+
else:
55+
if item.get_closest_marker("unit") is None:
56+
item.add_marker(pytest.mark.unit)
57+
if any(
58+
item.get_closest_marker(marker) is not None
59+
for marker in ("platform", "provider", "slow")
60+
):
61+
item.add_marker(pytest.mark.tier_c)
62+
unit_tests.append(item)
63+
64+
# 单元测试 -> 集成测试
65+
ordered_items = unit_tests + integration_tests
66+
if profile == "blocking":
67+
selected_items = []
68+
for item in ordered_items:
69+
if item.get_closest_marker("tier_c") or item.get_closest_marker("tier_d"):
70+
deselected.append(item)
71+
else:
72+
selected_items.append(item)
73+
if deselected:
74+
config.hook.pytest_deselected(items=deselected)
75+
items[:] = selected_items
76+
return
77+
78+
items[:] = ordered_items
79+
80+
81+
def pytest_addoption(parser):
82+
"""增加测试执行档位选择。"""
83+
parser.addoption(
84+
"--test-profile",
85+
action="store",
86+
default=None,
87+
choices=["all", "blocking"],
88+
help="Select test profile. 'blocking' excludes auto-classified tier_c/tier_d tests.",
89+
)
90+
91+
92+
def pytest_configure(config):
93+
"""注册自定义标记。"""
94+
config.addinivalue_line("markers", "unit: 单元测试")
95+
config.addinivalue_line("markers", "integration: 集成测试")
96+
config.addinivalue_line("markers", "slow: 慢速测试")
97+
config.addinivalue_line("markers", "platform: 平台适配器测试")
98+
config.addinivalue_line("markers", "provider: LLM Provider 测试")
99+
config.addinivalue_line("markers", "db: 数据库相关测试")
100+
config.addinivalue_line("markers", "tier_c: C-tier tests (optional / non-blocking)")
101+
config.addinivalue_line("markers", "tier_d: D-tier tests (extended / integration)")
102+
103+
104+
# ============================================================
105+
# 临时目录和文件 Fixtures
106+
# ============================================================
107+
108+
109+
@pytest.fixture
110+
def temp_dir(tmp_path: Path) -> Path:
111+
"""创建临时目录用于测试。"""
112+
return tmp_path
113+
114+
115+
@pytest.fixture
116+
def event_queue() -> Queue:
117+
"""Create a shared asyncio queue fixture for tests."""
118+
return Queue()
119+
120+
121+
@pytest.fixture
122+
def platform_settings() -> dict:
123+
"""Create a shared empty platform settings fixture for adapter tests."""
124+
return {}
125+
126+
127+
@pytest.fixture
128+
def temp_data_dir(temp_dir: Path) -> Path:
129+
"""创建模拟的 data 目录结构。"""
130+
data_dir = temp_dir / "data"
131+
data_dir.mkdir()
132+
133+
# 创建必要的子目录
134+
(data_dir / "config").mkdir()
135+
(data_dir / "plugins").mkdir()
136+
(data_dir / "temp").mkdir()
137+
(data_dir / "attachments").mkdir()
138+
139+
return data_dir
140+
141+
142+
@pytest.fixture
143+
def temp_config_file(temp_data_dir: Path) -> Path:
144+
"""创建临时配置文件。"""
145+
config_path = temp_data_dir / "config" / "cmd_config.json"
146+
default_config = {
147+
"provider": [],
148+
"platform": [],
149+
"provider_settings": {},
150+
"default_personality": None,
151+
"timezone": "Asia/Shanghai",
152+
}
153+
config_path.write_text(json.dumps(default_config, indent=2), encoding="utf-8")
154+
return config_path
155+
156+
157+
@pytest.fixture
158+
def temp_db_file(temp_data_dir: Path) -> Path:
159+
"""创建临时数据库文件路径。"""
160+
return temp_data_dir / "test.db"
161+
162+
163+
# ============================================================
164+
# Mock Fixtures
165+
# ============================================================
166+
167+
168+
@pytest.fixture
169+
def mock_provider():
170+
"""创建模拟的 Provider。"""
171+
provider = MagicMock()
172+
provider.provider_config = {
173+
"id": "test-provider",
174+
"type": "openai_chat_completion",
175+
"model": "gpt-4o-mini",
176+
}
177+
provider.get_model = MagicMock(return_value="gpt-4o-mini")
178+
provider.text_chat = AsyncMock()
179+
provider.text_chat_stream = AsyncMock()
180+
provider.terminate = AsyncMock()
181+
return provider
182+
183+
184+
@pytest.fixture
185+
def mock_platform():
186+
"""创建模拟的 Platform。"""
187+
platform = MagicMock()
188+
platform.platform_name = "test_platform"
189+
platform.platform_meta = MagicMock()
190+
platform.platform_meta.support_proactive_message = False
191+
platform.send_message = AsyncMock()
192+
platform.terminate = AsyncMock()
193+
return platform
194+
195+
196+
@pytest.fixture
197+
def mock_conversation():
198+
"""创建模拟的 Conversation。"""
199+
from astrbot.core.db.po import ConversationV2
200+
201+
return ConversationV2(
202+
conversation_id="test-conv-id",
203+
platform_id="test_platform",
204+
user_id="test_user",
205+
content=[],
206+
persona_id=None,
207+
)
208+
209+
210+
@pytest.fixture
211+
def mock_event():
212+
"""创建模拟的 AstrMessageEvent。"""
213+
event = MagicMock()
214+
event.unified_msg_origin = "test_umo"
215+
event.session_id = "test_session"
216+
event.message_str = "Hello, world!"
217+
event.message_obj = MagicMock()
218+
event.message_obj.message = []
219+
event.message_obj.sender = MagicMock()
220+
event.message_obj.sender.user_id = "test_user"
221+
event.message_obj.sender.nickname = "Test User"
222+
event.message_obj.group_id = None
223+
event.message_obj.group = None
224+
event.get_platform_name = MagicMock(return_value="test_platform")
225+
event.get_platform_id = MagicMock(return_value="test_platform")
226+
event.get_group_id = MagicMock(return_value=None)
227+
event.get_extra = MagicMock(return_value=None)
228+
event.set_extra = MagicMock()
229+
event.trace = MagicMock()
230+
event.platform_meta = MagicMock()
231+
event.platform_meta.support_proactive_message = False
232+
return event
233+
234+
235+
# ============================================================
236+
# 配置 Fixtures
237+
# ============================================================
238+
239+
240+
@pytest.fixture
241+
def astrbot_config(temp_config_file: Path):
242+
"""创建 AstrBotConfig 实例。"""
243+
from astrbot.core.config.astrbot_config import AstrBotConfig
244+
245+
config = AstrBotConfig()
246+
config._config_path = str(temp_config_file) # noqa: SLF001
247+
return config
248+
249+
250+
@pytest.fixture
251+
def main_agent_build_config():
252+
"""创建 MainAgentBuildConfig 实例。"""
253+
from astrbot.core.astr_main_agent import MainAgentBuildConfig
254+
255+
return MainAgentBuildConfig(
256+
tool_call_timeout=60,
257+
tool_schema_mode="full",
258+
provider_wake_prefix="",
259+
streaming_response=True,
260+
sanitize_context_by_modalities=False,
261+
kb_agentic_mode=False,
262+
file_extract_enabled=False,
263+
context_limit_reached_strategy="truncate_by_turns",
264+
llm_safety_mode=True,
265+
computer_use_runtime="local",
266+
add_cron_tools=True,
267+
)
268+
269+
270+
# ============================================================
271+
# 数据库 Fixtures
272+
# ============================================================
273+
274+
275+
@pytest_asyncio.fixture
276+
async def temp_db(temp_db_file: Path):
277+
"""创建临时数据库实例。"""
278+
from astrbot.core.db.sqlite import SQLiteDatabase
279+
280+
db = SQLiteDatabase(str(temp_db_file))
281+
try:
282+
yield db
283+
finally:
284+
await db.engine.dispose()
285+
if temp_db_file.exists():
286+
temp_db_file.unlink()
287+
288+
289+
# ============================================================
290+
# Context Fixtures
291+
# ============================================================
292+
293+
294+
@pytest_asyncio.fixture
295+
async def mock_context(
296+
astrbot_config,
297+
temp_db,
298+
mock_provider,
299+
mock_platform,
300+
):
301+
"""创建模拟的插件上下文。"""
302+
from asyncio import Queue
303+
304+
from astrbot.core.star.context import Context
305+
306+
event_queue = Queue()
307+
308+
provider_manager = MagicMock()
309+
provider_manager.get_using_provider = MagicMock(return_value=mock_provider)
310+
provider_manager.get_provider_by_id = MagicMock(return_value=mock_provider)
311+
312+
platform_manager = MagicMock()
313+
conversation_manager = MagicMock()
314+
message_history_manager = MagicMock()
315+
persona_manager = MagicMock()
316+
persona_manager.personas_v3 = []
317+
astrbot_config_mgr = MagicMock()
318+
knowledge_base_manager = MagicMock()
319+
cron_manager = MagicMock()
320+
subagent_orchestrator = None
321+
322+
context = Context(
323+
event_queue,
324+
astrbot_config,
325+
temp_db,
326+
provider_manager,
327+
platform_manager,
328+
conversation_manager,
329+
message_history_manager,
330+
persona_manager,
331+
astrbot_config_mgr,
332+
knowledge_base_manager,
333+
cron_manager,
334+
subagent_orchestrator,
335+
)
336+
337+
return context
338+
339+
340+
# ============================================================
341+
# Provider Request Fixtures
342+
# ============================================================
343+
344+
345+
@pytest.fixture
346+
def provider_request():
347+
"""创建 ProviderRequest 实例。"""
348+
from astrbot.core.provider.entities import ProviderRequest
349+
350+
return ProviderRequest(
351+
prompt="Hello",
352+
session_id="test_session",
353+
image_urls=[],
354+
contexts=[],
355+
system_prompt="You are a helpful assistant.",
356+
)
357+
358+
359+
# ============================================================
360+
# 跳过条件
361+
# ============================================================
362+
363+
364+
def pytest_runtest_setup(item):
365+
"""在测试运行前检查跳过条件。"""
366+
# 跳过需要 API Key 但未设置的 Provider 测试
367+
if item.get_closest_marker("provider"):
368+
if not os.environ.get("TEST_PROVIDER_API_KEY"):
369+
pytest.skip("TEST_PROVIDER_API_KEY not set")
370+
371+
# 跳过需要特定平台的测试
372+
if item.get_closest_marker("platform"):
373+
required_platform = None
374+
marker = item.get_closest_marker("platform")
375+
if marker and marker.args:
376+
required_platform = marker.args[0]
377+
378+
if required_platform and not os.environ.get(
379+
f"TEST_{required_platform.upper()}_ENABLED"
380+
):
381+
pytest.skip(f"TEST_{required_platform.upper()}_ENABLED not set")

0 commit comments

Comments
 (0)