3030from typing import Any , Callable , Dict , FrozenSet , Mapping , Optional , Tuple , Union
3131
3232import numpy as np
33- from immutables import Map
33+ from immutabledict import immutabledict
3434
3535import pytato as pt
3636from pytools .tag import Tag
4343
4444def _get_arg_id_to_arg (args : Tuple [Any , ...],
4545 kwargs : Mapping [str , Any ]
46- ) -> Map [Tuple [Any , ...], Any ]:
46+ ) -> immutabledict [Tuple [Any , ...], Any ]:
4747 """
4848 Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id
4949 to argument values. See
@@ -77,7 +77,7 @@ def id_collector(keys, ary):
7777 " either a scalar, pt.Array or an array container. Got"
7878 f" '{ arg } '." )
7979
80- return Map (arg_id_to_arg )
80+ return immutabledict (arg_id_to_arg )
8181
8282
8383def _get_input_arg_id_str (
@@ -95,14 +95,15 @@ def _get_output_arg_id_str(arg_id: Tuple[Any, ...]) -> str:
9595
9696def _get_arg_id_to_placeholder (
9797 arg_id_to_arg : Mapping [Tuple [Any , ...], Any ],
98- prefix : Optional [str ] = None ) -> Map [Tuple [Any , ...], pt .Placeholder ]:
98+ prefix : Optional [str ] = None
99+ ) -> immutabledict [Tuple [Any , ...], pt .Placeholder ]:
99100 """
100101 Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder`
101102 for each argument in *arg_id_to_arg*. See
102103 :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's
103104 representation.
104105 """
105- return Map ({
106+ return immutabledict ({
106107 arg_id : pt .make_placeholder (
107108 _get_input_arg_id_str (arg_id , prefix = prefix ),
108109 arg .shape ,
@@ -244,7 +245,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
244245 func_def = pt .function .FunctionDefinition (
245246 parameters = frozenset (call_bindings .keys ()),
246247 return_type = ret_type ,
247- returns = Map (unpacked_output ),
248+ returns = immutabledict (unpacked_output ),
248249 tags = self .tags ,
249250 )
250251
0 commit comments