6161from pytato .tags import ImplStored
6262from pytato .transform import (
6363 ArrayOrNames ,
64+ ArrayOrNamesOrFunctionDef ,
6465 ArrayOrNamesTc ,
6566 CachedWalkMapper ,
6667 CombineMapper ,
@@ -362,7 +363,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:
362363
363364class ListOfDirectPredecessorsGetter (
364365 Mapper [
365- list [ArrayOrNames | FunctionDefinition ],
366+ list [ArrayOrNamesOrFunctionDef ],
366367 list [ArrayOrNames ],
367368 []]):
368369 """
@@ -445,8 +446,8 @@ def map_distributed_send_ref_holder(self,
445446 return [expr .send .data , expr .passthrough_data ]
446447
447448 def map_call (
448- self , expr : Call ) -> list [ArrayOrNames | FunctionDefinition ]:
449- result : list [ArrayOrNames | FunctionDefinition ] = []
449+ self , expr : Call ) -> list [ArrayOrNamesOrFunctionDef ]:
450+ result : list [ArrayOrNamesOrFunctionDef ] = []
450451 if self .include_functions :
451452 result .append (expr .function )
452453 result += list (expr .bindings .values ())
@@ -483,7 +484,7 @@ def __init__(self, *, include_functions: bool = False) -> None:
483484 @overload
484485 def __call__ (
485486 self , expr : ArrayOrNames
486- ) -> FrozenOrderedSet [ArrayOrNames | FunctionDefinition ]:
487+ ) -> FrozenOrderedSet [ArrayOrNamesOrFunctionDef ]:
487488 ...
488489
489490 @overload
@@ -492,9 +493,9 @@ def __call__(self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]:
492493
493494 def __call__ (
494495 self ,
495- expr : ArrayOrNames | FunctionDefinition ,
496+ expr : ArrayOrNamesOrFunctionDef ,
496497 ) -> (
497- FrozenOrderedSet [ArrayOrNames | FunctionDefinition ]
498+ FrozenOrderedSet [ArrayOrNamesOrFunctionDef ]
498499 | FrozenOrderedSet [ArrayOrNames ]):
499500 """Get the direct predecessors of *expr*."""
500501 return FrozenOrderedSet (self ._pred_getter (expr ))
@@ -543,7 +544,7 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self:
543544 _visited_functions = self ._visited_functions )
544545
545546 @override
546- def post_visit (self , expr : ArrayOrNames | FunctionDefinition ) -> None :
547+ def post_visit (self , expr : ArrayOrNamesOrFunctionDef ) -> None :
547548 if not isinstance (expr , DictOfNamedArrays ):
548549 self .expr_type_counts [type (expr )] += 1
549550
@@ -606,7 +607,7 @@ def __init__(self, _visited_functions: set[Any] | None = None) -> None:
606607 super ().__init__ (_visited_functions = _visited_functions )
607608
608609 self .expr_multiplicity_counts : \
609- dict [ArrayOrNames | FunctionDefinition , int ] = defaultdict (int )
610+ dict [ArrayOrNamesOrFunctionDef , int ] = defaultdict (int )
610611
611612 @override
612613 def get_cache_key (self , expr : ArrayOrNames ) -> int :
@@ -619,13 +620,13 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
619620 return id (expr )
620621
621622 @override
622- def post_visit (self , expr : ArrayOrNames | FunctionDefinition ) -> None :
623+ def post_visit (self , expr : ArrayOrNamesOrFunctionDef ) -> None :
623624 if not isinstance (expr , DictOfNamedArrays ):
624625 self .expr_multiplicity_counts [expr ] += 1
625626
626627
627628def get_node_multiplicities (
628- outputs : ArrayOrNames ) -> dict [ArrayOrNames | FunctionDefinition , int ]:
629+ outputs : ArrayOrNames ) -> dict [ArrayOrNamesOrFunctionDef , int ]:
629630 """
630631 Returns the multiplicity per `expr`.
631632 """
@@ -662,7 +663,7 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
662663 return id (expr )
663664
664665 @override
665- def post_visit (self , expr : ArrayOrNames | FunctionDefinition ) -> None :
666+ def post_visit (self , expr : ArrayOrNamesOrFunctionDef ) -> None :
666667 if isinstance (expr , Call ):
667668 self .count += 1
668669
@@ -884,7 +885,7 @@ def map_call(self, expr: Call) -> None:
884885 f"{ type (self ).__name__ } does not support functions." )
885886
886887 @override
887- def post_visit (self , expr : ArrayOrNames | FunctionDefinition ) -> None :
888+ def post_visit (self , expr : ArrayOrNamesOrFunctionDef ) -> None :
888889 if not is_materialized (expr ):
889890 return
890891 assert isinstance (expr , Array )
@@ -978,7 +979,7 @@ def map_call(self, expr: Call) -> None:
978979 f"{ type (self ).__name__ } does not support functions." )
979980
980981 @override
981- def post_visit (self , expr : ArrayOrNames | FunctionDefinition ) -> None :
982+ def post_visit (self , expr : ArrayOrNamesOrFunctionDef ) -> None :
982983 if not is_materialized (expr ) or not has_taggable_materialization (expr ):
983984 return
984985 assert isinstance (expr , Array )
0 commit comments