Skip to content

Commit c843cb6

Browse files
committed
add PytatoParallelPyOpenCLArrayContext
1 parent 8762568 commit c843cb6

7 files changed

Lines changed: 1119 additions & 5 deletions

File tree

arraycontext/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@
8080
from .impl.jax import EagerJAXArrayContext
8181
from .impl.numpy import NumpyArrayContext
8282
from .impl.pyopencl import PyOpenCLArrayContext
83-
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
83+
from .impl.pytato import (
84+
PytatoJAXArrayContext,
85+
PytatoParallelPyOpenCLArrayContext,
86+
PytatoPyOpenCLArrayContext,
87+
)
8488
from .loopy import make_loopy_program
8589
from .pytest import (
8690
PytestArrayContextFactory,
@@ -140,6 +144,7 @@
140144
"NumpyArrayContext",
141145
"PyOpenCLArrayContext",
142146
"PytatoJAXArrayContext",
147+
"PytatoParallelPyOpenCLArrayContext",
143148
"PytatoPyOpenCLArrayContext",
144149
"PytestArrayContextFactory",
145150
"PytestPyOpenCLArrayContextFactory",

arraycontext/impl/pytato/__init__.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
The following :mod:`pytato`-based array contexts are provided:
1414
1515
.. autoclass:: PytatoPyOpenCLArrayContext
16+
.. autoclass:: PytatoParallelPyOpenCLArrayContext
1617
.. autoclass:: PytatoJAXArrayContext
1718
1819
@@ -28,7 +29,8 @@
2829
.. automodule:: arraycontext.impl.pytato.utils
2930
"""
3031
__copyright__ = """
31-
Copyright (C) 2020-1 University of Illinois Board of Trustees
32+
Copyright (C) 2020-6 University of Illinois Board of Trustees
33+
Copyright (C) 2022-3 Kaushik Kulkarni
3234
"""
3335

3436
__license__ = """
@@ -827,9 +829,15 @@ def compile(self,
827829
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
828830
) -> pytato.AbstractResultWithNamedArrays:
829831
import pytato as pt
832+
833+
dag = pt.transform.deduplicate_data_wrappers(dag)
834+
830835
dag = pt.tag_all_calls_to_be_inlined(dag)
831836
dag = pt.inline_calls(dag)
832-
return pt.transform.materialize_with_mpms(dag)
837+
838+
dag = pt.transform.materialize_with_mpms(dag)
839+
840+
return dag
833841

834842
@override
835843
def einsum(self, spec, *args, arg_names=None, tagged=()):
@@ -909,6 +917,110 @@ def clone(self):
909917
# }}}
910918

911919

920+
# {{{ PytatoParallelPyOpenCLArrayContext
921+
922+
class PytatoParallelPyOpenCLArrayContext(PytatoPyOpenCLArrayContext):
923+
"""
924+
Same as :class:`PytatoPyOpenCLArrayContext`, but parallelizes across the device.
925+
926+
.. note::
927+
928+
Refer to :meth:`transform_dag` and :meth:`transform_loopy_program` for
929+
details on the transformation algorithm provided by this array context.
930+
"""
931+
# FIXME: Is this something that the base PytatoParallelPyOpenCLArrayContext
932+
# should be calling, or should it be left for more-concrete derived array
933+
# contexts? If the latter, where should it live?
934+
def _materialize_einsum_inputs_and_outputs(
935+
self, dag: pytato.AbstractResultWithNamedArrays
936+
) -> pytato.AbstractResultWithNamedArrays:
937+
import pytato as pt
938+
939+
from .utils import (
940+
get_inputs_and_outputs_of_einsum,
941+
get_inputs_and_outputs_of_reduction_nodes,
942+
)
943+
944+
einsum_inputs, einsum_outputs = get_inputs_and_outputs_of_einsum(dag)
945+
redn_inputs, redn_outputs = get_inputs_and_outputs_of_reduction_nodes(dag)
946+
reduction_inputs_outputs = (
947+
einsum_inputs | einsum_outputs | redn_inputs | redn_outputs)
948+
949+
def materialize(
950+
expr: pt.transform.ArrayOrNames) -> pt.transform.ArrayOrNames:
951+
if expr in reduction_inputs_outputs:
952+
if isinstance(expr, pt.InputArgumentBase):
953+
return expr
954+
else:
955+
return expr.tagged(pt.tags.ImplStored())
956+
else:
957+
return expr
958+
959+
return pt.transform.map_and_copy(dag, materialize)
960+
961+
@override
962+
def transform_dag(
963+
self, dag: pytato.AbstractResultWithNamedArrays
964+
) -> pytato.AbstractResultWithNamedArrays:
965+
r"""
966+
Returns a transformed version of *dag*, where the applied transform is:
967+
968+
#. Materialize as per MPMS materialization heuristic.
969+
#. materialize every :class:`pytato.array.Einsum`\ 's inputs and outputs.
970+
"""
971+
import pytato as pt
972+
973+
dag = pt.transform.deduplicate_data_wrappers(dag)
974+
975+
dag = pt.tag_all_calls_to_be_inlined(dag)
976+
dag = pt.inline_calls(dag)
977+
978+
dag = pt.transform.materialize_with_mpms(dag)
979+
dag = self._materialize_einsum_inputs_and_outputs(dag)
980+
981+
return dag
982+
983+
def _parallelize_across_device(
984+
self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
985+
from .parallelize import (
986+
add_gbarrier_between_disjoint_loop_sets,
987+
alias_global_temporaries,
988+
split_iteration_domain_across_work_items,
989+
)
990+
991+
# Must add barriers before parallelizing, because some parallelization
992+
# transformations create new loop sets (for example, scalar reductions) and
993+
# create their own barriers as part of that process
994+
t_unit = add_gbarrier_between_disjoint_loop_sets(t_unit)
995+
996+
t_unit = split_iteration_domain_across_work_items(
997+
t_unit, self.queue.device.max_compute_units)
998+
999+
# FIXME: Is this something that this abstract-ish
1000+
# PytatoParallelPyOpenCLArrayContext class should be calling, or should it
1001+
# be left for more-concrete derived array contexts? If the latter, where
1002+
# should it live?
1003+
t_unit = alias_global_temporaries(t_unit)
1004+
1005+
return t_unit
1006+
1007+
def transform_loopy_program(
1008+
self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
1009+
r"""
1010+
Returns a transformed version of *t_unit*, where the applied transform is:
1011+
1012+
#. An execution grid size :math:`G` is selected based on *self*'s
1013+
OpenCL-device.
1014+
#. The iteration domain for each statement in the *t_unit* is divided to
1015+
equally among the work-items in :math:`G`.
1016+
#. Kernel boundaries are drawn between every set of disjoint loops.
1017+
#. Once the kernel boundaries are inferred, :func:`alias_global_temporaries`
1018+
is invoked to reduce the memory peak memory used by the transformed
1019+
program.
1020+
"""
1021+
return self._parallelize_across_device(t_unit)
1022+
1023+
9121024
# {{{ PytatoJAXArrayContext
9131025

9141026
class PytatoJAXArrayContext(_BasePytatoArrayContext):

0 commit comments

Comments
 (0)