Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions ccflow/evaluators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Optional, Set, Union

import dask.base
from pydantic import Field, PrivateAttr, field_validator
from typing_extensions import override

Expand Down Expand Up @@ -59,6 +58,17 @@ def combine_evaluators(first: Optional[EvaluatorBase], second: Optional[Evaluato
return MultiEvaluator(evaluators=[first, second])


def _flatten_cache_key_context(flow_obj: ModelEvaluationContext) -> tuple[ModelEvaluationContext, str, List[EvaluatorBase]]:
fn = flow_obj.fn
non_transparent: List[EvaluatorBase] = []
while isinstance(flow_obj.context, ModelEvaluationContext):
fn = flow_obj.fn if flow_obj.fn != "__call__" else fn
if not isinstance(flow_obj, TransparentModelEvaluationContext):
non_transparent.append(flow_obj.model)
flow_obj = flow_obj.context
return flow_obj, fn if fn != "__call__" else flow_obj.fn, non_transparent


class MultiEvaluator(EvaluatorBase):
"""An evaluator that combines multiple evaluators.

Expand Down Expand Up @@ -224,24 +234,28 @@ def cache_key(flow_obj: Union[ModelEvaluationContext, ContextBase, CallableModel
only on the underlying model, context, fn, options, and any non-transparent
evaluators in the chain.

When the underlying model has callable methods, a behavior token (SHA-256 of
method bytecode) is included so that code changes invalidate the cache.

Args:
flow_obj: The object to be tokenized to form the cache key.
"""
from ..utils.tokenize import compute_cache_token

if isinstance(flow_obj, ModelEvaluationContext):
fn = flow_obj.fn
non_transparent = []
while isinstance(flow_obj.context, ModelEvaluationContext):
fn = flow_obj.fn if flow_obj.fn != "__call__" else fn
if not isinstance(flow_obj, TransparentModelEvaluationContext):
non_transparent.append(flow_obj.model)
flow_obj = flow_obj.context
d = flow_obj.model_dump(mode="python")
d["fn"] = fn if fn != "__call__" else flow_obj.fn
if non_transparent:
d["_evaluators"] = [e.model_dump(mode="python") for e in non_transparent]
return dask.base.tokenize(d).encode("utf-8")
flow_obj, fn, non_transparent = _flatten_cache_key_context(flow_obj)
return compute_cache_token(
data_values=[
{**flow_obj.model_dump(mode="python"), "fn": fn},
*(evaluator.model_dump(mode="python") for evaluator in non_transparent),
],
behavior_classes=[type(flow_obj.model), *(type(evaluator) for evaluator in non_transparent)],
).encode("utf-8")
elif isinstance(flow_obj, (ContextBase, CallableModel)):
return dask.base.tokenize(flow_obj.model_dump(mode="python")).encode("utf-8")
return compute_cache_token(
data_values=[flow_obj.model_dump(mode="python")],
behavior_classes=[type(flow_obj)],
).encode("utf-8")
else:
raise TypeError(f"object of type {type(flow_obj)} cannot be serialized by this function!")

Expand Down
Loading
Loading