Skip to content

Commit 372b2e5

Browse files
Merge remote-tracking branch 'origin/main' into pass-cycle
2 parents a56643a + 5745fe3 commit 372b2e5

16 files changed

Lines changed: 1834 additions & 793 deletions

pytato/__init__.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,18 @@ def set_debug_enabled(flag: bool) -> None:
9393
import pytato.analysis as analysis
9494
import pytato.tags as tags
9595
import pytato.transform as transform
96-
from pytato.distributed import (make_distributed_send, make_distributed_recv,
96+
from pytato.distributed.nodes import (make_distributed_send, make_distributed_recv,
9797
DistributedRecv, DistributedSend,
9898
DistributedSendRefHolder,
99-
staple_distributed_send,
100-
find_distributed_partition,
101-
number_distributed_tags,
102-
execute_distributed_partition,
103-
verify_distributed_partition,
104-
)
99+
staple_distributed_send)
100+
from pytato.distributed.partition import (
101+
find_distributed_partition, DistributedGraphPart, DistributedGraphPartition)
102+
from pytato.distributed.tags import number_distributed_tags
103+
from pytato.distributed.verify import verify_distributed_partition
104+
from pytato.distributed.execute import execute_distributed_partition
105+
105106
from pytato.transform.lower_to_index_lambda import to_index_lambda
107+
from pytato.transform.metadata import unify_axes_tags
106108

107109
from pytato.partition import generate_code_for_partition
108110

@@ -152,7 +154,10 @@ def set_debug_enabled(flag: bool) -> None:
152154
"make_distributed_recv", "make_distributed_send", "DistributedRecv",
153155
"DistributedSend", "staple_distributed_send", "DistributedSendRefHolder",
154156

157+
"DistributedGraphPart",
158+
"DistributedGraphPartition",
155159
"find_distributed_partition",
160+
156161
"number_distributed_tags",
157162
"execute_distributed_partition",
158163
"verify_distributed_partition",
@@ -161,6 +166,8 @@ def set_debug_enabled(flag: bool) -> None:
161166

162167
"to_index_lambda",
163168

169+
"unify_axes_tags",
170+
164171
# sub-modules
165172
"analysis", "tags", "transform",
166173

pytato/analysis/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from pymbolic.mapper.optimize import optimize_mapper
3737

3838
if TYPE_CHECKING:
39-
from pytato.distributed import DistributedRecv, DistributedSendRefHolder
39+
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
4040

4141
__doc__ = """
4242
.. currentmodule:: pytato.analysis

pytato/distributed/__init__.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
r"""
2+
Distributed-memory evaluation of expression graphs is accomplished
3+
by :ref:`partitioning <partitioning>` the graph to reveal communication-free
4+
pieces of the computation. Communication (i.e. sending/receiving data) is then
5+
accomplished at the boundaries of the parts of the resulting graph partitioning.
6+
7+
Recall the requirement for partitioning that, "no part may depend on its own
8+
outputs as inputs". That sounds obvious, but in the distributed-memory case,
9+
this is harder to decide than it looks, since we do not have full knowledge of
10+
the computation graph. Edges go off to other nodes and then come back.
11+
12+
As a first step towards making this tractable, we currently strengthen the
13+
requirement to create partition boundaries on every edge that goes between
14+
nodes that are/are not a dependency of a receive or that feed/do not feed a send.
15+
16+
.. automodule:: pytato.distributed.nodes
17+
.. automodule:: pytato.distributed.partition
18+
.. automodule:: pytato.distributed.verify
19+
.. automodule:: pytato.distributed.execute
20+
21+
Internal stuff that is only here because the documentation tool wants it
22+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
23+
.. class:: Tag
24+
25+
See :class:`pytools.tag.Tag`.
26+
27+
.. class:: CommTagType
28+
29+
A type representing a communication tag.
30+
31+
.. class:: ShapeType
32+
33+
A type representing a shape.
34+
35+
.. class:: AxesT
36+
37+
A :class:`tuple` of :class:`Axis` objects.
38+
"""
39+
40+
from __future__ import annotations
41+
42+
__copyright__ = """
43+
Copyright (C) 2021 University of Illinois Board of Trustees
44+
"""
45+
46+
__license__ = """
47+
Permission is hereby granted, free of charge, to any person obtaining a copy
48+
of this software and associated documentation files (the "Software"), to deal
49+
in the Software without restriction, including without limitation the rights
50+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
51+
copies of the Software, and to permit persons to whom the Software is
52+
furnished to do so, subject to the following conditions:
53+
54+
The above copyright notice and this permission notice shall be included in
55+
all copies or substantial portions of the Software.
56+
57+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
58+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
59+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
60+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
61+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
62+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
63+
THE SOFTWARE.
64+
"""
65+
66+
from typing import Any
67+
68+
# These are here to support hold versions of grudge.
69+
70+
_depr_names = {
71+
"DistributedGraphPartition",
72+
"number_distributed_tags",
73+
"execute_distributed_partition", "pytato.distributed.execute"
74+
}
75+
76+
77+
def __getattr__(name: str) -> Any:
78+
if name in _depr_names:
79+
from warnings import warn
80+
warn(f"'pytato.distributed.{name}' is deprecated. "
81+
f"Import as 'pytato.{name}' instead. "
82+
"This will stop working in July 2023.",
83+
DeprecationWarning, stacklevel=2)
84+
85+
import pytato
86+
return getattr(pytato, name)
87+
88+
# let name lookup proceed normally
89+
raise AttributeError(name)

pytato/distributed/execute.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""
2+
.. currentmodule:: pytato
3+
4+
.. autofunction:: execute_distributed_partition
5+
"""
6+
7+
from __future__ import annotations
8+
9+
__copyright__ = """
10+
Copyright (C) 2021 University of Illinois Board of Trustees
11+
"""
12+
13+
__license__ = """
14+
Permission is hereby granted, free of charge, to any person obtaining a copy
15+
of this software and associated documentation files (the "Software"), to deal
16+
in the Software without restriction, including without limitation the rights
17+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18+
copies of the Software, and to permit persons to whom the Software is
19+
furnished to do so, subject to the following conditions:
20+
21+
The above copyright notice and this permission notice shall be included in
22+
all copies or substantial portions of the Software.
23+
24+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
30+
THE SOFTWARE.
31+
"""
32+
33+
from typing import Any, Dict, Hashable, Tuple, Optional, TYPE_CHECKING
34+
35+
36+
from pytato.target import BoundProgram
37+
from pytato.scalar_expr import INT_CLASSES
38+
39+
import numpy as np
40+
41+
42+
from pytato.distributed.nodes import (
43+
DistributedRecv, DistributedSend)
44+
from pytato.distributed.partition import (
45+
DistributedGraphPartition, DistributedGraphPart)
46+
47+
import logging
48+
logger = logging.getLogger(__name__)
49+
50+
51+
if TYPE_CHECKING:
52+
import mpi4py.MPI
53+
54+
55+
# {{{ distributed execute
56+
57+
def _post_receive(mpi_communicator: mpi4py.MPI.Comm,
58+
recv: DistributedRecv) -> Tuple[Any, np.ndarray[Any, Any]]:
59+
if not all(isinstance(dim, INT_CLASSES) for dim in recv.shape):
60+
raise NotImplementedError("Parametric shapes not supported yet.")
61+
62+
assert isinstance(recv.comm_tag, int)
63+
# mypy is right here, size params in 'recv.shape' must be evaluated
64+
buf = np.empty(recv.shape, dtype=recv.dtype) # type: ignore[arg-type]
65+
66+
return mpi_communicator.Irecv(
67+
buf=buf, source=recv.src_rank, tag=recv.comm_tag), buf
68+
69+
70+
def _mpi_send(mpi_communicator: Any, send_node: DistributedSend,
71+
data: np.ndarray[Any, Any]) -> Any:
72+
# Must use-non-blocking send, as blocking send may wait for a corresponding
73+
# receive to be posted (but if sending to self, this may only occur later).
74+
return mpi_communicator.Isend(
75+
data, dest=send_node.dest_rank, tag=send_node.comm_tag)
76+
77+
78+
def execute_distributed_partition(
79+
partition: DistributedGraphPartition, prg_per_partition:
80+
Dict[Hashable, BoundProgram],
81+
queue: Any, mpi_communicator: Any,
82+
*,
83+
allocator: Optional[Any] = None,
84+
input_args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
85+
86+
if input_args is None:
87+
input_args = {}
88+
89+
from mpi4py import MPI
90+
91+
if len(partition.parts) != 1:
92+
recv_names_tup, recv_requests_tup, recv_buffers_tup = zip(*[
93+
(name,) + _post_receive(mpi_communicator, recv)
94+
for part in partition.parts.values()
95+
for name, recv in part.input_name_to_recv_node.items()])
96+
recv_names = list(recv_names_tup)
97+
recv_requests = list(recv_requests_tup)
98+
recv_buffers = list(recv_buffers_tup)
99+
del recv_names_tup
100+
del recv_requests_tup
101+
del recv_buffers_tup
102+
else:
103+
# Only a single partition, no recv requests exist
104+
recv_names = []
105+
recv_requests = []
106+
recv_buffers = []
107+
108+
context: Dict[str, Any] = input_args.copy()
109+
110+
pids_to_execute = set(partition.parts)
111+
pids_executed = set()
112+
recv_names_completed = set()
113+
send_requests = []
114+
115+
# {{{ Input name refcount
116+
117+
# Keep a count on how often each input name is used
118+
# in order to be able to free them.
119+
120+
from pytools import memoize_on_first_arg
121+
122+
@memoize_on_first_arg
123+
def _get_partition_input_name_refcount(partition: DistributedGraphPartition) \
124+
-> Dict[str, int]:
125+
partition_input_names_refcount: Dict[str, int] = {}
126+
for pid in set(partition.parts):
127+
for name in partition.parts[pid].all_input_names():
128+
if name in partition_input_names_refcount:
129+
partition_input_names_refcount[name] += 1
130+
else:
131+
partition_input_names_refcount[name] = 1
132+
133+
return partition_input_names_refcount
134+
135+
partition_input_names_refcount = \
136+
_get_partition_input_name_refcount(partition).copy()
137+
138+
# }}}
139+
140+
def exec_ready_part(part: DistributedGraphPart) -> None:
141+
inputs = {k: context[k] for k in part.all_input_names()}
142+
143+
_evt, result_dict = prg_per_partition[part.pid](queue,
144+
allocator=allocator,
145+
**inputs)
146+
147+
context.update(result_dict)
148+
149+
for name, send_node in part.output_name_to_send_node.items():
150+
# FIXME: pytato shouldn't depend on pyopencl
151+
if isinstance(context[name], np.ndarray):
152+
data = context[name]
153+
else:
154+
data = context[name].get(queue)
155+
send_requests.append(_mpi_send(mpi_communicator, send_node, data))
156+
157+
pids_executed.add(part.pid)
158+
pids_to_execute.remove(part.pid)
159+
160+
def wait_for_some_recvs() -> None:
161+
complete_recv_indices = MPI.Request.Waitsome(recv_requests)
162+
163+
# Waitsome is allowed to return None
164+
if not complete_recv_indices:
165+
complete_recv_indices = []
166+
167+
# reverse to preserve indices
168+
for idx in sorted(complete_recv_indices, reverse=True):
169+
name = recv_names.pop(idx)
170+
recv_requests.pop(idx)
171+
buf = recv_buffers.pop(idx)
172+
173+
# FIXME: pytato shouldn't depend on pyopencl
174+
import pyopencl as cl
175+
context[name] = cl.array.to_device(queue, buf, allocator=allocator)
176+
recv_names_completed.add(name)
177+
178+
# {{{ main loop
179+
180+
while pids_to_execute:
181+
ready_pids = {pid
182+
for pid in pids_to_execute
183+
# FIXME: Only O(n**2) altogether. Nobody is going to notice, right?
184+
if partition.parts[pid].needed_pids <= pids_executed
185+
and (set(partition.parts[pid].input_name_to_recv_node)
186+
<= recv_names_completed)}
187+
for pid in ready_pids:
188+
part = partition.parts[pid]
189+
exec_ready_part(part)
190+
191+
for p in part.all_input_names():
192+
partition_input_names_refcount[p] -= 1
193+
if partition_input_names_refcount[p] == 0:
194+
del context[p]
195+
196+
if not ready_pids:
197+
wait_for_some_recvs()
198+
199+
# }}}
200+
201+
for send_req in send_requests:
202+
send_req.Wait()
203+
204+
if __debug__:
205+
for name, count in partition_input_names_refcount.items():
206+
assert count == 0
207+
assert name not in context
208+
209+
return context
210+
211+
# }}}
212+
213+
# vim: foldmethod=marker

0 commit comments

Comments
 (0)