|
| 1 | +""" |
| 2 | +.. currentmodule:: pytato |
| 3 | +
|
| 4 | +.. autofunction:: execute_distributed_partition |
| 5 | +""" |
| 6 | + |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +__copyright__ = """ |
| 10 | +Copyright (C) 2021 University of Illinois Board of Trustees |
| 11 | +""" |
| 12 | + |
| 13 | +__license__ = """ |
| 14 | +Permission is hereby granted, free of charge, to any person obtaining a copy |
| 15 | +of this software and associated documentation files (the "Software"), to deal |
| 16 | +in the Software without restriction, including without limitation the rights |
| 17 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 18 | +copies of the Software, and to permit persons to whom the Software is |
| 19 | +furnished to do so, subject to the following conditions: |
| 20 | +
|
| 21 | +The above copyright notice and this permission notice shall be included in |
| 22 | +all copies or substantial portions of the Software. |
| 23 | +
|
| 24 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 25 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 26 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 27 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 28 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 29 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
| 30 | +THE SOFTWARE. |
| 31 | +""" |
| 32 | + |
| 33 | +from typing import Any, Dict, Hashable, Tuple, Optional, TYPE_CHECKING |
| 34 | + |
| 35 | + |
| 36 | +from pytato.target import BoundProgram |
| 37 | +from pytato.scalar_expr import INT_CLASSES |
| 38 | + |
| 39 | +import numpy as np |
| 40 | + |
| 41 | + |
| 42 | +from pytato.distributed.nodes import ( |
| 43 | + DistributedRecv, DistributedSend) |
| 44 | +from pytato.distributed.partition import ( |
| 45 | + DistributedGraphPartition, DistributedGraphPart) |
| 46 | + |
| 47 | +import logging |
| 48 | +logger = logging.getLogger(__name__) |
| 49 | + |
| 50 | + |
| 51 | +if TYPE_CHECKING: |
| 52 | + import mpi4py.MPI |
| 53 | + |
| 54 | + |
| 55 | +# {{{ distributed execute |
| 56 | + |
| 57 | +def _post_receive(mpi_communicator: mpi4py.MPI.Comm, |
| 58 | + recv: DistributedRecv) -> Tuple[Any, np.ndarray[Any, Any]]: |
| 59 | + if not all(isinstance(dim, INT_CLASSES) for dim in recv.shape): |
| 60 | + raise NotImplementedError("Parametric shapes not supported yet.") |
| 61 | + |
| 62 | + assert isinstance(recv.comm_tag, int) |
| 63 | + # mypy is right here, size params in 'recv.shape' must be evaluated |
| 64 | + buf = np.empty(recv.shape, dtype=recv.dtype) # type: ignore[arg-type] |
| 65 | + |
| 66 | + return mpi_communicator.Irecv( |
| 67 | + buf=buf, source=recv.src_rank, tag=recv.comm_tag), buf |
| 68 | + |
| 69 | + |
| 70 | +def _mpi_send(mpi_communicator: Any, send_node: DistributedSend, |
| 71 | + data: np.ndarray[Any, Any]) -> Any: |
| 72 | + # Must use-non-blocking send, as blocking send may wait for a corresponding |
| 73 | + # receive to be posted (but if sending to self, this may only occur later). |
| 74 | + return mpi_communicator.Isend( |
| 75 | + data, dest=send_node.dest_rank, tag=send_node.comm_tag) |
| 76 | + |
| 77 | + |
| 78 | +def execute_distributed_partition( |
| 79 | + partition: DistributedGraphPartition, prg_per_partition: |
| 80 | + Dict[Hashable, BoundProgram], |
| 81 | + queue: Any, mpi_communicator: Any, |
| 82 | + *, |
| 83 | + allocator: Optional[Any] = None, |
| 84 | + input_args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: |
| 85 | + |
| 86 | + if input_args is None: |
| 87 | + input_args = {} |
| 88 | + |
| 89 | + from mpi4py import MPI |
| 90 | + |
| 91 | + if len(partition.parts) != 1: |
| 92 | + recv_names_tup, recv_requests_tup, recv_buffers_tup = zip(*[ |
| 93 | + (name,) + _post_receive(mpi_communicator, recv) |
| 94 | + for part in partition.parts.values() |
| 95 | + for name, recv in part.input_name_to_recv_node.items()]) |
| 96 | + recv_names = list(recv_names_tup) |
| 97 | + recv_requests = list(recv_requests_tup) |
| 98 | + recv_buffers = list(recv_buffers_tup) |
| 99 | + del recv_names_tup |
| 100 | + del recv_requests_tup |
| 101 | + del recv_buffers_tup |
| 102 | + else: |
| 103 | + # Only a single partition, no recv requests exist |
| 104 | + recv_names = [] |
| 105 | + recv_requests = [] |
| 106 | + recv_buffers = [] |
| 107 | + |
| 108 | + context: Dict[str, Any] = input_args.copy() |
| 109 | + |
| 110 | + pids_to_execute = set(partition.parts) |
| 111 | + pids_executed = set() |
| 112 | + recv_names_completed = set() |
| 113 | + send_requests = [] |
| 114 | + |
| 115 | + # {{{ Input name refcount |
| 116 | + |
| 117 | + # Keep a count on how often each input name is used |
| 118 | + # in order to be able to free them. |
| 119 | + |
| 120 | + from pytools import memoize_on_first_arg |
| 121 | + |
| 122 | + @memoize_on_first_arg |
| 123 | + def _get_partition_input_name_refcount(partition: DistributedGraphPartition) \ |
| 124 | + -> Dict[str, int]: |
| 125 | + partition_input_names_refcount: Dict[str, int] = {} |
| 126 | + for pid in set(partition.parts): |
| 127 | + for name in partition.parts[pid].all_input_names(): |
| 128 | + if name in partition_input_names_refcount: |
| 129 | + partition_input_names_refcount[name] += 1 |
| 130 | + else: |
| 131 | + partition_input_names_refcount[name] = 1 |
| 132 | + |
| 133 | + return partition_input_names_refcount |
| 134 | + |
| 135 | + partition_input_names_refcount = \ |
| 136 | + _get_partition_input_name_refcount(partition).copy() |
| 137 | + |
| 138 | + # }}} |
| 139 | + |
| 140 | + def exec_ready_part(part: DistributedGraphPart) -> None: |
| 141 | + inputs = {k: context[k] for k in part.all_input_names()} |
| 142 | + |
| 143 | + _evt, result_dict = prg_per_partition[part.pid](queue, |
| 144 | + allocator=allocator, |
| 145 | + **inputs) |
| 146 | + |
| 147 | + context.update(result_dict) |
| 148 | + |
| 149 | + for name, send_node in part.output_name_to_send_node.items(): |
| 150 | + # FIXME: pytato shouldn't depend on pyopencl |
| 151 | + if isinstance(context[name], np.ndarray): |
| 152 | + data = context[name] |
| 153 | + else: |
| 154 | + data = context[name].get(queue) |
| 155 | + send_requests.append(_mpi_send(mpi_communicator, send_node, data)) |
| 156 | + |
| 157 | + pids_executed.add(part.pid) |
| 158 | + pids_to_execute.remove(part.pid) |
| 159 | + |
| 160 | + def wait_for_some_recvs() -> None: |
| 161 | + complete_recv_indices = MPI.Request.Waitsome(recv_requests) |
| 162 | + |
| 163 | + # Waitsome is allowed to return None |
| 164 | + if not complete_recv_indices: |
| 165 | + complete_recv_indices = [] |
| 166 | + |
| 167 | + # reverse to preserve indices |
| 168 | + for idx in sorted(complete_recv_indices, reverse=True): |
| 169 | + name = recv_names.pop(idx) |
| 170 | + recv_requests.pop(idx) |
| 171 | + buf = recv_buffers.pop(idx) |
| 172 | + |
| 173 | + # FIXME: pytato shouldn't depend on pyopencl |
| 174 | + import pyopencl as cl |
| 175 | + context[name] = cl.array.to_device(queue, buf, allocator=allocator) |
| 176 | + recv_names_completed.add(name) |
| 177 | + |
| 178 | + # {{{ main loop |
| 179 | + |
| 180 | + while pids_to_execute: |
| 181 | + ready_pids = {pid |
| 182 | + for pid in pids_to_execute |
| 183 | + # FIXME: Only O(n**2) altogether. Nobody is going to notice, right? |
| 184 | + if partition.parts[pid].needed_pids <= pids_executed |
| 185 | + and (set(partition.parts[pid].input_name_to_recv_node) |
| 186 | + <= recv_names_completed)} |
| 187 | + for pid in ready_pids: |
| 188 | + part = partition.parts[pid] |
| 189 | + exec_ready_part(part) |
| 190 | + |
| 191 | + for p in part.all_input_names(): |
| 192 | + partition_input_names_refcount[p] -= 1 |
| 193 | + if partition_input_names_refcount[p] == 0: |
| 194 | + del context[p] |
| 195 | + |
| 196 | + if not ready_pids: |
| 197 | + wait_for_some_recvs() |
| 198 | + |
| 199 | + # }}} |
| 200 | + |
| 201 | + for send_req in send_requests: |
| 202 | + send_req.Wait() |
| 203 | + |
| 204 | + if __debug__: |
| 205 | + for name, count in partition_input_names_refcount.items(): |
| 206 | + assert count == 0 |
| 207 | + assert name not in context |
| 208 | + |
| 209 | + return context |
| 210 | + |
| 211 | +# }}} |
| 212 | + |
| 213 | +# vim: foldmethod=marker |
0 commit comments