Skip to content

Commit 0121702

Browse files
author
Nijat Khanbabayev
committed
Simplify Flow.model code, update docs
Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com>
1 parent f3886fe commit 0121702

9 files changed

Lines changed: 899 additions & 538 deletions

File tree

ccflow/callable.py

Lines changed: 81 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,24 @@
1616
import logging
1717
from functools import lru_cache, wraps
1818
from inspect import Signature, isclass, signature
19-
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, cast, get_args, get_origin
19+
from typing import (
20+
TYPE_CHECKING,
21+
Annotated,
22+
Any,
23+
Callable,
24+
ClassVar,
25+
Dict,
26+
Generic,
27+
List,
28+
Optional,
29+
Tuple,
30+
Type,
31+
TypeVar,
32+
Union,
33+
cast,
34+
get_args,
35+
get_origin,
36+
)
2037

2138
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator
2239
from typing_extensions import override
@@ -65,11 +82,11 @@ def _cached_signature(fn):
6582
return signature(fn)
6683

6784

68-
def _callable_qualname(fn: Callable[..., Any]) -> str:
69-
return getattr(fn, "__qualname__", type(fn).__qualname__)
70-
71-
7285
def _declared_type_matches(actual: Any, expected: Any) -> bool:
86+
while get_origin(actual) is Annotated:
87+
actual = get_args(actual)[0]
88+
while get_origin(expected) is Annotated:
89+
expected = get_args(expected)[0]
7390
if isinstance(expected, TypeVar):
7491
return True
7592
if get_origin(expected) is Union:
@@ -293,25 +310,13 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
293310
if not isinstance(model, CallableModel):
294311
raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.")
295312

296-
# Check if this is an auto_context decorated method
297-
has_auto_context = hasattr(fn, "__auto_context__")
298-
if has_auto_context:
299-
method_context_type = fn.__auto_context__
300-
else:
301-
method_context_type = model.context_type
302-
303-
# Validate context type (skip for auto contexts which are always valid ContextBase subclasses)
304-
if not has_auto_context:
305-
if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not (
306-
get_origin(model.context_type) is Union and type(None) in get_args(model.context_type)
307-
):
308-
raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase")
309-
310-
# Validate result type - use __result_type__ for auto contexts if available
311-
if has_auto_context and hasattr(fn, "__result_type__"):
312-
method_result_type = fn.__result_type__
313-
else:
314-
method_result_type = model.result_type
313+
method_context_type = getattr(fn, "__auto_context__", model.context_type)
314+
method_result_type = getattr(fn, "__result_type__", model.result_type)
315+
316+
if (not isclass(method_context_type) or not issubclass(method_context_type, ContextBase)) and not (
317+
get_origin(method_context_type) is Union and type(None) in get_args(method_context_type)
318+
):
319+
raise TypeError(f"Context type {method_context_type} must be a subclass of ContextBase")
315320
if (not isclass(method_result_type) or not issubclass(method_result_type, ResultBase)) and not (
316321
get_origin(method_result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(method_result_type))
317322
):
@@ -332,12 +337,13 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
332337
raise TypeError(f"{fn.__name__}() was passed a context and got an unexpected keyword argument '{next(iter(kwargs.keys()))}'")
333338

334339
# Type coercion on input. We do this here (rather than relying on ModelEvaluationContext) as it produces a nicer traceback/error message
335-
if not isinstance(context, method_context_type):
336-
if get_origin(method_context_type) is Union and type(None) in get_args(method_context_type):
337-
coerce_context_type = [t for t in get_args(method_context_type) if t is not type(None)][0]
338-
else:
339-
coerce_context_type = method_context_type
340-
context = coerce_context_type.model_validate(context)
340+
if get_origin(method_context_type) is Union and type(None) in get_args(method_context_type):
341+
if context is not None:
342+
method_context_type = [t for t in get_args(method_context_type) if t is not type(None)][0]
343+
if not isinstance(context, method_context_type):
344+
context = method_context_type.model_validate(context)
345+
elif not isinstance(context, method_context_type):
346+
context = method_context_type.model_validate(context)
341347

342348
if fn != getattr(model.__class__, fn.__name__).__wrapped__:
343349
# This happens when super().__call__ is used when implementing a CallableModel that derives from another one.
@@ -356,7 +362,6 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
356362
wrap_any.get_options = self.get_options
357363
wrap_any.get_evaluation_context = get_evaluation_context
358364

359-
# Preserve auto context attributes for introspection
360365
if hasattr(fn, "__auto_context__"):
361366
wrap_any.__auto_context__ = fn.__auto_context__
362367
if hasattr(fn, "__result_type__"):
@@ -476,19 +481,12 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult:
476481
# with infrastructure expecting DateContext instances.
477482
478483
"""
479-
# Extract auto_context option (not part of FlowOptions)
480-
# Can be: False, True, or a ContextBase subclass
481484
auto_context = kwargs.pop("auto_context", False)
482-
483-
# Determine if auto_context is enabled and extract parent class if provided
484485
if auto_context is False:
485-
auto_context_enabled = False
486486
context_parent = None
487487
elif auto_context is True:
488-
auto_context_enabled = True
489-
context_parent = None
488+
context_parent = ContextBase
490489
elif isclass(auto_context) and issubclass(auto_context, ContextBase):
491-
auto_context_enabled = True
492490
context_parent = auto_context
493491
else:
494492
raise TypeError(f"auto_context must be False, True, or a ContextBase subclass, got {auto_context!r}")
@@ -501,7 +499,7 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult:
501499
else:
502500
# Arguments to decorator, this is just returning the decorator
503501
# Note that the code below is executed only once
504-
if auto_context_enabled:
502+
if context_parent is not None:
505503
# Return a decorator that first applies auto_context, then FlowOptions
506504
def auto_context_decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
507505
wrapped = _apply_auto_context(fn, parent=context_parent)
@@ -589,6 +587,13 @@ def load_prices(
589587

590588
return flow_model(*args, **kwargs)
591589

590+
@staticmethod
591+
def transform(*args, **kwargs):
592+
"""Decorator that turns a top-level function into a serializable with_inputs() transform factory."""
593+
from .flow_model import flow_transform
594+
595+
return flow_transform(*args, **kwargs)
596+
592597

593598
# *****************************************************************************
594599
# Define "Evaluators" and associated types
@@ -631,7 +636,16 @@ def _context_validator(cls, values: Any, handler: Any, info: Any):
631636
if isinstance(values, dict):
632637
model = values.get("model")
633638
if model and isinstance(model, CallableModel) and not isinstance(values.get("context"), model.context_type):
634-
values["context"] = model.context_type.model_validate(values.get("context"))
639+
ctx_type = model.context_type
640+
ctx_value = values.get("context")
641+
# Handle Optional[ContextType] — if context is None, keep it; otherwise validate through the inner type
642+
if get_origin(ctx_type) is Union and type(None) in get_args(ctx_type):
643+
if ctx_value is not None:
644+
inner_type = [t for t in get_args(ctx_type) if t is not type(None)][0]
645+
if not isinstance(ctx_value, inner_type):
646+
values["context"] = inner_type.model_validate(ctx_value)
647+
else:
648+
values["context"] = ctx_type.model_validate(ctx_value)
635649

636650
# Apply standard pydantic validation
637651
context = handler(values)
@@ -965,31 +979,41 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult:
965979
model = MyCallable()
966980
model(x=42, y="hello") # Works with kwargs
967981
"""
968-
sig = signature(func)
982+
from .flow_model import _callable_qualname, _resolved_flow_signature
983+
984+
sig = _resolved_flow_signature(
985+
func,
986+
skip_self=True,
987+
require_return_annotation=True,
988+
annotation_error_suffix=" when auto_context=True",
989+
return_error_suffix=" when auto_context=True",
990+
function_name=_callable_qualname(func),
991+
)
969992
base_class = parent or ContextBase
970993

971-
if sig.return_annotation is inspect.Signature.empty:
972-
raise TypeError(f"Function {_callable_qualname(func)} must have a return type annotation when auto_context=True")
973-
974994
# Validate parent fields are in function signature
975995
if parent is not None:
976996
parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys())
977-
sig_params = set(sig.parameters.keys()) - {"self"}
997+
sig_params = set(sig.parameters)
978998
missing = parent_fields - sig_params
979999
if missing:
9801000
raise TypeError(f"Parent context fields {missing} must be included in function signature")
9811001

982-
# Build fields from parameters (skip 'self'), pydantic validates types
983-
fields = {}
984-
for name, param in sig.parameters.items():
985-
if name == "self":
986-
continue
987-
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
988-
raise TypeError(f"Function {_callable_qualname(func)} does not support {param.kind.description} when auto_context=True")
989-
if param.annotation is inspect.Parameter.empty:
990-
raise TypeError(f"Parameter '{name}' must have a type annotation when auto_context=True")
991-
default = ... if param.default is inspect.Parameter.empty else param.default
992-
fields[name] = (param.annotation, default)
1002+
# Validate parent field type compatibility
1003+
from .flow_model import _context_type_annotations_compatible
1004+
1005+
for fname in parent_fields:
1006+
parent_annotation = parent.model_fields[fname].annotation
1007+
func_annotation = sig.parameters[fname].annotation
1008+
if func_annotation is inspect.Parameter.empty:
1009+
continue
1010+
if not _context_type_annotations_compatible(func_annotation, parent_annotation):
1011+
raise TypeError(
1012+
f"auto_context field '{fname}' has annotation {func_annotation!r} which is incompatible "
1013+
f"with parent field annotation {parent_annotation!r}"
1014+
)
1015+
1016+
fields = {name: (param.annotation, ... if param.default is inspect.Parameter.empty else param.default) for name, param in sig.parameters.items()}
9931017

9941018
# Create auto context class
9951019
auto_context_class = create_ccflow_model(f"{_callable_qualname(func)}_AutoContext", __base__=base_class, **fields)

ccflow/context.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar
66

77
from deprecated import deprecated
8-
from pydantic import ConfigDict, field_validator, model_validator
8+
from pydantic import ConfigDict, PrivateAttr, field_validator, model_validator
99

1010
from .base import ContextBase
1111
from .exttypes import Frequency
@@ -106,19 +106,30 @@ class FlowContext(ContextBase):
106106
"""
107107

108108
model_config = ConfigDict(extra="allow", frozen=True)
109+
_frozen_hash_key: Hashable | None = PrivateAttr(default=None)
110+
_hash_value: int | None = PrivateAttr(default=None)
111+
112+
def _hash_key(self) -> Hashable:
113+
if self._frozen_hash_key is None:
114+
self._frozen_hash_key = _freeze_for_hash(self.model_dump(mode="python"))
115+
return self._frozen_hash_key
109116

110117
def __eq__(self, other: Any) -> bool:
118+
if self is other:
119+
return True
111120
if not isinstance(other, FlowContext):
112121
return False
113-
return self.model_dump(mode="python") == other.model_dump(mode="python")
122+
return self._hash_key() == other._hash_key()
114123

115124
def __hash__(self) -> int:
116-
return hash(_freeze_for_hash(self.model_dump(mode="python")))
125+
if self._hash_value is None:
126+
self._hash_value = hash(self._hash_key())
127+
return self._hash_value
117128

118129

119130
def _freeze_for_hash(value: Any) -> Hashable:
120131
if isinstance(value, Mapping):
121-
return tuple(sorted((key, _freeze_for_hash(item)) for key, item in value.items()))
132+
return tuple(sorted(((key, _freeze_for_hash(item)) for key, item in value.items()), key=lambda item: repr(item[0])))
122133
if isinstance(value, (list, tuple)):
123134
return tuple(_freeze_for_hash(item) for item in value)
124135
if isinstance(value, (set, frozenset)):
@@ -130,7 +141,7 @@ def _freeze_for_hash(value: Any) -> Hashable:
130141
except TypeError as exc:
131142
if hasattr(value, "__dict__"):
132143
return (type(value), _freeze_for_hash(vars(value)))
133-
raise TypeError(f"FlowContext contains an unhashable value of type {type(value).__name__}") from exc
144+
raise TypeError(f"FlowContext contains an unhashable value of type {type(value).__name__}: {value!r}") from exc
134145
return value
135146

136147

0 commit comments

Comments
 (0)