Skip to content

Commit 62e3586

Browse files
committed
Better transform typing
1 parent befb2d3 commit 62e3586

4 files changed

Lines changed: 53 additions & 41 deletions

File tree

arraycontext/impl/pytato/__init__.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def __init__(
155155
self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
156156
self._dag_transform_cache: dict[
157157
pt.DictOfNamedArrays,
158-
tuple[pt.DictOfNamedArrays, str]] = {}
158+
tuple[pt.AbstractResultWithNamedArrays, str]] = {}
159159

160160
if compile_trace_callback is None:
161161
def _compile_trace_callback(what, stage, ir):
@@ -177,8 +177,8 @@ def _frozen_array_types(self) -> tuple[type, ...]:
177177

178178
# {{{ compilation
179179

180-
def transform_dag(self, dag: pytato.DictOfNamedArrays
181-
) -> pytato.DictOfNamedArrays:
180+
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
181+
) -> pytato.AbstractResultWithNamedArrays:
182182
"""
183183
Returns a transformed version of *dag*. Sub-classes are supposed to
184184
override this method to implement context-specific transformations on
@@ -593,18 +593,19 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
593593
rec_keyed_map_array_container(_to_frozen, array),
594594
actx=None)
595595

596-
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
596+
dag = pt.make_dict_of_named_arrays(
597597
key_to_pt_arrays)
598598

599-
pt_dict_of_named_arrays = pt.deduplicate(pt_dict_of_named_arrays)
599+
from pytato.transform import Deduplicator
600+
dag = Deduplicator()(dag)
600601

601602
# FIXME: Remove this if/when _normalize_pt_expr gets support for functions
602-
pt_dict_of_named_arrays = pt.tag_all_calls_to_be_inlined(
603-
pt_dict_of_named_arrays)
604-
pt_dict_of_named_arrays = pt.inline_calls(pt_dict_of_named_arrays)
603+
dag = pt.tag_all_calls_to_be_inlined(
604+
dag)
605+
dag = pt.inline_calls(dag)
605606

606607
normalized_expr, bound_arguments = _normalize_pt_expr(
607-
pt_dict_of_named_arrays)
608+
dag)
608609

609610
try:
610611
pt_prg = self._freeze_prg_cache[normalized_expr]
@@ -756,13 +757,13 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
756757
from .compile import LazilyPyOpenCLCompilingFunctionCaller
757758
return LazilyPyOpenCLCompilingFunctionCaller(self, f)
758759

759-
def transform_dag(self, dag: pytato.DictOfNamedArrays
760-
) -> pytato.DictOfNamedArrays:
760+
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
761+
) -> pytato.AbstractResultWithNamedArrays:
761762
import pytato as pt
762-
dag = pt.tag_all_calls_to_be_inlined(dag)
763-
dag = pt.inline_calls(dag)
764-
dag = pt.transform.materialize_with_mpms(dag)
765-
return dag
763+
tdag = pt.tag_all_calls_to_be_inlined(dag)
764+
tdag = pt.inline_calls(tdag)
765+
tdag = pt.transform.materialize_with_mpms(tdag)
766+
return tdag
766767

767768
def einsum(self, spec, *args, arg_names=None, tagged=()):
768769
import pytato as pt
@@ -977,8 +978,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
977978
return LazilyJAXCompilingFunctionCaller(self, f)
978979

979980
@override
980-
def transform_dag(self, dag: pytato.DictOfNamedArrays
981-
) -> pytato.DictOfNamedArrays:
981+
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
982+
) -> pytato.AbstractResultWithNamedArrays:
982983
import pytato as pt
983984

984985
dag = pt.tag_all_calls_to_be_inlined(dag)

arraycontext/impl/pytato/compile.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
"""
88
from __future__ import annotations
99

10+
from pytato.array import AxesT
11+
from pytato.transform import Deduplicator
12+
1013

1114
__copyright__ = """
1215
Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -189,7 +192,7 @@ def _to_input_for_compiled(
189192
"""
190193
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
191194
if isinstance(ary, pt.Array):
192-
dag = pt.deduplicate(pt.make_dict_of_named_arrays({"_actx_out": ary}))
195+
dag = Deduplicator()(pt.make_dict_of_named_arrays({"_actx_out": ary}))
193196
# Transform the DAG to give metadata inference a chance to do its job
194197
return actx.transform_dag(dag)["_actx_out"].expr
195198
elif isinstance(ary, TaggableCLArray):
@@ -406,12 +409,16 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
406409
self.actx._compile_trace_callback(
407410
prg_id, "post_transform_dag", pt_dict_of_named_arrays)
408411

409-
name_in_program_to_tags = {
410-
name: out.tags
411-
for name, out in pt_dict_of_named_arrays._data.items()}
412-
name_in_program_to_axes = {
413-
name: out.axes
414-
for name, out in pt_dict_of_named_arrays._data.items()}
412+
name_in_program_to_tags: dict[str, frozenset[Tag]] = {}
413+
name_in_program_to_axes: dict[str, AxesT] = {}
414+
if isinstance(pt_dict_of_named_arrays, pt.DictOfNamedArrays):
415+
name_in_program_to_tags.update({
416+
name: out.tags
417+
for name, out in pt_dict_of_named_arrays._data.items()})
418+
419+
name_in_program_to_axes.update({
420+
name: out.axes
421+
for name, out in pt_dict_of_named_arrays._data.items()})
415422

416423
self.actx._compile_trace_callback(
417424
prg_id, "pre_generate_loopy", pt_dict_of_named_arrays)
@@ -503,12 +510,16 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
503510
self.actx._compile_trace_callback(
504511
prg_id, "post_transform_dag", pt_dict_of_named_arrays)
505512

506-
name_in_program_to_tags = {
507-
name: out.tags
508-
for name, out in pt_dict_of_named_arrays._data.items()}
509-
name_in_program_to_axes = {
510-
name: out.axes
511-
for name, out in pt_dict_of_named_arrays._data.items()}
513+
name_in_program_to_tags: dict[str, frozenset[Tag]] = {}
514+
name_in_program_to_axes: dict[str, AxesT] = {}
515+
if isinstance(pt_dict_of_named_arrays, pt.DictOfNamedArrays):
516+
name_in_program_to_tags.update({
517+
name: out.tags
518+
for name, out in pt_dict_of_named_arrays._data.items()})
519+
520+
name_in_program_to_axes.update({
521+
name: out.axes
522+
for name, out in pt_dict_of_named_arrays._data.items()})
512523

513524
self.actx._compile_trace_callback(
514525
prg_id, "pre_generate_jax", pt_dict_of_named_arrays)

arraycontext/impl/pytato/outline.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from pytato.transform import Deduplicator
4+
35

46
__doc__ = """
57
.. autoclass:: OutlinedCall
@@ -212,7 +214,7 @@ def __call__(self, *args: object, **kwargs: object) -> ArrayOrContainer:
212214

213215
prefixed_output = _call_with_placeholders(
214216
self.f, args, kwargs, arg_id_to_prefixed_placeholder)
215-
unpacked_prefixed_output = pt.deduplicate(
217+
unpacked_prefixed_output = Deduplicator()(
216218
pt.make_dict_of_named_arrays(_unpack_output(prefixed_output)))
217219

218220
prefixed_placeholders = frozenset(
@@ -230,7 +232,7 @@ def __call__(self, *args: object, **kwargs: object) -> ArrayOrContainer:
230232
arg_id_to_placeholder = _get_arg_id_to_placeholder(arg_id_to_arg)
231233

232234
output = _call_with_placeholders(self.f, args, kwargs, arg_id_to_placeholder)
233-
unpacked_output = pt.deduplicate(
235+
unpacked_output = Deduplicator()(
234236
pt.make_dict_of_named_arrays(_unpack_output(output)))
235237
if len(unpacked_output) == 1 and "_" in unpacked_output:
236238
ret_type = pt.function.ReturnType.ARRAY
@@ -247,10 +249,6 @@ def __call__(self, *args: object, **kwargs: object) -> ArrayOrContainer:
247249
for arg_id, placeholder in arg_id_to_placeholder.items()
248250
if placeholder in used_placeholders}
249251

250-
# pylint-disable-reason: pylint has a hard time with kw_only fields in
251-
# dataclasses
252-
253-
# pylint: disable=unexpected-keyword-arg
254252
func_def = pt.function.FunctionDefinition(
255253
parameters=frozenset(call_bindings.keys()),
256254
return_type=ret_type,

arraycontext/impl/pytato/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from typing_extensions import override
4545

4646
import pytools
47+
from pytato import AbstractResultWithNamedArrays
4748
from pytato.analysis import get_num_call_sites
4849
from pytato.array import (
4950
Array,
@@ -59,8 +60,9 @@
5960
from pytato.transform import (
6061
ArrayOrNames,
6162
CopyMapper,
63+
Deduplicator,
64+
MappedT,
6265
TransformMapperCache,
63-
deduplicate,
6466
)
6567
from pytools import UniqueNameGenerator, memoize_method
6668

@@ -135,7 +137,7 @@ def map_function_definition(
135137
# FIXME: This strategy doesn't work if the DAG has functions, since function
136138
# definitions can't contain non-argument placeholders
137139
def _normalize_pt_expr(
138-
expr: DictOfNamedArrays
140+
expr: AbstractResultWithNamedArrays
139141
) -> tuple[DictOfNamedArrays, Mapping[str, Any]]:
140142
"""
141143
Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
@@ -146,7 +148,7 @@ def _normalize_pt_expr(
146148
Deterministic naming of placeholders permits more effective caching of
147149
equivalent graphs.
148150
"""
149-
expr = deduplicate(expr)
151+
expr = Deduplicator()(expr)
150152

151153
if get_num_call_sites(expr):
152154
raise NotImplementedError(
@@ -246,15 +248,15 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
246248
non_equality_tags=expr.non_equality_tags)
247249

248250

249-
def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
251+
def transfer_from_numpy(expr: MappedT, actx: ArrayContext) -> MappedT:
250252
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
251253
instances to be device arrays, using
252254
:meth:`~arraycontext.ArrayContext.from_numpy`.
253255
"""
254256
return TransferFromNumpyMapper(actx)(expr)
255257

256258

257-
def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
259+
def transfer_to_numpy(expr: MappedT, actx: ArrayContext) -> MappedT:
258260
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
259261
instances to be :class:`numpy.ndarray` instances, using
260262
:meth:`~arraycontext.ArrayContext.to_numpy`.

0 commit comments

Comments
 (0)