Skip to content

Commit c0bc58d

Browse files
authored
Merge pull request #117 from kaste/better-patch
2 parents 8e0a2ac + 8004a23 commit c0bc58d

10 files changed

Lines changed: 841 additions & 66 deletions

File tree

CHANGES.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ Release 2.0.0
8282
mock.do(1, 2, x=3)
8383
assert call.value == ((1, 2), {"x": 3})
8484

85+
- Added `patch_attr` and `patch_dict` for non-callable monkeypatch-style use cases
86+
(e.g. `sys.stdout`, `sys.argv`, and environment/config dictionaries) with
87+
context-manager support and restoration through `unstub`.
88+
8589

8690

8791

docs/index.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ State-of-the-art, high-five argument matchers::
7878
when(math).sqrt(not_(number)).thenRaise(
7979
TypeError('argument must be a number'))
8080

81+
Captors::
82+
83+
args, kwargs = captor(), captor()
84+
when(mamma).said(*args, **kwargs)
85+
# use it ...
86+
assert args.value == ("Knock", "You", "Out")
87+
88+
8189
No need to `verify` (`assert_called_with`) all the time::
8290

8391
# Different arguments, different answers
@@ -120,6 +128,13 @@ Full async/await support::
120128
when(module_under_test).http_get('https://example.com', ...).thenReturn('Yep!')
121129

122130

131+
Convenience::
132+
133+
with patch_attr("sys.argv", ["foo", "bar"]):
134+
with patch_attr("sys.stdout", StringIO()) as stdout: ...
135+
with patch_dict(os.environ, {"user": "bob"}): ...
136+
137+
123138
Read
124139
----
125140

docs/the-functions.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
The functions
55
=============
66

7-
Stable entrypoints are: :func:`when`, :func:`mock`, :func:`unstub`, :func:`verify`, :func:`spy`. New function introduced in v1 are: :func:`when2`, :func:`expect`, :func:`verifyExpectedInteractions`, :func:`verifyStubbedInvocationsAreUsed`, :func:`patch`
7+
Stable entrypoints are: :func:`when`, :func:`mock`, :func:`unstub`, :func:`verify`, :func:`spy`. New function introduced in v1 are: :func:`when2`, :func:`expect`, :func:`verifyExpectedInteractions`, :func:`verifyStubbedInvocationsAreUsed`, :func:`patch`, :func:`patch_attr`, :func:`patch_dict`
88

99
.. autofunction:: when
1010
.. autofunction:: when2
1111
.. autofunction:: patch
12+
.. autofunction:: patch_attr
13+
.. autofunction:: patch_dict
1214
.. autofunction:: expect
1315
.. autofunction:: mock
1416
.. autofunction:: unstub

mockito/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
when,
2626
when2,
2727
patch,
28+
patch_attr,
29+
patch_dict,
2830
expect,
2931
unstub,
3032
forget_invocations,
@@ -61,6 +63,8 @@
6163
'when',
6264
'when2',
6365
'patch',
66+
'patch_attr',
67+
'patch_dict',
6468
'expect',
6569
'ensureNoUnverifiedInteractions',
6670
'verify',

mockito/mocking.py

Lines changed: 29 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from . import invocation, signature, utils
3232
from . import verification as verificationModule
3333
from .mock_registry import mock_registry
34+
from .patching import Patch, patcher
3435

3536

3637
__all__ = ['mock']
@@ -41,8 +42,6 @@
4142
)
4243
SUPPORTS_MARKCOROUTINEFUNCTION = hasattr(inspect, "markcoroutinefunction")
4344

44-
_MISSING_ATTRIBUTE = object()
45-
4645
_CONFIG_ASYNC_PREFIX = "async "
4746
_ASYNC_BY_PROTOCOL_METHODS = {"__aenter__", "__aexit__", "__anext__"}
4847

@@ -324,7 +323,7 @@ def __init__(
324323
self.stubbed_invocations: deque[invocation.StubbedInvocation] = deque()
325324

326325
self._original_methods: dict[str, object | None] = {}
327-
self._methods_to_unstub: dict[str, object] = {}
326+
self._methods_to_unstub: dict[str, Patch] = {}
328327
self._signatures_store: dict[str, signature.Signature | None] = {}
329328
self._property_access_context: \
330329
list[tuple[str, object | None, object]] = []
@@ -447,30 +446,13 @@ def _get_original_method_before_stub(
447446
if self.spec is None:
448447
return None, False
449448

450-
try:
451-
return self.spec.__dict__[method_name], True
452-
except (AttributeError, KeyError):
453-
# If the attr is not directly in __dict__, class specs should use
454-
# static lookup so inherited descriptors are preserved as
455-
# descriptors (instead of triggering __get__ via getattr).
456-
if inspect.isclass(self.spec):
457-
try:
458-
return inspect.getattr_static(self.spec, method_name), False
459-
except AttributeError:
460-
# If static lookup misses (e.g. metaclass __getattr__),
461-
# fall back to dynamic lookup.
462-
pass
463-
464-
# For instance specs, keep dynamic getattr so existing
465-
# bound-method/spying behavior stays unchanged.
466-
return getattr(self.spec, method_name, None), False
467-
468-
def set_method(self, method_name: str, new_method: object) -> None:
469-
setattr(self.mocked_obj, method_name, new_method)
449+
return utils.get_original_attribute(self.spec, method_name, default=None)
470450

471451
def replace_method(
472-
self, method_name: str, original_method: object | None
473-
) -> None:
452+
self,
453+
method_name: str,
454+
original_method: object | None,
455+
) -> Patch:
474456
discard_first_arg = self._takes_implicit_self_or_cls(original_method)
475457

476458
def new_mocked_method(*args, **kwargs):
@@ -506,44 +488,36 @@ def new_mocked_method(*args, **kwargs):
506488
):
507489
new_mocked_method = staticmethod(new_mocked_method)
508490

509-
self.set_method(method_name, new_mocked_method)
491+
return patcher.patch_attribute(
492+
self.mocked_obj,
493+
method_name,
494+
new_mocked_method,
495+
allow_unstub_by_replacement=False,
496+
)
510497

511498
def stub(self, method_name: str) -> None:
512499
try:
513500
self._methods_to_unstub[method_name]
514501
except KeyError:
515-
(
516-
original_method,
517-
was_in_spec
518-
) = self._get_original_method_before_stub(method_name)
519-
if was_in_spec:
520-
# This indicates the original method was found directly on
521-
# the spec object and should therefore be restored by unstub
522-
self._methods_to_unstub[method_name] = original_method
523-
else:
524-
self._methods_to_unstub[method_name] = _MISSING_ATTRIBUTE
525-
502+
original_method, _ = self._get_original_method_before_stub(method_name)
526503
self._original_methods[method_name] = original_method
527-
self.replace_method(method_name, original_method)
504+
self._methods_to_unstub[method_name] = self.replace_method(
505+
method_name,
506+
original_method,
507+
)
528508

529509
def stub_property(self, method_name: str) -> None:
530510
try:
531511
self._methods_to_unstub[method_name]
532512
except KeyError:
533-
(
534-
original_method,
535-
was_in_spec
536-
) = self._get_original_method_before_stub(method_name)
537-
513+
original_method, _ = self._get_original_method_before_stub(method_name)
538514
self._original_methods[method_name] = original_method
539-
self.set_method(method_name, _mocked_property(self, method_name))
540-
541-
if was_in_spec:
542-
# This indicates the original method was found directly on
543-
# the spec object and should therefore be restored by unstub
544-
self._methods_to_unstub[method_name] = original_method
545-
else:
546-
self._methods_to_unstub[method_name] = _MISSING_ATTRIBUTE
515+
self._methods_to_unstub[method_name] = patcher.patch_attribute(
516+
self.mocked_obj,
517+
method_name,
518+
_mocked_property(self, method_name),
519+
allow_unstub_by_replacement=False,
520+
)
547521

548522

549523
def forget_stubbed_invocation(
@@ -558,26 +532,18 @@ def forget_stubbed_invocation(
558532
inv.method_name == invocation.method_name
559533
for inv in self.stubbed_invocations
560534
):
561-
original_method = self._methods_to_unstub.pop(
562-
invocation.method_name
563-
)
564-
self.restore_method(invocation.method_name, original_method)
535+
patch = self._methods_to_unstub.pop(invocation.method_name)
536+
patch.restore_and_unregister()
565537

566538
if self.stubbed_invocations:
567539
return
568540

569541
mock_registry.unstub(self.mocked_obj)
570542

571-
def restore_method(self, method_name: str, original_method: object) -> None:
572-
if original_method is _MISSING_ATTRIBUTE:
573-
delattr(self.mocked_obj, method_name)
574-
else:
575-
self.set_method(method_name, original_method)
576-
577543
def unstub(self) -> None:
578544
while self._methods_to_unstub:
579-
method_name, original_method = self._methods_to_unstub.popitem()
580-
self.restore_method(method_name, original_method)
545+
_, patch = self._methods_to_unstub.popitem()
546+
patch.restore_and_unregister()
581547
self.stubbed_invocations = deque()
582548
self.invocations = []
583549
self._methods_marked_as_coroutine = set()

mockito/mockito.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@
1919
# THE SOFTWARE.
2020

2121
from __future__ import annotations
22+
from collections.abc import Iterable, MutableMapping
2223
import operator
23-
from typing import Iterable
2424

2525
from . import invocation
2626
from . import verification
2727

2828
from .utils import deprecated, get_obj, get_obj_attr_tuple
2929
from .mocking import Chain, Mock
3030
from .mock_registry import mock_registry
31+
from .patching import restore_patch_contextmanager, patcher
3132
from .verification import VerificationError
3233

3334

@@ -304,6 +305,100 @@ def patch(fn, attr_or_replacement, replacement=None):
304305
theMock, name, strict=False)(Ellipsis).thenAnswer(replacement)
305306

306307

308+
def patch_attr(obj_or_path, attr_or_replacement, replacement=OMITTED):
309+
"""Patch/replace an attribute with a concrete value.
310+
311+
Unlike :func:`patch`, this does *not* record interactions and does not
312+
expose verification. It is intended for simple attribute replacement like
313+
``sys.stdout`` or ``sys.argv``.
314+
315+
Two ways to call this. Either::
316+
317+
patch_attr('sys.stdout', StringIO()) # two arguments
318+
# OR
319+
patch_attr(sys, 'stdout', StringIO()) # three arguments
320+
321+
``with`` context management is supported and restores the original value
322+
on ``__exit__``. ``__enter__`` returns the replacement object.
323+
324+
.. note:: You must :func:`unstub` after patching, or use `with`
325+
statement.
326+
327+
"""
328+
if replacement is OMITTED:
329+
replacement = attr_or_replacement
330+
obj, name = get_obj_attr_tuple(obj_or_path)
331+
else:
332+
obj, name = obj_or_path, attr_or_replacement
333+
334+
patch = patcher.patch_attribute(
335+
obj,
336+
name,
337+
replacement,
338+
allow_unstub_by_replacement=True,
339+
)
340+
return restore_patch_contextmanager(patch, replacement)
341+
342+
343+
def patch_dict(mapping_or_path, values=None, *, clear=False, remove=None, **kwargs):
344+
"""Patch/update a dict-like object in place.
345+
346+
This is a convenience function for test-time dictionary patching,
347+
especially for mutable global maps like ``os.environ``.
348+
349+
Usage::
350+
351+
patch_dict(os.environ, {'USER': 'foo'})
352+
patch_dict(os.environ, [('USER', 'foo')])
353+
patch_dict(os.environ, USER='foo')
354+
patch_dict(os.environ, remove={'USER', 'PATH'})
355+
patch_dict(os.environ, remove=all)
356+
patch_dict(os.environ, clear=True)
357+
patch_dict('os.environ', {'USER': 'foo'})
358+
359+
``with`` context management is supported and restores the original mapping
360+
state on ``__exit__``. ``__enter__`` returns the patched mapping.
361+
362+
``values`` can be any value accepted by ``dict(values)``.
363+
``kwargs`` are merged into ``values`` and take precedence.
364+
365+
.. note:: You must :func:`unstub` after patching, or use `with`
366+
statement.
367+
368+
"""
369+
mapping = (
370+
get_obj(mapping_or_path)
371+
if isinstance(mapping_or_path, str)
372+
else mapping_or_path
373+
)
374+
375+
if not isinstance(mapping, MutableMapping):
376+
raise TypeError("target must be a mutable mapping")
377+
378+
if remove is all:
379+
clear = True
380+
remove = None
381+
382+
normalized_remove: tuple[object, ...]
383+
if remove is None:
384+
normalized_remove = ()
385+
elif isinstance(remove, (str, bytes)):
386+
normalized_remove = (remove,)
387+
elif not isinstance(remove, Iterable):
388+
raise TypeError("remove must be iterable, all, or None")
389+
else:
390+
normalized_remove = tuple(remove)
391+
392+
updates = {} if values is None else dict(values)
393+
updates.update(kwargs)
394+
patch = patcher.patch_dictionary(
395+
mapping,
396+
updates,
397+
clear=clear,
398+
remove=normalized_remove,
399+
)
400+
return restore_patch_contextmanager(patch, mapping)
401+
307402

308403
def expect(obj, strict=True,
309404
times=None, atleast=None, atmost=None, between=None):
@@ -348,7 +443,7 @@ def __getattr__(self, method_name):
348443

349444

350445
def unstub(*objs):
351-
"""Unstubs all stubbed methods and functions
446+
"""Unstubs all stubbed methods, functions, and patched attributes.
352447
353448
If you don't pass in any argument, *all* registered mocks and
354449
patched modules, classes etc. will be unstubbed.
@@ -363,8 +458,10 @@ def unstub(*objs):
363458
if isinstance(obj, str):
364459
obj = get_obj(obj)
365460
mock_registry.unstub(obj)
461+
patcher.unstub_matching(obj)
366462
else:
367463
mock_registry.unstub_all()
464+
patcher.unstub_all()
368465

369466

370467
def forget_invocations(*objs):

0 commit comments

Comments
 (0)