Skip to content

Commit 4d8d6d2

Browse files
authored
Merge pull request #123 from kaste/same-ish
2 parents 28bc30b + f15ef7e commit 4d8d6d2

6 files changed

Lines changed: 541 additions & 17 deletions

File tree

mockito/invocation.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from collections import deque
2828
from typing import TYPE_CHECKING, Union
2929

30-
from . import matchers, signature
30+
from . import matchers, sameish, signature
3131
from . import verification as verificationModule
3232
from .mock_registry import mock_registry
3333
from .utils import contains_strict
@@ -628,6 +628,21 @@ def transition_to_chain(self) -> ChainContinuation:
628628
continuation = self.get_continuation()
629629

630630
if isinstance(continuation, ChainContinuation):
631+
if (
632+
continuation.invocation is not self
633+
and sameish.invocations_have_distinct_captors(
634+
self,
635+
continuation.invocation,
636+
)
637+
):
638+
self.forget_self()
639+
raise InvocationError(
640+
"'%s' is already configured with a different captor "
641+
"instance for the same selector. Reuse the same "
642+
"captor() / call_captor() object across chain branches."
643+
% self.method_name
644+
)
645+
631646
self.rollback_if_not_configured_by(continuation)
632647
return continuation
633648

mockito/matchers.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@
6262
from abc import ABC, abstractmethod
6363
import functools
6464
import re
65+
from typing import TYPE_CHECKING
66+
67+
if TYPE_CHECKING:
68+
try:
69+
from typing import TypeGuard
70+
except ImportError:
71+
from typing_extensions import TypeGuard
72+
6573
builtin_any = any
6674

6775
__all__ = [
@@ -473,15 +481,15 @@ def __repr__(self):
473481
return "<CaptorKwargsSentinel: %r>" % self.captor
474482

475483

476-
def is_call_captor(value):
484+
def is_call_captor(value: object) -> 'TypeGuard[CallCaptor]':
477485
return isinstance(value, CallCaptor)
478486

479487

480-
def is_captor_args_sentinel(value):
488+
def is_captor_args_sentinel(value: object) -> 'TypeGuard[CaptorArgsSentinel]':
481489
return isinstance(value, CaptorArgsSentinel)
482490

483491

484-
def is_captor_kwargs_sentinel(value):
492+
def is_captor_kwargs_sentinel(value: object) -> 'TypeGuard[CaptorKwargsSentinel]':
485493
return isinstance(value, CaptorKwargsSentinel)
486494

487495

mockito/mocking.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from dataclasses import dataclass
2929
from typing import Any, AsyncIterator, Callable, Iterable, Iterator, cast
3030

31-
from . import invocation, signature, utils
31+
from . import invocation, sameish, signature, utils
3232
from . import verification as verificationModule
3333
from .mock_registry import mock_registry
3434
from .patching import Patch, patcher
@@ -407,12 +407,12 @@ def set_continuation(self, continuation: invocation.ConfiguredContinuation) -> N
407407
def _sameish_invocations(
408408
self, same: invocation.StubbedInvocation
409409
) -> list[invocation.StubbedInvocation]:
410-
"""Find prior stubs that are *mutually* signature-compatible.
410+
"""Find prior stubs that are signature-compatible.
411411
412412
This is used only for continuation bookkeeping (value-vs-chain mode),
413-
not for runtime call dispatch. We intentionally do a symmetric check
414-
(`a.matches(b)` and `b.matches(a)`) to approximate "same signature"
415-
despite one-way matchers like `any()`.
413+
not for runtime call dispatch. The comparison is structural and avoids
414+
executing matcher predicates, so `arg_that(...)` and other custom
415+
matchers cannot crash internal equivalence probing.
416416
417417
Why this exists: repeated selectors such as
418418
@@ -439,13 +439,7 @@ def _invocations_are_sameish(
439439
left: invocation.StubbedInvocation,
440440
right: invocation.StubbedInvocation,
441441
) -> bool:
442-
# Be conservative in internal equivalence probing: user predicates from
443-
# `arg_that` can throw when evaluated against matcher/sentinel objects.
444-
# In this phase, exceptions should mean "not equivalent", not failure.
445-
try:
446-
return left.matches(right) and right.matches(left)
447-
except Exception:
448-
return False
442+
return sameish.invocations_are_sameish(left, right)
449443

450444
def get_original_method(self, method_name: str) -> object | None:
451445
return self._original_methods.get(method_name, None)

mockito/sameish.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from . import matchers
6+
7+
if TYPE_CHECKING:
8+
from .invocation import StubbedInvocation
9+
10+
11+
def invocations_are_sameish(
12+
left: StubbedInvocation,
13+
right: StubbedInvocation,
14+
) -> bool:
15+
"""Structural signature-compatibility checks for continuation bookkeeping.
16+
17+
Intentionally avoids executing user-provided matcher predicates
18+
(e.g. `arg_that(...)) while comparing stub signatures.
19+
"""
20+
21+
return (
22+
_params_are_sameish(left.params, right.params)
23+
and _named_params_are_sameish(
24+
left.named_params,
25+
right.named_params,
26+
)
27+
)
28+
29+
30+
def invocations_have_distinct_captors(
31+
left: StubbedInvocation,
32+
right: StubbedInvocation,
33+
) -> bool:
34+
"""Return True when equivalent selectors bind different captor instances."""
35+
36+
for left_value, right_value in zip(left.params, right.params):
37+
if _values_bind_distinct_captors(left_value, right_value):
38+
return True
39+
40+
for key in set(left.named_params) & set(right.named_params):
41+
if _values_bind_distinct_captors(
42+
left.named_params[key],
43+
right.named_params[key],
44+
):
45+
return True
46+
47+
return False
48+
49+
50+
def _params_are_sameish(left: tuple, right: tuple) -> bool:
51+
if len(left) != len(right):
52+
return False
53+
54+
return all(
55+
_values_are_sameish(left_value, right_value)
56+
for left_value, right_value in zip(left, right)
57+
)
58+
59+
60+
def _named_params_are_sameish(left: dict, right: dict) -> bool:
61+
if set(left) != set(right):
62+
return False
63+
64+
return all(
65+
_values_are_sameish(left[key], right[key])
66+
for key in left
67+
)
68+
69+
70+
def _values_are_sameish(left: object, right: object) -> bool:
71+
if left is right:
72+
return True
73+
74+
if left is Ellipsis or right is Ellipsis:
75+
return left is right
76+
77+
if matchers.is_call_captor(left) and matchers.is_call_captor(right):
78+
return True
79+
80+
if matchers.is_call_captor(left) or matchers.is_call_captor(right):
81+
return False
82+
83+
if (
84+
matchers.is_captor_args_sentinel(left)
85+
and matchers.is_captor_args_sentinel(right)
86+
):
87+
return _values_are_sameish(left.captor.matcher, right.captor.matcher)
88+
89+
if (
90+
matchers.is_captor_kwargs_sentinel(left)
91+
and matchers.is_captor_kwargs_sentinel(right)
92+
):
93+
return _values_are_sameish(left.captor.matcher, right.captor.matcher)
94+
95+
if (
96+
matchers.is_captor_args_sentinel(left)
97+
or matchers.is_captor_args_sentinel(right)
98+
or matchers.is_captor_kwargs_sentinel(left)
99+
or matchers.is_captor_kwargs_sentinel(right)
100+
):
101+
return False
102+
103+
if isinstance(left, matchers.Matcher) and isinstance(right, matchers.Matcher):
104+
return _matchers_are_sameish(left, right)
105+
106+
if isinstance(left, matchers.Matcher) or isinstance(right, matchers.Matcher):
107+
return False
108+
109+
return _equals_or_identity(left, right)
110+
111+
112+
def _matchers_are_sameish( # noqa: C901
113+
left: matchers.Matcher,
114+
right: matchers.Matcher,
115+
) -> bool:
116+
if left is right:
117+
return True
118+
119+
if type(left) is not type(right):
120+
return False
121+
122+
if isinstance(left, matchers.Any) and isinstance(right, matchers.Any):
123+
return _equals_or_identity(left.wanted_type, right.wanted_type)
124+
125+
if (
126+
isinstance(left, matchers.ValueMatcher)
127+
and isinstance(right, matchers.ValueMatcher)
128+
):
129+
return _values_are_sameish(left.value, right.value)
130+
131+
if (
132+
isinstance(left, (matchers.And, matchers.Or))
133+
and isinstance(right, (matchers.And, matchers.Or))
134+
):
135+
return _params_are_sameish(
136+
tuple(left.matchers),
137+
tuple(right.matchers),
138+
)
139+
140+
if isinstance(left, matchers.Not) and isinstance(right, matchers.Not):
141+
return _values_are_sameish(left.matcher, right.matcher)
142+
143+
if isinstance(left, matchers.ArgThat) and isinstance(right, matchers.ArgThat):
144+
return left.predicate is right.predicate
145+
146+
if isinstance(left, matchers.Contains) and isinstance(right, matchers.Contains):
147+
return _values_are_sameish(left.sub, right.sub)
148+
149+
if isinstance(left, matchers.Matches) and isinstance(right, matchers.Matches):
150+
return (
151+
left.regex.pattern == right.regex.pattern
152+
and left.flags == right.flags
153+
)
154+
155+
if (
156+
isinstance(left, matchers.ArgumentCaptor)
157+
and isinstance(right, matchers.ArgumentCaptor)
158+
):
159+
return _values_are_sameish(left.matcher, right.matcher)
160+
161+
return _equals_or_identity(left, right)
162+
163+
164+
def _values_bind_distinct_captors(left: object, right: object) -> bool:
165+
left_binding = _captor_binding(left)
166+
right_binding = _captor_binding(right)
167+
168+
return (
169+
left_binding is not None
170+
and right_binding is not None
171+
and left_binding is not right_binding
172+
)
173+
174+
175+
def _captor_binding(value: object) -> object | None:
176+
if matchers.is_call_captor(value):
177+
return value
178+
179+
if isinstance(value, matchers.ArgumentCaptor):
180+
return value
181+
182+
if matchers.is_captor_args_sentinel(value):
183+
return value.captor
184+
185+
if matchers.is_captor_kwargs_sentinel(value):
186+
return value.captor
187+
188+
return None
189+
190+
191+
def _equals_or_identity(left: object, right: object) -> bool:
192+
try:
193+
return left == right
194+
except Exception:
195+
return left is right

0 commit comments

Comments
 (0)