Skip to content

Commit e80a6ce

Browse files
authored
Merge pull request #141 from Point72/tkp/gen2
Allow partial generics in inheritance tree
2 parents dabafc6 + 7098baa commit e80a6ce

2 files changed

Lines changed: 74 additions & 52 deletions

File tree

ccflow/callable.py

Lines changed: 45 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -574,55 +574,49 @@ def _validate_callable_model_generic_type(cls, m, handler, info):
574574
raise ValueError(f"{m} is not a CallableModel: {type(m)}")
575575

576576
# Extract the generic types from the class definition
577-
generic_base = None
578-
for base in cls.__mro__[1:]:
579-
if issubclass(base, CallableModelGenericType):
580-
# Found the generic base class, it should
581-
# have either generic parameters or context/result
582-
if hasattr(generic_base, "_context_type") and hasattr(generic_base, "_result_type"):
583-
generic_base = base
584-
break
585-
elif base.__pydantic_generic_metadata__["args"]:
586-
generic_base = base
587-
break
588-
# else continue
589-
590-
if generic_base:
591-
if hasattr(generic_base, "_context_type") and hasattr(generic_base, "_result_type"):
592-
# cls is subclass of generic_base which defines context_type and result_type
593-
new_context_type = generic_base._context_type
594-
new_result_type = generic_base._result_type
595-
elif generic_base.__pydantic_generic_metadata__["args"]:
596-
# cls is subclass of generic_base which defines the generic types
597-
# so use these as the context and result types
598-
subtypes = generic_base.__pydantic_generic_metadata__["args"]
599-
if len(subtypes) != 2:
600-
raise ValueError("CallableModelGenericType must have exactly two generic type parameters: ContextType and ResultType")
601-
new_context_type = subtypes[0]
602-
new_result_type = subtypes[1]
603-
else:
604-
raise ValueError(
605-
"CallableModelGenericType must either define context_type and result_type properties, or have generic type parameters"
606-
)
607-
# Validate that the model's context_type and result_type match
608-
orig_context_typ = _cached_signature(cls.__call__).parameters["context"].annotation
609-
orig_return_typ = _cached_signature(cls.__call__).return_annotation
610-
if orig_context_typ is not Signature.empty and orig_context_typ != new_context_type:
611-
raise TypeError(
612-
f"Context type annotation {orig_context_typ} on __call__ does not match context_type {new_context_type} defined by CallableModelGenericType"
613-
)
614-
if orig_return_typ is not Signature.empty and orig_return_typ != new_result_type:
615-
raise TypeError(
616-
f"Return type annotation {orig_return_typ} on __call__ does not match result_type {new_result_type} defined by CallableModelGenericType"
617-
)
618-
619-
# Set on class
620-
cls._context_type = new_context_type
621-
cls._result_type = new_result_type
622-
623-
else:
624-
subtypes = cls.__pydantic_generic_metadata__["args"]
625-
if subtypes:
626-
TypeAdapter(Type[subtypes[0]]).validate_python(m.context_type)
627-
TypeAdapter(Type[subtypes[1]]).validate_python(m.result_type)
577+
if not hasattr(cls, "_context_type") or not hasattr(cls, "_result_type"):
578+
new_context_type = None
579+
new_result_type = None
580+
for base in cls.__mro__[1:]:
581+
if issubclass(base, CallableModelGenericType):
582+
# Found the generic base class, it should
583+
# have either generic parameters or context/result
584+
if new_context_type is None and hasattr(base, "_context_type") and issubclass(base._context_type, ContextBase):
585+
new_context_type = base._context_type
586+
if new_result_type is None and hasattr(base, "_result_type") and issubclass(base._result_type, ResultBase):
587+
new_result_type = base._result_type
588+
if base.__pydantic_generic_metadata__["args"]:
589+
for arg in base.__pydantic_generic_metadata__["args"]:
590+
if new_context_type is None and isinstance(arg, type) and issubclass(arg, ContextBase):
591+
new_context_type = arg
592+
elif new_result_type is None and isinstance(arg, type) and issubclass(arg, ResultBase):
593+
# NOTE: ContextBase inherits from ResultBase, so order matters here!
594+
new_result_type = arg
595+
if new_context_type and new_result_type:
596+
break
597+
if new_context_type is not None:
598+
# Validate that the model's context_type match
599+
orig_context_typ = _cached_signature(cls.__call__).parameters["context"].annotation
600+
if orig_context_typ is not Signature.empty and orig_context_typ != new_context_type:
601+
raise TypeError(
602+
f"Context type annotation {orig_context_typ} on __call__ does not match context_type {new_context_type} defined by CallableModelGenericType"
603+
)
604+
# Set on class
605+
cls._context_type = new_context_type
606+
607+
if new_result_type is not None:
608+
# Validate that the model's result_type match
609+
orig_return_typ = _cached_signature(cls.__call__).return_annotation
610+
if orig_return_typ is not Signature.empty and orig_return_typ != new_result_type:
611+
raise TypeError(
612+
f"Return type annotation {orig_return_typ} on __call__ does not match result_type {new_result_type} defined by CallableModelGenericType"
613+
)
614+
615+
# Set on class
616+
cls._result_type = new_result_type
617+
618+
subtypes = cls.__pydantic_generic_metadata__["args"]
619+
if subtypes:
620+
TypeAdapter(Type[subtypes[0]]).validate_python(m.context_type)
621+
TypeAdapter(Type[subtypes[1]]).validate_python(m.result_type)
628622
return m

ccflow/tests/test_callable.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, TypeVar
1+
from typing import Generic, List, TypeVar
22
from unittest import TestCase
33

44
from pydantic import ValidationError
@@ -365,6 +365,34 @@ def __call__(self, context: NullContext) -> GenericResult[int]:
365365
res2 = m2(NullContext())
366366
self.assertEqual(res2.value, 42)
367367

368+
def test_use_as_base_class_mixed_annotations(self):
369+
class Base(CallableModelGenericType[ContextType, ResultType], Generic[ContextType, ResultType]): ...
370+
371+
class Next(Base[ContextType, ResultType], Generic[ContextType, ResultType]): ...
372+
373+
class Partial(Next[NullContext, ResultType], Generic[ResultType]): ...
374+
375+
class Last(Partial[GenericResult[int]]):
376+
@Flow.call
377+
def __call__(self, context: NullContext) -> GenericResult[int]:
378+
return GenericResult[int](value=42)
379+
380+
Last()
381+
382+
def test_use_as_base_class_mixed_annotations_reversed(self):
383+
class Base(CallableModelGenericType[ContextType, ResultType], Generic[ContextType, ResultType]): ...
384+
385+
class Next(Base[ContextType, ResultType], Generic[ContextType, ResultType]): ...
386+
387+
class Partial(Next[ContextType, GenericResult[int]], Generic[ContextType]): ...
388+
389+
class Last(Partial[NullContext]):
390+
@Flow.call
391+
def __call__(self, context: NullContext) -> GenericResult[int]:
392+
return GenericResult[int](value=42)
393+
394+
Last()
395+
368396
def test_use_as_base_class_conflict(self):
369397
class MyCallable(CallableModelGenericType[NullContext, GenericResult[int]]):
370398
@Flow.call

0 commit comments

Comments
 (0)