Skip to content

Commit c3a17d6

Browse files
committed
remove duplicates when creating FunctionDefinition
1 parent abf059e commit c3a17d6

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

arraycontext/impl/pytato/outline.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,14 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
204204
unpacked_prefixed_output = _unpack_output(prefixed_output)
205205
if isinstance(unpacked_prefixed_output, pt.Array):
206206
unpacked_prefixed_output = {"_": unpacked_prefixed_output}
207+
unpacked_prefixed_output = pt.transform.Deduplicator()(
208+
pt.make_dict_of_named_arrays(unpacked_prefixed_output))
207209

208210
prefixed_placeholders = frozenset(
209211
arg_id_to_prefixed_placeholder.values())
210212

211213
found_placeholders = frozenset({
212-
arg for arg in pt.transform.InputGatherer()(
213-
pt.make_dict_of_named_arrays(unpacked_prefixed_output))
214+
arg for arg in pt.transform.InputGatherer()(unpacked_prefixed_output)
214215
if isinstance(arg, pt.Placeholder)})
215216

216217
extra_placeholders = found_placeholders - prefixed_placeholders
@@ -227,10 +228,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
227228
ret_type = pt.function.ReturnType.ARRAY
228229
else:
229230
ret_type = pt.function.ReturnType.DICT_OF_ARRAYS
231+
unpacked_output = pt.transform.Deduplicator()(
232+
pt.make_dict_of_named_arrays(unpacked_output))
230233

231234
used_placeholders = frozenset({
232235
arg for arg in pt.transform.InputGatherer()(
233-
pt.make_dict_of_named_arrays(unpacked_output))
236+
unpacked_output)
234237
if isinstance(arg, pt.Placeholder)})
235238

236239
call_bindings = {
@@ -245,7 +248,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
245248
func_def = pt.function.FunctionDefinition(
246249
parameters=frozenset(call_bindings.keys()),
247250
return_type=ret_type,
248-
returns=immutabledict(unpacked_output),
251+
returns=immutabledict(unpacked_output._data),
249252
tags=self.tags,
250253
)
251254

0 commit comments

Comments
 (0)