Skip to content

Commit 1bc0a86

Browse files
author
Nijat Khanbabayev
committed
Merge branch 'main' into nk/auto_deps_auto_callable_model
Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com>
2 parents 22c86f1 + a16f19b commit 1bc0a86

11 files changed

Lines changed: 763 additions & 252 deletions

File tree

ccflow/callable.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"FlowOptionsDeps",
6464
"FlowOptionsOverride",
6565
"ModelEvaluationContext",
66+
"TransparentModelEvaluationContext",
6667
"EvaluatorBase",
6768
"Evaluator",
6869
"WrapperModel",
@@ -215,6 +216,19 @@ def __deps__(
215216
Implementations should be decorated with Flow.call.
216217
"""
217218

219+
def _evaluation_identity_payload(
220+
self,
221+
context: Any,
222+
child_evaluation_key: Callable[[Any, Any], bytes],
223+
) -> Optional[Any]:
224+
"""Return a narrower identity payload for cache/graph keys when available.
225+
226+
Returning ``None`` keeps the model on the existing structural key path.
227+
This is intentionally narrow and internal: only models whose effective
228+
invocation can be described declaratively should override it.
229+
"""
230+
return None
231+
218232

219233
CallableModelType = TypeVar("CallableModelType", bound=_CallableModel)
220234

@@ -303,7 +317,7 @@ def get_evaluation_context(model: CallableModelType, context: ContextType, as_di
303317
if as_dict:
304318
return dict(model=evaluator, context=evaluation_context)
305319
else:
306-
return ModelEvaluationContext(model=evaluator, context=evaluation_context)
320+
return evaluator.make_evaluation_context(evaluation_context)
307321

308322
# The decorator implementation
309323
def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = None, **kwargs):
@@ -588,11 +602,11 @@ def load_prices(
588602
return flow_model(*args, **kwargs)
589603

590604
@staticmethod
591-
def transform(*args, **kwargs):
592-
"""Decorator that turns a top-level function into a serializable with_inputs() transform factory."""
593-
from .flow_model import flow_transform
605+
def context_transform(*args, **kwargs):
606+
"""Decorator that turns a top-level function into a serializable with_context() transform factory."""
607+
from .flow_model import flow_context_transform
594608

595-
return flow_transform(*args, **kwargs)
609+
return flow_context_transform(*args, **kwargs)
596610

597611

598612
# *****************************************************************************
@@ -698,10 +712,47 @@ def __deps__(self, context: ModelEvaluationContext) -> GraphDepList:
698712
def __exit__(self):
699713
pass
700714

715+
def is_transparent(self, context: ModelEvaluationContext) -> bool:
716+
"""Whether this evaluator does NOT modify the return value for the given context.
717+
718+
Transparent evaluators may add side effects (logging, caching, timing,
719+
dependency ordering) but always return the same value as ``context()``.
720+
This allows cache key computation and dependency graph deduplication to
721+
skip these layers.
722+
723+
Override this method to return ``True`` for evaluators that are always
724+
transparent, or implement context-dependent logic for evaluators that
725+
are only sometimes transparent.
726+
"""
727+
return False
728+
729+
def make_evaluation_context(self, context: ModelEvaluationContext, **kwargs) -> ModelEvaluationContext:
730+
"""Create a ModelEvaluationContext wrapping this evaluator around the given context.
731+
732+
Returns a ``TransparentModelEvaluationContext`` when ``is_transparent(context)``
733+
is ``True``, signaling that this layer can be skipped for cache key computation.
734+
"""
735+
if self.is_transparent(context):
736+
return TransparentModelEvaluationContext(model=self, context=context, **kwargs)
737+
return ModelEvaluationContext(model=self, context=context, **kwargs)
738+
739+
740+
class TransparentModelEvaluationContext(ModelEvaluationContext):
741+
"""A ModelEvaluationContext layer that is safe to skip for cache key computation.
742+
743+
Created by ``EvaluatorBase.make_evaluation_context()`` when the evaluator's
744+
``is_transparent()`` returns ``True``. Signals that this evaluator layer does
745+
not modify the return value and can be ignored when computing cache keys or
746+
deduplicating dependency graph nodes.
747+
"""
748+
701749

702750
class Evaluator(EvaluatorBase):
703751
"""A higher-order model that evaluates a function on a CallableModel and a Context."""
704752

753+
def is_transparent(self, context: ModelEvaluationContext) -> bool:
754+
return True
755+
705756
@override
706757
def __call__(self, context: ModelEvaluationContext) -> ResultType:
707758
return context()

ccflow/evaluators/common.py

Lines changed: 142 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@
88
from typing import Any, Callable, Dict, List, Optional, Set, Union
99

1010
import dask.base
11-
from pydantic import Field, PrivateAttr, field_validator
11+
from pydantic import Field, PrivateAttr, ValidationError, field_validator
1212
from typing_extensions import override
1313

1414
from ..base import BaseModel, make_lazy_result
15-
from ..callable import CallableModel, ContextBase, EvaluatorBase, ModelEvaluationContext, ResultType
15+
from ..callable import (
16+
CallableModel,
17+
ContextBase,
18+
EvaluatorBase,
19+
ModelEvaluationContext,
20+
ResultType,
21+
TransparentModelEvaluationContext,
22+
)
1623

1724
__all__ = [
1825
"cache_key",
@@ -30,6 +37,13 @@
3037
log = logging.getLogger(__name__)
3138

3239

40+
class _EvaluationIdentityFallback(Exception):
41+
"""Internal signal to stay on the structural evaluation-key path."""
42+
43+
44+
_EXPECTED_IDENTITY_FAILURES = (TypeError, ValueError, ValidationError)
45+
46+
3347
def combine_evaluators(first: Optional[EvaluatorBase], second: Optional[EvaluatorBase]) -> EvaluatorBase:
3448
"""Helper function to combine evaluators into a new evaluator.
3549
@@ -53,16 +67,25 @@ def combine_evaluators(first: Optional[EvaluatorBase], second: Optional[Evaluato
5367

5468

5569
class MultiEvaluator(EvaluatorBase):
56-
"""An evaluator that combines multiple evaluators."""
70+
"""An evaluator that combines multiple evaluators.
71+
72+
Each child evaluator is wrapped in a ModelEvaluationContext using its own
73+
``make_evaluation_context()`` method, so transparent children produce
74+
``TransparentModelEvaluationContext`` layers that can be skipped during
75+
cache key computation.
76+
"""
5777

5878
evaluators: List[EvaluatorBase] = Field(
5979
description="The list of evaluators to combine. The first evaluator in the list will be called first during evaluation."
6080
)
6181

82+
def is_transparent(self, context: ModelEvaluationContext) -> bool:
83+
return all(e.is_transparent(context) for e in self.evaluators)
84+
6285
@override
6386
def __call__(self, context: ModelEvaluationContext) -> ResultType:
6487
for evaluator in self.evaluators:
65-
context = ModelEvaluationContext(model=evaluator, context=context, options=context.options)
88+
context = evaluator.make_evaluation_context(context, options=context.options)
6689
return context()
6790

6891

@@ -71,6 +94,9 @@ class FallbackEvaluator(EvaluatorBase):
7194

7295
evaluators: List[EvaluatorBase] = Field(description="The list of evaluators to try (in order).")
7396

97+
def is_transparent(self, context: ModelEvaluationContext) -> bool:
98+
return all(e.is_transparent(context) for e in self.evaluators)
99+
74100
@override
75101
def __call__(self, context: ModelEvaluationContext) -> ResultType:
76102
for evaluator in self.evaluators:
@@ -120,6 +146,9 @@ class LoggingEvaluator(EvaluatorBase):
120146
log_result: bool = Field(False, description="Whether to log the result of the evaluation")
121147
format_config: FormatConfig = Field(FormatConfig(), description="Configuration for formatting the result of the evaluation if log_result=True")
122148

149+
def is_transparent(self, context: ModelEvaluationContext) -> bool:
150+
return True
151+
123152
@field_validator("log_level", mode="before")
124153
@classmethod
125154
def _validate_log_level(cls, v: Union[int, str]) -> int:
@@ -194,13 +223,105 @@ def _format_result(self, result: ResultType) -> str:
194223
return f"{msg_str}{pformat(result_dict, **self.format_config.pformat_config)}"
195224

196225

226+
def _unwrap_evaluation_context(flow_obj: ModelEvaluationContext) -> tuple[ModelEvaluationContext, str, List[CallableModel]]:
227+
fn = flow_obj.fn
228+
non_transparent = []
229+
while isinstance(flow_obj.context, ModelEvaluationContext):
230+
fn = flow_obj.fn if flow_obj.fn != "__call__" else fn
231+
if not isinstance(flow_obj, TransparentModelEvaluationContext):
232+
non_transparent.append(flow_obj.model)
233+
flow_obj = flow_obj.context
234+
return flow_obj, fn if fn != "__call__" else flow_obj.fn, non_transparent
235+
236+
237+
def _structural_evaluation_key(flow_obj: ModelEvaluationContext) -> bytes:
238+
flow_obj, fn, non_transparent = _unwrap_evaluation_context(flow_obj)
239+
d = flow_obj.model_dump(mode="python")
240+
d["fn"] = fn
241+
if non_transparent:
242+
d["_evaluators"] = [e.model_dump(mode="python") for e in non_transparent]
243+
return dask.base.tokenize(d).encode("utf-8")
244+
245+
246+
class _EvaluationKeyBuilder:
247+
def __init__(self) -> None:
248+
self._memo: Dict[tuple[int, str], bytes] = {}
249+
self._active: set[tuple[int, str]] = set()
250+
251+
def build(self, context: ModelEvaluationContext) -> bytes:
252+
try:
253+
return self._build(context)
254+
except _EvaluationIdentityFallback:
255+
return _structural_evaluation_key(context)
256+
257+
def _build(self, context: ModelEvaluationContext) -> bytes:
258+
inner, fn, non_transparent = _unwrap_evaluation_context(context)
259+
if fn != "__call__":
260+
raise _EvaluationIdentityFallback("Only __call__ evaluations support narrowed identity.")
261+
if non_transparent:
262+
raise _EvaluationIdentityFallback("Non-transparent evaluator layers stay on the structural key path.")
263+
if inner.options.get("validate_result", True) is False:
264+
raise _EvaluationIdentityFallback("validate_result=False stays on the structural key path.")
265+
return self._key_for_model(inner.model, inner.context)
266+
267+
def _key_for_model(self, model: CallableModel, context: Any) -> bytes:
268+
memo_token = self._memo_token(model, context)
269+
cached = self._memo.get(memo_token)
270+
if cached is not None:
271+
return cached
272+
if memo_token in self._active:
273+
raise _EvaluationIdentityFallback("Recursive cycle detected while deriving evaluation identity.")
274+
275+
self._active.add(memo_token)
276+
try:
277+
payload = self._identity_payload(model, context)
278+
if payload is None:
279+
raise _EvaluationIdentityFallback("Model did not provide a narrowed identity payload.")
280+
key = dask.base.tokenize(("ccflow_evaluation_identity_v1", payload)).encode("utf-8")
281+
self._memo[memo_token] = key
282+
return key
283+
finally:
284+
self._active.discard(memo_token)
285+
286+
def _identity_payload(self, model: CallableModel, context: Any) -> Optional[Any]:
287+
try:
288+
return model._evaluation_identity_payload(context, self._child_evaluation_key)
289+
except _EXPECTED_IDENTITY_FAILURES as exc:
290+
raise _EvaluationIdentityFallback(str(exc)) from exc
291+
292+
def _child_evaluation_key(self, model: CallableModel, context: Any) -> bytes:
293+
try:
294+
evaluation = model.__call__.get_evaluation_context(model, context)
295+
except _EXPECTED_IDENTITY_FAILURES as exc:
296+
raise _EvaluationIdentityFallback(str(exc)) from exc
297+
return self.build(evaluation)
298+
299+
def _memo_token(self, model: CallableModel, context: Any) -> tuple[int, str]:
300+
if hasattr(context, "model_dump"):
301+
context_value = context.model_dump(mode="python")
302+
else:
303+
context_value = context
304+
return (id(model), dask.base.tokenize((type(context), context_value)))
305+
306+
307+
def _evaluation_key(flow_obj: ModelEvaluationContext) -> bytes:
308+
return _EvaluationKeyBuilder().build(flow_obj)
309+
310+
197311
def cache_key(flow_obj: Union[ModelEvaluationContext, ContextBase, CallableModel]) -> bytes:
198-
"""Returns a key suitable for use in caching.
312+
"""Returns a structural key suitable for caching and dependency graph deduplication.
313+
314+
For ``ModelEvaluationContext`` inputs, strips ``TransparentModelEvaluationContext``
315+
layers (evaluators that don't modify the return value) so that the key depends
316+
only on the underlying model, context, fn, options, and any non-transparent
317+
evaluators in the chain.
199318
200319
Args:
201320
flow_obj: The object to be tokenized to form the cache key.
202321
"""
203-
if isinstance(flow_obj, (ModelEvaluationContext, ContextBase, CallableModel)):
322+
if isinstance(flow_obj, ModelEvaluationContext):
323+
return _structural_evaluation_key(flow_obj)
324+
elif isinstance(flow_obj, (ContextBase, CallableModel)):
204325
return dask.base.tokenize(flow_obj.model_dump(mode="python")).encode("utf-8")
205326
else:
206327
raise TypeError(f"object of type {type(flow_obj)} cannot be serialized by this function!")
@@ -213,9 +334,17 @@ class MemoryCacheEvaluator(EvaluatorBase):
213334
_cache: Dict[bytes, ResultType] = PrivateAttr({})
214335
_ids: Dict[bytes, ModelEvaluationContext] = PrivateAttr({})
215336

337+
def is_transparent(self, context: ModelEvaluationContext) -> bool:
338+
return True
339+
216340
def key(self, context: ModelEvaluationContext):
217-
"""Function to convert a ModelEvaluationContext to a key"""
218-
return cache_key(context)
341+
"""Function to convert a ModelEvaluationContext to a cache key.
342+
343+
Delegates to the shared evaluation-key builder, which narrows generated
344+
``@Flow.model`` identities when safe and otherwise falls back to
345+
``cache_key()`` semantics.
346+
"""
347+
return _evaluation_key(context)
219348

220349
@property
221350
def cache(self):
@@ -252,7 +381,7 @@ class CallableModelGraph(BaseModel):
252381

253382

254383
def _build_dependency_graph(evaluation_context: ModelEvaluationContext, graph: CallableModelGraph, parent_key: Optional[bytes] = None):
255-
key = cache_key(evaluation_context)
384+
key = _evaluation_key(evaluation_context)
256385
if parent_key:
257386
graph.graph[parent_key].add(key)
258387
if key not in graph.ids:
@@ -275,7 +404,7 @@ def get_dependency_graph(evaluation_context: ModelEvaluationContext) -> Callable
275404
Args:
276405
evaluation_context: The model and context to build the graph for.
277406
"""
278-
root_key = cache_key(evaluation_context)
407+
root_key = _evaluation_key(evaluation_context)
279408
graph = CallableModelGraph(ids={}, graph={}, root_id=root_key)
280409
_build_dependency_graph(evaluation_context, graph)
281410
return graph
@@ -289,6 +418,9 @@ class GraphEvaluator(EvaluatorBase):
289418

290419
_is_evaluating: bool = PrivateAttr(False)
291420

421+
def is_transparent(self, context: ModelEvaluationContext) -> bool:
422+
return True
423+
292424
@override
293425
def __call__(self, context: ModelEvaluationContext) -> ResultType:
294426
import graphlib

0 commit comments

Comments
 (0)