|
9 | 9 |
|
10 | 10 | #include "../common/type.h" // for EraseType, RestoreType |
11 | 11 | #include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler |
| 12 | +#include "broadcast.h" // for Broadcast |
12 | 13 | #include "comm.h" // for Comm, RestoreType |
13 | 14 | #include "comm_group.h" // for GlobalCommGroup |
14 | 15 | #include "xgboost/collective/result.h" // for Result |
@@ -75,4 +76,142 @@ template <typename T> |
75 | 76 | Allreduce(Context const* ctx, T* data, Op op) { |
76 | 77 | return Allreduce(ctx, linalg::MakeVec(data, 1), op); |
77 | 78 | } |
| 79 | + |
| 80 | +/** |
| 81 | + * @brief Allreduce a variable-length vector over `comm`. |
| 82 | + * |
| 83 | + * The method performs a tree reduction rooted at rank 0 using `redop`, then broadcasts |
| 84 | + * the result so every rank ends with the same reduced payload in `data`. |
| 85 | + * |
| 86 | + * `redop` must have the signature |
| 87 | + * `void(Fn(const Span<T const>& lhs, const Span<T const>& rhs, std::vector<T>* out))` and must |
| 88 | + * write the combined result into `out`. |
| 89 | + */ |
| 90 | +template <typename T, typename Fn> |
| 91 | +std::enable_if_t< |
| 92 | + std::is_invocable_v<Fn, common::Span<T const>, common::Span<T const>, std::vector<T>*>, Result> |
| 93 | +AllreduceV(Comm const& comm, std::vector<T>* data, Fn redop) { |
| 94 | + static_assert(std::is_standard_layout_v<T> && std::is_trivially_copyable_v<T>, |
| 95 | + "AllreduceV supports only standard-layout trivially-copyable types."); |
| 96 | + CHECK(data); |
| 97 | + if (!comm.IsDistributed() || comm.World() == 1) { |
| 98 | + return Success(); |
| 99 | + } |
| 100 | + |
| 101 | + auto const world = comm.World(); |
| 102 | + auto const rank = comm.Rank(); |
| 103 | + auto constexpr kRoot = 0; |
| 104 | + |
| 105 | + auto send = [&](std::int32_t peer, std::vector<T> const& vec) { |
| 106 | + std::int64_t n = static_cast<std::int64_t>(vec.size()); |
| 107 | + auto n_bytes = |
| 108 | + common::Span<std::int8_t const>{reinterpret_cast<std::int8_t const*>(&n), sizeof(n)}; |
| 109 | + return Success() << [&] { |
| 110 | + return comm.Chan(peer)->SendAll(n_bytes); |
| 111 | + } << [&] { |
| 112 | + if (n == 0) { |
| 113 | + return Success(); |
| 114 | + } |
| 115 | + auto payload_bytes = static_cast<std::size_t>(n) * sizeof(T); |
| 116 | + auto bytes = common::Span<std::int8_t const>{reinterpret_cast<std::int8_t const*>(vec.data()), |
| 117 | + payload_bytes}; |
| 118 | + return comm.Chan(peer)->SendAll(bytes); |
| 119 | + } << [&] { |
| 120 | + return comm.Chan(peer)->Block(); |
| 121 | + }; |
| 122 | + }; |
| 123 | + |
| 124 | + auto recv = [&](std::int32_t peer, std::vector<T>* out) { |
| 125 | + std::int64_t n = 0; |
| 126 | + auto n_bytes = common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(&n), sizeof(n)}; |
| 127 | + auto rc = Success() << [&] { |
| 128 | + return comm.Chan(peer)->RecvAll(n_bytes); |
| 129 | + } << [&] { |
| 130 | + return comm.Chan(peer)->Block(); |
| 131 | + }; |
| 132 | + if (!rc.OK()) { |
| 133 | + return rc; |
| 134 | + } |
| 135 | + CHECK_GE(n, 0); |
| 136 | + out->resize(static_cast<std::size_t>(n)); |
| 137 | + if (n == 0) { |
| 138 | + return Success(); |
| 139 | + } |
| 140 | + auto payload_bytes = static_cast<std::size_t>(n) * sizeof(T); |
| 141 | + auto bytes = |
| 142 | + common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(out->data()), payload_bytes}; |
| 143 | + return Success() << [&] { |
| 144 | + return comm.Chan(peer)->RecvAll(bytes); |
| 145 | + } << [&] { |
| 146 | + return comm.Chan(peer)->Block(); |
| 147 | + }; |
| 148 | + }; |
| 149 | + |
| 150 | + auto shifted_rank = rank; |
| 151 | + std::vector<T> incoming; |
| 152 | + std::vector<T> out; |
| 153 | + bool continue_reduce = true; |
| 154 | + for (std::int32_t step = 1; step < world; step <<= 1) { |
| 155 | + if (!continue_reduce) { |
| 156 | + continue; |
| 157 | + } |
| 158 | + if (shifted_rank % (step * 2) == step) { |
| 159 | + auto parent = shifted_rank - step; |
| 160 | + auto rc = send(parent, *data); |
| 161 | + if (!rc.OK()) { |
| 162 | + return Fail("AllreduceV failed to send data to parent.", std::move(rc)); |
| 163 | + } |
| 164 | + continue_reduce = false; |
| 165 | + continue; |
| 166 | + } |
| 167 | + if (shifted_rank % (step * 2) == 0 && shifted_rank + step < world) { |
| 168 | + auto child = shifted_rank + step; |
| 169 | + auto rc = recv(child, &incoming); |
| 170 | + if (!rc.OK()) { |
| 171 | + return Fail("AllreduceV failed to receive data from child.", std::move(rc)); |
| 172 | + } |
| 173 | + out.clear(); |
| 174 | + redop(common::Span<T const>{data->data(), data->size()}, |
| 175 | + common::Span<T const>{incoming.data(), incoming.size()}, &out); |
| 176 | + data->swap(out); |
| 177 | + } |
| 178 | + } |
| 179 | + |
| 180 | + std::int64_t reduced_size = static_cast<std::int64_t>(rank == kRoot ? data->size() : 0); |
| 181 | + auto rc = Broadcast(comm, common::Span<std::int64_t>{&reduced_size, 1}, kRoot); |
| 182 | + if (!rc.OK()) { |
| 183 | + return Fail("AllreduceV failed to broadcast reduced size.", std::move(rc)); |
| 184 | + } |
| 185 | + if (reduced_size == 0) { |
| 186 | + data->clear(); |
| 187 | + return Success(); |
| 188 | + } |
| 189 | + if (rank != kRoot) { |
| 190 | + data->resize(static_cast<std::size_t>(reduced_size)); |
| 191 | + } |
| 192 | + auto reduced = common::Span<T>{data->data(), static_cast<std::size_t>(reduced_size)}; |
| 193 | + rc = Broadcast(comm, reduced, kRoot); |
| 194 | + if (!rc.OK()) { |
| 195 | + return Fail("AllreduceV failed to broadcast reduced payload.", std::move(rc)); |
| 196 | + } |
| 197 | + return Success(); |
| 198 | +} |
| 199 | + |
| 200 | +template <typename T, typename Fn> |
| 201 | +std::enable_if_t< |
| 202 | + std::is_invocable_v<Fn, common::Span<T const>, common::Span<T const>, std::vector<T>*>, Result> |
| 203 | +AllreduceV(Context const* ctx, CommGroup const& comm, std::vector<T>* data, Fn redop) { |
| 204 | + if (!comm.IsDistributed()) { |
| 205 | + return Success(); |
| 206 | + } |
| 207 | + auto const& cctx = comm.Ctx(ctx, DeviceOrd::CPU()); |
| 208 | + return AllreduceV(cctx, data, redop); |
| 209 | +} |
| 210 | + |
| 211 | +template <typename T, typename Fn> |
| 212 | +std::enable_if_t< |
| 213 | + std::is_invocable_v<Fn, common::Span<T const>, common::Span<T const>, std::vector<T>*>, Result> |
| 214 | +AllreduceV(Context const* ctx, std::vector<T>* data, Fn redop) { |
| 215 | + return AllreduceV(ctx, *GlobalCommGroup(), data, redop); |
| 216 | +} |
78 | 217 | } // namespace xgboost::collective |
0 commit comments