@@ -735,76 +735,42 @@ def __init__(
735735 invocation : StubbedInvocation ,
736736 expects_awaitable : bool ,
737737 discard_first_arg : bool
738- ) -> None :
739- self .__impl = AnswerSelectorImpl (
740- invocation ,
741- expects_awaitable = expects_awaitable ,
742- discard_first_arg = discard_first_arg ,
743- )
744-
745- def thenReturn (self , * return_values : Any ) -> Self :
746- self .__impl .thenReturn (* return_values )
747- return self
748-
749- def thenRaise (self , * exceptions : Exception | type [Exception ]) -> Self :
750- self .__impl .thenRaise (* exceptions )
751- return self
752-
753- def thenAnswer (self , * callables : Callable ) -> Self :
754- self .__impl .thenAnswer (* callables )
755- return self
756-
757- def thenCallOriginalImplementation (self ) -> Self :
758- self .__impl .thenCallOriginalImplementation ()
759- return self
760-
761- def __getattr__ (self , method_name : str ) -> Callable [..., AnswerSelector ]:
762- return self .__impl .chain (method_name )
763-
764- def __enter__ (self ) -> None :
765- self .__impl .__enter__ ()
766-
767- def __exit__ (self , * exc_info ) -> None :
768- self .__impl .__exit__ (* exc_info )
769-
770-
771- class AnswerSelectorImpl (object ):
772- def __init__ (
773- self ,
774- invocation : StubbedInvocation ,
775- expects_awaitable : bool ,
776- discard_first_arg : bool ,
777738 ) -> None :
778739 self .invocation = invocation
779740 self .expects_awaitable = expects_awaitable
780741 self .discard_first_arg = discard_first_arg
781742
782- def thenReturn (self , * return_values : Any ) -> None :
743+ def thenReturn (self , * return_values : Any ) -> Self :
783744 for return_value in return_values or (None ,):
784- if self .expects_awaitable :
785- answer = return_awaitable (return_value )
786- else :
787- answer = return_ (return_value )
788- self .__then (answer )
745+ answer = (
746+ return_awaitable (return_value )
747+ if self .expects_awaitable
748+ else return_ (return_value )
749+ )
750+ self ._then (answer )
751+ return self
789752
790- def thenRaise (self , * exceptions : Exception | type [Exception ]) -> None :
753+ def thenRaise (self , * exceptions : Exception | type [Exception ]) -> Self :
791754 for exception in exceptions or (Exception ,):
792- if self .expects_awaitable :
793- answer = raise_awaitable (exception )
794- else :
795- answer = raise_ (exception )
796- self .__then (answer )
755+ answer = (
756+ raise_awaitable (exception )
757+ if self .expects_awaitable
758+ else raise_ (exception )
759+ )
760+ self ._then (answer )
761+ return self
797762
798- def thenAnswer (self , * callables : Callable ) -> None :
763+ def thenAnswer (self , * callables : Callable ) -> Self :
799764 for callable in callables or (return_ (None ),):
800765 answer = callable
801766 if self .discard_first_arg :
802767 answer = discard_self (answer )
803768 if self .expects_awaitable and not is_awaitable_when_called (callable ):
804769 answer = as_awaitable (answer )
805- self .__then (answer )
770+ self ._then (answer )
771+ return self
806772
807- def thenCallOriginalImplementation (self ) -> None :
773+ def thenCallOriginalImplementation (self ) -> Self :
808774 answer = self .invocation .mock .get_original_method (
809775 self .invocation .method_name
810776 )
@@ -818,8 +784,8 @@ def thenCallOriginalImplementation(self) -> None:
818784 self .invocation .method_name ,
819785 )
820786 )
821- self .__then (self ._property_descriptor_answer (answer ))
822- return
787+ self ._then (self ._property_descriptor_answer (answer ))
788+ return self
823789
824790 if answer is None :
825791 self .invocation .forget_self ()
@@ -840,21 +806,8 @@ def thenCallOriginalImplementation(self) -> None:
840806
841807 # `answer` is runtime-validated by stubbing setup and optional
842808 # unwrapping above, but mypy still sees `object` here.
843- self .__then (answer ) # type: ignore[arg-type]
844-
845- def _property_descriptor_answer (self , descriptor : Any ) -> Callable :
846- def answer (* args : Any , ** kwargs : Any ) -> Any :
847- obj , type_ = self .invocation .mock .get_current_property_access (
848- self .invocation .method_name
849- )
850- # Guarded by `hasattr(descriptor, '__get__')` in caller.
851- return descriptor .__get__ (obj , type_ )
852-
853- return answer
854-
855- def __then (self , answer : Callable ) -> None :
856- self .invocation .transition_to_value ()
857- self .invocation .add_answer (answer )
809+ self ._then (answer ) # type: ignore[arg-type]
810+ return self
858811
859812 def __enter__ (self ) -> None :
860813 pass
@@ -867,19 +820,19 @@ def __exit__(self, *exc_info) -> None:
867820 finally :
868821 self .invocation .forget_self ()
869822
870- def chain (self , method_name : str ) -> Callable [..., AnswerSelector ]:
871- def chain_invocation (* args : Any , ** kwargs : Any ) -> AnswerSelector :
872- continuation = self .invocation .transition_to_chain ()
873- verification = self .invocation .pop_verification ()
874- stub = StubbedInvocation (
875- continuation .chain_mock ,
876- method_name ,
877- verification = verification ,
878- parent_invocation = continuation .invocation ,
823+ def _property_descriptor_answer (self , descriptor : Any ) -> Callable :
824+ def answer (* args : Any , ** kwargs : Any ) -> Any :
825+ obj , type_ = self .invocation .mock .get_current_property_access (
826+ self .invocation .method_name
879827 )
880- return stub (* args , ** kwargs )
828+ # Guarded by `hasattr(descriptor, '__get__')` in caller.
829+ return descriptor .__get__ (obj , type_ )
881830
882- return chain_invocation
831+ return answer
832+
833+ def _then (self , answer : Callable ) -> None :
834+ self .invocation .transition_to_value ()
835+ self .invocation .add_answer (answer )
883836
884837
885838
0 commit comments