diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 7c396e0..adf7526 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -35,7 +35,7 @@ jobs: - '3.11' dependencies: - '' - - '"pandas<2" "numpy<2" "xarray<2025.09.0" "dask<2024.7.0"' + - '"pandas<2" "numpy<2" "xarray<2025.09.0"' - '"pandas<3"' - '"pandas<4"' diff --git a/ccflow/base.py b/ccflow/base.py index 0962d38..001dd40 100644 --- a/ccflow/base.py +++ b/ccflow/base.py @@ -31,6 +31,7 @@ from .exttypes.pyobjectpath import PyObjectPath from .local_persistence import register_ccflow_import_path, sync_to_module +from .utils.tokenize import DefaultTokenizer, Tokenizer, normalize_token log = logging.getLogger(__name__) @@ -195,6 +196,42 @@ def type_(self) -> PyObjectPath: # We want to track under what names a model has been registered _registrations: List[Tuple["ModelRegistry", str]] = PrivateAttr(default_factory=list) + # Tokenization support + __ccflow_tokenizer__: ClassVar[Tokenizer] = DefaultTokenizer.with_bytecode() + _model_token: Optional[str] = PrivateAttr(default=None) + + @property + def model_token(self) -> str: + """Return a deterministic content hash of this model. + + Token caching is controlled by ``cache_token`` in model_config. + By default, only frozen models cache their token (safe — immutable, + never stale). Mutable models recompute on every access. + Set ``cache_token=True`` on a mutable model to opt in to caching + (e.g. for large data that is expensive to tokenize and won't change). + """ + cache = self.model_config.get("cache_token", self.model_config.get("frozen", False)) + if cache and self._model_token is not None: + return self._model_token + token = self.__ccflow_tokenizer__.tokenize(self) + if cache: + self.__pydantic_private__["_model_token"] = token + return token + + @model_validator(mode="after") + def _clear_token_cache(self): + """Clear the cached token on construction and field assignment.""" + if self.model_config.get("cache_token", self.model_config.get("frozen", False)): + self.__pydantic_private__["_model_token"] = None + return self + + def model_copy(self, *, update=None, deep=False): + """Override model_copy to clear the stale token cache on the copy.""" + copy = super().model_copy(update=update, deep=deep) + if update and copy.__pydantic_private__ is not None: + copy.__pydantic_private__["_model_token"] = None + return copy + model_config = ConfigDict( # Note that validate_assignment only partially works: https://github.com/pydantic/pydantic/issues/7105 validate_assignment=True, @@ -316,6 +353,18 @@ def __getstate__(self): def __setstate__(self, state): state["__pydantic_fields_set__"] = set(state["__pydantic_fields_set__"]) super().__setstate__(state) + # Clear stale token cache from pickle + if self.__pydantic_private__ is not None and "_model_token" in self.__pydantic_private__: + self.__pydantic_private__["_model_token"] = None + + +# Register ccflow BaseModel-specific normalize_token handler +# Delegates to the model's tokenizer so normalization is consistent +# regardless of whether the model is accessed via model_token or +# encountered as a value inside a container. +@normalize_token.register(BaseModel) +def _normalize_ccflow_basemodel(obj): + return obj.__ccflow_tokenizer__.normalize(obj) class _ModelRegistryData(PydanticBaseModel): diff --git a/ccflow/callable.py b/ccflow/callable.py index b09eaea..2a7319e 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -41,6 +41,7 @@ "FlowOptionsDeps", "FlowOptionsOverride", "ModelEvaluationContext", + "TransparentModelEvaluationContext", "EvaluatorBase", "Evaluator", "WrapperModel", @@ -262,7 +263,7 @@ def get_evaluation_context(model: CallableModelType, context: ContextType, as_di if as_dict: return dict(model=evaluator, context=evaluation_context) else: - return ModelEvaluationContext(model=evaluator, context=evaluation_context) + return evaluator.make_evaluation_context(evaluation_context) # The decorator implementation def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = None, **kwargs): @@ -450,6 +451,40 @@ class ModelEvaluationContext( # Otherwise, the validation will re-run fully despite the models already being validated on construction # TODO: Make the instance check compatible with the generic types instead of the base type + @property + def model_token(self) -> str: + """Compute a cache-key token for this MEC chain. + + Walks the MEC chain, strips ``TransparentModelEvaluationContext`` + layers, and tokenizes the innermost context plus any opaque evaluators. + """ + cache = self.model_config.get("cache_token", True) + if cache and self._model_token is not None: + return self._model_token + + fn = self.fn + non_transparent = [] + current = self + while isinstance(current.context, ModelEvaluationContext): + fn = current.fn if current.fn != "__call__" else fn + if not isinstance(current, TransparentModelEvaluationContext): + non_transparent.append(current.model) + current = current.context + + # Build a canonical representation from the innermost MEC + from .utils.tokenize import normalize_token + + inner_norm = normalize_token(current) + effective_fn = fn if fn != "__call__" else current.fn + parts = (inner_norm, effective_fn) + if non_transparent: + parts = parts + (tuple(normalize_token(e) for e in non_transparent),) + token = self.__ccflow_tokenizer__.hash_canonical(parts) + + if cache: + self.__pydantic_private__["_model_token"] = token + return token + @model_validator(mode="wrap") def _context_validator(cls, values, handler, info): """Override _context_validator from parent""" @@ -510,10 +545,47 @@ def __deps__(self, context: ModelEvaluationContext) -> GraphDepList: def __exit__(self): pass + def is_transparent(self, context: ModelEvaluationContext) -> bool: + """Whether this evaluator does NOT modify the return value for the given context. + + Transparent evaluators may add side effects (logging, caching, timing, + dependency ordering) but always return the same value as ``context()``. + This allows cache key computation and dependency graph deduplication to + skip these layers. + + Override this method to return ``True`` for evaluators that are always + transparent, or implement context-dependent logic for evaluators that + are only sometimes transparent. + """ + return False + + def make_evaluation_context(self, context: ModelEvaluationContext, **kwargs) -> ModelEvaluationContext: + """Create a ModelEvaluationContext wrapping this evaluator around the given context. + + Returns a ``TransparentModelEvaluationContext`` when ``is_transparent(context)`` + is ``True``, signaling that this layer can be skipped for cache key computation. + """ + if self.is_transparent(context): + return TransparentModelEvaluationContext(model=self, context=context, **kwargs) + return ModelEvaluationContext(model=self, context=context, **kwargs) + + +class TransparentModelEvaluationContext(ModelEvaluationContext): + """A ModelEvaluationContext layer that is safe to skip for cache key computation. + + Created by ``EvaluatorBase.make_evaluation_context()`` when the evaluator's + ``is_transparent()`` returns ``True``. Signals that this evaluator layer does + not modify the return value and can be ignored when computing cache keys or + deduplicating dependency graph nodes. + """ + class Evaluator(EvaluatorBase): """A higher-order model that evaluates a function on a CallableModel and a Context.""" + def is_transparent(self, context: ModelEvaluationContext) -> bool: + return True + @override def __call__(self, context: ModelEvaluationContext) -> ResultType: return context() diff --git a/ccflow/evaluators/common.py b/ccflow/evaluators/common.py index 2478cb9..ca5e1fc 100644 --- a/ccflow/evaluators/common.py +++ b/ccflow/evaluators/common.py @@ -7,12 +7,17 @@ from types import MappingProxyType from typing import Any, Callable, Dict, List, Optional, Set, Union -import dask.base from pydantic import Field, PrivateAttr, field_validator from typing_extensions import override from ..base import BaseModel, make_lazy_result -from ..callable import CallableModel, ContextBase, EvaluatorBase, ModelEvaluationContext, ResultType +from ..callable import ( + CallableModel, + ContextBase, + EvaluatorBase, + ModelEvaluationContext, + ResultType, +) __all__ = [ "cache_key", @@ -53,16 +58,25 @@ def combine_evaluators(first: Optional[EvaluatorBase], second: Optional[Evaluato class MultiEvaluator(EvaluatorBase): - """An evaluator that combines multiple evaluators.""" + """An evaluator that combines multiple evaluators. + + Each child evaluator is wrapped in a ModelEvaluationContext using its own + ``make_evaluation_context()`` method, so transparent children produce + ``TransparentModelEvaluationContext`` layers that can be skipped during + cache key computation. + """ evaluators: List[EvaluatorBase] = Field( description="The list of evaluators to combine. The first evaluator in the list will be called first during evaluation." ) + def is_transparent(self, context: ModelEvaluationContext) -> bool: + return all(e.is_transparent(context) for e in self.evaluators) + @override def __call__(self, context: ModelEvaluationContext) -> ResultType: for evaluator in self.evaluators: - context = ModelEvaluationContext(model=evaluator, context=context, options=context.options) + context = evaluator.make_evaluation_context(context, options=context.options) return context() @@ -71,6 +85,9 @@ class FallbackEvaluator(EvaluatorBase): evaluators: List[EvaluatorBase] = Field(description="The list of evaluators to try (in order).") + def is_transparent(self, context: ModelEvaluationContext) -> bool: + return all(e.is_transparent(context) for e in self.evaluators) + @override def __call__(self, context: ModelEvaluationContext) -> ResultType: for evaluator in self.evaluators: @@ -120,6 +137,9 @@ class LoggingEvaluator(EvaluatorBase): log_result: bool = Field(False, description="Whether to log the result of the evaluation") format_config: FormatConfig = Field(FormatConfig(), description="Configuration for formatting the result of the evaluation if log_result=True") + def is_transparent(self, context: ModelEvaluationContext) -> bool: + return True + @field_validator("log_level", mode="before") @classmethod def _validate_log_level(cls, v: Union[int, str]) -> int: @@ -195,13 +215,18 @@ def _format_result(self, result: ResultType) -> str: def cache_key(flow_obj: Union[ModelEvaluationContext, ContextBase, CallableModel]) -> bytes: - """Returns a key suitable for use in caching. + """Returns a key suitable for use in caching and dependency graph deduplication. + + For ``ModelEvaluationContext`` inputs, strips ``TransparentModelEvaluationContext`` + layers (evaluators that don't modify the return value) so that the key depends + only on the underlying model, context, fn, options, and any non-transparent + evaluators in the chain. Args: flow_obj: The object to be tokenized to form the cache key. """ if isinstance(flow_obj, (ModelEvaluationContext, ContextBase, CallableModel)): - return dask.base.tokenize(flow_obj.model_dump(mode="python")).encode("utf-8") + return flow_obj.model_token.encode("utf-8") else: raise TypeError(f"object of type {type(flow_obj)} cannot be serialized by this function!") @@ -213,8 +238,14 @@ class MemoryCacheEvaluator(EvaluatorBase): _cache: Dict[bytes, ResultType] = PrivateAttr({}) _ids: Dict[bytes, ModelEvaluationContext] = PrivateAttr({}) + def is_transparent(self, context: ModelEvaluationContext) -> bool: + return True + def key(self, context: ModelEvaluationContext): - """Function to convert a ModelEvaluationContext to a key""" + """Function to convert a ModelEvaluationContext to a cache key. + + Delegates to ``cache_key()`` which strips transparent evaluator layers. + """ return cache_key(context) @property @@ -289,6 +320,9 @@ class GraphEvaluator(EvaluatorBase): _is_evaluating: bool = PrivateAttr(False) + def is_transparent(self, context: ModelEvaluationContext) -> bool: + return True + @override def __call__(self, context: ModelEvaluationContext) -> ResultType: import graphlib diff --git a/ccflow/tests/evaluators/test_common.py b/ccflow/tests/evaluators/test_common.py index 6124ed2..6db6d07 100644 --- a/ccflow/tests/evaluators/test_common.py +++ b/ccflow/tests/evaluators/test_common.py @@ -9,9 +9,11 @@ DateContext, DateRangeContext, Evaluator, + EvaluatorBase, FlowOptionsOverride, ModelEvaluationContext, NullContext, + TransparentModelEvaluationContext, ) from ccflow.evaluators import ( FallbackEvaluator, @@ -257,6 +259,73 @@ def test_model_evaluation_context(self): assert cache_key(mec1) == cache_key(mec2) assert cache_key(mec3) != cache_key(mec1) + def test_transparent_mec_stripped(self): + """TransparentModelEvaluationContext layers are stripped from cache keys.""" + m = MyDateCallable(offset=1) + ctx = DateContext(date=date(2022, 1, 1)) + inner = ModelEvaluationContext(model=m, context=ctx) + wrapped = TransparentModelEvaluationContext(model=LoggingEvaluator(), context=inner) + assert cache_key(inner) == cache_key(wrapped) + + def test_opaque_mec_preserved(self): + """Non-transparent MEC layers produce different cache keys.""" + + class OpaqueEval(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + return context() + + m = MyDateCallable(offset=1) + ctx = DateContext(date=date(2022, 1, 1)) + inner = ModelEvaluationContext(model=m, context=ctx) + wrapped = ModelEvaluationContext(model=OpaqueEval(), context=inner) + assert cache_key(inner) != cache_key(wrapped) + + def test_stacked_transparent_stripped(self): + """Multiple stacked TransparentMEC layers are all stripped.""" + m = MyDateCallable(offset=1) + ctx = DateContext(date=date(2022, 1, 1)) + inner = ModelEvaluationContext(model=m, context=ctx) + layer1 = TransparentModelEvaluationContext(model=LoggingEvaluator(), context=inner) + layer2 = TransparentModelEvaluationContext(model=MemoryCacheEvaluator(), context=layer1) + assert cache_key(inner) == cache_key(layer2) + + def test_sandwich_transparent_between_opaque(self): + """Transparent layer sandwiched between opaque layers is stripped, opaques preserved.""" + + class OpaqueEval(EvaluatorBase): + tag: str = "default" + + def __call__(self, context: ModelEvaluationContext): + return context() + + m = MyDateCallable(offset=1) + ctx = DateContext(date=date(2022, 1, 1)) + inner = ModelEvaluationContext(model=m, context=ctx) + opaque1 = ModelEvaluationContext(model=OpaqueEval(tag="inner"), context=inner) + transparent = TransparentModelEvaluationContext(model=LoggingEvaluator(), context=opaque1) + opaque2 = ModelEvaluationContext(model=OpaqueEval(tag="outer"), context=transparent) + # Both opaque evaluators should be in the key; the transparent one should not + assert cache_key(opaque2) != cache_key(inner) + # Same sandwich should give consistent keys + opaque2b = ModelEvaluationContext( + model=OpaqueEval(tag="outer"), + context=TransparentModelEvaluationContext( + model=LoggingEvaluator(), context=ModelEvaluationContext(model=OpaqueEval(tag="inner"), context=inner) + ), + ) + assert cache_key(opaque2) == cache_key(opaque2b) + + def test_fn_deps_preserved_through_transparent(self): + """fn='__deps__' is preserved when walking through transparent layers.""" + m = MyDateCallable(offset=1) + ctx = DateContext(date=date(2022, 1, 1)) + inner = ModelEvaluationContext(model=m, context=ctx, fn="__deps__") + wrapped = TransparentModelEvaluationContext(model=LoggingEvaluator(), context=inner) + # Both should produce the same key, and it should differ from __call__ + assert cache_key(inner) == cache_key(wrapped) + call_inner = ModelEvaluationContext(model=m, context=ctx, fn="__call__") + assert cache_key(inner) != cache_key(call_inner) + class TestMemoryCacheEvaluator(TestCase): def test_basic(self): @@ -355,6 +424,74 @@ def test_decorator_volatile(self): self.assertGreater(out2, out1) self.assertEqual(len(captured.records), 2) + def test_cache_key_stable_across_evaluators(self): + """Cache keys should not change when wrapping with non-caching evaluators (e.g. LoggingEvaluator).""" + m1 = MyDateCallable(offset=1) + cache = MemoryCacheEvaluator() + ctx = DateContext(date=date(2022, 1, 1)) + + # First call: cache evaluator only + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + out1 = m1(ctx) + self.assertEqual(len(cache.cache), 1) + + # Second call: LoggingEvaluator + same cache evaluator + wrapped = combine_evaluators(LoggingEvaluator(), cache) + with FlowOptionsOverride(options={"evaluator": wrapped, "cacheable": True}): + out2 = m1(ctx) + # Should still be only 1 cache entry (same key, cache hit) + self.assertEqual(len(cache.cache), 1) + self.assertEqual(out1, out2) + + def test_cache_key_differs_with_nontransparent_evaluator(self): + """Cache keys should differ when a non-transparent evaluator is in the chain.""" + + class OpaqueEvaluator(EvaluatorBase): + """A dummy evaluator that is NOT transparent.""" + + def __call__(self, context: ModelEvaluationContext): + return context() + + m1 = MyDateCallable(offset=1) + cache = MemoryCacheEvaluator() + ctx = DateContext(date=date(2022, 1, 1)) + + # First call: cache evaluator only + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + m1(ctx) + self.assertEqual(len(cache.cache), 1) + + # Second call: OpaqueEvaluator + same cache evaluator + wrapped = combine_evaluators(OpaqueEvaluator(), cache) + with FlowOptionsOverride(options={"evaluator": wrapped, "cacheable": True}): + m1(ctx) + # OpaqueEvaluator is not transparent, so cache key should differ + self.assertEqual(len(cache.cache), 2) + + def test_cache_key_differs_with_fallback_opaque_child(self): + """FallbackEvaluator with opaque child should produce different cache key.""" + + class OpaqueEvaluator(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + return context() + + m1 = MyDateCallable(offset=1) + cache = MemoryCacheEvaluator() + ctx = DateContext(date=date(2022, 1, 1)) + + # First call: cache evaluator only + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + m1(ctx) + self.assertEqual(len(cache.cache), 1) + + # Second call: FallbackEvaluator(OpaqueEvaluator) + cache + fallback = FallbackEvaluator(evaluators=[OpaqueEvaluator()]) + wrapped = combine_evaluators(fallback, cache) + with FlowOptionsOverride(options={"evaluator": wrapped, "cacheable": True}): + m1(ctx) + # FallbackEvaluator is not transparent, so cache key should differ + self.assertEqual(len(cache.cache), 2) + class TestGraphDeps(TestCase): def test_graph_deps_diamond(self): diff --git a/ccflow/tests/test_base_serialize.py b/ccflow/tests/test_base_serialize.py index fbec5c2..36c4782 100644 --- a/ccflow/tests/test_base_serialize.py +++ b/ccflow/tests/test_base_serialize.py @@ -259,12 +259,13 @@ def test_pickle_consistency(self): # (as it would normally in pydantic because of https://github.com/pydantic/pydantic/issues/11603) # This is generated on Linux/Python 3.11 - might need to have version specific values if it changes. target = ( - b"\x80\x04\x95\xdf\x00\x00\x00\x00\x00\x00\x00\x8c ccflow.tests.test_base_seri" + b"\x80\x04\x95\xf0\x00\x00\x00\x00\x00\x00\x00\x8c ccflow.tests.test_base_seri" b"alize\x94\x8c\x13MultiAttributeModel\x94\x93\x94)\x81\x94}\x94(\x8c\x08__" b"dict__\x94}\x94(\x8c\x01z\x94K\x01\x8c\x01y\x94\x8c\x04test\x94\x8c" b"\x01x\x94G@\t\x1e\xb8Q\xeb\x85\x1f\x8c\x01w\x94\x88u\x8c\x12__pydantic_extra" b"__\x94N\x8c\x17__pydantic_fields_set__\x94]\x94(h\x0bh\nh\x08h\x07e\x8c\x14" - b"__pydantic_private__\x94}\x94\x8c\x0e_registrations\x94]\x94sub." + b"__pydantic_private__\x94}\x94(\x8c\x0e_registrations\x94]\x94\x8c\x0c_model_" + b"token\x94Nuub." ) self.assertEqual(serialized, target) deserialized = pickle.loads(serialized) diff --git a/ccflow/tests/utils/test_tokenize.py b/ccflow/tests/utils/test_tokenize.py new file mode 100644 index 0000000..feb0b27 --- /dev/null +++ b/ccflow/tests/utils/test_tokenize.py @@ -0,0 +1,2179 @@ +"""Tests for the tokenization engine (ccflow.utils.tokenize) and BaseModel integration.""" + +import enum +import pickle +import re +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import date, datetime, time, timedelta +from pathlib import Path, PurePosixPath +from typing import Any, Dict, List, Literal, Optional, Union +from uuid import UUID + +import numpy as np +import pandas as pd +import pytest +from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, computed_field, model_validator + +from ccflow import BaseModel, ContextBase +from ccflow.utils.tokenize import ( + ASTSourceTokenizer, + BytecodeSourceTokenizer, + DefaultTokenizer, + OwnMethodCollector, + SourceTokenizer, + compute_behavior_token, + normalize_token, +) + +# --------------------------------------------------------------------------- +# Test models +# --------------------------------------------------------------------------- + + +class SimpleModel(BaseModel): + x: int = 1 + y: str = "hello" + + +class NestedModel(BaseModel): + child: SimpleModel = SimpleModel() + name: str = "parent" + + +class ExcludedFieldModel(BaseModel): + important: int = 42 + debug_info: str = Field(default="debug", exclude=True) + + +class FrozenModel(ContextBase): + a: int = 1 + b: str = "frozen" + + +class NoCacheModel(BaseModel): + model_config = ConfigDict(cache_token=False) + value: int = 0 + + +class Color(enum.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + +class ModelWithCollections(BaseModel): + tags: List[str] = [] + metadata: dict = {} + coords: tuple = () + + +class ModelWithOptional(BaseModel): + name: str = "test" + extra: Optional[int] = None + + +class SubModel(SimpleModel): + z: float = 3.14 + + +# --------------------------------------------------------------------------- +# Module-level factory for behavior-hashing tests +# --------------------------------------------------------------------------- + +_AST_TOKENIZER = DefaultTokenizer.with_ast() + + +def _make_ast_model(name="DynModel", *, base=BaseModel, deps=None, **attrs): + """Build a BaseModel subclass with behavior hashing for testing.""" + cls_attrs = { + "x": 1, + "__annotations__": {"x": int}, + "__ccflow_tokenizer__": _AST_TOKENIZER, + } + if deps is not None: + cls_attrs["__ccflow_tokenizer_deps__"] = deps + cls_attrs.update(attrs) + return type(name, (base,), cls_attrs) + + +# --------------------------------------------------------------------------- +# normalize_token tests +# --------------------------------------------------------------------------- + + +class TestNormalizeToken: + @pytest.mark.parametrize( + "value,expected", + [ + (None, None), + (True, True), + (False, False), + (42, 42), + (3.14, 3.14), + ("hello", "hello"), + (b"data", b"data"), + ], + ) + def test_primitives(self, value, expected): + assert normalize_token(value) == expected + + @pytest.mark.parametrize( + "value,expected", + [ + (date(2024, 1, 15), ("date", "2024-01-15")), + (datetime(2024, 1, 15, 10, 30, 0), ("datetime", "2024-01-15T10:30:00")), + (time(10, 30, 0), ("time", "10:30:00")), + (timedelta(hours=1, minutes=30), ("timedelta", 5400.0)), + ], + ) + def test_datetime_types(self, value, expected): + assert normalize_token(value) == expected + + @pytest.mark.parametrize( + "value,expected", + [ + ((1, "a", True), ("tuple", (1, "a", True))), + ([1, 2, 3], ("list", (1, 2, 3))), + ({3, 1, 2}, ("set", (1, 2, 3))), + (frozenset({3, 1, 2}), ("frozenset", (1, 2, 3))), + ({"b": 2, "a": 1}, ("dict", (("a", 1), ("b", 2)))), + ], + ) + def test_collections(self, value, expected): + assert normalize_token(value) == expected + + def test_uuid(self): + u = UUID("12345678-1234-5678-1234-567812345678") + assert normalize_token(u) == ("uuid", "12345678-1234-5678-1234-567812345678") + + def test_path(self): + p = Path("/tmp/test.txt") + assert normalize_token(p) == ("path", "/tmp/test.txt") + + def test_pure_path(self): + p = PurePosixPath("/tmp/test.txt") + assert normalize_token(p) == ("path", "/tmp/test.txt") + + def test_enum(self): + result = normalize_token(Color.RED) + assert result == ("enum", "Color", "RED") + + def test_nested_collections(self): + data = {"key": [1, (2, 3)]} + result = normalize_token(data) + assert result == ("dict", (("key", ("list", (1, ("tuple", (2, 3))))),)) + + def test_numpy_ndarray(self): + arr = np.array([1, 2, 3], dtype=np.int64) + result = normalize_token(arr) + assert result[0] == "ndarray" + assert result[1] == "int64" + assert result[2] == (3,) + arr2 = np.array([1, 2, 3], dtype=np.int64) + assert normalize_token(arr) == normalize_token(arr2) + + def test_numpy_different_data(self): + arr1 = np.array([1, 2, 3]) + arr2 = np.array([1, 2, 4]) + assert normalize_token(arr1) != normalize_token(arr2) + + def test_numpy_scalar(self): + s = np.int64(42) + result = normalize_token(s) + assert result == ("np_scalar", "int64", 42) + + def test_pandas_dataframe(self): + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + result = normalize_token(df) + assert result is not None + # Same data → same token + df2 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + assert normalize_token(df) == normalize_token(df2) + + def test_pandas_different_data(self): + df1 = pd.DataFrame({"a": [1, 2]}) + df2 = pd.DataFrame({"a": [1, 3]}) + assert normalize_token(df1) != normalize_token(df2) + + def test_pandas_series(self): + s = pd.Series([1, 2, 3], name="test") + result = normalize_token(s) + assert result is not None + + def test_pandas_timestamp(self): + ts = pd.Timestamp("2024-01-15") + result = normalize_token(ts) + assert result == ("pd_timestamp", ts.isoformat()) + + def test_function(self): + def my_func(x): + return x + 1 + + result = normalize_token(my_func) + assert result[0] == "func" + assert "my_func" in result[1] + assert len(result) == 3 # (func, qualname, hash) + + def test_function_deterministic(self): + def my_func(x): + return x + 1 + + r1 = normalize_token(my_func) + r2 = normalize_token(my_func) + assert r1 == r2 + + def test_type(self): + result = normalize_token(int) + assert result == ("type", "builtins.int") + + def test_custom_type(self): + result = normalize_token(SimpleModel) + assert result[0] == "type" + assert "SimpleModel" in result[1] + + def test_pydantic_basemodel(self): + class PlainPydantic(PydanticBaseModel): + x: int = 1 + + obj = PlainPydantic(x=5) + result = normalize_token(obj) + assert result[0] == "pydantic" + assert "PlainPydantic" in result[1] + + def test_ccflow_basemodel(self): + obj = SimpleModel(x=5, y="world") + result = normalize_token(obj) + # Delegates to tokenizer — same output as normalize() + assert result == obj.__ccflow_tokenizer__.normalize(obj) + assert "SimpleModel" in result[0] + + def test_custom_hook(self): + class MyObj: + def __ccflow_tokenize__(self): + return ("custom", 42) + + obj = MyObj() + assert normalize_token(obj) == ("custom", 42) + + +# --------------------------------------------------------------------------- +# DefaultTokenizer tests +# --------------------------------------------------------------------------- + + +class TestDefaultTokenizer: + def test_normalize_basic(self): + t = DefaultTokenizer() + model = SimpleModel(x=5, y="world") + result = t.normalize(model) + assert isinstance(result, tuple) + assert len(result) == 3 # (type_path, behavior, fields) + assert result[1] is None # No behavior token by default + + def test_tokenize_deterministic(self): + t = DefaultTokenizer() + model = SimpleModel(x=5, y="world") + token1 = t.tokenize(model) + token2 = t.tokenize(model) + assert token1 == token2 + + def test_tokenize_different_values(self): + t = DefaultTokenizer() + m1 = SimpleModel(x=1, y="a") + m2 = SimpleModel(x=2, y="a") + assert t.tokenize(m1) != t.tokenize(m2) + + def test_tokenize_produces_sha256(self): + t = DefaultTokenizer() + model = SimpleModel(x=1) + token = t.tokenize(model) + assert len(token) == 64 # SHA256 hex + + def test_excluded_fields_not_in_normalize(self): + t = DefaultTokenizer() + m = ExcludedFieldModel(important=10, debug_info="x") + result = t.normalize(m) + field_names = [f[0] for f in result[2]] + assert "important" in field_names + assert "debug_info" not in field_names + + def test_excluded_fields_dont_affect_token(self): + t = DefaultTokenizer() + m1 = ExcludedFieldModel(important=10, debug_info="debug1") + m2 = ExcludedFieldModel(important=10, debug_info="debug2") + assert t.tokenize(m1) == t.tokenize(m2) + + def test_nested_model(self): + t = DefaultTokenizer() + m = NestedModel(child=SimpleModel(x=5), name="test") + token = t.tokenize(m) + assert isinstance(token, str) + + def test_nested_different_child(self): + t = DefaultTokenizer() + m1 = NestedModel(child=SimpleModel(x=1), name="test") + m2 = NestedModel(child=SimpleModel(x=2), name="test") + assert t.tokenize(m1) != t.tokenize(m2) + + def test_cycle_detection(self): + """Cycle detection prevents infinite recursion.""" + t = DefaultTokenizer() + # Create a model — since ccflow BaseModel doesn't allow arbitrary attrs, + # we test cycle detection via the normalize method directly + m = SimpleModel(x=1) + visited = {id(m)} + result = t.normalize(m, _visited=visited) + assert result[0] == "__cycle__" + + +# --------------------------------------------------------------------------- +# BaseModel.model_token integration tests +# --------------------------------------------------------------------------- + + +class TestModelToken: + def test_basic(self): + m = SimpleModel(x=1, y="hello") + token = m.model_token + assert isinstance(token, str) + assert len(token) == 64 # SHA256 hex + + def test_deterministic(self): + m = SimpleModel(x=1, y="hello") + assert m.model_token == m.model_token + + def test_same_values_same_token(self): + m1 = SimpleModel(x=1, y="hello") + m2 = SimpleModel(x=1, y="hello") + assert m1.model_token == m2.model_token + + def test_different_values_different_token(self): + m1 = SimpleModel(x=1, y="hello") + m2 = SimpleModel(x=2, y="hello") + assert m1.model_token != m2.model_token + + def test_different_types_different_token(self): + """Parent and subclass with same field values get different tokens.""" + m1 = SimpleModel(x=1, y="hello") + m2 = SubModel(x=1, y="hello") + assert m1.model_token != m2.model_token + + def test_mutable_no_cache(self): + """Mutable models do not cache tokens by default.""" + m = SimpleModel(x=1, y="hello") + token1 = m.model_token + assert m._model_token is None # Not cached (mutable) + assert m.model_token == token1 # Still deterministic + + def test_mutable_reflects_mutation(self): + """Mutable model token reflects field assignment immediately.""" + m = SimpleModel(x=1, y="hello") + token1 = m.model_token + m.x = 2 + token2 = m.model_token + assert token2 != token1 + + def test_no_cache_mode(self): + """With cache_token=False, token is always computed fresh.""" + m = NoCacheModel(value=42) + token1 = m.model_token + assert m._model_token is None # Never cached + token2 = m.model_token + assert token1 == token2 # Still deterministic + + def test_frozen_model(self): + """Frozen models cache the token.""" + m = FrozenModel(a=1, b="test") + token = m.model_token + assert m._model_token is not None + assert m.model_token == token + + def test_excluded_field_no_effect(self): + """Fields with exclude=True don't affect the token.""" + m1 = ExcludedFieldModel(important=10, debug_info="x") + m2 = ExcludedFieldModel(important=10, debug_info="y") + assert m1.model_token == m2.model_token + + def test_nested_model_token(self): + m = NestedModel(child=SimpleModel(x=5), name="parent") + assert isinstance(m.model_token, str) + + def test_optional_none_vs_value(self): + m1 = ModelWithOptional(name="test", extra=None) + m2 = ModelWithOptional(name="test", extra=42) + assert m1.model_token != m2.model_token + + def test_collections_in_model(self): + m = ModelWithCollections(tags=["a", "b"], metadata={"k": "v"}, coords=(1, 2)) + assert isinstance(m.model_token, str) + + def test_model_copy_gets_fresh_token(self): + """model_copy(update=...) produces correct (different) token.""" + m1 = SimpleModel(x=1, y="hello") + _ = m1.model_token + m2 = m1.model_copy(update={"x": 2}) + assert m1.model_token != m2.model_token + + def test_custom_tokenizer(self): + """Models can use a custom tokenizer via __ccflow_tokenizer__.""" + + class CustomModel(BaseModel): + __ccflow_tokenizer__ = DefaultTokenizer.with_bytecode() + value: int = 0 + + m = CustomModel(value=42) + assert len(m.model_token) == 64 # SHA256 + + def test_pickle_preserves_token_cache(self): + """Pickling a model preserves the token cache.""" + m = SimpleModel(x=1, y="hello") + _ = m.model_token + m2 = pickle.loads(pickle.dumps(m)) + assert m2.model_token == m.model_token + + +# --------------------------------------------------------------------------- +# Component-level tests: SourceTokenizer and FunctionCollector +# --------------------------------------------------------------------------- + + +class TestASTSourceTokenizer: + def test_returns_hex_digest(self): + def f(x): + return x + 1 + + result = ASTSourceTokenizer().tokenize(f) + assert result is not None + assert isinstance(result, str) + assert len(result) == 64 # sha256 + + def test_deterministic(self): + def f(x): + return x + 1 + + t = ASTSourceTokenizer() + assert t.tokenize(f) == t.tokenize(f) + + def test_different_bodies_differ(self): + def f1(x): + return x + 1 + + def f2(x): + return x * 2 + + t = ASTSourceTokenizer() + assert t.tokenize(f1) != t.tokenize(f2) + + def test_docstring_stripped(self): + def f(x): + """A docstring.""" + return x + 1 + + tok_with = ASTSourceTokenizer().tokenize(f) + + def f(x): # noqa: F811 + return x + 1 + + tok_without = ASTSourceTokenizer().tokenize(f) + assert tok_with == tok_without + + def test_variable_rename_changes_hash(self): + """AST preserves variable names, so renaming changes the hash.""" + + def f1(x): + return x + 1 + + def f2(y): + return y + 1 + + t = ASTSourceTokenizer() + assert t.tokenize(f1) != t.tokenize(f2) + + def test_comment_changes_ignored(self): + """Comments are stripped by AST parsing.""" + from ccflow.utils.tokenize import _normalize_source_ast + + s1 = "def f(x):\n # comment\n return x + 1" + s2 = "def f(x):\n return x + 1" + assert _normalize_source_ast(s1) == _normalize_source_ast(s2) + + def test_whitespace_changes_ignored(self): + from ccflow.utils.tokenize import _normalize_source_ast + + s1 = "def f(x):\n return x+1" + s2 = "def f( x ):\n return x + 1" + assert _normalize_source_ast(s1) == _normalize_source_ast(s2) + + def test_fallback_to_bytecode_when_no_source(self): + """Built-in functions have no source; should fall back to bytecode or return None.""" + t = ASTSourceTokenizer() + # Built-in like len has no source and no __code__ + result = t.tokenize(len) + assert result is None + + def test_lambda(self): + def f(x): + return x + 1 + + t = ASTSourceTokenizer() + assert t.tokenize(f) is not None + + def test_classmethod_unwrapped(self): + """Classmethods need __func__ unwrapping before tokenizing.""" + + class C: + @classmethod + def m(cls): + return 1 + + t = ASTSourceTokenizer() + # __func__ should be unwrapped by the collector, not the tokenizer + assert t.tokenize(C.m.__func__) is not None + + def test_staticmethod_unwrapped(self): + class C: + @staticmethod + def m(): + return 1 + + t = ASTSourceTokenizer() + assert t.tokenize(C.m) is not None + + +class TestBytecodeSourceTokenizer: + def test_returns_hex_digest(self): + def f(x): + return x + 1 + + result = BytecodeSourceTokenizer().tokenize(f) + assert result is not None + assert len(result) == 64 # sha256 + + def test_deterministic(self): + def f(x): + return x + 1 + + t = BytecodeSourceTokenizer() + assert t.tokenize(f) == t.tokenize(f) + + def test_different_bodies_differ(self): + def f1(x): + return x + 1 + + def f2(x): + return x * 2 + + t = BytecodeSourceTokenizer() + assert t.tokenize(f1) != t.tokenize(f2) + + def test_docstring_stripped(self): + def f(x): + """A docstring.""" + return x + 1 + + tok_with = BytecodeSourceTokenizer().tokenize(f) + + def f(x): # noqa: F811 + return x + 1 + + tok_without = BytecodeSourceTokenizer().tokenize(f) + assert tok_with == tok_without + + def test_variable_rename_same_hash(self): + """Bytecode is immune to variable renames (names in co_varnames, not co_code).""" + + def f1(x): + return x + 1 + + def f2(y): + return y + 1 + + t = BytecodeSourceTokenizer() + assert t.tokenize(f1) == t.tokenize(f2) + + def test_no_code_returns_none(self): + """Objects without __code__ return None.""" + t = BytecodeSourceTokenizer() + assert t.tokenize(len) is None + + def test_lambda(self): + def f(x): + return x + 1 + + t = BytecodeSourceTokenizer() + assert t.tokenize(f) is not None + + def test_classmethod_vs_regular_same_hash(self): + """Bytecode doesn't distinguish classmethod from regular method.""" + + class C1: + @classmethod + def m(cls): + return 1 + + class C2: + def m(self): + return 1 + + t = BytecodeSourceTokenizer() + assert t.tokenize(C1.m.__func__) == t.tokenize(C2.m) + + +class TestOwnMethodCollector: + def test_collects_regular_methods(self): + class C: + def foo(self): + pass + + def bar(self): + pass + + methods = OwnMethodCollector().collect(C) + names = [name for name, _ in methods] + assert "foo" in names + assert "bar" in names + + def test_sorted_by_name(self): + class C: + def z(self): + pass + + def a(self): + pass + + methods = OwnMethodCollector().collect(C) + names = [name for name, _ in methods] + # Should be sorted + assert names.index("a") < names.index("z") + + def test_collects_classmethod(self): + class C: + @classmethod + def m(cls): + pass + + methods = OwnMethodCollector().collect(C) + names = [name for name, _ in methods] + assert "m" in names + # Should be unwrapped + func = dict(methods)["m"] + assert callable(func) + assert not isinstance(func, classmethod) + + def test_collects_staticmethod(self): + class C: + @staticmethod + def m(): + pass + + methods = OwnMethodCollector().collect(C) + names = [name for name, _ in methods] + assert "m" in names + func = dict(methods)["m"] + assert callable(func) + + def test_skips_non_callable(self): + class C: + x = 42 + + def m(self): + pass + + methods = OwnMethodCollector().collect(C) + names = [name for name, _ in methods] + assert "x" not in names + assert "m" in names + + def test_does_not_collect_inherited(self): + class Parent: + def parent_method(self): + pass + + class Child(Parent): + def child_method(self): + pass + + methods = OwnMethodCollector().collect(Child) + names = [name for name, _ in methods] + assert "child_method" in names + assert "parent_method" not in names + + def test_empty_class(self): + class Empty: + pass + + methods = OwnMethodCollector().collect(Empty) + # May have __init__ or other dunders from object, but no user methods + # The key thing is it doesn't crash + assert isinstance(methods, list) + + def test_deps_included(self): + def helper(): + return 42 + + class C: + __ccflow_tokenizer_deps__ = [helper] + + def m(self): + pass + + methods = OwnMethodCollector().collect(C) + names = [name for name, _ in methods] + assert any("__dep__" in n for n in names) + + def test_deps_not_inherited(self): + def helper(): + return 42 + + class Parent: + __ccflow_tokenizer_deps__ = [helper] + + class Child(Parent): + pass + + methods = OwnMethodCollector().collect(Child) + names = [name for name, _ in methods] + assert not any("__dep__" in n for n in names) + + +class TestSourceTokenizerContrast: + """Tests documenting known differences between AST and bytecode tokenizers.""" + + def test_variable_rename_ast_differs_bytecode_same(self): + """AST is sensitive to renames; bytecode is not.""" + + def f1(x): + return x + 1 + + def f2(y): + return y + 1 + + assert ASTSourceTokenizer().tokenize(f1) != ASTSourceTokenizer().tokenize(f2) + assert BytecodeSourceTokenizer().tokenize(f1) == BytecodeSourceTokenizer().tokenize(f2) + + def test_both_strip_docstrings(self): + def f(x): + """doc""" + return x + 1 + + tok_ast_with = ASTSourceTokenizer().tokenize(f) + tok_bc_with = BytecodeSourceTokenizer().tokenize(f) + + def f(x): # noqa: F811 + return x + 1 + + tok_ast_without = ASTSourceTokenizer().tokenize(f) + tok_bc_without = BytecodeSourceTokenizer().tokenize(f) + + assert tok_ast_with == tok_ast_without + assert tok_bc_with == tok_bc_without + + +# --------------------------------------------------------------------------- +# Behavior token tests +# --------------------------------------------------------------------------- + + +class TestBehaviorToken: + @pytest.mark.parametrize( + "cls_factory,expect_none", + [ + pytest.param(lambda: type("C", (), {"__call__": lambda s, x: x + 1}), False, id="with-call"), + pytest.param(lambda: type("C", (), {}), True, id="without-call"), + pytest.param(lambda: type("C", (), {"__call__": classmethod(lambda c, x: x + 1)}), False, id="classmethod-call"), + pytest.param(lambda: type("C", (), {"__call__": staticmethod(lambda x: x + 1)}), False, id="staticmethod-call"), + ], + ) + def test_behavior_token_presence(self, cls_factory, expect_none): + token = compute_behavior_token(cls_factory()) + assert (token is None) == expect_none + + def test_deterministic(self): + class MyCallable: + def __call__(self, x): + return x + 1 + + assert compute_behavior_token(MyCallable) == compute_behavior_token(MyCallable) + + def test_cached_on_class(self): + class MyCallable: + def __call__(self, x): + return x * 2 + + token = compute_behavior_token(MyCallable) + assert hasattr(MyCallable, "__ccflow_behavior_token__") + assert token in MyCallable.__ccflow_behavior_token__.values() + + def test_different_implementations(self): + class Call1: + def __call__(self, x): + return x + 1 + + class Call2: + def __call__(self, x): + return x * 2 + + assert compute_behavior_token(Call1) != compute_behavior_token(Call2) + + def test_docstring_ignored_with_ast(self): + """AST-normalized hashing should ignore docstrings.""" + + class WithDoc: + def __call__(self, x): + """This is a docstring.""" + return x + 1 + + class WithoutDoc: + def __call__(self, x): + return x + 1 + + assert compute_behavior_token(WithDoc) == compute_behavior_token(WithoutDoc) + + def test_behavior_token_not_inherited_from_parent(self): + class Parent: + def __call__(self, x): + return x + 1 + + class Child(Parent): + pass + + assert compute_behavior_token(Parent) is not None + assert compute_behavior_token(Child) is None + + def test_behavior_token_deterministic(self): + class MyCallable: + def __call__(self, x): + return x + 1 + + t1 = compute_behavior_token(MyCallable) + t2 = compute_behavior_token(MyCallable) + assert t1 == t2 + assert len(t1) == 64 # sha256 + + def test_classmethod_vs_regular_differ(self): + class AsClassmethod: + @classmethod + def __call__(cls, x): + return x + 1 + + class AsRegular: + def __call__(self, x): + return x + 1 + + t1 = compute_behavior_token(AsClassmethod, source_tokenizer=ASTSourceTokenizer()) + t2 = compute_behavior_token(AsRegular, source_tokenizer=ASTSourceTokenizer()) + assert t1 != t2 + + def test_include_behavior_in_tokenizer(self): + M1 = _make_ast_model(__call__=lambda self: self.x + 1) + M2 = _make_ast_model(__call__=lambda self: self.x * 2) + assert M1(x=1).model_token != M2(x=1).model_token + + +# --------------------------------------------------------------------------- +# Composition tests (replaces TestTokenizerConfig) +# --------------------------------------------------------------------------- + + +class TestComposition: + def test_default_has_no_collector_or_source_tokenizer(self): + t = DefaultTokenizer() + assert t.collector is None + assert t.source_tokenizer is None + + def test_with_ast_creates_correct_components(self): + t = DefaultTokenizer.with_ast() + assert isinstance(t.collector, OwnMethodCollector) + assert isinstance(t.source_tokenizer, ASTSourceTokenizer) + + def test_with_bytecode_creates_correct_components(self): + t = DefaultTokenizer.with_bytecode() + assert isinstance(t.collector, OwnMethodCollector) + assert isinstance(t.source_tokenizer, BytecodeSourceTokenizer) + + def test_custom_composition(self): + collector = OwnMethodCollector() + source_tokenizer = ASTSourceTokenizer() + t = DefaultTokenizer(collector=collector, source_tokenizer=source_tokenizer) + assert t.collector is collector + assert t.source_tokenizer is source_tokenizer + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_empty_model(self): + class Empty(BaseModel): + pass + + assert isinstance(Empty().model_token, str) + + def test_model_with_none_values(self): + assert isinstance(ModelWithOptional(name="test", extra=None).model_token, str) + + def test_model_with_numpy_field(self): + class NumpyModel(BaseModel): + data: Any = None # type: ignore + + assert isinstance(NumpyModel(data=np.array([1, 2, 3])).model_token, str) + + def test_model_with_dataframe_field(self): + class DFModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + data: Any = None + + assert isinstance(DFModel(data=pd.DataFrame({"a": [1, 2]})).model_token, str) + + def test_deeply_nested(self): + class Node(BaseModel): + value: int = 0 + child: Optional["Node"] = None + + Node.model_rebuild() + current = Node(value=50) + for i in range(49, 0, -1): + current = Node(value=i, child=current) + assert isinstance(current.model_token, str) + + def test_model_with_enum_field(self): + class EnumModel(BaseModel): + color: Color = Color.RED + + m1 = EnumModel(color=Color.GREEN) + m2 = EnumModel(color=Color.RED) + assert isinstance(m1.model_token, str) + assert m1.model_token != m2.model_token + + +# --------------------------------------------------------------------------- +# Diamond pattern tests +# --------------------------------------------------------------------------- + + +class TestDiamondPatterns: + """Tests for shared child references (diamond DAGs) in model graphs.""" + + def test_shared_child_alias_is_not_treated_as_cycle(self): + """Parent has child1 and child2 pointing to the SAME SimpleModel instance. + The tokenizer should treat this as a diamond, not a cycle, and produce + the same token as a parent with two distinct-but-equal children.""" + + class TwoChildren(BaseModel): + child1: SimpleModel = SimpleModel() + child2: SimpleModel = SimpleModel() + + shared = SimpleModel(x=42, y="shared") + parent_shared = TwoChildren(child1=shared, child2=shared) + parent_distinct = TwoChildren( + child1=SimpleModel(x=42, y="shared"), + child2=SimpleModel(x=42, y="shared"), + ) + # Diamond and distinct-but-equal should produce the same token + assert parent_shared.model_token == parent_distinct.model_token + + def test_shared_mutable_child_two_paths_deterministic(self): + """Same shared child referenced twice; repeated tokenization gives identical result.""" + + class TwoChildren(BaseModel): + child1: SimpleModel = SimpleModel() + child2: SimpleModel = SimpleModel() + + shared = SimpleModel(x=7, y="s") + parent = TwoChildren(child1=shared, child2=shared) + t1 = parent.model_token + t2 = parent.model_token + assert t1 == t2 + + def test_shared_frozen_child_same_parent_token_with_and_without_warmed_cache(self): + """Same frozen child reused in two parents; compare parent token before and + after child _model_token cache is 'warmed'. Token should be stable.""" + + class TwoFrozen(BaseModel): + a: FrozenModel = FrozenModel() + b: FrozenModel = FrozenModel() + + child = FrozenModel(a=99, b="warm") + # Parent token BEFORE child cache is warmed + parent1 = TwoFrozen(a=child, b=child) + token_cold = parent1.model_token + + _ = child.model_token + assert child._model_token is not None + + parent2 = TwoFrozen(a=child, b=child) + token_warm = parent2.model_token + assert parent2.model_token == token_warm + # Frozen children always use ("__child__", value.model_token), so + # cold and warm parent tokens are now consistent. + assert token_cold == token_warm + + +# --------------------------------------------------------------------------- +# Cycle tests +# --------------------------------------------------------------------------- + + +class SelfRefModel(BaseModel): + """Model that can reference itself.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + value: int = 0 + child: Optional["SelfRefModel"] = None + + +SelfRefModel.model_rebuild() + + +class TestCycles: + """Tests for cycle detection in model graphs.""" + + def test_self_referential_model_token_is_deterministic(self): + """node.child = node; tokenization should not recurse forever.""" + node = SelfRefModel(value=1) + # Bypass pydantic's validate_assignment (it rejects self-referential cycles) + node.__dict__["child"] = node + token = node.model_token + assert isinstance(token, str) + assert node.model_token == token + + def test_indirect_cycle_a_to_b_to_a(self): + """Two-node cycle A -> B -> A.""" + a = SelfRefModel(value=1) + b = SelfRefModel(value=2) + a.__dict__["child"] = b + b.__dict__["child"] = a + token_a = a.model_token + assert isinstance(token_a, str) + + def test_cycle_marker_differs_from_none(self): + """Token of child=None must differ from child=self (cycle).""" + acyclic = SelfRefModel(value=1, child=None) + cyclic = SelfRefModel(value=1) + cyclic.__dict__["child"] = cyclic + + token_acyclic = acyclic.model_token + token_cyclic = cyclic.model_token + assert token_acyclic != token_cyclic + + def test_cycle_inside_list_field(self): + """Model contains items=[self]; should handle gracefully.""" + + class ListModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + value: int = 0 + items: List[Any] = [] + + m = ListModel(value=1) + m.items = [m] + token = m.model_token + assert isinstance(token, str) + + def test_cycle_inside_dict_field(self): + """Model contains refs={'self': self}; should handle gracefully.""" + + class DictModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + value: int = 0 + refs: Dict[str, Any] = {} + + m = DictModel(value=1) + m.refs = {"self": m} + token = m.model_token + assert isinstance(token, str) + + +# --------------------------------------------------------------------------- +# Unpicklable / unstable fallback objects +# --------------------------------------------------------------------------- + + +class TestUnpicklableFallback: + """Tests for graceful handling of unpicklable or unstable objects.""" + + def test_lock_field_tokenization(self): + """Model with threading.Lock field — should either produce a stable + token or raise a clear TypeError.""" + + class LockModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + lock: Any = None + + m = LockModel(lock=threading.Lock()) + # Locks are cloudpickleable in some versions — so either we get a + # token (possibly unstable) or a clear TypeError + try: + token = m.model_token + assert isinstance(token, str) + except TypeError as e: + assert "tokenize" in str(e).lower() or "Cannot" in str(e) + + def test_generator_field_does_not_silently_tokenize(self): + """Generators encode execution state; tokenization should be handled + carefully or rejected.""" + + def gen(): + yield 1 + + class GenModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + g: Any = None + + m = GenModel(g=gen()) + try: + token = m.model_token + assert isinstance(token, str) + except TypeError: + pass # Acceptable: clean rejection + + def test_compiled_regex_same_pattern_same_token(self): + """re.compile with same pattern/flags should produce the same token.""" + r1 = re.compile("abc", re.IGNORECASE) + r2 = re.compile("abc", re.IGNORECASE) + t1 = normalize_token(r1) + t2 = normalize_token(r2) + assert t1 == t2 + + def test_cloudpickle_failure_becomes_typeerror(self): + """Object whose pickling deliberately raises should produce TypeError.""" + + class Unpicklable: + def __reduce__(self): + raise RuntimeError("cannot pickle") + + with pytest.raises(TypeError, match="Cannot tokenize"): + normalize_token(Unpicklable()) + + +# --------------------------------------------------------------------------- +# Graceful failure +# --------------------------------------------------------------------------- + + +class TestGracefulFailure: + def test_untokenizable_type_raises(self): + """Types that can't be cloudpickled raise TypeError.""" + + class Opaque: + def __reduce__(self): + raise TypeError("nope") + + with pytest.raises(TypeError): + normalize_token(Opaque()) + + +# --------------------------------------------------------------------------- +# Model inheritance edge cases +# --------------------------------------------------------------------------- + + +class TestModelInheritance: + def test_multiple_inheritance_tokenizer_resolution(self): + """When two bases define different tokenizers, child uses MRO resolution.""" + + class Base1(BaseModel): + __ccflow_tokenizer__ = DefaultTokenizer() + x: int = 1 + + class Base2(BaseModel): + __ccflow_tokenizer__ = DefaultTokenizer.with_bytecode() + y: int = 2 + + class Child(Base1, Base2): + z: int = 3 + + m = Child() + # MRO: Child -> Base1 -> Base2. Should use Base1's tokenizer (data-only) + assert m.__ccflow_tokenizer__ is Base1.__ccflow_tokenizer__ + assert len(m.model_token) == 64 + + +# --------------------------------------------------------------------------- +# Pydantic-specific edge cases +# --------------------------------------------------------------------------- + + +class TestPydanticEdgeCases: + def test_model_construct_token_works_without_validation(self): + """model_construct() skips validators. Token should still work.""" + m = SimpleModel.model_construct(x=10, y="constructed") + token = m.model_token + assert isinstance(token, str) + m2 = SimpleModel(x=10, y="constructed") + assert token == m2.model_token + + def test_validators_normalize_inputs_before_tokening(self): + """Validator transforms raw input; token reflects validated state.""" + + class NormalizedModel(BaseModel): + name: str = "" + + @model_validator(mode="after") + def _normalize_name(self): + object.__setattr__(self, "name", self.name.strip().lower()) + # Also clear token cache since we modified a field + if self.__pydantic_private__ is not None: + self.__pydantic_private__["_model_token"] = None + return self + + m1 = NormalizedModel(name=" HELLO ") + m2 = NormalizedModel(name="hello") + assert m1.model_token == m2.model_token + + def test_computed_field_does_not_affect_token(self): + """@computed_field should not be included in the token by default, + since it's derived from other fields.""" + + class WithComputed(BaseModel): + x: int = 1 + y: int = 2 + + @computed_field + @property + def total(self) -> int: + return self.x + self.y + + m1 = WithComputed(x=1, y=2) + m2 = WithComputed(x=1, y=2) + assert m1.model_token == m2.model_token + assert "total" not in type(m1).model_fields + + def test_discriminated_union_variant_changes_token(self): + """Different union branches with same outer model → different tokens.""" + + class Cat(BaseModel): + kind: Literal["cat"] = "cat" + lives: int = 9 + + class Dog(BaseModel): + kind: Literal["dog"] = "dog" + lives: int = 1 + + class Owner(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + pet: Union[Cat, Dog] = Cat() + + owner_cat = Owner(pet=Cat(lives=9)) + owner_dog = Owner(pet=Dog(lives=9)) + assert owner_cat.model_token != owner_dog.model_token + + def test_model_validate_from_dict_matches_constructor_token(self): + """model_validate({...}) vs normal construction → same token.""" + m1 = SimpleModel(x=42, y="test") + m2 = SimpleModel.model_validate({"x": 42, "y": "test"}) + assert m1.model_token == m2.model_token + + +# --------------------------------------------------------------------------- +# Float / numeric edge cases +# --------------------------------------------------------------------------- + + +class TestFloatEdgeCases: + def test_nan_token_is_deterministic(self): + """Two separate NaN values should produce the same token.""" + nan1 = float("nan") + assert normalize_token(nan1) is nan1 # Primitive passthrough + + # Even though nan != nan, repr is the same → same hash + class NanModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + v: Any = None + + m1 = NanModel(v=float("nan")) + m2 = NanModel(v=float("nan")) + assert m1.model_token == m2.model_token + + def test_negative_zero_behavior_is_pinned(self): + """-0.0 vs 0.0: document which way the tokenizer goes.""" + t_pos = normalize_token(0.0) + t_neg = normalize_token(-0.0) + # repr(0.0)='0.0', repr(-0.0)='-0.0' → different + # This is a design choice: bit-pattern identity, not numeric equality + # Just pin the behavior so changes are intentional + if repr(0.0) != repr(-0.0): + assert t_pos != t_neg or t_pos == t_neg # Always true; real check below + + # The tokens in a model: + class ZeroModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + v: Any = None + + m_pos = ZeroModel(v=0.0) + m_neg = ZeroModel(v=-0.0) + # Pin: these should differ (repr-based tokenization) + assert m_pos.model_token != m_neg.model_token + + def test_inf_and_negative_inf_distinct(self): + """inf vs -inf should produce different tokens.""" + assert normalize_token(float("inf")) != normalize_token(float("-inf")) + + def test_complex_number_tokenizes(self): + """Complex numbers should be tokenizable (likely via cloudpickle fallback).""" + c = 1 + 2j + token = normalize_token(c) + assert normalize_token(1 + 2j) == token + assert normalize_token(3 + 4j) != token + + +# --------------------------------------------------------------------------- +# Container edge cases +# --------------------------------------------------------------------------- + + +class TestContainerEdgeCases: + def test_dict_with_mixed_key_types(self): + """Dict like {1: 'a', '1': 'b'} — canonicalization must handle + heterogeneous keys without crashing.""" + d = {1: "a", "1": "b"} + result = normalize_token(d) + assert result[0] == "dict" + assert len(result[1]) == 2 + + def test_list_of_models_normalizes_structurally(self): + """List[ChildModel] field should normalize models structurally.""" + + class Parent(BaseModel): + children: List[SimpleModel] = [] + + m1 = Parent(children=[SimpleModel(x=1), SimpleModel(x=2)]) + m2 = Parent(children=[SimpleModel(x=1), SimpleModel(x=2)]) + assert m1.model_token == m2.model_token + + m3 = Parent(children=[SimpleModel(x=1), SimpleModel(x=3)]) + assert m1.model_token != m3.model_token + + def test_dict_of_models_normalizes_structurally(self): + """Dict[str, ChildModel] field should normalize models structurally.""" + + class Parent(BaseModel): + children: Dict[str, SimpleModel] = {} + + m1 = Parent(children={"a": SimpleModel(x=1), "b": SimpleModel(x=2)}) + m2 = Parent(children={"a": SimpleModel(x=1), "b": SimpleModel(x=2)}) + assert m1.model_token == m2.model_token + + def test_nested_containers_with_none_values(self): + """Mixed nested None in lists/dicts should tokenize cleanly.""" + + class MixedModel(BaseModel): + data: Any = None + + m1 = MixedModel(data=[None, {"key": None}, [None, 1]]) + m2 = MixedModel(data=[None, {"key": None}, [None, 1]]) + assert m1.model_token == m2.model_token + + m3 = MixedModel(data=[None, {"key": 1}, [None, 1]]) + assert m1.model_token != m3.model_token + + +# --------------------------------------------------------------------------- +# Merkle tree correctness +# --------------------------------------------------------------------------- + + +class TestMerkleCorrectness: + def test_frozen_child_merkle_shortcut_matches_full_normalize(self): + """Parent token with cached frozen child vs parent token when child + is fully re-normalized. The Merkle shortcut should produce a + consistent result.""" + + class Parent(BaseModel): + child: FrozenModel = FrozenModel() + name: str = "p" + + child = FrozenModel(a=42, b="merkle") + + # Token WITHOUT Merkle shortcut (child cache not warmed) + p1 = Parent(child=child, name="test") + token_full = p1.model_token + + # Warm the child cache + _ = child.model_token + assert child._model_token is not None + + # Token WITH Merkle shortcut (child cache IS warmed) + p2 = Parent(child=child, name="test") + # Clear parent cache to force recomputation + p2.__pydantic_private__["_model_token"] = None + token_merkle = p2.model_token + + assert isinstance(token_full, str) + assert isinstance(token_merkle, str) + # Frozen children always use ("__child__", value.model_token), so + # full normalize and Merkle shortcut now produce the same result. + assert token_full == token_merkle + + def test_frozen_child_cached_token_is_reused(self): + """Frozen child's cached _model_token is used as a Merkle leaf, + producing the same result as a full normalize.""" + + class Parent(BaseModel): + child: FrozenModel = FrozenModel() + name: str = "p" + + child = FrozenModel(a=1, b="test") + # Warm child cache + _ = child.model_token + assert child._model_token is not None + + p = Parent(child=child, name="test") + token_cached = p.model_token + + # Clear child cache and tokenize again + child.__pydantic_private__["_model_token"] = None + p2 = Parent(child=child, name="test") + token_fresh = p2.model_token + + assert token_cached == token_fresh + + def test_nonfrozen_child_not_cached(self): + """Non-frozen children don't cache tokens, so they always reflect current state.""" + + class MutableChild(BaseModel): + value: int = 0 + + class Parent(BaseModel): + child: MutableChild = MutableChild() + + child = MutableChild(value=1) + _ = child.model_token + # Mutable models don't cache by default + assert child._model_token is None + + child.value = 2 + p = Parent(child=child) + token1 = p.model_token + + p2 = Parent(child=MutableChild(value=2)) + token2 = p2.model_token + + assert token1 == token2 + + +# --------------------------------------------------------------------------- +# Concurrency +# --------------------------------------------------------------------------- + + +class TestConcurrency: + def test_model_token_cache_threadsafe_for_parallel_reads(self): + """Multiple threads reading .model_token on same instance concurrently.""" + m = SimpleModel(x=42, y="threadsafe") + results = [] + + def read_token(): + return m.model_token + + with ThreadPoolExecutor(max_workers=8) as pool: + futures = [pool.submit(read_token) for _ in range(50)] + for f in as_completed(futures): + results.append(f.result()) + + # All results should be identical + assert len(set(results)) == 1 + + def test_behavior_token_class_cache_threadsafe(self): + """Multiple threads computing behavior token on same class.""" + + class MyCallable: + def __call__(self, x): + return x + 1 + + results = [] + + def compute(): + return compute_behavior_token(MyCallable) + + with ThreadPoolExecutor(max_workers=8) as pool: + futures = [pool.submit(compute) for _ in range(50)] + for f in as_completed(futures): + results.append(f.result()) + + assert len(set(results)) == 1 + + +# --------------------------------------------------------------------------- +# Pickle edge cases +# --------------------------------------------------------------------------- + + +class TestPickleEdgeCases: + def test_pickle_roundtrip_produces_valid_token(self): + """After pickle/unpickle, model_token should still work and match.""" + m = SimpleModel(x=1, y="hello") + original_token = m.model_token + + m2 = pickle.loads(pickle.dumps(m)) + assert m2.model_token == original_token + + def test_pickle_frozen_model_preserves_correct_token(self): + """Frozen model's token should survive pickle correctly.""" + m = FrozenModel(a=10, b="frozen_pickle") + original_token = m.model_token + + m2 = pickle.loads(pickle.dumps(m)) + assert m2.model_token == original_token + + +# --------------------------------------------------------------------------- +# Pandas unhashable object columns +# --------------------------------------------------------------------------- + + +class TestPandasUnhashable: + def test_dataframe_with_dict_column(self): + """DataFrame with unhashable object column (dicts) should not crash.""" + df = pd.DataFrame({"a": [1, 2], "b": [{"x": 1}, {"y": 2}]}) + token = normalize_token(df) + assert token is not None + + def test_dataframe_with_dict_column_deterministic(self): + """Same dict-column DataFrame produces same token.""" + df1 = pd.DataFrame({"a": [1], "b": [{"k": "v"}]}) + df2 = pd.DataFrame({"a": [1], "b": [{"k": "v"}]}) + assert normalize_token(df1) == normalize_token(df2) + + def test_series_with_dict_elements(self): + """Series with unhashable elements should not crash.""" + s = pd.Series([{"x": 1}, {"y": 2}], name="dicts") + token = normalize_token(s) + assert token is not None + + def test_dataframe_with_list_column(self): + """DataFrame with list-valued column should not crash.""" + df = pd.DataFrame({"a": [[1, 2], [3, 4]]}) + token = normalize_token(df) + assert token is not None + + +# --------------------------------------------------------------------------- +# Dask coverage gap tests (numpy/pandas edge cases) +# --------------------------------------------------------------------------- + + +class TestDaskCoverageGaps: + @pytest.mark.parametrize( + "left,right,expect_equal", + [ + pytest.param( + np.array([1, 2, 3], dtype=np.float32), + np.array([1, 2, 3], dtype=np.float64), + False, + id="different-dtypes", + ), + pytest.param( + np.array([], dtype=np.float64), + np.array([], dtype=np.float64), + True, + id="empty-same-dtype", + ), + pytest.param( + np.array([], dtype=np.float32), + np.array([], dtype=np.float64), + False, + id="empty-diff-dtype", + ), + pytest.param( + {"b": 2, "a": 1}, + {"a": 1, "b": 2}, + True, + id="dict-order-independence", + ), + ], + ) + def test_normalize_token_equality(self, left, right, expect_equal): + assert (normalize_token(left) == normalize_token(right)) == expect_equal + + def test_numpy_discontiguous_array(self): + arr = np.arange(10) + s1 = arr[::2] + s2 = arr[::3] + assert normalize_token(s1) != normalize_token(s2) + + def test_numpy_structured_array(self): + dt = np.dtype([("x", np.int32), ("y", np.float64)]) + arr = np.array([(1, 2.0), (3, 4.0)], dtype=dt) + assert normalize_token(arr)[0] == "ndarray" + + def test_pandas_categorical(self): + df1 = pd.DataFrame({"cat": pd.Categorical(["a", "b", "a"])}) + df2 = pd.DataFrame({"cat": pd.Categorical(["a", "b", "a"])}) + assert normalize_token(df1) == normalize_token(df2) + + def test_pandas_empty_dataframe(self): + assert normalize_token(pd.DataFrame()) is not None + + def test_pandas_multiindex(self): + idx = pd.MultiIndex.from_tuples([(1, "a"), (2, "b")]) + df = pd.DataFrame({"v": [10, 20]}, index=idx) + assert normalize_token(df) is not None + + def test_plain_pydantic_frozen_tokenizes(self): + class FrozenPlain(PydanticBaseModel): + model_config = ConfigDict(frozen=True) + x: int = 1 + + t = DefaultTokenizer() + token = t.tokenize(FrozenPlain(x=42)) + assert isinstance(token, str) + assert t.tokenize(FrozenPlain(x=42)) == token + + def test_plain_pydantic_nonfrozen_tokenizes(self): + class PlainModel(PydanticBaseModel): + x: int = 1 + y: str = "hello" + + t = DefaultTokenizer() + token = t.tokenize(PlainModel(x=5, y="world")) + assert isinstance(token, str) + assert t.tokenize(PlainModel(x=5, y="world")) == token + + +# --------------------------------------------------------------------------- +# Own-methods behavior token (all methods from cls.__dict__) +# --------------------------------------------------------------------------- + + +def _standalone_helper(x): + """A standalone function for __ccflow_tokenizer_deps__ tests.""" + return x * 10 + + +def _another_helper(x): + return x + 99 + + +class TestOwnMethodsBehaviorToken: + """Tests for hashing all own methods (not just __call__/__deps__).""" + + @pytest.mark.parametrize( + "left_attrs,right_attrs,expect_equal", + [ + pytest.param( + {"_helper": lambda self: 1, "__call__": lambda self: self._helper()}, + {"_helper": lambda self: 2, "__call__": lambda self: self._helper()}, + False, + id="helper-method-change", + ), + pytest.param( + {"__call__": lambda self: self.x + 1}, + {"__call__": lambda self: self.x + 1, "extra": lambda self: 42}, + False, + id="adding-method", + ), + pytest.param( + {"_private": lambda self: 1}, + {"_private": lambda self: 2}, + False, + id="private-method", + ), + ], + ) + def test_behavior_token_comparison(self, left_attrs, right_attrs, expect_equal): + A = _make_ast_model("A", **left_attrs) + B = _make_ast_model("B", **right_attrs) + assert (A(x=1).model_token == B(x=1).model_token) == expect_equal + + def test_classmethod_included(self): + class A1(BaseModel): + __ccflow_tokenizer__ = _AST_TOKENIZER + x: int = 1 + + @classmethod + def from_config(cls, cfg): + return cls(**cfg) + + class A2(BaseModel): + __ccflow_tokenizer__ = _AST_TOKENIZER + x: int = 1 + + @classmethod + def from_config(cls, cfg): + return cls() + + assert compute_behavior_token(A1) != compute_behavior_token(A2) + + def test_staticmethod_included(self): + class B1(BaseModel): + __ccflow_tokenizer__ = _AST_TOKENIZER + x: int = 1 + + @staticmethod + def validate(x): + return x > 0 + + class B2(BaseModel): + __ccflow_tokenizer__ = _AST_TOKENIZER + x: int = 1 + + @staticmethod + def validate(x): + return x >= 0 + + assert compute_behavior_token(B1) != compute_behavior_token(B2) + + def test_validator_included(self): + from pydantic import field_validator + + class V1(BaseModel): + __ccflow_tokenizer__ = _AST_TOKENIZER + x: int = 1 + + @field_validator("x") + @classmethod + def check_x(cls, v): + if v < 0: + raise ValueError("negative") + return v + + class V2(BaseModel): + __ccflow_tokenizer__ = _AST_TOKENIZER + x: int = 1 + + @field_validator("x") + @classmethod + def check_x(cls, v): + if v < -10: + raise ValueError("too negative") + return v + + assert V1(x=1).model_token != V2(x=1).model_token + + def test_no_own_methods_returns_none(self): + class NoMethods: + pass + + assert compute_behavior_token(NoMethods) is None + + +class TestTokenizerDeps: + """Tests for __ccflow_tokenizer_deps__ extension mechanism.""" + + def test_standalone_function_included(self): + A = _make_ast_model("A", deps=[_standalone_helper], __call__=lambda self: _standalone_helper(self.x)) + B = _make_ast_model("B", __call__=lambda self: _standalone_helper(self.x)) + assert A(x=1).model_token != B(x=1).model_token + + def test_different_dep_functions_differ(self): + A = _make_ast_model("A", deps=[_standalone_helper]) + B = _make_ast_model("B", deps=[_another_helper]) + assert A(x=1).model_token != B(x=1).model_token + + def test_deps_not_inherited(self): + Parent = _make_ast_model("Parent", deps=[_standalone_helper], __call__=lambda self: self.x) + Child = _make_ast_model("Child", base=Parent) + assert compute_behavior_token(Parent) != compute_behavior_token(Child) + + def test_empty_deps_same_as_no_deps(self): + class WithEmpty(BaseModel): + __ccflow_tokenizer_deps__ = [] + + def __call__(self, x): + return x + 1 + + class WithoutDeps(BaseModel): + def __call__(self, x): + return x + 1 + + assert compute_behavior_token(WithEmpty) == compute_behavior_token(WithoutDeps) + + +class TestRuntimeTokenizerMutation: + """Tests for mutating BaseModel.__ccflow_tokenizer__ at runtime.""" + + @pytest.fixture(autouse=False) + def _restore_base_tokenizer(self): + original = BaseModel.__ccflow_tokenizer__ + yield + BaseModel.__ccflow_tokenizer__ = original + + def test_global_mutation_affects_subclasses(self, _restore_base_tokenizer): + """Mutating BaseModel.__ccflow_tokenizer__ affects all subclasses without overrides.""" + + class PlainModel(BaseModel): + x: int = 1 + + def __call__(self, x): + return x + 1 + + token_before = PlainModel(x=42).model_token + + BaseModel.__ccflow_tokenizer__ = DefaultTokenizer.with_ast() + + token_after = PlainModel(x=42).model_token + assert token_before != token_after + + def test_global_mutation_does_not_affect_overridden_subclass(self, _restore_base_tokenizer): + """Subclass with its own __ccflow_tokenizer__ is NOT affected by base mutation.""" + + class CustomModel(BaseModel): + __ccflow_tokenizer__ = DefaultTokenizer.with_bytecode() + x: int = 1 + + token_before = CustomModel(x=42).model_token + + BaseModel.__ccflow_tokenizer__ = DefaultTokenizer.with_ast() + + token_after = CustomModel(x=42).model_token + assert token_before == token_after + + def test_existing_instances_pick_up_new_tokenizer(self, _restore_base_tokenizer): + """After global tokenizer mutation, mutable instances use the new tokenizer immediately.""" + + class M(BaseModel): + x: int = 1 + + def __call__(self, x): + return x + + m = M(x=1) + token_before = m.model_token + + BaseModel.__ccflow_tokenizer__ = DefaultTokenizer.with_ast() + + # Mutable models don't cache, so they pick up the new tokenizer immediately + token_after = m.model_token + assert token_before != token_after + + def test_new_instances_after_mutation_use_new_tokenizer(self, _restore_base_tokenizer): + """New instances created after mutation use the new tokenizer.""" + + class M(BaseModel): + x: int = 1 + + def __call__(self, x): + return x + + token_before = M(x=1).model_token + + BaseModel.__ccflow_tokenizer__ = DefaultTokenizer.with_ast() + + token_after = M(x=1).model_token + assert token_before != token_after + + +# --------------------------------------------------------------------------- +# Additional type handler tests +# --------------------------------------------------------------------------- + + +class TestAdditionalTypeHandlers: + """Tests for builtin and library type handlers.""" + + def test_complex(self): + assert normalize_token(complex(1, 2)) == ("complex", 1.0, 2.0) + assert normalize_token(complex(1, 2)) != normalize_token(complex(2, 1)) + + def test_ellipsis(self): + assert normalize_token(...) == ("ellipsis",) + + def test_slice(self): + assert normalize_token(slice(1, 10, 2)) == ("slice", 1, 10, 2) + assert normalize_token(slice(1, 10)) == ("slice", 1, 10, None) + assert normalize_token(slice(1, 10)) != normalize_token(slice(1, 11)) + + def test_builtin_function(self): + assert normalize_token(len) == ("builtin", "len") + assert normalize_token(len) != normalize_token(print) + + def test_decimal(self): + from decimal import Decimal + + assert normalize_token(Decimal("3.14")) == ("decimal", "3.14") + assert normalize_token(Decimal("3.14")) != normalize_token(Decimal("3.15")) + + def test_partial(self): + from functools import partial + + p1 = partial(int, base=16) + p2 = partial(int, base=10) + assert normalize_token(p1) != normalize_token(p2) + assert normalize_token(p1) == normalize_token(partial(int, base=16)) + + def test_mappingproxy(self): + from types import MappingProxyType + + mp = MappingProxyType({"a": 1, "b": 2}) + d = {"a": 1, "b": 2} + # MappingProxy normalizes like dict + assert normalize_token(mp) == normalize_token(d) + + def test_polars_dataframe(self): + """Polars DataFrames fall through to cloudpickle — same data same token.""" + import polars as pl + + df1 = pl.DataFrame({"a": [1, 2, 3]}) + df2 = pl.DataFrame({"a": [1, 2, 3]}) + df3 = pl.DataFrame({"a": [1, 2, 4]}) + assert normalize_token(df1) == normalize_token(df2) + assert normalize_token(df1) != normalize_token(df3) + + def test_polars_series(self): + import polars as pl + + s1 = pl.Series("x", [1, 2, 3]) + s2 = pl.Series("x", [1, 2, 3]) + s3 = pl.Series("x", [1, 2, 4]) + assert normalize_token(s1) == normalize_token(s2) + assert normalize_token(s1) != normalize_token(s3) + + def test_polars_lazyframe(self): + """Polars LazyFrames are tokenized via cloudpickle (no collect).""" + import polars as pl + + lf1 = pl.LazyFrame({"a": [1, 2, 3]}) + lf2 = pl.LazyFrame({"a": [1, 2, 3]}) + assert normalize_token(lf1) == normalize_token(lf2) + + def test_pandas_dataframe(self): + """Pandas DataFrames fall through to cloudpickle.""" + df1 = pd.DataFrame({"a": [1, 2, 3]}) + df2 = pd.DataFrame({"a": [1, 2, 3]}) + df3 = pd.DataFrame({"a": [1, 2, 4]}) + assert normalize_token(df1) == normalize_token(df2) + assert normalize_token(df1) != normalize_token(df3) + + def test_pandas_series(self): + df1 = pd.Series([1, 2, 3], name="x") + df2 = pd.Series([1, 2, 3], name="x") + df3 = pd.Series([1, 2, 4], name="x") + assert normalize_token(df1) == normalize_token(df2) + assert normalize_token(df1) != normalize_token(df3) + + def test_narwhals_dataframe(self): + import narwhals as nw + + nw_df1 = nw.from_native(pd.DataFrame({"a": [1, 2, 3]})) + nw_df2 = nw.from_native(pd.DataFrame({"a": [1, 2, 3]})) + nw_df3 = nw.from_native(pd.DataFrame({"a": [1, 2, 4]})) + assert normalize_token(nw_df1) == normalize_token(nw_df2) + assert normalize_token(nw_df1) != normalize_token(nw_df3) + + def test_narwhals_series(self): + import narwhals as nw + + nw_s1 = nw.from_native(pd.Series([1, 2, 3], name="x"), allow_series=True) + nw_s2 = nw.from_native(pd.Series([1, 2, 3], name="x"), allow_series=True) + assert normalize_token(nw_s1) == normalize_token(nw_s2) + + def test_numpy_datetime64(self): + token = normalize_token(np.datetime64("2024-01-01")) + assert token == ("np_scalar", "datetime64", date(2024, 1, 1)) + + def test_numpy_complex(self): + token = normalize_token(np.complex128(1 + 2j)) + assert token == ("np_scalar", "complex128", (1 + 2j)) + + +class TestDaskTokenizer: + """Tests for DaskTokenizer backward compatibility.""" + + def test_matches_raw_dask_tokenize(self): + import dask.base + + from ccflow.utils.tokenize import DaskTokenizer + + class M(BaseModel): + x: int = 1 + y: str = "hello" + + m = M() + t = DaskTokenizer() + assert t.tokenize(m) == dask.base.tokenize(m.model_dump(mode="python")) + + def test_different_values_different_token(self): + from ccflow.utils.tokenize import DaskTokenizer + + class M(BaseModel): + x: int = 1 + + t = DaskTokenizer() + assert t.tokenize(M(x=1)) != t.tokenize(M(x=2)) + + def test_same_values_same_token(self): + from ccflow.utils.tokenize import DaskTokenizer + + class M(BaseModel): + x: int = 1 + y: str = "a" + + t = DaskTokenizer() + assert t.tokenize(M()) == t.tokenize(M()) + + def test_works_as_ccflow_tokenizer(self): + from ccflow.utils.tokenize import DaskTokenizer + + class M(BaseModel): + __ccflow_tokenizer__ = DaskTokenizer() + x: int = 1 + + m = M() + assert isinstance(m.model_token, str) + assert len(m.model_token) == 32 # dask uses MD5 + + def test_nested_model(self): + import dask.base + + from ccflow.utils.tokenize import DaskTokenizer + + class Child(BaseModel): + a: int = 1 + + class Parent(BaseModel): + child: Child = Child() + + t = DaskTokenizer() + p = Parent() + assert t.tokenize(p) == dask.base.tokenize(p.model_dump(mode="python")) + + def test_works_with_plain_pydantic(self): + import dask.base + from pydantic import BaseModel as PydanticBaseModel + + from ccflow.utils.tokenize import DaskTokenizer + + class Plain(PydanticBaseModel): + x: int = 1 + + t = DaskTokenizer() + assert t.tokenize(Plain()) == dask.base.tokenize(Plain().model_dump(mode="python")) + + +# --------------------------------------------------------------------------- +# Review finding tests +# --------------------------------------------------------------------------- + + +class TestReviewFindings: + """Tests for issues identified during PR #195 review.""" + + def test_mutable_model_no_stale_parent_token(self): + """Finding 1: Mutable models recompute token fresh — no stale parent.""" + + class Child(BaseModel): + x: int = 1 + + class Parent(BaseModel): + child: Child = Child() + + child = Child(x=1) + parent = Parent(child=child) + t1 = parent.model_token + child.x = 2 + t2 = parent.model_token + # Mutable model recomputes — should reflect child change + assert t1 != t2 + + def test_frozen_model_caches_token(self): + """Finding 1: Frozen models cache their token.""" + + class Frozen(BaseModel): + model_config = ConfigDict(frozen=True) + x: int = 1 + + m = Frozen(x=42) + t1 = m.model_token + t2 = m.model_token + assert t1 == t2 + # Verify it's actually cached (same object) + assert m._model_token is not None + + def test_mutable_model_opt_in_caching(self): + """Finding 1: Mutable models can opt in to caching.""" + + class Cached(BaseModel): + model_config = ConfigDict(cache_token=True) + x: int = 1 + + m = Cached(x=1) + _ = m.model_token + assert m._model_token is not None + + def test_container_cycle_list(self): + """Finding 2: Self-referential list doesn't cause RecursionError.""" + + class M(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + data: object = None + + lst = [] + lst.append(lst) + # Should not raise RecursionError + token = M(data=lst).model_token + assert isinstance(token, str) + + def test_container_cycle_dict(self): + """Finding 2: Self-referential dict doesn't cause RecursionError.""" + + class M(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + data: object = None + + d = {} + d["self"] = d + token = M(data=d).model_token + assert isinstance(token, str) + + def test_behavior_cache_stateful_tokenizer(self): + """Finding 3: Custom stateful tokenizers don't collide in cache.""" + import hashlib + + class SaltedTokenizer(SourceTokenizer): + def __init__(self, salt): + self.salt = salt + + def tokenize(self, func): + code = getattr(func, "__code__", None) + if code is None: + return None + return hashlib.sha256((self.salt + repr(code.co_code)).encode()).hexdigest() + + class M(BaseModel): + x: int = 1 + + def f(self): + return 1 + + t1 = compute_behavior_token(M, collector=OwnMethodCollector(), source_tokenizer=SaltedTokenizer("a")) + t2 = compute_behavior_token(M, collector=OwnMethodCollector(), source_tokenizer=SaltedTokenizer("b")) + assert t1 != t2 + + def test_dep_order_insensitive(self): + """Finding 4: __ccflow_tokenizer_deps__ order doesn't affect behavior token.""" + + def helper_a(): + return "a" + + def helper_b(): + return "b" + + class A(BaseModel): + __ccflow_tokenizer_deps__ = [helper_a, helper_b] + + def f(self): + return 1 + + class B(BaseModel): + __ccflow_tokenizer_deps__ = [helper_b, helper_a] + + def f(self): + return 1 + + # Compare behavior tokens directly (model tokens differ due to type path) + bt_a = compute_behavior_token(A) + bt_b = compute_behavior_token(B) + assert bt_a == bt_b + + def test_unpicklable_raises_type_error(self): + """Finding 5: Unpicklable objects raise TypeError, not repr fallback.""" + + class Unpicklable: + def __reduce__(self): + raise TypeError("nope") + + with pytest.raises(TypeError, match="Cannot tokenize"): + normalize_token(Unpicklable()) diff --git a/ccflow/utils/__init__.py b/ccflow/utils/__init__.py index e1b3188..b2c69d7 100644 --- a/ccflow/utils/__init__.py +++ b/ccflow/utils/__init__.py @@ -1,4 +1,14 @@ from .chunker import * from .core import * from .logging import * -from .tokenize import normalize_token, tokenize +from .tokenize import ( + ASTSourceTokenizer, + BytecodeSourceTokenizer, + DaskTokenizer, + DefaultTokenizer, + FunctionCollector, + OwnMethodCollector, + SourceTokenizer, + Tokenizer, + normalize_token, +) diff --git a/ccflow/utils/tokenize.py b/ccflow/utils/tokenize.py index 20b7161..b879196 100644 --- a/ccflow/utils/tokenize.py +++ b/ccflow/utils/tokenize.py @@ -1,2 +1,640 @@ -# ruff: noqa: F401 -from dask.base import normalize_token, tokenize +"""Tokenization engine for ccflow models. + +Provides deterministic content hashing for pydantic BaseModel instances, +with configurable hash algorithms, Merkle tree optimization for frozen models, +and extensibility via ``__ccflow_tokenize__`` hooks. + +Usage:: + + from ccflow import BaseModel + from ccflow.utils.tokenize import DefaultTokenizer + + class MyModel(BaseModel): + x: int = 1 + + model = MyModel() + model.model_token # hex digest string + + # With behavior hashing + tokenizer = DefaultTokenizer.with_bytecode() +""" + +import ast +import enum +import hashlib +import inspect +import logging +import textwrap +from abc import ABC, abstractmethod +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from functools import partial, singledispatch +from pathlib import PurePath +from types import MappingProxyType +from typing import Any, Callable, List, Optional, Set, Tuple +from uuid import UUID + +from pydantic import BaseModel as PydanticBaseModel + +log = logging.getLogger(__name__) + +__all__ = [ + "SourceTokenizer", + "ASTSourceTokenizer", + "BytecodeSourceTokenizer", + "FunctionCollector", + "OwnMethodCollector", + "Tokenizer", + "DefaultTokenizer", + "normalize_token", + "compute_behavior_token", +] + + +# --------------------------------------------------------------------------- +# SourceTokenizer — how to hash a single function +# --------------------------------------------------------------------------- + + +class SourceTokenizer(ABC): + """Tokenizes a single callable into a digest string. + + Subclass to provide different code hashing strategies (AST, bytecode, etc.). + """ + + @abstractmethod + def tokenize(self, func: Callable) -> Optional[str]: + """Return a hex digest of *func*'s source/bytecode, or None if unavailable.""" + ... + + +def _normalize_source_ast(source: str) -> str: + """Parse source, strip docstrings, return normalized form via ast.unparse.""" + source = textwrap.dedent(source) + tree = ast.parse(source) + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + if node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, (ast.Constant, ast.Str)): + node.body.pop(0) + return ast.unparse(tree) + + +class ASTSourceTokenizer(SourceTokenizer): + """Hash function source via AST normalization (strips docstrings, normalizes whitespace).""" + + def tokenize(self, func: Callable) -> Optional[str]: + try: + source = inspect.getsource(func) + normalized = _normalize_source_ast(source) + return hashlib.sha256(normalized.encode("utf-8")).hexdigest() + except (OSError, TypeError): + # Source not available — fall back to bytecode + code = getattr(func, "__code__", None) + if code is not None: + return hashlib.sha256(code.co_code).hexdigest() + return None + + +class BytecodeSourceTokenizer(SourceTokenizer): + """Hash function bytecode (co_code + co_consts, docstring stripped).""" + + def tokenize(self, func: Callable) -> Optional[str]: + code = getattr(func, "__code__", None) + if code is None: + return None + consts = code.co_consts + # Strip the docstring slot (co_consts[0]): a str when present, None when absent + if consts and isinstance(consts[0], (str, type(None))): + consts = consts[1:] + payload = repr((code.co_code, consts)).encode("utf-8") + return hashlib.sha256(payload).hexdigest() + + +# --------------------------------------------------------------------------- +# FunctionCollector — which functions to hash for a class +# --------------------------------------------------------------------------- + + +class FunctionCollector(ABC): + """Discovers functions relevant to a class for behavior hashing. + + Subclass to control which methods/functions participate in the behavior token. + """ + + @abstractmethod + def collect(self, cls: type) -> List[Tuple[str, Callable]]: + """Return a sorted list of ``(name, callable)`` pairs for *cls*.""" + ... + + +class OwnMethodCollector(FunctionCollector): + """Collects all callable methods defined directly on *cls* (via ``cls.__dict__``), + plus any standalone functions listed in ``cls.__ccflow_tokenizer_deps__``.""" + + def collect(self, cls: type) -> List[Tuple[str, Callable]]: + methods = [] + for name, value in cls.__dict__.items(): + if isinstance(value, (classmethod, staticmethod)): + methods.append((name, value.__func__)) + elif callable(value): + methods.append((name, value)) + methods.sort(key=lambda pair: pair[0]) + + # Add standalone functions from __ccflow_tokenizer_deps__ + deps = [] + extra_deps = cls.__dict__.get("__ccflow_tokenizer_deps__") + if extra_deps is not None: + for func in extra_deps: + if isinstance(func, (classmethod, staticmethod)): + func = func.__func__ + if callable(func): + func_id = getattr(func, "__qualname__", getattr(func, "__name__", repr(func))) + deps.append((f"__dep__:{func_id}", func)) + deps.sort(key=lambda pair: pair[0]) + methods.extend(deps) + + return methods + + +# --------------------------------------------------------------------------- +# compute_behavior_token — composes collector + source tokenizer +# --------------------------------------------------------------------------- + +_BEHAVIOR_CACHE_ATTR = "__ccflow_behavior_token__" + + +def _get_behavior_cache(cls: type, cache_key: tuple) -> Optional[str]: + """Read from the per-class behavior token cache (returns None on miss).""" + cached = cls.__dict__.get(_BEHAVIOR_CACHE_ATTR) + if isinstance(cached, dict) and cache_key in cached: + return cached[cache_key] + return None + + +def _set_behavior_cache(cls: type, cache_key: tuple, token: str) -> None: + """Write to the per-class behavior token cache.""" + try: + if not isinstance(cls.__dict__.get(_BEHAVIOR_CACHE_ATTR), dict): + setattr(cls, _BEHAVIOR_CACHE_ATTR, {}) + getattr(cls, _BEHAVIOR_CACHE_ATTR)[cache_key] = token + except (TypeError, AttributeError): + pass + + +def compute_behavior_token( + cls: type, + collector: Optional["FunctionCollector"] = None, + source_tokenizer: Optional["SourceTokenizer"] = None, +) -> Optional[str]: + """Compute a behavior token for a class by hashing collected functions. + + Results are cached on the class keyed by ``(collector_type, source_type)``. + """ + if collector is None: + collector = OwnMethodCollector() + if source_tokenizer is None: + source_tokenizer = BytecodeSourceTokenizer() + + cache_key = (id(collector), id(source_tokenizer)) + cached = _get_behavior_cache(cls, cache_key) + if cached is not None: + return cached + + method_hashes = [(name, h) for name, method in collector.collect(cls) if (h := source_tokenizer.tokenize(method)) is not None] + if not method_hashes: + return None + + token = hashlib.sha256(repr(method_hashes).encode("utf-8")).hexdigest() + _set_behavior_cache(cls, cache_key, token) + return token + + +# --------------------------------------------------------------------------- +# normalize_token — singledispatch-based canonical normalization +# --------------------------------------------------------------------------- + + +@singledispatch +def normalize_token(obj: Any) -> Any: + """Produce a canonical, hashable representation of an object. + + This is a singledispatch function — register handlers for new types via:: + + @normalize_token.register(MyType) + def _(obj): + return ... + + Objects with a ``__ccflow_tokenize__`` method use it automatically. + """ + # Check for custom hook + method = getattr(obj, "__ccflow_tokenize__", None) + if method is not None: + return method() + + # Try cloudpickle as last resort + try: + import cloudpickle + + pickled = cloudpickle.dumps(obj) + return ("__cloudpickle__", hashlib.sha256(pickled).hexdigest()) + except Exception: + raise TypeError( + f"Cannot tokenize object of type {type(obj).__qualname__}. Implement __ccflow_tokenize__() or register a normalize_token handler." + ) + + +# --- Primitives --- + + +@normalize_token.register(type(None)) +def _normalize_none(obj): + return None + + +@normalize_token.register(bool) +@normalize_token.register(int) +@normalize_token.register(float) +@normalize_token.register(str) +@normalize_token.register(bytes) +def _normalize_primitive(obj): + return obj + + +# --- Date/time --- + + +@normalize_token.register(date) +def _normalize_date(obj): + return ("date", obj.isoformat()) + + +@normalize_token.register(datetime) +def _normalize_datetime(obj): + return ("datetime", obj.isoformat()) + + +@normalize_token.register(time) +def _normalize_time(obj): + return ("time", obj.isoformat()) + + +@normalize_token.register(timedelta) +def _normalize_timedelta(obj): + return ("timedelta", obj.total_seconds()) + + +# --- UUID --- + + +@normalize_token.register(UUID) +def _normalize_uuid(obj): + return ("uuid", str(obj)) + + +# --- Path --- + + +@normalize_token.register(PurePath) +def _normalize_path(obj): + return ("path", str(obj)) + + +# --- Enum --- + + +@normalize_token.register(enum.Enum) +def _normalize_enum(obj): + return ("enum", type(obj).__qualname__, obj.name) + + +# --- Collections --- + + +@normalize_token.register(tuple) +def _normalize_tuple(obj): + return ("tuple", tuple(normalize_token(item) for item in obj)) + + +@normalize_token.register(list) +def _normalize_list(obj): + return ("list", tuple(normalize_token(item) for item in obj)) + + +@normalize_token.register(set) +def _normalize_set(obj): + return ("set", tuple(sorted((normalize_token(item) for item in obj), key=repr))) + + +@normalize_token.register(frozenset) +def _normalize_frozenset(obj): + return ("frozenset", tuple(sorted((normalize_token(item) for item in obj), key=repr))) + + +@normalize_token.register(dict) +def _normalize_dict(obj): + return ( + "dict", + tuple( + sorted( + ((normalize_token(k), normalize_token(v)) for k, v in obj.items()), + key=repr, + ) + ), + ) + + +# --- Additional builtins --- + + +@normalize_token.register(complex) +def _normalize_complex(obj): + return ("complex", obj.real, obj.imag) + + +@normalize_token.register(type(Ellipsis)) +def _normalize_ellipsis(obj): + return ("ellipsis",) + + +@normalize_token.register(slice) +def _normalize_slice(obj): + return ("slice", obj.start, obj.stop, obj.step) + + +@normalize_token.register(type(len)) # builtin_function_or_method +def _normalize_builtin(obj): + return ("builtin", obj.__qualname__) + + +@normalize_token.register(Decimal) +def _normalize_decimal(obj): + return ("decimal", str(obj)) + + +@normalize_token.register(partial) +def _normalize_partial(obj): + return ( + "partial", + normalize_token(obj.func), + normalize_token(obj.args), + normalize_token(sorted(obj.keywords.items())), + ) + + +@normalize_token.register(MappingProxyType) +def _normalize_mappingproxy(obj): + return _normalize_dict(dict(obj)) + + +# --- Numpy --- + + +def _register_numpy(): + """Register numpy normalize_token handlers.""" + try: + import numpy as np + except ImportError: + return + + @normalize_token.register(np.ndarray) + def _normalize_ndarray(obj): + return ("ndarray", str(obj.dtype), obj.shape, hashlib.sha256(obj.tobytes()).hexdigest()) + + @normalize_token.register(np.generic) + def _normalize_np_scalar(obj): + return ("np_scalar", str(type(obj).__name__), obj.item()) + + +def _register_pandas(): + """Register pandas normalize_token handlers.""" + try: + import pandas as pd + except ImportError: + return + + @normalize_token.register(pd.Timestamp) + def _normalize_pd_timestamp(obj): + return ("pd_timestamp", obj.isoformat()) + + +# --- Functions --- + + +@normalize_token.register(type(lambda: None)) # FunctionType +def _normalize_function(obj): + method = getattr(obj, "__ccflow_tokenize__", None) + if method is not None: + return method() + # Try AST-normalized source + try: + source = inspect.getsource(obj) + normalized = _normalize_source_ast(source) + return ("func", obj.__qualname__, hashlib.sha256(normalized.encode()).hexdigest()) + except (OSError, TypeError): + pass + # Fallback to bytecode + code = getattr(obj, "__code__", None) + if code is not None: + return ("func", obj.__qualname__, hashlib.sha256(code.co_code).hexdigest()) + # Last resort: qualified name only + return ("func", obj.__qualname__) + + +@normalize_token.register(type) +def _normalize_type(obj): + return ("type", f"{obj.__module__}.{obj.__qualname__}") + + +# --- Pydantic BaseModel --- +# NOTE: The ccflow BaseModel handler is registered in ccflow/base.py +# to avoid circular imports. This handles plain pydantic BaseModel. + + +@normalize_token.register(PydanticBaseModel) +def _normalize_pydantic_basemodel(obj): + type_path = f"{type(obj).__module__}.{type(obj).__qualname__}" + model_fields = type(obj).model_fields + fields = tuple((k, normalize_token(v)) for k, v in obj if k in model_fields and not model_fields[k].exclude) + return ("pydantic", type_path, fields) + + +# Register numpy/pandas handlers at import time +_register_numpy() +_register_pandas() + + +# --------------------------------------------------------------------------- +# Tokenizer ABC and DefaultTokenizer +# --------------------------------------------------------------------------- + + +class Tokenizer(ABC): + """Abstract tokenization engine. + + Subclass and override methods to customize tokenization behavior. + Set ``__ccflow_tokenizer__`` on a BaseModel class to swap engines. + """ + + def hash_canonical(self, canonical: Any) -> str: + """Hash an arbitrary canonical form to a hex digest.""" + return hashlib.sha256(repr(canonical).encode("utf-8")).hexdigest() + + @abstractmethod + def normalize(self, model: PydanticBaseModel, *, _visited: Optional[Set[int]] = None) -> Any: + """Produce a canonical structured representation of a model.""" + ... + + @abstractmethod + def tokenize(self, model: PydanticBaseModel) -> str: + """Produce a hex digest token for a model.""" + ... + + def normalize_value(self, value: Any, *, _visited: Optional[Set[int]] = None) -> Any: + """Normalize an arbitrary value. Override for custom dispatch.""" + return normalize_token(value) + + +def _normalize_model_fields(tokenizer: "Tokenizer", model: PydanticBaseModel, _visited: Set[int]) -> List[Tuple[str, Any]]: + """Normalize a model's non-excluded fields via the tokenizer's normalize_value.""" + fields = [] + model_fields = type(model).model_fields + for field_name, field_info in model_fields.items(): + if field_info.exclude: + continue + value = getattr(model, field_name) + fields.append((field_name, tokenizer.normalize_value(value, _visited=_visited))) + return fields + + +class DefaultTokenizer(Tokenizer): + """Default tokenization engine using singledispatch-based normalization. + + Composes a ``FunctionCollector`` and ``SourceTokenizer`` for optional + behavior hashing. When both are ``None`` (the default), only field + data is hashed. + """ + + def __init__( + self, + collector: Optional[FunctionCollector] = None, + source_tokenizer: Optional[SourceTokenizer] = None, + ): + self.collector = collector + self.source_tokenizer = source_tokenizer + + @classmethod + def with_ast(cls) -> "DefaultTokenizer": + """Convenience constructor: own methods hashed via AST normalization.""" + return cls(collector=OwnMethodCollector(), source_tokenizer=ASTSourceTokenizer()) + + @classmethod + def with_bytecode(cls) -> "DefaultTokenizer": + """Convenience constructor: own methods hashed via bytecode.""" + return cls(collector=OwnMethodCollector(), source_tokenizer=BytecodeSourceTokenizer()) + + def normalize_value(self, value: Any, *, _visited: Optional[Set[int]] = None) -> Any: + """Normalize an arbitrary value, routing containers and models through the tokenizer. + + Re-implements container handling (rather than delegating to normalize_token) + so that nested models participate in cycle detection via _visited. + """ + # Fast path for common primitives — avoids singledispatch overhead + if type(value) in (int, str, float, bool, type(None), bytes): + return value + + if isinstance(value, PydanticBaseModel): + is_frozen = value.model_config.get("frozen", False) + if is_frozen and hasattr(value, "model_token"): + return ("__child__", value.model_token) + return self.normalize(value, _visited=_visited) + + # Cycle detection for mutable containers + val_id = id(value) + if isinstance(value, (list, dict, set)): + if _visited is None: + _visited = set() + if val_id in _visited: + return ("__cycle__", type(value).__name__) + _visited.add(val_id) + + if isinstance(value, dict): + result = ( + "dict", + tuple( + sorted( + ((self.normalize_value(k, _visited=_visited), self.normalize_value(v, _visited=_visited)) for k, v in value.items()), + key=repr, + ) + ), + ) + _visited.discard(val_id) + return result + if isinstance(value, (list, tuple)): + tag = "list" if isinstance(value, list) else "tuple" + result = (tag, tuple(self.normalize_value(v, _visited=_visited) for v in value)) + _visited.discard(val_id) + return result + if isinstance(value, (set, frozenset)): + tag = "set" if isinstance(value, set) else "frozenset" + result = ( + tag, + tuple( + sorted( + (self.normalize_value(v, _visited=_visited) for v in value), + key=repr, + ) + ), + ) + _visited.discard(val_id) + return result + + return normalize_token(value) + + def normalize(self, model: PydanticBaseModel, *, _visited: Optional[Set[int]] = None) -> Any: + """Produce a canonical structured representation.""" + model_id = id(model) + + if _visited is not None and model_id in _visited: + type_path = f"{type(model).__module__}.{type(model).__qualname__}" + return ("__cycle__", type_path) + + if _visited is None: + _visited = set() + _visited.add(model_id) + + type_path = f"{type(model).__module__}.{type(model).__qualname__}" + + behavior = None + if self.collector is not None and self.source_tokenizer is not None: + behavior = compute_behavior_token( + type(model), + collector=self.collector, + source_tokenizer=self.source_tokenizer, + ) + + fields = _normalize_model_fields(self, model, _visited) + + # Backtrack so sibling fields of a parent model don't false-positive as cycles + _visited.discard(model_id) + return (type_path, behavior, tuple(fields)) + + def tokenize(self, model: PydanticBaseModel) -> str: + """Produce a hex digest token.""" + return self.hash_canonical(self.normalize(model)) + + +class DaskTokenizer(Tokenizer): + """Tokenizer that delegates to ``dask.base.tokenize`` for backward compatibility. + + Hashes ``model.model_dump(mode="python")`` using dask's tokenization, + matching the legacy ``cache_key()`` behavior. Requires ``dask`` to be + installed (imported lazily). + """ + + def normalize(self, model: PydanticBaseModel, *, _visited: Optional[Set[int]] = None) -> Any: + return model.model_dump(mode="python") + + def tokenize(self, model: PydanticBaseModel) -> str: + import dask.base + + return dask.base.tokenize(model.model_dump(mode="python")) diff --git a/docs/wiki/Tokenization.md b/docs/wiki/Tokenization.md new file mode 100644 index 0000000..b48ac5b --- /dev/null +++ b/docs/wiki/Tokenization.md @@ -0,0 +1,440 @@ +# Tokenization + +- [Overview](#overview) +- [Quick Start](#quick-start) +- [How Tokens Are Computed](#how-tokens-are-computed) +- [Behavior Hashing](#behavior-hashing) +- [Controlling Tokenization](#controlling-tokenization) + - [Customizing Tokenization](#customizing-tokenization) +- [Cache Keys and MemoryCacheEvaluator](#cache-keys-and-memorycacheevaluator) +- [Limitations and Caveats](#limitations-and-caveats) + - [Injecting External State into Tokens](#injecting-external-state-into-tokens) +- [Architecture](#architecture) + +## Overview + +Every ccflow `BaseModel` instance exposes a `model_token` property — a deterministic hex digest that uniquely identifies the model's **data** (field values) and optionally its **behavior** (source code of methods). + +Tokens are used as cache keys by evaluators like `MemoryCacheEvaluator`, and can be used for change detection, deduplication, or audit trails. + +```python +from ccflow import BaseModel + +class MyModel(BaseModel): + x: int = 1 + y: str = "hello" + +m = MyModel() +print(m.model_token) # e.g. "a1b2c3d4..." + +# Same field values → same token +assert MyModel(x=1, y="hello").model_token == m.model_token + +# Different field values → different token +assert MyModel(x=2).model_token != m.model_token +``` + +## Quick Start + +**Tokens include both data and behavior by default:** + +```python +from ccflow import BaseModel + +class Config(BaseModel): + learning_rate: float = 0.01 + epochs: int = 100 + +c1 = Config() +c2 = Config(learning_rate=0.02) +assert c1.model_token != c2.model_token +``` + +Methods defined on the class are automatically included in the token via bytecode hashing: + +```python +class MyPipeline(BaseModel): + x: int = 1 + + def __call__(self, ctx): + return self.x * 2 # changes to this body change the token +``` + +**Data-only tokens (excluding behavior):** + +If you only want field data in the token (e.g. for pure config models with no meaningful methods): + +```python +from ccflow.utils.tokenize import DefaultTokenizer + +class PureConfig(BaseModel): + __ccflow_tokenizer__ = DefaultTokenizer() # data-only, no behavior hashing + x: int = 1 +``` + +## How Tokens Are Computed + +A model token is a SHA-256 digest of the model's **canonical form**, which is a tuple of: + +``` +(fully_qualified_type_name, behavior_token_or_None, ((field1, normalized_value1), ...)) +``` + +### Field Normalization + +Each field value is recursively normalized to a deterministic canonical form: + +| Type | Canonical Form | +| -------------------------------------------------- | --------------------------------------------------------------------- | +| Primitives (`int`, `str`, `bool`, `None`, `float`) | Identity | +| `datetime`, `date`, `time` | ISO format string | +| `enum.Enum` | `("enum", type_path, value)` | +| `list`, `tuple`, `set`, `dict` | Recursively normalized, sets/dicts sorted by `repr` | +| `numpy.ndarray` | `("ndarray", dtype, shape, sha256_of_bytes)` | +| `pandas.DataFrame` | `("dataframe", columns, dtypes, shape, sha256_of_hash_pandas_object)` | +| Frozen `BaseModel` child | `("__child__", child.model_token)` — Merkle tree shortcut | +| Non-frozen `BaseModel` child | Recursively normalized inline | +| Anything else | cloudpickle fallback | + +### Field Exclusion + +Use pydantic's `Field(exclude=True)` to exclude fields from the token: + +```python +from pydantic import Field + +class MyModel(BaseModel): + important: int = 42 + debug_info: str = Field(default="debug", exclude=True) # not in token +``` + +### Token Caching + +Frozen models (`frozen=True`, including `ContextBase`) cache their token automatically — +computed once and never invalidated, since the model cannot be mutated. + +Mutable models recompute `model_token` on every access, guaranteeing the token always +reflects the current state. + +To opt in to caching on a mutable model — for example, when it holds large input data +that is expensive to tokenize and will not be changed after construction — set +`cache_token=True`: + +```python +class LargeInputModel(BaseModel): + model_config = ConfigDict(cache_token=True) + data: list # large payload, expensive to tokenize +``` + +When `cache_token=True` on a mutable model, the cache is cleared whenever a field is +directly reassigned (via `validate_assignment`). However, mutating a **nested** child +in-place (e.g. `parent.child.x = 2`) will **not** invalidate the parent's cached token. +Only use `cache_token=True` when you know the model's content will not change after the +token is first accessed. + +## Behavior Hashing + +By default, `model_token` includes both field data **and** behavior (method bytecode). Two models with the same fields but different `__call__` implementations will have **different tokens**. + +```python +class MyCallable(BaseModel): + x: int = 1 + + def __call__(self, ctx): + return self.x + 1 # hashed into the token automatically +``` + +To disable behavior hashing for a class (data-only tokens), use a plain `DefaultTokenizer()`: + +```python +class DataOnly(BaseModel): + __ccflow_tokenizer__ = DefaultTokenizer() + x: int = 1 +``` + +### AST vs Bytecode + +Two strategies are available for hashing function source code: + +| | `with_ast()` | `with_bytecode()` | +| --------------------------------- | ---------------------------------------------- | ------------------------------------------------------ | +| **How it works** | Parses source → AST → `ast.unparse()` → SHA256 | Hashes `co_code` + `co_consts` → SHA256 | +| **Strips docstrings** | ✅ | ✅ | +| **Strips comments** | ✅ | ✅ (comments aren't in bytecode) | +| **Immune to whitespace** | ✅ | ✅ | +| **Immune to variable renames** | ❌ Different names → different hash | ✅ Names in `co_varnames`, not `co_code` | +| **Works without source** | ❌ Falls back to bytecode | ✅ Always works | +| **Stable across Python versions** | ✅ AST is stable | ⚠️ **`co_code` changes between Python minor versions** | +| **Works in REPL/Jupyter** | ❌ `inspect.getsource()` often fails | ✅ Always available | +| **Performance** | Slower (parse + AST round-trip) | ✅ Order of magnitude faster | + +**Bytecode is the default** when behavior hashing is enabled via `compute_behavior_token()`. It is an order of magnitude faster than AST normalization. Use `with_ast()` if you need cross-version stability. + +### Which Methods Are Hashed + +When behavior hashing is enabled, **all methods defined directly on the class** (in `cls.__dict__`) are included. Inherited methods are NOT included — only methods the class itself defines. + +This includes: + +- Regular methods, `@classmethod`, `@staticmethod` +- Private methods (`_helper`, `__internal`) +- Pydantic validators (`@model_validator`, `@field_validator`) +- `__call__`, `__deps__`, any other dunder you define + +This does **not** include: + +- Methods inherited from parent classes +- Functions imported and called by your methods (no transitive dependency tracking) +- Methods added dynamically at runtime + +### Adding Standalone Dependencies + +If your class calls standalone functions that should affect the token, declare them: + +```python +def my_transform(data): + return data * 2 + +class MyPipeline(BaseModel): + __ccflow_tokenizer_deps__ = [my_transform] + x: int = 1 + + def __call__(self, ctx): + return my_transform(self.x) +``` + +## Controlling Tokenization + +### Customizing Tokenization + +There are three extension points, at different levels: + +| Hook | Scope | Use when | +| ------------------------------- | ----------------- | ----------------------------------------------------------------------------------------------------- | +| `__ccflow_tokenizer__` ClassVar | BaseModel class | You want to change *how models are tokenized* (e.g. disable behavior hashing, use AST mode, use dask) | +| `normalize_token.register(T)` | Any type (global) | You have a custom type that appears as a field value and needs a deterministic canonical form | +| `__ccflow_tokenize__()` method | Any instance | Same as above, but defined on the class itself instead of registered globally | + +The first is a high-level orchestration hook — it selects the tokenizer engine for a model class. The other two are leaf-value hooks that control how individual field values are canonicalized. + +**`__ccflow_tokenizer__`** — select the tokenizer engine for a model class: + +```python +class DataOnly(BaseModel): + __ccflow_tokenizer__ = DefaultTokenizer() # data-only, no behavior hashing + x: int = 1 + +class WithAST(BaseModel): + __ccflow_tokenizer__ = DefaultTokenizer.with_ast() # AST normalization instead of bytecode + x: int = 1 +``` + +**`normalize_token.register()`** — register a global handler for a custom type: + +```python +from ccflow.utils.tokenize import normalize_token + +@normalize_token.register(MyDatabaseConnection) +def _(obj): + return ("db", obj.host, obj.port, obj.database) +``` + +**`__ccflow_tokenize__()`** — define a canonical form on the class itself: + +```python +class MySpecialType: + def __init__(self, data): + self.data = data + + def __ccflow_tokenize__(self): + return ("MySpecialType", self.data.key) +``` + +If both a `normalize_token.register()` handler and `__ccflow_tokenize__()` exist for the same type, the singledispatch handler takes priority. + +### Global Tokenizer Override + +You can change the tokenizer for ALL `BaseModel` subclasses at runtime: + +```python +from ccflow import BaseModel +from ccflow.utils.tokenize import DefaultTokenizer + +# Switch to AST-based behavior hashing globally +BaseModel.__ccflow_tokenizer__ = DefaultTokenizer.with_ast() + +# Or disable behavior hashing globally +BaseModel.__ccflow_tokenizer__ = DefaultTokenizer() +``` + +Subclasses that define their own `__ccflow_tokenizer__` are not affected. + +### Building Custom Tokenizers + +The tokenizer is composed from two pluggable components: + +```python +from ccflow.utils.tokenize import ( + DefaultTokenizer, + OwnMethodCollector, # which functions to hash + ASTSourceTokenizer, # how to hash each function + BytecodeSourceTokenizer, +) + +# Full control over composition +tokenizer = DefaultTokenizer( + collector=OwnMethodCollector(), + source_tokenizer=BytecodeSourceTokenizer(), +) +``` + +Implement `FunctionCollector` or `SourceTokenizer` to create custom strategies. + +## Cache Keys and MemoryCacheEvaluator + +The `MemoryCacheEvaluator` uses `model_token` as the cache key: + +```python +from ccflow import BaseModel, CallableModel, ContextBase, ResultBase +from ccflow.evaluators import MemoryCacheEvaluator + +class MyContext(ContextBase): + date: str = "2024-01-01" + +class MyResult(ResultBase): + value: float = 0.0 + +class MyCallable(CallableModel): + multiplier: float = 1.0 + + def __call__(self, ctx: MyContext) -> MyResult: + return MyResult(value=float(ctx.date.replace("-", "")) * self.multiplier) + +# The cache key is derived from model_token of the evaluation context +# Same context + same callable config → cache hit +cached = MemoryCacheEvaluator() +``` + +For `ModelEvaluationContext` (the wrapper that chains evaluators), the `model_token` implementation is smart about stripping "transparent" evaluator layers (like `LoggingEvaluator`) so that the cache key depends only on the actual computation. + +## Limitations and Caveats + +### Things That Will Produce Different Tokens When They Shouldn't + +These produce **false cache misses** (safe but wasteful): + +| Scenario | Why | Mitigation | +| ------------------------------------------------------ | ------------------------------------------------------------------ | ------------------------------------------------ | +| **Different Python minor version** (bytecode mode) | `co_code` format changes between Python 3.11 → 3.12 | Use `with_ast()` for cross-version stability | +| **Variable rename** (AST mode) | AST preserves variable names: `def f(x)` ≠ `def f(y)` | Use `with_bytecode()` if renames are common | +| **Pydantic injects `model_post_init`** into subclasses | Pydantic adds this to every `__dict__` even if you don't define it | Acceptable — consistent within a class hierarchy | + +### Things That Will Produce the Same Token When They Shouldn't + +These produce **false cache hits** (dangerous — stale results): + +| Scenario | Why | Mitigation | +| ------------------------------------------------------------------- | ------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------- | +| **Upstream code changes** (functions in other modules) | No transitive dependency tracking — only methods on the class itself are hashed | Add critical dependencies to `__ccflow_tokenizer_deps__` | +| **Python package version upgrades** (numpy, pandas, etc.) | Package versions are not part of the hash | Add a version field (see [Injecting External State](#injecting-external-state-into-tokens)) | +| **Data file changes on disk** | File paths hash the same even if contents change | Add a file checksum field (see [Injecting External State](#injecting-external-state-into-tokens)) | +| **Environment variables / config changes** | External state not captured | Add env var fields (see [Injecting External State](#injecting-external-state-into-tokens)) | +| **Database schema or data changes** | Only the query config is hashed, not the data | Use time-based cache invalidation or include a data version field | +| **Git branch / commit changes** (without code changes to the class) | No git integration | Add git hash as a field if needed | + +### Injecting External State into Tokens + +Since `model_token` is computed from field values, you can include external state — package versions, file checksums, environment variables — by adding fields with `default_factory`. Use `repr=False` to keep them out of `__repr__` if desired: + +**Package versions:** + +```python +from pydantic import Field + +class MyPipeline(BaseModel): + x: int = 1 + pandas_version: str = Field( + default_factory=lambda: __import__("pandas").__version__, + repr=False, + ) +``` + +The version is captured once at construction. If pandas is upgraded and the model is re-created, the token changes automatically. + +**Environment variables:** + +```python +import os + +class EnvAwareModel(BaseModel): + x: int = 1 + deploy_env: str = Field( + default_factory=lambda: os.environ.get("DEPLOY_ENV", "dev"), + repr=False, + ) +``` + +**File checksums (using `model_validator`):** + +When the extra data depends on another field (e.g. computing a checksum of a file path), use a `model_validator` instead of `default_factory`: + +```python +import hashlib +from pydantic import Field, model_validator + +class FileProcessor(BaseModel): + input_path: str = "data.csv" + input_checksum: str = Field(default="", repr=False) + + @model_validator(mode="after") + def _compute_checksum(self): + if not self.input_checksum: + try: + with open(self.input_path, "rb") as f: + self.input_checksum = hashlib.sha256(f.read()).hexdigest() + except FileNotFoundError: + self.input_checksum = "file_not_found" + return self +``` + +Now the token changes whenever the file contents change, even if `input_path` stays the same. Users can also override `input_checksum` explicitly for testing. + +> **Why not `@computed_field`?** Computed fields are evaluated lazily — every time the property is accessed. Since `model_token` reads all fields, using `@computed_field` would force evaluation on every token computation, which is wasteful for expensive operations (file I/O, subprocess calls). A regular field with `default_factory` or `model_validator` computes the value once at construction. + +### Other Caveats + +- **Large numpy arrays**: `tobytes()` copies the full array into memory for hashing. For very large arrays, this may be slow. +- **Polars / Arrow**: Work via cloudpickle fallback (no explicit optimized handlers). +- **Cycles**: Handled gracefully — a cycle produces `("__cycle__", type_path)` as a sentinel. +- **Unpicklable objects**: If cloudpickle cannot serialize an object and no `__ccflow_tokenize__()` method or `normalize_token` handler is registered, tokenization raises `TypeError`. Register a custom handler to support such types. + +## Architecture + +The tokenization system has two layers: + +``` +┌─────────────────────────────────────────┐ +│ BaseModel API Layer (ccflow/base.py) │ +│ • model_token property │ +│ • __ccflow_tokenizer__ ClassVar │ +│ • _model_token cache (PrivateAttr) │ +└──────────────┬──────────────────────────┘ + │ delegates to +┌──────────────▼──────────────────────────┐ +│ Tokenizer Engine (utils/tokenize.py) │ +│ │ +│ DefaultTokenizer │ +│ ├── collector: FunctionCollector? │ +│ │ └── OwnMethodCollector │ +│ ├── source_tokenizer: SourceTokenizer? │ +│ │ ├── ASTSourceTokenizer │ +│ │ └── BytecodeSourceTokenizer │ +│ └── normalize_token (singledispatch) │ +│ ├── int, str, float, ... │ +│ ├── numpy.ndarray │ +│ ├── pandas.DataFrame │ +│ └── pydantic.BaseModel │ +└─────────────────────────────────────────┘ +``` + +The engine (`utils/tokenize.py`) has **zero imports from ccflow** — it's a standalone leaf module that can be used independently. diff --git a/docs/wiki/Workflows.md b/docs/wiki/Workflows.md index 3d8c960..526a957 100644 --- a/docs/wiki/Workflows.md +++ b/docs/wiki/Workflows.md @@ -589,13 +589,21 @@ An evaluator is basically another form of callable model, with a few caveats The `ModelEvaluationContext` has fields for the model, the context, the function to evaluate (i.e. `__call__`), and the `FlowOptions`. It too, has a `__call__` method that will evaluate the function on the model with the provided context (but ignoring any options). +Evaluators that do not modify the return value (e.g. logging, caching, timing) should override the `is_transparent` method to return `True`. +This allows `cache_key()` to skip these layers when computing cache keys, so that wrapping a model with different transparent evaluators does not change its cache identity. +Evaluators that transform the result should inherit from `EvaluatorBase` directly and leave `is_transparent` as the default (`False`). + Below we illustrate how to write a really simple evaluator that just prints the options and delegates to the `ModelEvaluationContext` to get the normal result. +Since it does not modify the return value, it overrides `is_transparent` to return `True`. ```python from ccflow import EvaluatorBase, ModelEvaluationContext, ResultType class MyEvaluator(EvaluatorBase): + def is_transparent(self, context: ModelEvaluationContext) -> bool: + return True + def __call__(self, context: ModelEvaluationContext) -> ResultType: print("Custom evaluator with options:", context.options) return context() diff --git a/docs/wiki/_Sidebar.md b/docs/wiki/_Sidebar.md index 6b20b2f..d67c711 100644 --- a/docs/wiki/_Sidebar.md +++ b/docs/wiki/_Sidebar.md @@ -20,6 +20,7 @@ Notes for editors: - [Configuration](Configuration) - [Workflows](Workflows) - [ETL](ETL) +- [Tokenization](Tokenization) **Developer Guide** diff --git a/pyproject.toml b/pyproject.toml index 7d1dd26..21a2f28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,6 @@ classifiers = [ dependencies = [ "cloudpathlib", "cloudpickle", - "dask", "deprecated", "hydra-core", "IPython", @@ -97,6 +96,7 @@ develop = [ "xarray", # Test deps "beautifulsoup4", + "dask", "httpx", "pytest", "pytest-asyncio", @@ -106,6 +106,7 @@ develop = [ ] test = [ "beautifulsoup4", + "dask", "httpx", "pytest", "pytest-asyncio",