Skip to content

Commit ec3fdc7

Browse files
committed
test: add comprehensive tests and improve test structure for adapters
1 parent 714166a commit ec3fdc7

7 files changed

Lines changed: 170 additions & 204 deletions

File tree

AGENTS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ Runs on `http://localhost:3000` by default.
2828
5. Use English for all new comments.
2929
6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.path_utils` to get the AstrBot data and temp directory.
3030

31+
## Testing
32+
33+
When you modify functionality, add or update a corresponding test and run it locally (e.g. `uv run pytest tests/path/to/test_xxx.py --cov=astrbot.xxx`).
34+
Use `--cov-report term-missing` or similar to generate coverage information.
35+
36+
3137
## PR instructions
3238

3339
1. Title format: use conventional commit messages

astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import asyncio
2+
import importlib
23
import itertools
34
import logging
45
import time
56
import uuid
6-
from collections.abc import Awaitable
7+
from collections.abc import Awaitable, Callable
78
from typing import Any, cast
89

910
import aiocqhttp
@@ -45,6 +46,7 @@ def __init__(
4546
platform_config: dict,
4647
platform_settings: dict,
4748
event_queue: asyncio.Queue,
49+
bot_factory: Callable[..., Any] | None = None,
4850
) -> None:
4951
super().__init__(platform_config, event_queue)
5052

@@ -59,14 +61,7 @@ def __init__(
5961
support_streaming_message=False,
6062
)
6163

62-
self.bot = aiocqhttp.CQHttp(
63-
use_ws_reverse=True,
64-
import_name="aiocqhttp",
65-
api_timeout_sec=180,
66-
access_token=platform_config.get(
67-
"ws_reverse_token",
68-
), # 以防旧版本配置不存在
69-
)
64+
self.bot = self._create_bot(platform_config, bot_factory=bot_factory)
7065

7166
@self.bot.on_request()
7267
async def request(event: aiocqhttp.Event) -> None:
@@ -113,6 +108,29 @@ async def private(event: aiocqhttp.Event) -> None:
113108
def on_websocket_connection(_) -> None:
114109
logger.info("aiocqhttp(OneBot v11) 适配器已连接。")
115110

111+
@staticmethod
112+
def _create_bot(
113+
platform_config: dict,
114+
bot_factory: Callable[..., Any] | None = None,
115+
) -> aiocqhttp.CQHttp:
116+
if bot_factory is None:
117+
# Resolve aiocqhttp at runtime so tests that swap sys.modules later
118+
# still affect bot creation even if this module was imported earlier.
119+
aiocqhttp_module = importlib.import_module("aiocqhttp")
120+
bot_factory = aiocqhttp_module.CQHttp
121+
122+
return cast(
123+
aiocqhttp.CQHttp,
124+
bot_factory(
125+
use_ws_reverse=True,
126+
import_name="aiocqhttp",
127+
api_timeout_sec=180,
128+
access_token=platform_config.get(
129+
"ws_reverse_token",
130+
), # 以防旧版本配置不存在
131+
),
132+
)
133+
116134
async def send_by_session(
117135
self,
118136
session: MessageSesion,

tests/conftest.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import os
99
import sys
10+
from asyncio import Queue
1011
from pathlib import Path
1112
from typing import Any
1213
from unittest.mock import AsyncMock, MagicMock
@@ -33,6 +34,10 @@ def pytest_collection_modifyitems(session, config, items): # noqa: ARG001
3334
"""重新排序测试:单元测试优先,集成测试在后。"""
3435
unit_tests = []
3536
integration_tests = []
37+
deselected = []
38+
profile = config.getoption("--test-profile") or os.environ.get(
39+
"ASTRBOT_TEST_PROFILE", "all"
40+
)
3641

3742
for item in items:
3843
item_path = Path(str(item.path))
@@ -41,14 +46,44 @@ def pytest_collection_modifyitems(session, config, items): # noqa: ARG001
4146
if is_integration:
4247
if item.get_closest_marker("integration") is None:
4348
item.add_marker(pytest.mark.integration)
49+
item.add_marker(pytest.mark.tier_d)
4450
integration_tests.append(item)
4551
else:
4652
if item.get_closest_marker("unit") is None:
4753
item.add_marker(pytest.mark.unit)
54+
if any(
55+
item.get_closest_marker(marker) is not None
56+
for marker in ("platform", "provider", "slow")
57+
):
58+
item.add_marker(pytest.mark.tier_c)
4859
unit_tests.append(item)
4960

5061
# 单元测试 -> 集成测试
51-
items[:] = unit_tests + integration_tests
62+
ordered_items = unit_tests + integration_tests
63+
if profile == "blocking":
64+
selected_items = []
65+
for item in ordered_items:
66+
if item.get_closest_marker("tier_c") or item.get_closest_marker("tier_d"):
67+
deselected.append(item)
68+
else:
69+
selected_items.append(item)
70+
if deselected:
71+
config.hook.pytest_deselected(items=deselected)
72+
items[:] = selected_items
73+
return
74+
75+
items[:] = ordered_items
76+
77+
78+
def pytest_addoption(parser):
79+
"""增加测试执行档位选择。"""
80+
parser.addoption(
81+
"--test-profile",
82+
action="store",
83+
default=None,
84+
choices=["all", "blocking"],
85+
help="Select test profile. 'blocking' excludes auto-classified tier_c/tier_d tests.",
86+
)
5287

5388

5489
def pytest_configure(config):
@@ -59,6 +94,8 @@ def pytest_configure(config):
5994
config.addinivalue_line("markers", "platform: 平台适配器测试")
6095
config.addinivalue_line("markers", "provider: LLM Provider 测试")
6196
config.addinivalue_line("markers", "db: 数据库相关测试")
97+
config.addinivalue_line("markers", "tier_c: C-tier tests (optional / non-blocking)")
98+
config.addinivalue_line("markers", "tier_d: D-tier tests (extended / integration)")
6299

63100

64101
# ============================================================
@@ -72,6 +109,18 @@ def temp_dir(tmp_path: Path) -> Path:
72109
return tmp_path
73110

74111

112+
@pytest.fixture
113+
def event_queue() -> Queue:
114+
"""Create a shared asyncio queue fixture for tests."""
115+
return Queue()
116+
117+
118+
@pytest.fixture
119+
def platform_settings() -> dict:
120+
"""Create a shared empty platform settings fixture for adapter tests."""
121+
return {}
122+
123+
75124
@pytest.fixture
76125
def temp_data_dir(temp_dir: Path) -> Path:
77126
"""创建模拟的 data 目录结构。"""

tests/test_api_key_open_api.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from astrbot.dashboard.server import AstrBotDashboard
1313

1414

15-
@pytest_asyncio.fixture(scope="module")
15+
@pytest_asyncio.fixture(scope="module", loop_scope="module")
1616
async def core_lifecycle_td(tmp_path_factory):
1717
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_api_key.db"
1818
db = SQLiteDatabase(str(tmp_db_path))
@@ -37,7 +37,7 @@ def app(core_lifecycle_td: AstrBotCoreLifecycle):
3737
return server.app
3838

3939

40-
@pytest_asyncio.fixture(scope="module")
40+
@pytest_asyncio.fixture(scope="module", loop_scope="module")
4141
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
4242
test_client = app.test_client()
4343
response = await test_client.post(
@@ -258,7 +258,7 @@ async def test_open_chat_sessions_pagination(
258258
assert create_data["status"] == "ok"
259259
raw_key = create_data["data"]["api_key"]
260260

261-
creator = "alice"
261+
creator = f"alice_{uuid.uuid4().hex[:8]}"
262262
for idx in range(3):
263263
await core_lifecycle_td.db.create_platform_session(
264264
creator=creator,
@@ -276,7 +276,8 @@ async def test_open_chat_sessions_pagination(
276276
)
277277

278278
page_1_res = await test_client.get(
279-
"/api/v1/chat/sessions?page=1&page_size=2&username=alice",
279+
"/api/v1/chat/sessions?page=1&page_size=2&username="
280+
f"{creator}",
280281
headers={"X-API-Key": raw_key},
281282
)
282283
assert page_1_res.status_code == 200
@@ -286,10 +287,11 @@ async def test_open_chat_sessions_pagination(
286287
assert page_1_data["data"]["page_size"] == 2
287288
assert page_1_data["data"]["total"] == 3
288289
assert len(page_1_data["data"]["sessions"]) == 2
289-
assert all(item["creator"] == "alice" for item in page_1_data["data"]["sessions"])
290+
assert all(item["creator"] == creator for item in page_1_data["data"]["sessions"])
290291

291292
page_2_res = await test_client.get(
292-
"/api/v1/chat/sessions?page=2&page_size=2&username=alice",
293+
"/api/v1/chat/sessions?page=2&page_size=2&username="
294+
f"{creator}",
293295
headers={"X-API-Key": raw_key},
294296
)
295297
assert page_2_res.status_code == 200

tests/test_smoke.py

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,32 @@
77
from pathlib import Path
88

99
from astrbot.core.pipeline.bootstrap import ensure_builtin_stages_registered
10-
from astrbot.core.pipeline.stage import Stage, registered_stages
11-
from astrbot.core.pipeline.stage_order import STAGES_ORDER
1210
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import (
1311
InternalAgentSubStage,
1412
)
1513
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.third_party import (
1614
ThirdPartyAgentSubStage,
1715
)
16+
from astrbot.core.pipeline.stage import Stage, registered_stages
17+
from astrbot.core.pipeline.stage_order import STAGES_ORDER
18+
19+
REPO_ROOT = Path(__file__).resolve().parents[1]
20+
21+
22+
def _run_code_in_fresh_interpreter(code: str, failure_message: str) -> None:
23+
proc = subprocess.run(
24+
[sys.executable, "-c", code],
25+
cwd=REPO_ROOT,
26+
capture_output=True,
27+
text=True,
28+
check=False,
29+
)
30+
assert proc.returncode == 0, (
31+
f"{failure_message}\nstdout:\n{proc.stdout}\nstderr:\n{proc.stderr}\n"
32+
)
1833

1934

2035
def test_smoke_critical_imports_in_fresh_interpreter() -> None:
21-
repo_root = Path(__file__).resolve().parents[1]
2236
code = (
2337
"import importlib;"
2438
"mods=["
@@ -30,18 +44,7 @@ def test_smoke_critical_imports_in_fresh_interpreter() -> None:
3044
"];"
3145
"[importlib.import_module(m) for m in mods]"
3246
)
33-
proc = subprocess.run(
34-
[sys.executable, "-c", code],
35-
cwd=repo_root,
36-
capture_output=True,
37-
text=True,
38-
check=False,
39-
)
40-
assert proc.returncode == 0, (
41-
"Smoke import check failed.\n"
42-
f"stdout:\n{proc.stdout}\n"
43-
f"stderr:\n{proc.stderr}\n"
44-
)
47+
_run_code_in_fresh_interpreter(code, "Smoke import check failed.")
4548

4649

4750
def test_smoke_pipeline_stage_registration_matches_order() -> None:
@@ -55,3 +58,58 @@ def test_smoke_pipeline_stage_registration_matches_order() -> None:
5558
def test_smoke_agent_sub_stages_are_stage_subclasses() -> None:
5659
assert issubclass(InternalAgentSubStage, Stage)
5760
assert issubclass(ThirdPartyAgentSubStage, Stage)
61+
62+
63+
def test_pipeline_package_exports_remain_compatible() -> None:
64+
import astrbot.core.pipeline as pipeline
65+
66+
assert pipeline.ProcessStage is not None
67+
assert pipeline.RespondStage is not None
68+
assert isinstance(pipeline.STAGES_ORDER, list)
69+
assert "ProcessStage" in pipeline.STAGES_ORDER
70+
71+
72+
def test_builtin_stage_bootstrap_is_idempotent() -> None:
73+
ensure_builtin_stages_registered()
74+
before_count = len(registered_stages)
75+
stage_names = {cls.__name__ for cls in registered_stages}
76+
77+
expected_stage_names = {
78+
"WakingCheckStage",
79+
"WhitelistCheckStage",
80+
"SessionStatusCheckStage",
81+
"RateLimitStage",
82+
"ContentSafetyCheckStage",
83+
"PreProcessStage",
84+
"ProcessStage",
85+
"ResultDecorateStage",
86+
"RespondStage",
87+
}
88+
89+
assert expected_stage_names.issubset(stage_names)
90+
91+
ensure_builtin_stages_registered()
92+
assert len(registered_stages) == before_count
93+
94+
95+
def test_pipeline_import_is_stable_with_mocked_apscheduler() -> None:
96+
"""Regression: importing pipeline should not require cron/apscheduler modules."""
97+
code = (
98+
"import sys;"
99+
"from unittest.mock import MagicMock;"
100+
"mock_apscheduler = MagicMock();"
101+
"mock_apscheduler.schedulers = MagicMock();"
102+
"mock_apscheduler.schedulers.asyncio = MagicMock();"
103+
"mock_apscheduler.schedulers.background = MagicMock();"
104+
"sys.modules['apscheduler'] = mock_apscheduler;"
105+
"sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;"
106+
"sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;"
107+
"sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;"
108+
"import astrbot.core.pipeline as pipeline;"
109+
"assert pipeline.ProcessStage is not None;"
110+
"assert pipeline.RespondStage is not None"
111+
)
112+
_run_code_in_fresh_interpreter(
113+
code,
114+
"Pipeline import should not depend on real apscheduler package.",
115+
)

0 commit comments

Comments
 (0)