88from typing import Any , Callable , Dict , List , Optional , Set , Union
99
1010import dask .base
11- from pydantic import Field , PrivateAttr , field_validator
11+ from pydantic import Field , PrivateAttr , ValidationError , field_validator
1212from typing_extensions import override
1313
1414from ..base import BaseModel , make_lazy_result
15- from ..callable import CallableModel , ContextBase , EvaluatorBase , ModelEvaluationContext , ResultType
15+ from ..callable import (
16+ CallableModel ,
17+ ContextBase ,
18+ EvaluatorBase ,
19+ ModelEvaluationContext ,
20+ ResultType ,
21+ TransparentModelEvaluationContext ,
22+ )
1623
1724__all__ = [
1825 "cache_key" ,
3037log = logging .getLogger (__name__ )
3138
3239
40+ class _EvaluationIdentityFallback (Exception ):
41+ """Internal signal to stay on the structural evaluation-key path."""
42+
43+
44+ _EXPECTED_IDENTITY_FAILURES = (TypeError , ValueError , ValidationError )
45+
46+
3347def combine_evaluators (first : Optional [EvaluatorBase ], second : Optional [EvaluatorBase ]) -> EvaluatorBase :
3448 """Helper function to combine evaluators into a new evaluator.
3549
@@ -53,16 +67,25 @@ def combine_evaluators(first: Optional[EvaluatorBase], second: Optional[Evaluato
5367
5468
5569class MultiEvaluator (EvaluatorBase ):
56- """An evaluator that combines multiple evaluators."""
70+ """An evaluator that combines multiple evaluators.
71+
72+ Each child evaluator is wrapped in a ModelEvaluationContext using its own
73+ ``make_evaluation_context()`` method, so transparent children produce
74+ ``TransparentModelEvaluationContext`` layers that can be skipped during
75+ cache key computation.
76+ """
5777
5878 evaluators : List [EvaluatorBase ] = Field (
5979 description = "The list of evaluators to combine. The first evaluator in the list will be called first during evaluation."
6080 )
6181
82+ def is_transparent (self , context : ModelEvaluationContext ) -> bool :
83+ return all (e .is_transparent (context ) for e in self .evaluators )
84+
6285 @override
6386 def __call__ (self , context : ModelEvaluationContext ) -> ResultType :
6487 for evaluator in self .evaluators :
65- context = ModelEvaluationContext ( model = evaluator , context = context , options = context .options )
88+ context = evaluator . make_evaluation_context ( context , options = context .options )
6689 return context ()
6790
6891
@@ -71,6 +94,9 @@ class FallbackEvaluator(EvaluatorBase):
7194
7295 evaluators : List [EvaluatorBase ] = Field (description = "The list of evaluators to try (in order)." )
7396
97+ def is_transparent (self , context : ModelEvaluationContext ) -> bool :
98+ return all (e .is_transparent (context ) for e in self .evaluators )
99+
74100 @override
75101 def __call__ (self , context : ModelEvaluationContext ) -> ResultType :
76102 for evaluator in self .evaluators :
@@ -120,6 +146,9 @@ class LoggingEvaluator(EvaluatorBase):
120146 log_result : bool = Field (False , description = "Whether to log the result of the evaluation" )
121147 format_config : FormatConfig = Field (FormatConfig (), description = "Configuration for formatting the result of the evaluation if log_result=True" )
122148
149+ def is_transparent (self , context : ModelEvaluationContext ) -> bool :
150+ return True
151+
123152 @field_validator ("log_level" , mode = "before" )
124153 @classmethod
125154 def _validate_log_level (cls , v : Union [int , str ]) -> int :
@@ -194,13 +223,105 @@ def _format_result(self, result: ResultType) -> str:
194223 return f"{ msg_str } { pformat (result_dict , ** self .format_config .pformat_config )} "
195224
196225
226+ def _unwrap_evaluation_context (flow_obj : ModelEvaluationContext ) -> tuple [ModelEvaluationContext , str , List [CallableModel ]]:
227+ fn = flow_obj .fn
228+ non_transparent = []
229+ while isinstance (flow_obj .context , ModelEvaluationContext ):
230+ fn = flow_obj .fn if flow_obj .fn != "__call__" else fn
231+ if not isinstance (flow_obj , TransparentModelEvaluationContext ):
232+ non_transparent .append (flow_obj .model )
233+ flow_obj = flow_obj .context
234+ return flow_obj , fn if fn != "__call__" else flow_obj .fn , non_transparent
235+
236+
237+ def _structural_evaluation_key (flow_obj : ModelEvaluationContext ) -> bytes :
238+ flow_obj , fn , non_transparent = _unwrap_evaluation_context (flow_obj )
239+ d = flow_obj .model_dump (mode = "python" )
240+ d ["fn" ] = fn
241+ if non_transparent :
242+ d ["_evaluators" ] = [e .model_dump (mode = "python" ) for e in non_transparent ]
243+ return dask .base .tokenize (d ).encode ("utf-8" )
244+
245+
246+ class _EvaluationKeyBuilder :
247+ def __init__ (self ) -> None :
248+ self ._memo : Dict [tuple [int , str ], bytes ] = {}
249+ self ._active : set [tuple [int , str ]] = set ()
250+
251+ def build (self , context : ModelEvaluationContext ) -> bytes :
252+ try :
253+ return self ._build (context )
254+ except _EvaluationIdentityFallback :
255+ return _structural_evaluation_key (context )
256+
257+ def _build (self , context : ModelEvaluationContext ) -> bytes :
258+ inner , fn , non_transparent = _unwrap_evaluation_context (context )
259+ if fn != "__call__" :
260+ raise _EvaluationIdentityFallback ("Only __call__ evaluations support narrowed identity." )
261+ if non_transparent :
262+ raise _EvaluationIdentityFallback ("Non-transparent evaluator layers stay on the structural key path." )
263+ if inner .options .get ("validate_result" , True ) is False :
264+ raise _EvaluationIdentityFallback ("validate_result=False stays on the structural key path." )
265+ return self ._key_for_model (inner .model , inner .context )
266+
267+ def _key_for_model (self , model : CallableModel , context : Any ) -> bytes :
268+ memo_token = self ._memo_token (model , context )
269+ cached = self ._memo .get (memo_token )
270+ if cached is not None :
271+ return cached
272+ if memo_token in self ._active :
273+ raise _EvaluationIdentityFallback ("Recursive cycle detected while deriving evaluation identity." )
274+
275+ self ._active .add (memo_token )
276+ try :
277+ payload = self ._identity_payload (model , context )
278+ if payload is None :
279+ raise _EvaluationIdentityFallback ("Model did not provide a narrowed identity payload." )
280+ key = dask .base .tokenize (("ccflow_evaluation_identity_v1" , payload )).encode ("utf-8" )
281+ self ._memo [memo_token ] = key
282+ return key
283+ finally :
284+ self ._active .discard (memo_token )
285+
286+ def _identity_payload (self , model : CallableModel , context : Any ) -> Optional [Any ]:
287+ try :
288+ return model ._evaluation_identity_payload (context , self ._child_evaluation_key )
289+ except _EXPECTED_IDENTITY_FAILURES as exc :
290+ raise _EvaluationIdentityFallback (str (exc )) from exc
291+
292+ def _child_evaluation_key (self , model : CallableModel , context : Any ) -> bytes :
293+ try :
294+ evaluation = model .__call__ .get_evaluation_context (model , context )
295+ except _EXPECTED_IDENTITY_FAILURES as exc :
296+ raise _EvaluationIdentityFallback (str (exc )) from exc
297+ return self .build (evaluation )
298+
299+ def _memo_token (self , model : CallableModel , context : Any ) -> tuple [int , str ]:
300+ if hasattr (context , "model_dump" ):
301+ context_value = context .model_dump (mode = "python" )
302+ else :
303+ context_value = context
304+ return (id (model ), dask .base .tokenize ((type (context ), context_value )))
305+
306+
307+ def _evaluation_key (flow_obj : ModelEvaluationContext ) -> bytes :
308+ return _EvaluationKeyBuilder ().build (flow_obj )
309+
310+
197311def cache_key (flow_obj : Union [ModelEvaluationContext , ContextBase , CallableModel ]) -> bytes :
198- """Returns a key suitable for use in caching.
312+ """Returns a structural key suitable for caching and dependency graph deduplication.
313+
314+ For ``ModelEvaluationContext`` inputs, strips ``TransparentModelEvaluationContext``
315+ layers (evaluators that don't modify the return value) so that the key depends
316+ only on the underlying model, context, fn, options, and any non-transparent
317+ evaluators in the chain.
199318
200319 Args:
201320 flow_obj: The object to be tokenized to form the cache key.
202321 """
203- if isinstance (flow_obj , (ModelEvaluationContext , ContextBase , CallableModel )):
322+ if isinstance (flow_obj , ModelEvaluationContext ):
323+ return _structural_evaluation_key (flow_obj )
324+ elif isinstance (flow_obj , (ContextBase , CallableModel )):
204325 return dask .base .tokenize (flow_obj .model_dump (mode = "python" )).encode ("utf-8" )
205326 else :
206327 raise TypeError (f"object of type { type (flow_obj )} cannot be serialized by this function!" )
@@ -213,9 +334,17 @@ class MemoryCacheEvaluator(EvaluatorBase):
213334 _cache : Dict [bytes , ResultType ] = PrivateAttr ({})
214335 _ids : Dict [bytes , ModelEvaluationContext ] = PrivateAttr ({})
215336
337+ def is_transparent (self , context : ModelEvaluationContext ) -> bool :
338+ return True
339+
216340 def key (self , context : ModelEvaluationContext ):
217- """Function to convert a ModelEvaluationContext to a key"""
218- return cache_key (context )
341+ """Function to convert a ModelEvaluationContext to a cache key.
342+
343+ Delegates to the shared evaluation-key builder, which narrows generated
344+ ``@Flow.model`` identities when safe and otherwise falls back to
345+ ``cache_key()`` semantics.
346+ """
347+ return _evaluation_key (context )
219348
220349 @property
221350 def cache (self ):
@@ -252,7 +381,7 @@ class CallableModelGraph(BaseModel):
252381
253382
254383def _build_dependency_graph (evaluation_context : ModelEvaluationContext , graph : CallableModelGraph , parent_key : Optional [bytes ] = None ):
255- key = cache_key (evaluation_context )
384+ key = _evaluation_key (evaluation_context )
256385 if parent_key :
257386 graph .graph [parent_key ].add (key )
258387 if key not in graph .ids :
@@ -275,7 +404,7 @@ def get_dependency_graph(evaluation_context: ModelEvaluationContext) -> Callable
275404 Args:
276405 evaluation_context: The model and context to build the graph for.
277406 """
278- root_key = cache_key (evaluation_context )
407+ root_key = _evaluation_key (evaluation_context )
279408 graph = CallableModelGraph (ids = {}, graph = {}, root_id = root_key )
280409 _build_dependency_graph (evaluation_context , graph )
281410 return graph
@@ -289,6 +418,9 @@ class GraphEvaluator(EvaluatorBase):
289418
290419 _is_evaluating : bool = PrivateAttr (False )
291420
421+ def is_transparent (self , context : ModelEvaluationContext ) -> bool :
422+ return True
423+
292424 @override
293425 def __call__ (self , context : ModelEvaluationContext ) -> ResultType :
294426 import graphlib
0 commit comments