Skip to content

Commit ed1e365

Browse files
authored
Merge pull request #110 from kaste/stub-order
2 parents 29c5ec7 + ba08077 commit ed1e365

2 files changed

Lines changed: 136 additions & 6 deletions

File tree

mockito/invocation.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import inspect
2525
import operator
2626
from collections import deque
27+
from functools import cached_property
2728
from typing import TYPE_CHECKING
2829

2930
from . import matchers, signature
@@ -114,12 +115,11 @@ def __call__(self, *params: Any, **named_params: Any) -> Any | None:
114115
self._remember_params(params_without_first_arg, named_params)
115116
self.mock.remember(self)
116117

117-
for matching_invocation in self.mock.stubbed_invocations:
118-
if matching_invocation.matches(self):
119-
matching_invocation.should_answer(self)
120-
matching_invocation.capture_arguments(self)
121-
return matching_invocation.answer_first(
122-
*params, **named_params)
118+
matching_invocation = self._find_best_matching_stubbed_invocation()
119+
if matching_invocation is not None:
120+
matching_invocation.should_answer(self)
121+
matching_invocation.capture_arguments(self)
122+
return matching_invocation.answer_first(*params, **named_params)
123123

124124
if self.strict:
125125
stubbed_invocations = [
@@ -148,6 +148,21 @@ def __call__(self, *params: Any, **named_params: Any) -> Any | None:
148148

149149
return None
150150

151+
def _find_best_matching_stubbed_invocation(self) -> StubbedInvocation | None:
152+
candidates = [
153+
candidate
154+
for candidate in self.mock.stubbed_invocations
155+
if candidate.matches(self)
156+
]
157+
158+
if not candidates:
159+
return None
160+
161+
if len(candidates) == 1:
162+
return candidates[0]
163+
164+
return max(candidates, key=lambda candidate: candidate.specificity_score)
165+
151166

152167
class RememberedPropertyAccess(RememberedInvocation):
153168
def ensure_mocked_object_has_method(self, method_name):
@@ -478,6 +493,33 @@ def __call__(self, *params: Any, **named_params: Any) -> AnswerSelector:
478493
self.mock.finish_stubbing(self)
479494
return AnswerSelector(self, self.refers_coroutine, self.discard_first_arg)
480495

496+
@cached_property
497+
def specificity_score(self) -> tuple[int, int]:
498+
quality = 0
499+
500+
for value in self.params:
501+
if value is not matchers.ARGS_SENTINEL:
502+
quality += self._specificity_score(value)
503+
504+
for key, value in self.named_params.items():
505+
if key is not matchers.KWARGS_SENTINEL:
506+
quality += self._specificity_score(value)
507+
508+
coverage = len(self.params) + len(self.named_params)
509+
return coverage, quality
510+
511+
def _specificity_score(self, value: object) -> int:
512+
if value is Ellipsis:
513+
return 0
514+
515+
if isinstance(value, matchers.Any) and value.wanted_type is None:
516+
return 0
517+
518+
if isinstance(value, matchers.Matcher):
519+
return 1
520+
521+
return 3
522+
481523
def forget_self(self) -> None:
482524
if self in self.mock.stubbed_invocations:
483525
self.mock.forget_stubbed_invocation(self)

tests/stub_specificity_test.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import pytest
2+
3+
from mockito import any, args, kwargs, mock, when
4+
5+
6+
pytestmark = pytest.mark.usefixtures("unstub")
7+
8+
9+
class _Path:
10+
def exists(self, location):
11+
return f"orig:{location}"
12+
13+
14+
def test_literal_stub_beats_ellipsis_even_if_ellipsis_added_last():
15+
path = mock(_Path)
16+
17+
when(path).exists(".flake8").thenReturn("stubbed")
18+
when(path).exists(...).thenCallOriginalImplementation()
19+
20+
assert path.exists(".flake8") == "stubbed"
21+
assert path.exists("README.rst") == "orig:README.rst"
22+
23+
24+
def test_literal_stub_beats_ellipsis_even_if_literal_added_last():
25+
path = mock(_Path)
26+
27+
when(path).exists(...).thenCallOriginalImplementation()
28+
when(path).exists(".flake8").thenReturn("stubbed")
29+
30+
assert path.exists(".flake8") == "stubbed"
31+
assert path.exists("README.rst") == "orig:README.rst"
32+
33+
34+
def test_typed_any_is_more_specific_than_any_and_ellipsis():
35+
path = mock()
36+
37+
when(path).exists(...).thenReturn("ellipsis")
38+
when(path).exists(any()).thenReturn("any")
39+
when(path).exists(any(str)).thenReturn("typed-any")
40+
41+
assert path.exists(".flake8") == "typed-any"
42+
assert path.exists(1) == "any"
43+
44+
45+
def test_any_and_ellipsis_have_same_specificity_and_keep_last_wins_tie_break():
46+
path = mock()
47+
48+
when(path).exists(any()).thenReturn("any")
49+
when(path).exists(...).thenReturn("ellipsis")
50+
assert path.exists(1) == "ellipsis"
51+
52+
other = mock()
53+
when(other).exists(...).thenReturn("ellipsis")
54+
when(other).exists(any()).thenReturn("any")
55+
assert other.exists(1) == "any"
56+
57+
58+
def test_coverage_beats_quality_when_both_match():
59+
subject = mock()
60+
61+
when(subject).f("x", ...).thenReturn("prefix")
62+
when(subject).f(..., retry=..., headers=...).thenReturn("kwargs-shape")
63+
64+
assert subject.f("x", retry=5, headers={}) == "kwargs-shape"
65+
66+
67+
def test_literal_beats_matchers_when_coverage_is_equal():
68+
subject = mock()
69+
70+
when(subject).f("x", ...).thenReturn("prefix-fallback")
71+
when(subject).f(any(str), any(int)).thenReturn("typed-exact")
72+
73+
assert subject.f("x", 1) == "prefix-fallback"
74+
75+
76+
def test_args_and_kwargs_sentinels_have_same_weight_as_ellipsis():
77+
subject = mock()
78+
79+
when(subject).f(...).thenReturn("ellipsis")
80+
when(subject).f(*args).thenReturn("args")
81+
82+
assert subject.f(1) == "args"
83+
84+
other = mock()
85+
when(other).g(...).thenReturn("ellipsis")
86+
when(other).g(**kwargs).thenReturn("kwargs")
87+
88+
assert other.g(retry=1) == "kwargs"

0 commit comments

Comments
 (0)