Skip to content

Commit d180f35

Browse files
author
Nijat Khanbabayev
committed
Add extra stuff, need clean-up
Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com>
1 parent 2696fce commit d180f35

7 files changed

Lines changed: 1042 additions & 93 deletions

File tree

ccflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .context import *
1313
from .dep import *
1414
from .enums import Enum
15+
from .flow_model import FlowAPI, BoundModel, Lazy
1516
from .global_state import *
1617
from .local_persistence import *
1718
from .models import *

ccflow/callable.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,10 @@ def _resolve_deps_and_call(model, context, fn):
312312
# Get Dep-annotated fields for this model class
313313
dep_fields = _get_dep_fields(model.__class__)
314314

315-
if not dep_fields:
315+
# Check if model has custom deps (from @func.deps decorator)
316+
has_custom_deps = getattr(model.__class__, "__has_custom_deps__", False)
317+
318+
if not dep_fields and not has_custom_deps:
316319
return fn(model, context)
317320

318321
# Get dependencies from __deps__
@@ -324,27 +327,37 @@ def _resolve_deps_and_call(model, context, fn):
324327

325328
# Resolve dependencies and store in context var
326329
resolved_values = {}
327-
for field_name, dep in dep_fields.items():
328-
field_value = getattr(model, field_name, None)
329-
if field_value is None:
330-
continue
331-
332-
# Check if field is a CallableModel that needs resolution
333-
if not isinstance(field_value, _CallableModel):
334-
continue # Already a resolved value, skip
335330

336-
# Check if this field is in __deps__ (for custom transforms)
337-
if id(field_value) in dep_map:
338-
dep_model, contexts = dep_map[id(field_value)]
339-
# Call dependency with the (transformed) context
331+
# If custom deps, resolve ALL CallableModel fields from dep_map
332+
if has_custom_deps:
333+
for dep_model, contexts in deps_result:
340334
resolved = dep_model(contexts[0]) if contexts else dep_model(context)
341-
else:
342-
# Not in __deps__, use Dep annotation transform directly
343-
transformed_ctx = dep.apply(context)
344-
resolved = field_value(transformed_ctx)
335+
# Unwrap GenericResult if present (consistent with auto-detected deps)
336+
if hasattr(resolved, 'value'):
337+
resolved = resolved.value
338+
resolved_values[id(dep_model)] = resolved
339+
else:
340+
# Standard path: iterate over Dep-annotated fields
341+
for field_name, dep in dep_fields.items():
342+
field_value = getattr(model, field_name, None)
343+
if field_value is None:
344+
continue
345+
346+
# Check if field is a CallableModel that needs resolution
347+
if not isinstance(field_value, _CallableModel):
348+
continue # Already a resolved value, skip
349+
350+
# Check if this field is in __deps__ (for custom transforms)
351+
if id(field_value) in dep_map:
352+
dep_model, contexts = dep_map[id(field_value)]
353+
# Call dependency with the (transformed) context
354+
resolved = dep_model(contexts[0]) if contexts else dep_model(context)
355+
else:
356+
# Not in __deps__, use Dep annotation transform directly
357+
transformed_ctx = dep.apply(context)
358+
resolved = field_value(transformed_ctx)
345359

346-
# Store resolved value keyed by the CallableModel's id
347-
resolved_values[id(field_value)] = resolved
360+
resolved_values[id(field_value)] = resolved
348361

349362
# Store in context var and call function
350363
current_store = _resolved_deps.get()

ccflow/context.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import warnings
44
from datetime import date, datetime
5-
from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar
5+
from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar
66

77
from deprecated import deprecated
8-
from pydantic import field_validator, model_validator
8+
from pydantic import ConfigDict, field_validator, model_validator
99

1010
from .base import ContextBase
1111
from .exttypes import Frequency
@@ -15,6 +15,7 @@
1515

1616

1717
__all__ = (
18+
"FlowContext",
1819
"NullContext",
1920
"GenericContext",
2021
"DateContext",
@@ -93,6 +94,42 @@
9394
# Starting 0.8.0 Nullcontext is an alias to ContextBase
9495
NullContext = ContextBase
9596

97+
98+
class FlowContext(ContextBase):
99+
"""Universal context for @Flow.model functions.
100+
101+
Instead of generating a new ContextBase subclass for each @Flow.model,
102+
this single class with extra="allow" serves as the universal carrier.
103+
Validation happens via TypedDict + TypeAdapter at compute() time.
104+
105+
This design avoids:
106+
- Proliferation of dynamic _funcname_Context classes
107+
- Class registration overhead for serialization
108+
- Pickling issues with Ray/distributed computing
109+
110+
Fields are stored in __pydantic_extra__ and accessed via __getattr__.
111+
"""
112+
113+
model_config = ConfigDict(extra="allow", frozen=True)
114+
115+
def __getattr__(self, name: str) -> Any:
116+
"""Access fields stored in __pydantic_extra__."""
117+
# Use object.__getattribute__ to avoid infinite recursion
118+
try:
119+
extra = object.__getattribute__(self, "__pydantic_extra__")
120+
if extra is not None and name in extra:
121+
return extra[name]
122+
except AttributeError:
123+
pass
124+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
125+
126+
def __repr__(self) -> str:
127+
"""Show all fields including extra fields."""
128+
extra = object.__getattribute__(self, "__pydantic_extra__") or {}
129+
fields = ", ".join(f"{k}={v!r}" for k, v in extra.items())
130+
return f"FlowContext({fields})"
131+
132+
96133
C = TypeVar("C", bound=Hashable)
97134

98135

0 commit comments

Comments
 (0)