Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- '3.11'
dependencies:
- ''
- '"pandas<2" "numpy<2" "xarray<2025.09.0" "dask<2024.7.0"'
- '"pandas<2" "numpy<2" "xarray<2025.09.0"'
- '"pandas<3"'
- '"pandas<4"'

Expand Down
49 changes: 49 additions & 0 deletions ccflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from .exttypes.pyobjectpath import PyObjectPath
from .local_persistence import register_ccflow_import_path, sync_to_module
from .utils.tokenize import DefaultTokenizer, Tokenizer, normalize_token

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -195,6 +196,42 @@ def type_(self) -> PyObjectPath:
# We want to track under what names a model has been registered
_registrations: List[Tuple["ModelRegistry", str]] = PrivateAttr(default_factory=list)

# Tokenization support
__ccflow_tokenizer__: ClassVar[Tokenizer] = DefaultTokenizer.with_bytecode()
_model_token: Optional[str] = PrivateAttr(default=None)

@property
def model_token(self) -> str:
"""Return a deterministic content hash of this model.

Token caching is controlled by ``cache_token`` in model_config.
By default, only frozen models cache their token (safe — immutable,
never stale). Mutable models recompute on every access.
Set ``cache_token=True`` on a mutable model to opt in to caching
(e.g. for large data that is expensive to tokenize and won't change).
"""
cache = self.model_config.get("cache_token", self.model_config.get("frozen", False))
if cache and self._model_token is not None:
return self._model_token
token = self.__ccflow_tokenizer__.tokenize(self)
if cache:
self.__pydantic_private__["_model_token"] = token
return token

@model_validator(mode="after")
def _clear_token_cache(self):
"""Clear the cached token on construction and field assignment."""
if self.model_config.get("cache_token", self.model_config.get("frozen", False)):
self.__pydantic_private__["_model_token"] = None
return self

def model_copy(self, *, update=None, deep=False):
"""Override model_copy to clear the stale token cache on the copy."""
copy = super().model_copy(update=update, deep=deep)
if update and copy.__pydantic_private__ is not None:
copy.__pydantic_private__["_model_token"] = None
return copy

model_config = ConfigDict(
# Note that validate_assignment only partially works: https://github.com/pydantic/pydantic/issues/7105
validate_assignment=True,
Expand Down Expand Up @@ -316,6 +353,18 @@ def __getstate__(self):
def __setstate__(self, state):
state["__pydantic_fields_set__"] = set(state["__pydantic_fields_set__"])
super().__setstate__(state)
# Clear stale token cache from pickle
if self.__pydantic_private__ is not None and "_model_token" in self.__pydantic_private__:
self.__pydantic_private__["_model_token"] = None


# Register ccflow BaseModel-specific normalize_token handler
# Delegates to the model's tokenizer so normalization is consistent
# regardless of whether the model is accessed via model_token or
# encountered as a value inside a container.
@normalize_token.register(BaseModel)
def _normalize_ccflow_basemodel(obj):
return obj.__ccflow_tokenizer__.normalize(obj)


class _ModelRegistryData(PydanticBaseModel):
Expand Down
74 changes: 73 additions & 1 deletion ccflow/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"FlowOptionsDeps",
"FlowOptionsOverride",
"ModelEvaluationContext",
"TransparentModelEvaluationContext",
"EvaluatorBase",
"Evaluator",
"WrapperModel",
Expand Down Expand Up @@ -262,7 +263,7 @@ def get_evaluation_context(model: CallableModelType, context: ContextType, as_di
if as_dict:
return dict(model=evaluator, context=evaluation_context)
else:
return ModelEvaluationContext(model=evaluator, context=evaluation_context)
return evaluator.make_evaluation_context(evaluation_context)

# The decorator implementation
def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = None, **kwargs):
Expand Down Expand Up @@ -450,6 +451,40 @@ class ModelEvaluationContext(
# Otherwise, the validation will re-run fully despite the models already being validated on construction
# TODO: Make the instance check compatible with the generic types instead of the base type

@property
def model_token(self) -> str:
"""Compute a cache-key token for this MEC chain.

Walks the MEC chain, strips ``TransparentModelEvaluationContext``
layers, and tokenizes the innermost context plus any opaque evaluators.
"""
cache = self.model_config.get("cache_token", True)
if cache and self._model_token is not None:
return self._model_token

fn = self.fn
non_transparent = []
current = self
while isinstance(current.context, ModelEvaluationContext):
fn = current.fn if current.fn != "__call__" else fn
if not isinstance(current, TransparentModelEvaluationContext):
non_transparent.append(current.model)
current = current.context

# Build a canonical representation from the innermost MEC
from .utils.tokenize import normalize_token

inner_norm = normalize_token(current)
effective_fn = fn if fn != "__call__" else current.fn
parts = (inner_norm, effective_fn)
if non_transparent:
parts = parts + (tuple(normalize_token(e) for e in non_transparent),)
token = self.__ccflow_tokenizer__.hash_canonical(parts)

if cache:
self.__pydantic_private__["_model_token"] = token
return token

@model_validator(mode="wrap")
def _context_validator(cls, values, handler, info):
"""Override _context_validator from parent"""
Expand Down Expand Up @@ -510,10 +545,47 @@ def __deps__(self, context: ModelEvaluationContext) -> GraphDepList:
def __exit__(self):
pass

def is_transparent(self, context: ModelEvaluationContext) -> bool:
"""Whether this evaluator does NOT modify the return value for the given context.

Transparent evaluators may add side effects (logging, caching, timing,
dependency ordering) but always return the same value as ``context()``.
This allows cache key computation and dependency graph deduplication to
skip these layers.

Override this method to return ``True`` for evaluators that are always
transparent, or implement context-dependent logic for evaluators that
are only sometimes transparent.
"""
return False

def make_evaluation_context(self, context: ModelEvaluationContext, **kwargs) -> ModelEvaluationContext:
"""Create a ModelEvaluationContext wrapping this evaluator around the given context.

Returns a ``TransparentModelEvaluationContext`` when ``is_transparent(context)``
is ``True``, signaling that this layer can be skipped for cache key computation.
"""
if self.is_transparent(context):
return TransparentModelEvaluationContext(model=self, context=context, **kwargs)
return ModelEvaluationContext(model=self, context=context, **kwargs)


class TransparentModelEvaluationContext(ModelEvaluationContext):
"""A ModelEvaluationContext layer that is safe to skip for cache key computation.

Created by ``EvaluatorBase.make_evaluation_context()`` when the evaluator's
``is_transparent()`` returns ``True``. Signals that this evaluator layer does
not modify the return value and can be ignored when computing cache keys or
deduplicating dependency graph nodes.
"""


class Evaluator(EvaluatorBase):
"""A higher-order model that evaluates a function on a CallableModel and a Context."""

def is_transparent(self, context: ModelEvaluationContext) -> bool:
return True

@override
def __call__(self, context: ModelEvaluationContext) -> ResultType:
return context()
Expand Down
48 changes: 41 additions & 7 deletions ccflow/evaluators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Optional, Set, Union

import dask.base
from pydantic import Field, PrivateAttr, field_validator
from typing_extensions import override

from ..base import BaseModel, make_lazy_result
from ..callable import CallableModel, ContextBase, EvaluatorBase, ModelEvaluationContext, ResultType
from ..callable import (
CallableModel,
ContextBase,
EvaluatorBase,
ModelEvaluationContext,
ResultType,
)

__all__ = [
"cache_key",
Expand Down Expand Up @@ -53,16 +58,25 @@ def combine_evaluators(first: Optional[EvaluatorBase], second: Optional[Evaluato


class MultiEvaluator(EvaluatorBase):
"""An evaluator that combines multiple evaluators."""
"""An evaluator that combines multiple evaluators.

Each child evaluator is wrapped in a ModelEvaluationContext using its own
``make_evaluation_context()`` method, so transparent children produce
``TransparentModelEvaluationContext`` layers that can be skipped during
cache key computation.
"""

evaluators: List[EvaluatorBase] = Field(
description="The list of evaluators to combine. The first evaluator in the list will be called first during evaluation."
)

def is_transparent(self, context: ModelEvaluationContext) -> bool:
return all(e.is_transparent(context) for e in self.evaluators)

@override
def __call__(self, context: ModelEvaluationContext) -> ResultType:
for evaluator in self.evaluators:
context = ModelEvaluationContext(model=evaluator, context=context, options=context.options)
context = evaluator.make_evaluation_context(context, options=context.options)
return context()


Expand All @@ -71,6 +85,9 @@ class FallbackEvaluator(EvaluatorBase):

evaluators: List[EvaluatorBase] = Field(description="The list of evaluators to try (in order).")

def is_transparent(self, context: ModelEvaluationContext) -> bool:
return all(e.is_transparent(context) for e in self.evaluators)

@override
def __call__(self, context: ModelEvaluationContext) -> ResultType:
for evaluator in self.evaluators:
Expand Down Expand Up @@ -120,6 +137,9 @@ class LoggingEvaluator(EvaluatorBase):
log_result: bool = Field(False, description="Whether to log the result of the evaluation")
format_config: FormatConfig = Field(FormatConfig(), description="Configuration for formatting the result of the evaluation if log_result=True")

def is_transparent(self, context: ModelEvaluationContext) -> bool:
return True

@field_validator("log_level", mode="before")
@classmethod
def _validate_log_level(cls, v: Union[int, str]) -> int:
Expand Down Expand Up @@ -195,13 +215,18 @@ def _format_result(self, result: ResultType) -> str:


def cache_key(flow_obj: Union[ModelEvaluationContext, ContextBase, CallableModel]) -> bytes:
"""Returns a key suitable for use in caching.
"""Returns a key suitable for use in caching and dependency graph deduplication.

For ``ModelEvaluationContext`` inputs, strips ``TransparentModelEvaluationContext``
layers (evaluators that don't modify the return value) so that the key depends
only on the underlying model, context, fn, options, and any non-transparent
evaluators in the chain.

Args:
flow_obj: The object to be tokenized to form the cache key.
"""
if isinstance(flow_obj, (ModelEvaluationContext, ContextBase, CallableModel)):
return dask.base.tokenize(flow_obj.model_dump(mode="python")).encode("utf-8")
return flow_obj.model_token.encode("utf-8")
else:
raise TypeError(f"object of type {type(flow_obj)} cannot be serialized by this function!")

Expand All @@ -213,8 +238,14 @@ class MemoryCacheEvaluator(EvaluatorBase):
_cache: Dict[bytes, ResultType] = PrivateAttr({})
_ids: Dict[bytes, ModelEvaluationContext] = PrivateAttr({})

def is_transparent(self, context: ModelEvaluationContext) -> bool:
return True

def key(self, context: ModelEvaluationContext):
"""Function to convert a ModelEvaluationContext to a key"""
"""Function to convert a ModelEvaluationContext to a cache key.

Delegates to ``cache_key()`` which strips transparent evaluator layers.
"""
return cache_key(context)

@property
Expand Down Expand Up @@ -289,6 +320,9 @@ class GraphEvaluator(EvaluatorBase):

_is_evaluating: bool = PrivateAttr(False)

def is_transparent(self, context: ModelEvaluationContext) -> bool:
return True

@override
def __call__(self, context: ModelEvaluationContext) -> ResultType:
import graphlib
Expand Down
Loading
Loading