Skip to content

Commit 46ab8d0

Browse files
committed
add PytatoParallelPyOpenCLArrayContext
1 parent ce30702 commit 46ab8d0

7 files changed

Lines changed: 1383 additions & 9 deletions

File tree

arraycontext/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@
8080
from .impl.jax import EagerJAXArrayContext
8181
from .impl.numpy import NumpyArrayContext
8282
from .impl.pyopencl import PyOpenCLArrayContext
83-
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
83+
from .impl.pytato import (
84+
PytatoJAXArrayContext,
85+
PytatoParallelPyOpenCLArrayContext,
86+
PytatoPyOpenCLArrayContext,
87+
)
8488
from .loopy import make_loopy_program
8589
from .pytest import (
8690
PytestArrayContextFactory,
@@ -140,6 +144,7 @@
140144
"NumpyArrayContext",
141145
"PyOpenCLArrayContext",
142146
"PytatoJAXArrayContext",
147+
"PytatoParallelPyOpenCLArrayContext",
143148
"PytatoPyOpenCLArrayContext",
144149
"PytestArrayContextFactory",
145150
"PytestPyOpenCLArrayContextFactory",

arraycontext/impl/pytato/__init__.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
The following :mod:`pytato`-based array contexts are provided:
1414
1515
.. autoclass:: PytatoPyOpenCLArrayContext
16+
.. autoclass:: PytatoParallelPyOpenCLArrayContext
1617
.. autoclass:: PytatoJAXArrayContext
1718
1819
@@ -28,7 +29,8 @@
2829
.. automodule:: arraycontext.impl.pytato.utils
2930
"""
3031
__copyright__ = """
31-
Copyright (C) 2020-1 University of Illinois Board of Trustees
32+
Copyright (C) 2020-6 University of Illinois Board of Trustees
33+
Copyright (C) 2022-3 Kaushik Kulkarni
3234
"""
3335

3436
__license__ = """
@@ -827,9 +829,15 @@ def compile(self,
827829
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
828830
) -> pytato.AbstractResultWithNamedArrays:
829831
import pytato as pt
832+
833+
dag = pt.transform.deduplicate_data_wrappers(dag)
834+
830835
dag = pt.tag_all_calls_to_be_inlined(dag)
831836
dag = pt.inline_calls(dag)
832-
return pt.transform.materialize_with_mpms(dag)
837+
838+
dag = pt.transform.materialize_with_mpms(dag)
839+
840+
return dag
833841

834842
@override
835843
def einsum(self, spec, *args, arg_names=None, tagged=()):
@@ -909,6 +917,106 @@ def clone(self):
909917
# }}}
910918

911919

920+
# {{{ PytatoParallelPyOpenCLArrayContext
921+
922+
class PytatoParallelPyOpenCLArrayContext(PytatoPyOpenCLArrayContext):
923+
"""
924+
Same as :class:`PytatoPyOpenCLArrayContext`, but parallelizes across the device.
925+
926+
.. note::
927+
928+
Refer to :meth:`transform_dag` and :meth:`transform_loopy_program` for
929+
details on the transformation algorithm provided by this array context.
930+
931+
.. automethod:: transform_dag
932+
.. automethod:: transform_loopy_program
933+
"""
934+
# FIXME: Is this something that the base PytatoParallelPyOpenCLArrayContext
935+
# should be calling, or should it be left for more-concrete derived array
936+
# contexts? If the latter, where should it live?
937+
def _materialize_einsum_inputs_and_outputs(
938+
self, dag: pytato.AbstractResultWithNamedArrays
939+
) -> pytato.AbstractResultWithNamedArrays:
940+
import pytato as pt
941+
942+
from .utils import (
943+
get_inputs_and_outputs_of_einsum,
944+
get_inputs_and_outputs_of_reduction_nodes,
945+
)
946+
947+
einsum_inputs, einsum_outputs = get_inputs_and_outputs_of_einsum(dag)
948+
redn_inputs, redn_outputs = get_inputs_and_outputs_of_reduction_nodes(dag)
949+
reduction_inputs_outputs = (
950+
einsum_inputs | einsum_outputs | redn_inputs | redn_outputs)
951+
952+
def materialize(
953+
expr: pt.transform.ArrayOrNames) -> pt.transform.ArrayOrNames:
954+
if expr in reduction_inputs_outputs:
955+
if isinstance(expr, pt.InputArgumentBase):
956+
return expr
957+
else:
958+
return expr.tagged(pt.tags.ImplStored())
959+
else:
960+
return expr
961+
962+
return pt.transform.map_and_copy(dag, materialize)
963+
964+
@override
965+
def transform_dag(
966+
self, dag: pytato.AbstractResultWithNamedArrays
967+
) -> pytato.AbstractResultWithNamedArrays:
968+
r"""
969+
Returns a transformed version of *dag*, where the applied transform is:
970+
971+
#. Materialize as per MPMS materialization heuristic.
972+
#. materialize every :class:`pytato.array.Einsum`\ 's inputs and outputs.
973+
"""
974+
import pytato as pt
975+
976+
dag = pt.transform.deduplicate_data_wrappers(dag)
977+
978+
dag = pt.tag_all_calls_to_be_inlined(dag)
979+
dag = pt.inline_calls(dag)
980+
981+
dag = pt.transform.materialize_with_mpms(dag)
982+
dag = self._materialize_einsum_inputs_and_outputs(dag)
983+
984+
return dag
985+
986+
def _parallelize_across_device(
987+
self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
988+
from .parallelize import (
989+
alias_global_temporaries,
990+
parallelize_disjoint_loop_sets,
991+
)
992+
993+
t_unit = parallelize_disjoint_loop_sets(
994+
t_unit, self.queue.device.max_compute_units)
995+
996+
# FIXME: Is this something that this abstract-ish
997+
# PytatoParallelPyOpenCLArrayContext class should be calling, or should it
998+
# be left for more-concrete derived array contexts? If the latter, where
999+
# should it live?
1000+
t_unit = alias_global_temporaries(t_unit)
1001+
1002+
return t_unit
1003+
1004+
def transform_loopy_program(
1005+
self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
1006+
r"""
1007+
Returns a transformed version of *t_unit*, where the applied transform is:
1008+
1009+
#. An execution grid size :math:`G` is selected based on *self*'s
1010+
OpenCL-device.
1011+
#. The iteration domain for each statement in the *t_unit* is divided to
1012+
equally among the work-items in :math:`G`.
1013+
#. Kernel boundaries are drawn between every set of disjoint loops.
1014+
#. Once the kernel boundaries are inferred, global temporaries are aliased
1015+
to reduce the memory peak memory used by the transformed program.
1016+
"""
1017+
return self._parallelize_across_device(t_unit)
1018+
1019+
9121020
# {{{ PytatoJAXArrayContext
9131021

9141022
class PytatoJAXArrayContext(_BasePytatoArrayContext):

0 commit comments

Comments
 (0)