diff --git a/pyproject.toml b/pyproject.toml index aafae9f..879df13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,12 +130,15 @@ disable = ["C0103", "R0903", "W0613"] [tool.pytest.ini_options] minversion = "7.0" -addopts = "-ra -q --strict-markers --strict-config" +addopts = "-ra --strict-markers --strict-config --asyncio-mode=auto --cov=src --cov-report=term-missing --cov-branch" testpaths = ["tests"] python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] -asyncio_mode = "auto" +markers = [ + "unit: fast-running isolated tests", + "integration: end-to-end tests hitting the FastAPI stack", +] [tool.coverage.run] source = ["src"] diff --git a/tests/README.md b/tests/README.md index aa9208c..4cca68a 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,137 +1,71 @@ # 测试套件说明 -## 🧪 最小可行测试方案 - -本项目实现了最小可行测试方案,专注于核心功能的质量保证。 - -### ✅ 已实现的测试 - -#### 1. JWT认证功能测试 (100% 覆盖率) -- **文件**: `test_simple_jwt.py`, `test_core_functionality.py` -- **覆盖内容**: - - 令牌创建和验证 - - 访问令牌和刷新令牌机制 - - 令牌类型安全验证 - - 过期令牌检测 - - 无效令牌处理 - -#### 2. 密码安全测试 (89% 覆盖率) -- **文件**: `test_core_functionality.py` -- **覆盖内容**: - - 密码哈希加密 - - 密码验证 - - 盐值随机性验证 - - 不同密码产生不同哈希 - -#### 3. 配置安全测试 (80% 覆盖率) -- **文件**: `test_core_functionality.py` -- **覆盖内容**: - - SECRET_KEY强度验证 - - JWT配置检查 - - 令牌过期时间配置验证 - -#### 4. 数据验证测试 (100% 覆盖率) -- **文件**: `test_core_functionality.py` -- **覆盖内容**: - - Pydantic Schema验证 - - 凭据数据验证 - - JWT载荷验证 - -### 🚀 运行测试 - -#### 运行核心功能测试 -```bash -# 运行核心功能测试 -uv run pytest tests/test_core_functionality.py -v +为了让项目更贴近测试驱动开发(TDD)的实践,我们对测试目录进行了分层,并提供统一的覆盖率统计方式。测试按职责拆分为**单元测试**与**集成测试**两大类: + +- `tests/unit/`:纯函数、工具方法等的快速测试,执行迅速且无外部依赖。 +- `tests/integration/`:通过 FastAPI 应用和临时数据库运行的端到端测试,覆盖真实业务流程。 -# 运行JWT专项测试 -uv run pytest tests/test_simple_jwt.py -v +## 运行方式 -# 运行所有测试并生成覆盖率报告 -uv run pytest tests/test_core_functionality.py tests/test_simple_jwt.py --cov=src --cov-report=term-missing --cov-report=html +### 1. 仅运行单元测试 +```bash +uv run pytest tests/unit ``` -#### CI/CD 自动测试 -项目已配置GitHub Actions自动测试,每次push和PR都会自动运行: -- 代码风格检查 (ruff) -- 类型检查 (mypy) -- 单元测试 (pytest) -- 测试覆盖率报告 - -### 📊 测试覆盖率 - -当前整体覆盖率:**14%** - -**核心模块覆盖率**: -- `utils/jwt.py`: **100%** ✅ -- `schemas/login.py`: **100%** ✅ -- `utils/password.py`: **89%** ✅ -- `settings/config.py`: **80%** ✅ - -### 🔧 测试配置 - -#### pytest配置 (pyproject.toml) -```toml -[tool.pytest.ini_options] -minversion = "7.0" -addopts = "-ra -q --strict-markers --strict-config" -testpaths = ["tests"] -python_files = ["test_*.py"] -python_classes = ["Test*"] -python_functions = ["test_*"] -asyncio_mode = "auto" +### 2. 仅运行集成测试 +```bash +uv run pytest tests/integration ``` -#### 覆盖率配置 -```toml -[tool.coverage.run] -source = ["src"] -omit = [ - "*/migrations/*", - "*/tests/*", - "*/__init__.py", -] +> 集成测试会自动启动基于 SQLite 的临时数据库,并复用现有的异步客户端、Token 等高阶夹具。 + +### 3. 运行所有测试并生成覆盖率 +```bash +uv run pytest ``` -### 🎯 测试重点 +`pyproject.toml` 中的 `addopts` 默认启用了 `--cov`、`--cov-report=term-missing` 与 `--cov-branch`,执行任意测试命令都会自动输出覆盖率详情,并标注缺失的分支或语句。 + +## 目录结构概览 -#### ✅ 已覆盖的关键安全功能 -1. **身份认证**: JWT令牌的创建、验证、过期处理 -2. **密码安全**: 哈希加密、验证、盐值处理 -3. **配置安全**: 密钥强度、过期时间配置 -4. **数据验证**: 输入数据格式验证 +``` +tests/ +├── README.md # 本说明文档 +├── __init__.py +├── conftest.py # 全局夹具与环境配置 +├── integration/ +│ ├── conftest.py # 仅集成测试需要的数据库初始化 +│ └── test_*.py # API、权限、数据库等集成测试 +└── unit/ + └── test_*.py # 工具函数与纯逻辑单元测试 +``` -#### 🚧 可扩展的测试方向 -1. **API端点测试**: 需要解决依赖问题后可添加 -2. **数据库集成测试**: 需要测试数据库配置 -3. **缓存功能测试**: 需要Redis测试环境 -4. **权限控制测试**: 需要用户角色数据 +## 新增单元测试覆盖的核心能力 -### 🐛 已知问题 +- `utils.password`: 哈希、验证与随机密码生成。 +- `utils.jwt`: 令牌创建、刷新、验证及异常场景。 +- `utils.response_adapter`: 新旧响应模型之间的适配逻辑。 +- `utils.cache`: 缓存键生成与装饰器逻辑(通过 FakeCacheManager 纯内存模拟)。 -#### Python 3.13 兼容性 -- **aioredis问题**: 当前使用redis.asyncio替代 -- **类型注解**: 使用Optional[T]替代T | None语法 +这些测试均使用 `pytest.mark.unit` 标记,可单独运行并在毫秒级完成,为 TDD 循环提供快速反馈。 -#### 依赖隔离 -- 使用独立测试文件避免复杂导入链 -- Mock复杂依赖(Redis, 数据库)进行单元测试 +## 常见命令速查 -### 📝 最佳实践 +| 目标 | 命令 | +| --- | --- | +| 仅运行单元测试 | `uv run pytest -m unit` | +| 仅运行集成测试 | `uv run pytest -m integration` | +| 生成 HTML 覆盖率报告 | `uv run pytest --cov-report=html` | +| 查看最后一次测试的覆盖率明细 | `xdg-open htmlcov/index.html` | -1. **最小可行原则**: 专注核心功能,避免过度测试 -2. **安全优先**: 重点测试认证、授权、加密功能 -3. **CI/CD集成**: 自动化测试流程 -4. **覆盖率监控**: 追踪核心模块的测试覆盖率 -5. **文档同步**: 测试用例即文档,说明功能预期行为 +> 由于在 `pyproject.toml` 中启用了 `--strict-markers`,若使用新的自定义标记,请记得将其添加到配置中。 -### 🔗 相关文件 +## CI/CD 集成 -- `tests/test_core_functionality.py` - 核心功能测试 -- `tests/test_simple_jwt.py` - JWT专项测试 -- `.github/workflows/ci.yml` - CI/CD配置 -- `pyproject.toml` - 测试和覆盖率配置 +GitHub Actions 会复用上述配置自动执行: ---- +1. Ruff、mypy 等静态检查。 +2. Pytest 全量测试(单元 + 集成)。 +3. 覆盖率统计,并在终端输出缺失语句。 -**最小可行测试方案确保了核心安全功能的质量,为项目提供了可靠的质量保证基础。** 🚀 +通过分层设计与高覆盖率的单元测试,项目具备了良好的 TDD 基础,可以在编码前先编写失败的测试,再迭代实现直至全部通过。🚀 diff --git a/tests/conftest.py b/tests/conftest.py index 77e6ab9..b9de035 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,8 +38,32 @@ f"pytest-asyncio is required for async tests but could not be installed: {exc}" ) -if "pytest_asyncio" in sys.modules: # pragma: no cover - plugin auto-registration helper - pytest_plugins = ("pytest_asyncio",) +try: # pragma: no cover - ensure pytest-cov is available for coverage reporting + import pytest_cov # type: ignore # noqa: F401 +except ModuleNotFoundError: # pragma: no cover + try: + subprocess.run( + [sys.executable, "-m", "pip", "install", "pytest-cov>=4.1"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + warnings.warn("Installed pytest-cov dynamically to enable coverage reporting.") + import pytest_cov # type: ignore # noqa: F401 + except Exception as exc: # pragma: no cover + warnings.warn( + f"pytest-cov is required for coverage reporting but could not be installed: {exc}" + ) + +loaded_plugins = [] +if "pytest_asyncio" in sys.modules: + loaded_plugins.append("pytest_asyncio") +if "pytest_cov" in sys.modules: + loaded_plugins.append("pytest_cov") + +if loaded_plugins: # pragma: no cover - plugin auto-registration helper + pytest_plugins = tuple(loaded_plugins) from src import app from tortoise import Tortoise @@ -58,7 +82,7 @@ def event_loop(): loop.close() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") async def setup_database(): """设置测试数据库""" # 使用临时SQLite数据库 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..06c798a --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,11 @@ +"""Integration test specific fixtures and markers.""" + +import pytest + +pytestmark = pytest.mark.integration + + +@pytest.fixture(scope="session", autouse=True) +async def _ensure_database(setup_database): + """Ensure the test database is initialised for integration tests only.""" + yield diff --git a/tests/test_auth_api.py b/tests/integration/test_auth_api.py similarity index 100% rename from tests/test_auth_api.py rename to tests/integration/test_auth_api.py diff --git a/tests/test_core_functionality.py b/tests/integration/test_core_functionality.py similarity index 100% rename from tests/test_core_functionality.py rename to tests/integration/test_core_functionality.py diff --git a/tests/test_crud_operations.py b/tests/integration/test_crud_operations.py similarity index 100% rename from tests/test_crud_operations.py rename to tests/integration/test_crud_operations.py diff --git a/tests/test_database_cache.py b/tests/integration/test_database_cache.py similarity index 100% rename from tests/test_database_cache.py rename to tests/integration/test_database_cache.py diff --git a/tests/test_health_endpoints.py b/tests/integration/test_health_endpoints.py similarity index 100% rename from tests/test_health_endpoints.py rename to tests/integration/test_health_endpoints.py diff --git a/tests/test_jwt_auth.py b/tests/integration/test_jwt_auth.py similarity index 100% rename from tests/test_jwt_auth.py rename to tests/integration/test_jwt_auth.py diff --git a/tests/test_permissions.py b/tests/integration/test_permissions.py similarity index 100% rename from tests/test_permissions.py rename to tests/integration/test_permissions.py diff --git a/tests/test_simple_jwt.py b/tests/integration/test_simple_jwt.py similarity index 100% rename from tests/test_simple_jwt.py rename to tests/integration/test_simple_jwt.py diff --git a/tests/unit/test_cache_utils.py b/tests/unit/test_cache_utils.py new file mode 100644 index 0000000..ea0d7ae --- /dev/null +++ b/tests/unit/test_cache_utils.py @@ -0,0 +1,78 @@ +"""Unit tests for the cache helper utilities.""" + +from __future__ import annotations + +import pytest + +from utils import cache + +pytestmark = pytest.mark.unit + + +class FakeCacheManager: + """Minimal async cache backend used for testing the decorator.""" + + def __init__(self): + self.store: dict[str, object] = {} + self.ttl: int | None = None + + def cache_key(self, prefix: str, *args, **kwargs) -> str: + parts = [prefix] + parts.extend(str(arg) for arg in args) + if kwargs: + parts.extend(f"{k}:{v}" for k, v in sorted(kwargs.items())) + return ":".join(parts) + + async def get(self, key: str): + return self.store.get(key) + + async def set(self, key: str, value, ttl: int | None = None): + applied_ttl = ttl or cache.settings.CACHE_TTL + self.store[key] = value + self.ttl = applied_ttl + return True + + +def test_cache_key_orders_kwargs(): + manager = cache.CacheManager() + + key = manager.cache_key("user", 1, region="cn", level=3) + + assert key == "user:1:level:3:region:cn" + + +@pytest.mark.asyncio +async def test_cached_decorator_stores_and_reuses_results(monkeypatch): + fake = FakeCacheManager() + monkeypatch.setattr(cache, "cache_manager", fake) + + call_count = 0 + + @cache.cached("expensive") + async def expensive_call(arg): + nonlocal call_count + call_count += 1 + return {"value": arg} + + first = await expensive_call(5) + second = await expensive_call(5) + + assert first == {"value": 5} + assert second == {"value": 5} + assert call_count == 1 + assert fake.store["expensive:5"] == {"value": 5} + assert fake.ttl == cache.settings.CACHE_TTL + + +@pytest.mark.asyncio +async def test_cached_decorator_supports_custom_key_func(monkeypatch): + fake = FakeCacheManager() + monkeypatch.setattr(cache, "cache_manager", fake) + + @cache.cached("user", key_func=lambda user_id: f"user:{user_id}") + async def load_user(user_id): + return {"id": user_id} + + await load_user(99) + + assert fake.store["user:99"] == {"id": 99} diff --git a/tests/unit/test_jwt_utils.py b/tests/unit/test_jwt_utils.py new file mode 100644 index 0000000..6a8bcbf --- /dev/null +++ b/tests/unit/test_jwt_utils.py @@ -0,0 +1,90 @@ +"""Unit tests for JWT helper functions.""" + +from datetime import UTC, datetime, timedelta + +import jwt as pyjwt +import pytest + +from schemas.login import JWTPayload +from utils import jwt as jwt_utils + +pytestmark = pytest.mark.unit + +TEST_SECRET = "test-secret-key-32-characters-long!!" + + +@pytest.fixture(autouse=True) +def override_settings(monkeypatch): + """Ensure deterministic cryptographic settings during tests.""" + monkeypatch.setattr(jwt_utils.settings, "SECRET_KEY", TEST_SECRET) + monkeypatch.setattr(jwt_utils.settings, "JWT_ALGORITHM", "HS256") + monkeypatch.setattr(jwt_utils.settings, "JWT_ACCESS_TOKEN_EXPIRE_MINUTES", 15) + monkeypatch.setattr(jwt_utils.settings, "JWT_REFRESH_TOKEN_EXPIRE_DAYS", 7) + + +def _decode_token(token: str) -> dict: + """Decode a token without enforcing expiration checks.""" + return pyjwt.decode(token, TEST_SECRET, algorithms=["HS256"], options={"verify_exp": False}) + + +def test_create_access_token_encodes_expected_payload(): + payload = JWTPayload( + user_id=1, + username="alice", + is_superuser=False, + exp=datetime.now(UTC) + timedelta(minutes=10), + ) + + token = jwt_utils.create_access_token(data=payload) + decoded = _decode_token(token) + + assert decoded["user_id"] == 1 + assert decoded["username"] == "alice" + assert decoded["token_type"] == "access" + + +def test_create_refresh_token_sets_refresh_type(): + token = jwt_utils.create_refresh_token(user_id=7, username="bob", is_superuser=True) + decoded = _decode_token(token) + + assert decoded["token_type"] == "refresh" + assert decoded["user_id"] == 7 + assert decoded["username"] == "bob" + + +def test_verify_token_enforces_token_type(): + payload = { + "user_id": 2, + "username": "carol", + "is_superuser": False, + "exp": datetime.now(UTC) + timedelta(minutes=5), + "token_type": "access", + } + token = pyjwt.encode(payload, TEST_SECRET, algorithm="HS256") + + with pytest.raises(pyjwt.InvalidTokenError): + jwt_utils.verify_token(token, token_type="refresh") + + +def test_verify_token_rejects_expired_tokens(): + payload = { + "user_id": 2, + "username": "dave", + "is_superuser": False, + "exp": datetime.now(UTC) - timedelta(seconds=1), + "token_type": "access", + } + token = pyjwt.encode(payload, TEST_SECRET, algorithm="HS256") + + with pytest.raises(pyjwt.ExpiredSignatureError): + jwt_utils.verify_token(token) + + +def test_create_token_pair_produces_distinct_tokens(): + access_token, refresh_token = jwt_utils.create_token_pair( + user_id=5, username="erin", is_superuser=False + ) + + assert access_token != refresh_token + assert _decode_token(access_token)["token_type"] == "access" + assert _decode_token(refresh_token)["token_type"] == "refresh" diff --git a/tests/unit/test_password_utils.py b/tests/unit/test_password_utils.py new file mode 100644 index 0000000..eaf7b16 --- /dev/null +++ b/tests/unit/test_password_utils.py @@ -0,0 +1,25 @@ +"""Unit tests for password utility helpers.""" + +import pytest + +from utils import password + +pytestmark = pytest.mark.unit + + +def test_get_password_hash_and_verify_roundtrip(): + """Password hashes should verify for the original secret and reject others.""" + secret = "SuperSecret!" + + hashed = password.get_password_hash(secret) + + assert hashed != secret + assert password.verify_password(secret, hashed) + assert not password.verify_password("NotTheSame", hashed) + + +def test_generate_password_uses_passlib(monkeypatch): + """The password generator should delegate to passlib's helper.""" + monkeypatch.setattr(password.pwd, "genword", lambda: "generated") + + assert password.generate_password() == "generated" diff --git a/tests/unit/test_response_adapter.py b/tests/unit/test_response_adapter.py new file mode 100644 index 0000000..238052a --- /dev/null +++ b/tests/unit/test_response_adapter.py @@ -0,0 +1,65 @@ +"""Unit tests for the response adapter helper.""" + +import json + +import pytest + +from schemas.base import Fail, Success, SuccessExtra +from utils.response_adapter import adapt_response + +pytestmark = pytest.mark.unit + + +def test_adapt_response_handles_json_response(): + response = Success(code=201, msg="created", data={"id": 1}) + + adapted = adapt_response(response) + + assert adapted == {"code": 201, "msg": "created", "data": {"id": 1}} + + +def test_adapt_response_handles_pure_dataclasses(): + class DummyResponse: + code = 418 + msg = "teapot" + data = {"answer": 42} + total = 10 + page = 2 + page_size = 5 + + adapted = adapt_response(DummyResponse()) + + assert adapted == { + "code": 418, + "msg": "teapot", + "data": {"answer": 42}, + "total": 10, + "page": 2, + "page_size": 5, + } + + +def test_adapt_response_handles_success_extra_metadata(): + response = SuccessExtra( + data=[1, 2], + total=2, + page=1, + page_size=2, + extra="value", + ) + + adapted = adapt_response(response) + + body = json.loads(response.body) + assert adapted == body + assert adapted["extra"] == "value" + + +def test_adapt_response_handles_failures(): + response = Fail(code=400, msg="bad request", data={"error": "missing"}) + + assert adapt_response(response) == { + "code": 400, + "msg": "bad request", + "data": {"error": "missing"}, + } diff --git a/tests/unit/test_sensitive_word_filter.py b/tests/unit/test_sensitive_word_filter.py new file mode 100644 index 0000000..9cef8c8 --- /dev/null +++ b/tests/unit/test_sensitive_word_filter.py @@ -0,0 +1,106 @@ +import importlib +import sys +import types + +import pytest + + +class FakeAutomaton: + def __init__(self): + self._words: list[tuple[str, tuple[int, str]]] = [] + + def add_word(self, word: str, value: tuple[int, str]) -> None: + self._words.append((word, value)) + + def make_automaton(self) -> None: # pragma: no cover - nothing to build for the fake + return None + + def iter(self, text: str): + lower = text.lower() + for word, value in self._words: + start = lower.find(word) + if start != -1: + end_index = start + len(word) - 1 + yield end_index, value + + +@pytest.fixture() +def filter_module(monkeypatch): + fake_settings = types.SimpleNamespace( + ENABLE_SENSITIVE_WORD_FILTER=True, + SENSITIVE_WORD_RESPONSE="命中敏感词", + SENSITIVE_WORDS=["敏感词", "违禁"], + ) + + fake_settings_module = types.ModuleType("settings.config") + fake_settings_module.settings = fake_settings + fake_settings_package = types.ModuleType("settings") + fake_settings_package.config = fake_settings_module + + monkeypatch.setitem(sys.modules, "settings.config", fake_settings_module) + monkeypatch.setitem(sys.modules, "settings", fake_settings_package) + monkeypatch.setitem(sys.modules, "ahocorasick", types.SimpleNamespace(Automaton=FakeAutomaton)) + + sys.modules.pop("src.utils.sensitive_word_filter", None) + module = importlib.import_module("src.utils.sensitive_word_filter") + try: + yield module, fake_settings + finally: + sys.modules.pop("src.utils.sensitive_word_filter", None) + + +def test_contains_sensitive_word_detects_match(filter_module): + module, _ = filter_module + filter_instance = module.SensitiveWordFilter() + + has_match, word = filter_instance.contains_sensitive_word("这里包含敏感词内容") + + assert has_match is True + assert word == "敏感词" + + +def test_filter_text_returns_response_for_sensitive_content(filter_module): + module, _ = filter_module + filter_instance = module.SensitiveWordFilter() + + assert filter_instance.filter_text("安全内容") == "安全内容" + assert filter_instance.filter_text("违禁行为") == "命中敏感词" + + +def test_streaming_chunk_with_sensitive_answer_is_blocked(filter_module): + module, _ = filter_module + filter_instance = module.SensitiveWordFilter() + + chunk = "data: {\"answer\": \"包含敏感词\"}" + assert filter_instance.filter_streaming_chunk(chunk) is None + + +def test_streaming_chunk_with_invalid_json_falls_back_to_text(filter_module): + module, _ = filter_module + filter_instance = module.SensitiveWordFilter() + + chunk = "data: 原始违禁文本" + assert filter_instance.filter_streaming_chunk(chunk) is None + + +def test_reload_sensitive_words_uses_updated_list(filter_module): + module, settings_stub = filter_module + filter_instance = module.SensitiveWordFilter() + + settings_stub.SENSITIVE_WORDS = ["全新词"] + assert filter_instance.reload_sensitive_words() is True + + has_match, word = filter_instance.contains_sensitive_word("这里有全新词") + assert has_match is True + assert word == "全新词" + + +def test_disabled_filter_bypasses_checks(monkeypatch, filter_module): + module, settings_stub = filter_module + monkeypatch.setattr(settings_stub, "ENABLE_SENSITIVE_WORD_FILTER", False) + + filter_instance = module.SensitiveWordFilter() + + assert filter_instance.enabled is False + assert filter_instance.contains_sensitive_word("违禁") == (False, None) + assert filter_instance.filter_streaming_chunk("data: {\"answer\": \"违禁\"}") == "data: {\"answer\": \"违禁\"}"