|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import contextlib |
| 4 | +import unittest.mock |
| 5 | +from collections.abc import Iterator |
| 6 | +from dataclasses import dataclass, field |
| 7 | +from typing import Any |
| 8 | + |
| 9 | +from .errors import DoNotMockError |
| 10 | +from .protected import ProtectedFunc |
| 11 | + |
| 12 | + |
| 13 | +_original_mock_init = unittest.mock.NonCallableMock.__init__ |
| 14 | +_original_patch_enter = unittest.mock._patch.__enter__ # pyright: ignore[reportPrivateUsage,reportUnknownVariableType,reportUnknownMemberType] |
| 15 | +_original_patch_dict_enter = unittest.mock._patch_dict.__enter__ # pyright: ignore[reportPrivateUsage] |
| 16 | + |
| 17 | + |
| 18 | +@dataclass |
| 19 | +class _ActiveGuard: |
| 20 | + test_name: str |
| 21 | + block_all: bool = False |
| 22 | + protected: list[ProtectedFunc] = field(default_factory=lambda: []) |
| 23 | + |
| 24 | + |
| 25 | +_active_guard: _ActiveGuard | None = None |
| 26 | + |
| 27 | + |
| 28 | +def _guarded_mock_init(self: Any, *args: Any, **kwargs: Any) -> None: |
| 29 | + guard = _active_guard |
| 30 | + if guard is not None and guard.block_all: |
| 31 | + raise DoNotMockError( |
| 32 | + f"\nMocking is not allowed in test '{guard.test_name}' (decorated with @pytest.do_not_mock).\n" |
| 33 | + ) |
| 34 | + _original_mock_init(self, *args, **kwargs) |
| 35 | + |
| 36 | + |
| 37 | +def _guarded_patch_enter(self: Any) -> Any: |
| 38 | + guard = _active_guard |
| 39 | + if guard is None: |
| 40 | + return _original_patch_enter(self) # pyright: ignore[reportUnknownVariableType] |
| 41 | + |
| 42 | + if guard.block_all: |
| 43 | + raise DoNotMockError( |
| 44 | + f"\nPatching is not allowed in test '{guard.test_name}' (decorated with @pytest.do_not_mock).\n" |
| 45 | + ) |
| 46 | + |
| 47 | + # Targeted mode — only block patches that hit a protected function |
| 48 | + try: |
| 49 | + target = self.getter() |
| 50 | + except Exception: |
| 51 | + return _original_patch_enter(self) # pyright: ignore[reportUnknownVariableType] |
| 52 | + |
| 53 | + for pf in guard.protected: |
| 54 | + if pf.matches_patch_target(target, self.attribute): |
| 55 | + raise DoNotMockError( |
| 56 | + f"\nTest '{guard.test_name}' marked '{pf.name}' with @pytest.do_not_mock\n" |
| 57 | + f"but it is being patched.\n" |
| 58 | + f"\nPlease remove the patch for '{pf.module_path}'.\n" |
| 59 | + ) |
| 60 | + return _original_patch_enter(self) # pyright: ignore[reportUnknownVariableType] |
| 61 | + |
| 62 | + |
| 63 | +def _guarded_patch_dict_enter(self: Any) -> Any: |
| 64 | + guard = _active_guard |
| 65 | + if guard is not None and guard.block_all: |
| 66 | + raise DoNotMockError( |
| 67 | + f"\nPatching is not allowed in test '{guard.test_name}' (decorated with @pytest.do_not_mock).\n" |
| 68 | + ) |
| 69 | + return _original_patch_dict_enter(self) |
| 70 | + |
| 71 | + |
| 72 | +@contextlib.contextmanager |
| 73 | +def mock_guard( |
| 74 | + test_name: str, |
| 75 | + *, |
| 76 | + block_all: bool = False, |
| 77 | + protected: list[ProtectedFunc] | None = None, |
| 78 | +) -> Iterator[None]: |
| 79 | + """Install mock guards for the duration of a ``with`` block, then restore originals.""" |
| 80 | + global _active_guard # noqa: PLW0603 |
| 81 | + |
| 82 | + _active_guard = _ActiveGuard(test_name=test_name, block_all=block_all, protected=protected or []) |
| 83 | + |
| 84 | + unittest.mock.NonCallableMock.__init__ = _guarded_mock_init # type: ignore[method-assign] |
| 85 | + unittest.mock._patch.__enter__ = _guarded_patch_enter # type: ignore[method-assign] |
| 86 | + unittest.mock._patch_dict.__enter__ = _guarded_patch_dict_enter # type: ignore[method-assign] |
| 87 | + try: |
| 88 | + yield |
| 89 | + finally: |
| 90 | + unittest.mock.NonCallableMock.__init__ = _original_mock_init # type: ignore[method-assign] |
| 91 | + unittest.mock._patch.__enter__ = _original_patch_enter # type: ignore[method-assign] |
| 92 | + unittest.mock._patch_dict.__enter__ = _original_patch_dict_enter # type: ignore[method-assign] |
| 93 | + _active_guard = None |
0 commit comments