@@ -155,7 +155,7 @@ def __init__(
155155 self ._freeze_prg_cache : dict [pt .DictOfNamedArrays , lp .TranslationUnit ] = {}
156156 self ._dag_transform_cache : dict [
157157 pt .DictOfNamedArrays ,
158- tuple [pt .DictOfNamedArrays , str ]] = {}
158+ tuple [pt .AbstractResultWithNamedArrays , str ]] = {}
159159
160160 if compile_trace_callback is None :
161161 def _compile_trace_callback (what , stage , ir ):
@@ -177,8 +177,8 @@ def _frozen_array_types(self) -> tuple[type, ...]:
177177
178178 # {{{ compilation
179179
180- def transform_dag (self , dag : pytato .DictOfNamedArrays
181- ) -> pytato .DictOfNamedArrays :
180+ def transform_dag (self , dag : pytato .AbstractResultWithNamedArrays
181+ ) -> pytato .AbstractResultWithNamedArrays :
182182 """
183183 Returns a transformed version of *dag*. Sub-classes are supposed to
184184 override this method to implement context-specific transformations on
@@ -593,18 +593,19 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
593593 rec_keyed_map_array_container (_to_frozen , array ),
594594 actx = None )
595595
596- pt_dict_of_named_arrays = pt .make_dict_of_named_arrays (
596+ dag = pt .make_dict_of_named_arrays (
597597 key_to_pt_arrays )
598598
599- pt_dict_of_named_arrays = pt .deduplicate (pt_dict_of_named_arrays )
599+ from pytato .transform import Deduplicator
600+ dag = Deduplicator ()(dag )
600601
601602 # FIXME: Remove this if/when _normalize_pt_expr gets support for functions
602- pt_dict_of_named_arrays = pt .tag_all_calls_to_be_inlined (
603- pt_dict_of_named_arrays )
604- pt_dict_of_named_arrays = pt .inline_calls (pt_dict_of_named_arrays )
603+ dag = pt .tag_all_calls_to_be_inlined (
604+ dag )
605+ dag = pt .inline_calls (dag )
605606
606607 normalized_expr , bound_arguments = _normalize_pt_expr (
607- pt_dict_of_named_arrays )
608+ dag )
608609
609610 try :
610611 pt_prg = self ._freeze_prg_cache [normalized_expr ]
@@ -756,13 +757,13 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
756757 from .compile import LazilyPyOpenCLCompilingFunctionCaller
757758 return LazilyPyOpenCLCompilingFunctionCaller (self , f )
758759
759- def transform_dag (self , dag : pytato .DictOfNamedArrays
760- ) -> pytato .DictOfNamedArrays :
760+ def transform_dag (self , dag : pytato .AbstractResultWithNamedArrays
761+ ) -> pytato .AbstractResultWithNamedArrays :
761762 import pytato as pt
762- dag = pt .tag_all_calls_to_be_inlined (dag )
763- dag = pt .inline_calls (dag )
764- dag = pt .transform .materialize_with_mpms (dag )
765- return dag
763+ tdag = pt .tag_all_calls_to_be_inlined (dag )
764+ tdag = pt .inline_calls (tdag )
765+ tdag = pt .transform .materialize_with_mpms (tdag )
766+ return tdag
766767
767768 def einsum (self , spec , * args , arg_names = None , tagged = ()):
768769 import pytato as pt
@@ -977,8 +978,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
977978 return LazilyJAXCompilingFunctionCaller (self , f )
978979
979980 @override
980- def transform_dag (self , dag : pytato .DictOfNamedArrays
981- ) -> pytato .DictOfNamedArrays :
981+ def transform_dag (self , dag : pytato .AbstractResultWithNamedArrays
982+ ) -> pytato .AbstractResultWithNamedArrays :
982983 import pytato as pt
983984
984985 dag = pt .tag_all_calls_to_be_inlined (dag )
0 commit comments