Skip to content

Commit 0ec047d

Browse files
authored
Merge pull request #112 from kaste/chains-2
2 parents cd7ccc4 + db37a8f commit 0ec047d

6 files changed

Lines changed: 264 additions & 184 deletions

File tree

mockito/invocation.py

Lines changed: 35 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -735,76 +735,42 @@ def __init__(
735735
invocation: StubbedInvocation,
736736
expects_awaitable: bool,
737737
discard_first_arg: bool
738-
) -> None:
739-
self.__impl = AnswerSelectorImpl(
740-
invocation,
741-
expects_awaitable=expects_awaitable,
742-
discard_first_arg=discard_first_arg,
743-
)
744-
745-
def thenReturn(self, *return_values: Any) -> Self:
746-
self.__impl.thenReturn(*return_values)
747-
return self
748-
749-
def thenRaise(self, *exceptions: Exception | type[Exception]) -> Self:
750-
self.__impl.thenRaise(*exceptions)
751-
return self
752-
753-
def thenAnswer(self, *callables: Callable) -> Self:
754-
self.__impl.thenAnswer(*callables)
755-
return self
756-
757-
def thenCallOriginalImplementation(self) -> Self:
758-
self.__impl.thenCallOriginalImplementation()
759-
return self
760-
761-
def __getattr__(self, method_name: str) -> Callable[..., AnswerSelector]:
762-
return self.__impl.chain(method_name)
763-
764-
def __enter__(self) -> None:
765-
self.__impl.__enter__()
766-
767-
def __exit__(self, *exc_info) -> None:
768-
self.__impl.__exit__(*exc_info)
769-
770-
771-
class AnswerSelectorImpl(object):
772-
def __init__(
773-
self,
774-
invocation: StubbedInvocation,
775-
expects_awaitable: bool,
776-
discard_first_arg: bool,
777738
) -> None:
778739
self.invocation = invocation
779740
self.expects_awaitable = expects_awaitable
780741
self.discard_first_arg = discard_first_arg
781742

782-
def thenReturn(self, *return_values: Any) -> None:
743+
def thenReturn(self, *return_values: Any) -> Self:
783744
for return_value in return_values or (None,):
784-
if self.expects_awaitable:
785-
answer = return_awaitable(return_value)
786-
else:
787-
answer = return_(return_value)
788-
self.__then(answer)
745+
answer = (
746+
return_awaitable(return_value)
747+
if self.expects_awaitable
748+
else return_(return_value)
749+
)
750+
self._then(answer)
751+
return self
789752

790-
def thenRaise(self, *exceptions: Exception | type[Exception]) -> None:
753+
def thenRaise(self, *exceptions: Exception | type[Exception]) -> Self:
791754
for exception in exceptions or (Exception,):
792-
if self.expects_awaitable:
793-
answer = raise_awaitable(exception)
794-
else:
795-
answer = raise_(exception)
796-
self.__then(answer)
755+
answer = (
756+
raise_awaitable(exception)
757+
if self.expects_awaitable
758+
else raise_(exception)
759+
)
760+
self._then(answer)
761+
return self
797762

798-
def thenAnswer(self, *callables: Callable) -> None:
763+
def thenAnswer(self, *callables: Callable) -> Self:
799764
for callable in callables or (return_(None),):
800765
answer = callable
801766
if self.discard_first_arg:
802767
answer = discard_self(answer)
803768
if self.expects_awaitable and not is_awaitable_when_called(callable):
804769
answer = as_awaitable(answer)
805-
self.__then(answer)
770+
self._then(answer)
771+
return self
806772

807-
def thenCallOriginalImplementation(self) -> None:
773+
def thenCallOriginalImplementation(self) -> Self:
808774
answer = self.invocation.mock.get_original_method(
809775
self.invocation.method_name
810776
)
@@ -818,8 +784,8 @@ def thenCallOriginalImplementation(self) -> None:
818784
self.invocation.method_name,
819785
)
820786
)
821-
self.__then(self._property_descriptor_answer(answer))
822-
return
787+
self._then(self._property_descriptor_answer(answer))
788+
return self
823789

824790
if answer is None:
825791
self.invocation.forget_self()
@@ -840,21 +806,8 @@ def thenCallOriginalImplementation(self) -> None:
840806

841807
# `answer` is runtime-validated by stubbing setup and optional
842808
# unwrapping above, but mypy still sees `object` here.
843-
self.__then(answer) # type: ignore[arg-type]
844-
845-
def _property_descriptor_answer(self, descriptor: Any) -> Callable:
846-
def answer(*args: Any, **kwargs: Any) -> Any:
847-
obj, type_ = self.invocation.mock.get_current_property_access(
848-
self.invocation.method_name
849-
)
850-
# Guarded by `hasattr(descriptor, '__get__')` in caller.
851-
return descriptor.__get__(obj, type_)
852-
853-
return answer
854-
855-
def __then(self, answer: Callable) -> None:
856-
self.invocation.transition_to_value()
857-
self.invocation.add_answer(answer)
809+
self._then(answer) # type: ignore[arg-type]
810+
return self
858811

859812
def __enter__(self) -> None:
860813
pass
@@ -867,19 +820,19 @@ def __exit__(self, *exc_info) -> None:
867820
finally:
868821
self.invocation.forget_self()
869822

870-
def chain(self, method_name: str) -> Callable[..., AnswerSelector]:
871-
def chain_invocation(*args: Any, **kwargs: Any) -> AnswerSelector:
872-
continuation = self.invocation.transition_to_chain()
873-
verification = self.invocation.pop_verification()
874-
stub = StubbedInvocation(
875-
continuation.chain_mock,
876-
method_name,
877-
verification=verification,
878-
parent_invocation=continuation.invocation,
823+
def _property_descriptor_answer(self, descriptor: Any) -> Callable:
824+
def answer(*args: Any, **kwargs: Any) -> Any:
825+
obj, type_ = self.invocation.mock.get_current_property_access(
826+
self.invocation.method_name
879827
)
880-
return stub(*args, **kwargs)
828+
# Guarded by `hasattr(descriptor, '__get__')` in caller.
829+
return descriptor.__get__(obj, type_)
881830

882-
return chain_invocation
831+
return answer
832+
833+
def _then(self, answer: Callable) -> None:
834+
self.invocation.transition_to_value()
835+
self.invocation.add_answer(answer)
883836

884837

885838

0 commit comments

Comments
 (0)