@@ -574,55 +574,49 @@ def _validate_callable_model_generic_type(cls, m, handler, info):
574574 raise ValueError (f"{ m } is not a CallableModel: { type (m )} " )
575575
576576 # Extract the generic types from the class definition
577- generic_base = None
578- for base in cls .__mro__ [1 :]:
579- if issubclass (base , CallableModelGenericType ):
580- # Found the generic base class, it should
581- # have either generic parameters or context/result
582- if hasattr (generic_base , "_context_type" ) and hasattr (generic_base , "_result_type" ):
583- generic_base = base
584- break
585- elif base .__pydantic_generic_metadata__ ["args" ]:
586- generic_base = base
587- break
588- # else continue
589-
590- if generic_base :
591- if hasattr (generic_base , "_context_type" ) and hasattr (generic_base , "_result_type" ):
592- # cls is subclass of generic_base which defines context_type and result_type
593- new_context_type = generic_base ._context_type
594- new_result_type = generic_base ._result_type
595- elif generic_base .__pydantic_generic_metadata__ ["args" ]:
596- # cls is subclass of generic_base which defines the generic types
597- # so use these as the context and result types
598- subtypes = generic_base .__pydantic_generic_metadata__ ["args" ]
599- if len (subtypes ) != 2 :
600- raise ValueError ("CallableModelGenericType must have exactly two generic type parameters: ContextType and ResultType" )
601- new_context_type = subtypes [0 ]
602- new_result_type = subtypes [1 ]
603- else :
604- raise ValueError (
605- "CallableModelGenericType must either define context_type and result_type properties, or have generic type parameters"
606- )
607- # Validate that the model's context_type and result_type match
608- orig_context_typ = _cached_signature (cls .__call__ ).parameters ["context" ].annotation
609- orig_return_typ = _cached_signature (cls .__call__ ).return_annotation
610- if orig_context_typ is not Signature .empty and orig_context_typ != new_context_type :
611- raise TypeError (
612- f"Context type annotation { orig_context_typ } on __call__ does not match context_type { new_context_type } defined by CallableModelGenericType"
613- )
614- if orig_return_typ is not Signature .empty and orig_return_typ != new_result_type :
615- raise TypeError (
616- f"Return type annotation { orig_return_typ } on __call__ does not match result_type { new_result_type } defined by CallableModelGenericType"
617- )
618-
619- # Set on class
620- cls ._context_type = new_context_type
621- cls ._result_type = new_result_type
622-
623- else :
624- subtypes = cls .__pydantic_generic_metadata__ ["args" ]
625- if subtypes :
626- TypeAdapter (Type [subtypes [0 ]]).validate_python (m .context_type )
627- TypeAdapter (Type [subtypes [1 ]]).validate_python (m .result_type )
577+ if not hasattr (cls , "_context_type" ) or not hasattr (cls , "_result_type" ):
578+ new_context_type = None
579+ new_result_type = None
580+ for base in cls .__mro__ [1 :]:
581+ if issubclass (base , CallableModelGenericType ):
582+ # Found the generic base class, it should
583+ # have either generic parameters or context/result
584+ if new_context_type is None and hasattr (base , "_context_type" ) and issubclass (base ._context_type , ContextBase ):
585+ new_context_type = base ._context_type
586+ if new_result_type is None and hasattr (base , "_result_type" ) and issubclass (base ._result_type , ResultBase ):
587+ new_result_type = base ._result_type
588+ if base .__pydantic_generic_metadata__ ["args" ]:
589+ for arg in base .__pydantic_generic_metadata__ ["args" ]:
590+ if new_context_type is None and isinstance (arg , type ) and issubclass (arg , ContextBase ):
591+ new_context_type = arg
592+ elif new_result_type is None and isinstance (arg , type ) and issubclass (arg , ResultBase ):
593+ # NOTE: ContextBase inherits from ResultBase, so order matters here!
594+ new_result_type = arg
595+ if new_context_type and new_result_type :
596+ break
597+ if new_context_type is not None :
598+ # Validate that the model's context_type match
599+ orig_context_typ = _cached_signature (cls .__call__ ).parameters ["context" ].annotation
600+ if orig_context_typ is not Signature .empty and orig_context_typ != new_context_type :
601+ raise TypeError (
602+ f"Context type annotation { orig_context_typ } on __call__ does not match context_type { new_context_type } defined by CallableModelGenericType"
603+ )
604+ # Set on class
605+ cls ._context_type = new_context_type
606+
607+ if new_result_type is not None :
608+ # Validate that the model's result_type match
609+ orig_return_typ = _cached_signature (cls .__call__ ).return_annotation
610+ if orig_return_typ is not Signature .empty and orig_return_typ != new_result_type :
611+ raise TypeError (
612+ f"Return type annotation { orig_return_typ } on __call__ does not match result_type { new_result_type } defined by CallableModelGenericType"
613+ )
614+
615+ # Set on class
616+ cls ._result_type = new_result_type
617+
618+ subtypes = cls .__pydantic_generic_metadata__ ["args" ]
619+ if subtypes :
620+ TypeAdapter (Type [subtypes [0 ]]).validate_python (m .context_type )
621+ TypeAdapter (Type [subtypes [1 ]]).validate_python (m .result_type )
628622 return m
0 commit comments