Skip to content

Commit 7bccbeb

Browse files
committed
feat: native tokenization engine for BaseModel
Replace dask-based tokenization with a standalone, composable engine. - model_token property on BaseModel with automatic cache invalidation - DefaultTokenizer with pluggable SourceTokenizer × FunctionCollector - AST and bytecode source hashing (bytecode default) - Singledispatch normalize_token with handlers for numpy, pandas, etc. - __ccflow_tokenizer__ ClassVar to swap tokenizer per class or globally - __ccflow_tokenizer_deps__ ClassVar for standalone function dependencies - Simplified cache_key() to use model_token directly - Removed dask dependency from tokenization path - Comprehensive tests (185 cases) and wiki documentation Signed-off-by: Pascal Tomecek <pascal.tomecek@cubistsystematic.com>
1 parent 5adf933 commit 7bccbeb

11 files changed

Lines changed: 3249 additions & 24 deletions

File tree

.github/workflows/build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
- '3.11'
3636
dependencies:
3737
- ''
38-
- '"pandas<2" "numpy<2" "xarray<2025.09.0" "dask<2024.7.0"'
38+
- '"pandas<2" "numpy<2" "xarray<2025.09.0"'
3939
- '"pandas<3"'
4040
- '"pandas<4"'
4141

ccflow/base.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from .exttypes.pyobjectpath import PyObjectPath
3333
from .local_persistence import register_ccflow_import_path, sync_to_module
34+
from .utils.tokenize import DefaultTokenizer, Tokenizer, normalize_token
3435

3536
log = logging.getLogger(__name__)
3637

@@ -195,6 +196,41 @@ def type_(self) -> PyObjectPath:
195196
# We want to track under what names a model has been registered
196197
_registrations: List[Tuple["ModelRegistry", str]] = PrivateAttr(default_factory=list)
197198

199+
# Tokenization support
200+
__ccflow_tokenizer__: ClassVar[Tokenizer] = DefaultTokenizer.with_bytecode()
201+
_model_token: Optional[str] = PrivateAttr(default=None)
202+
203+
@property
204+
def model_token(self) -> str:
205+
"""Return a deterministic content hash of this model.
206+
207+
The token is cached by default (controlled by ``cache_token`` in model_config).
208+
For frozen models, the token is computed once and never recomputed.
209+
For mutable models, the cache is cleared on field assignment (via ``validate_assignment``).
210+
Set ``cache_token=False`` in model_config to always compute fresh.
211+
"""
212+
cache = self.model_config.get("cache_token", True)
213+
if cache and self._model_token is not None:
214+
return self._model_token
215+
token = self.__ccflow_tokenizer__.tokenize(self)
216+
if cache:
217+
self.__pydantic_private__["_model_token"] = token
218+
return token
219+
220+
@model_validator(mode="after")
221+
def _clear_token_cache(self):
222+
"""Clear the cached token on construction and field assignment."""
223+
if self.model_config.get("cache_token", True):
224+
self.__pydantic_private__["_model_token"] = None
225+
return self
226+
227+
def model_copy(self, *, update=None, deep=False):
228+
"""Override model_copy to clear the stale token cache on the copy."""
229+
copy = super().model_copy(update=update, deep=deep)
230+
if update and copy.__pydantic_private__ is not None:
231+
copy.__pydantic_private__["_model_token"] = None
232+
return copy
233+
198234
model_config = ConfigDict(
199235
# Note that validate_assignment only partially works: https://github.com/pydantic/pydantic/issues/7105
200236
validate_assignment=True,
@@ -316,6 +352,18 @@ def __getstate__(self):
316352
def __setstate__(self, state):
317353
state["__pydantic_fields_set__"] = set(state["__pydantic_fields_set__"])
318354
super().__setstate__(state)
355+
# Clear stale token cache from pickle
356+
if self.__pydantic_private__ is not None and "_model_token" in self.__pydantic_private__:
357+
self.__pydantic_private__["_model_token"] = None
358+
359+
360+
# Register ccflow BaseModel-specific normalize_token handler
361+
# Delegates to the model's tokenizer so normalization is consistent
362+
# regardless of whether the model is accessed via model_token or
363+
# encountered as a value inside a container.
364+
@normalize_token.register(BaseModel)
365+
def _normalize_ccflow_basemodel(obj):
366+
return obj.__ccflow_tokenizer__.normalize(obj)
319367

320368

321369
class _ModelRegistryData(PydanticBaseModel):

ccflow/callable.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,40 @@ class ModelEvaluationContext(
451451
# Otherwise, the validation will re-run fully despite the models already being validated on construction
452452
# TODO: Make the instance check compatible with the generic types instead of the base type
453453

454+
@property
455+
def model_token(self) -> str:
456+
"""Compute a cache-key token for this MEC chain.
457+
458+
Walks the MEC chain, strips ``TransparentModelEvaluationContext``
459+
layers, and tokenizes the innermost context plus any opaque evaluators.
460+
"""
461+
cache = self.model_config.get("cache_token", True)
462+
if cache and self._model_token is not None:
463+
return self._model_token
464+
465+
fn = self.fn
466+
non_transparent = []
467+
current = self
468+
while isinstance(current.context, ModelEvaluationContext):
469+
fn = current.fn if current.fn != "__call__" else fn
470+
if not isinstance(current, TransparentModelEvaluationContext):
471+
non_transparent.append(current.model)
472+
current = current.context
473+
474+
# Build a canonical representation from the innermost MEC
475+
from .utils.tokenize import normalize_token
476+
477+
inner_norm = normalize_token(current)
478+
effective_fn = fn if fn != "__call__" else current.fn
479+
parts = (inner_norm, effective_fn)
480+
if non_transparent:
481+
parts = parts + (tuple(normalize_token(e) for e in non_transparent),)
482+
token = self.__ccflow_tokenizer__.hash_canonical(parts)
483+
484+
if cache:
485+
self.__pydantic_private__["_model_token"] = token
486+
return token
487+
454488
@model_validator(mode="wrap")
455489
def _context_validator(cls, values, handler, info):
456490
"""Override _context_validator from parent"""

ccflow/evaluators/common.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from types import MappingProxyType
88
from typing import Any, Callable, Dict, List, Optional, Set, Union
99

10-
import dask.base
1110
from pydantic import Field, PrivateAttr, field_validator
1211
from typing_extensions import override
1312

@@ -18,7 +17,6 @@
1817
EvaluatorBase,
1918
ModelEvaluationContext,
2019
ResultType,
21-
TransparentModelEvaluationContext,
2220
)
2321

2422
__all__ = [
@@ -227,21 +225,8 @@ def cache_key(flow_obj: Union[ModelEvaluationContext, ContextBase, CallableModel
227225
Args:
228226
flow_obj: The object to be tokenized to form the cache key.
229227
"""
230-
if isinstance(flow_obj, ModelEvaluationContext):
231-
fn = flow_obj.fn
232-
non_transparent = []
233-
while isinstance(flow_obj.context, ModelEvaluationContext):
234-
fn = flow_obj.fn if flow_obj.fn != "__call__" else fn
235-
if not isinstance(flow_obj, TransparentModelEvaluationContext):
236-
non_transparent.append(flow_obj.model)
237-
flow_obj = flow_obj.context
238-
d = flow_obj.model_dump(mode="python")
239-
d["fn"] = fn if fn != "__call__" else flow_obj.fn
240-
if non_transparent:
241-
d["_evaluators"] = [e.model_dump(mode="python") for e in non_transparent]
242-
return dask.base.tokenize(d).encode("utf-8")
243-
elif isinstance(flow_obj, (ContextBase, CallableModel)):
244-
return dask.base.tokenize(flow_obj.model_dump(mode="python")).encode("utf-8")
228+
if isinstance(flow_obj, (ModelEvaluationContext, ContextBase, CallableModel)):
229+
return flow_obj.model_token.encode("utf-8")
245230
else:
246231
raise TypeError(f"object of type {type(flow_obj)} cannot be serialized by this function!")
247232

ccflow/tests/test_base_serialize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,13 @@ def test_pickle_consistency(self):
259259
# (as it would normally in pydantic because of https://github.com/pydantic/pydantic/issues/11603)
260260
# This is generated on Linux/Python 3.11 - might need to have version specific values if it changes.
261261
target = (
262-
b"\x80\x04\x95\xdf\x00\x00\x00\x00\x00\x00\x00\x8c ccflow.tests.test_base_seri"
262+
b"\x80\x04\x95\xf0\x00\x00\x00\x00\x00\x00\x00\x8c ccflow.tests.test_base_seri"
263263
b"alize\x94\x8c\x13MultiAttributeModel\x94\x93\x94)\x81\x94}\x94(\x8c\x08__"
264264
b"dict__\x94}\x94(\x8c\x01z\x94K\x01\x8c\x01y\x94\x8c\x04test\x94\x8c"
265265
b"\x01x\x94G@\t\x1e\xb8Q\xeb\x85\x1f\x8c\x01w\x94\x88u\x8c\x12__pydantic_extra"
266266
b"__\x94N\x8c\x17__pydantic_fields_set__\x94]\x94(h\x0bh\nh\x08h\x07e\x8c\x14"
267-
b"__pydantic_private__\x94}\x94\x8c\x0e_registrations\x94]\x94sub."
267+
b"__pydantic_private__\x94}\x94(\x8c\x0e_registrations\x94]\x94\x8c\x0c_model_"
268+
b"token\x94Nuub."
268269
)
269270
self.assertEqual(serialized, target)
270271
deserialized = pickle.loads(serialized)

0 commit comments

Comments
 (0)