Skip to content

Commit 83d405c

Browse files
authored
Do tree reductions instead of allgather in distributed quantile construction (#12061)
1 parent a62f19b commit 83d405c

5 files changed

Lines changed: 648 additions & 327 deletions

File tree

src/collective/allreduce.h

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "../common/type.h" // for EraseType, RestoreType
1111
#include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler
12+
#include "broadcast.h" // for Broadcast
1213
#include "comm.h" // for Comm, RestoreType
1314
#include "comm_group.h" // for GlobalCommGroup
1415
#include "xgboost/collective/result.h" // for Result
@@ -75,4 +76,142 @@ template <typename T>
7576
Allreduce(Context const* ctx, T* data, Op op) {
7677
return Allreduce(ctx, linalg::MakeVec(data, 1), op);
7778
}
79+
80+
/**
81+
* @brief Allreduce a variable-length vector over `comm`.
82+
*
83+
* The method performs a tree reduction rooted at rank 0 using `redop`, then broadcasts
84+
* the result so every rank ends with the same reduced payload in `data`.
85+
*
86+
* `redop` must have the signature
87+
* `void(Fn(const Span<T const>& lhs, const Span<T const>& rhs, std::vector<T>* out))` and must
88+
* write the combined result into `out`.
89+
*/
90+
template <typename T, typename Fn>
91+
std::enable_if_t<
92+
std::is_invocable_v<Fn, common::Span<T const>, common::Span<T const>, std::vector<T>*>, Result>
93+
AllreduceV(Comm const& comm, std::vector<T>* data, Fn redop) {
94+
static_assert(std::is_standard_layout_v<T> && std::is_trivially_copyable_v<T>,
95+
"AllreduceV supports only standard-layout trivially-copyable types.");
96+
CHECK(data);
97+
if (!comm.IsDistributed() || comm.World() == 1) {
98+
return Success();
99+
}
100+
101+
auto const world = comm.World();
102+
auto const rank = comm.Rank();
103+
auto constexpr kRoot = 0;
104+
105+
auto send = [&](std::int32_t peer, std::vector<T> const& vec) {
106+
std::int64_t n = static_cast<std::int64_t>(vec.size());
107+
auto n_bytes =
108+
common::Span<std::int8_t const>{reinterpret_cast<std::int8_t const*>(&n), sizeof(n)};
109+
return Success() << [&] {
110+
return comm.Chan(peer)->SendAll(n_bytes);
111+
} << [&] {
112+
if (n == 0) {
113+
return Success();
114+
}
115+
auto payload_bytes = static_cast<std::size_t>(n) * sizeof(T);
116+
auto bytes = common::Span<std::int8_t const>{reinterpret_cast<std::int8_t const*>(vec.data()),
117+
payload_bytes};
118+
return comm.Chan(peer)->SendAll(bytes);
119+
} << [&] {
120+
return comm.Chan(peer)->Block();
121+
};
122+
};
123+
124+
auto recv = [&](std::int32_t peer, std::vector<T>* out) {
125+
std::int64_t n = 0;
126+
auto n_bytes = common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(&n), sizeof(n)};
127+
auto rc = Success() << [&] {
128+
return comm.Chan(peer)->RecvAll(n_bytes);
129+
} << [&] {
130+
return comm.Chan(peer)->Block();
131+
};
132+
if (!rc.OK()) {
133+
return rc;
134+
}
135+
CHECK_GE(n, 0);
136+
out->resize(static_cast<std::size_t>(n));
137+
if (n == 0) {
138+
return Success();
139+
}
140+
auto payload_bytes = static_cast<std::size_t>(n) * sizeof(T);
141+
auto bytes =
142+
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(out->data()), payload_bytes};
143+
return Success() << [&] {
144+
return comm.Chan(peer)->RecvAll(bytes);
145+
} << [&] {
146+
return comm.Chan(peer)->Block();
147+
};
148+
};
149+
150+
auto shifted_rank = rank;
151+
std::vector<T> incoming;
152+
std::vector<T> out;
153+
bool continue_reduce = true;
154+
for (std::int32_t step = 1; step < world; step <<= 1) {
155+
if (!continue_reduce) {
156+
continue;
157+
}
158+
if (shifted_rank % (step * 2) == step) {
159+
auto parent = shifted_rank - step;
160+
auto rc = send(parent, *data);
161+
if (!rc.OK()) {
162+
return Fail("AllreduceV failed to send data to parent.", std::move(rc));
163+
}
164+
continue_reduce = false;
165+
continue;
166+
}
167+
if (shifted_rank % (step * 2) == 0 && shifted_rank + step < world) {
168+
auto child = shifted_rank + step;
169+
auto rc = recv(child, &incoming);
170+
if (!rc.OK()) {
171+
return Fail("AllreduceV failed to receive data from child.", std::move(rc));
172+
}
173+
out.clear();
174+
redop(common::Span<T const>{data->data(), data->size()},
175+
common::Span<T const>{incoming.data(), incoming.size()}, &out);
176+
data->swap(out);
177+
}
178+
}
179+
180+
std::int64_t reduced_size = static_cast<std::int64_t>(rank == kRoot ? data->size() : 0);
181+
auto rc = Broadcast(comm, common::Span<std::int64_t>{&reduced_size, 1}, kRoot);
182+
if (!rc.OK()) {
183+
return Fail("AllreduceV failed to broadcast reduced size.", std::move(rc));
184+
}
185+
if (reduced_size == 0) {
186+
data->clear();
187+
return Success();
188+
}
189+
if (rank != kRoot) {
190+
data->resize(static_cast<std::size_t>(reduced_size));
191+
}
192+
auto reduced = common::Span<T>{data->data(), static_cast<std::size_t>(reduced_size)};
193+
rc = Broadcast(comm, reduced, kRoot);
194+
if (!rc.OK()) {
195+
return Fail("AllreduceV failed to broadcast reduced payload.", std::move(rc));
196+
}
197+
return Success();
198+
}
199+
200+
template <typename T, typename Fn>
201+
std::enable_if_t<
202+
std::is_invocable_v<Fn, common::Span<T const>, common::Span<T const>, std::vector<T>*>, Result>
203+
AllreduceV(Context const* ctx, CommGroup const& comm, std::vector<T>* data, Fn redop) {
204+
if (!comm.IsDistributed()) {
205+
return Success();
206+
}
207+
auto const& cctx = comm.Ctx(ctx, DeviceOrd::CPU());
208+
return AllreduceV(cctx, data, redop);
209+
}
210+
211+
template <typename T, typename Fn>
212+
std::enable_if_t<
213+
std::is_invocable_v<Fn, common::Span<T const>, common::Span<T const>, std::vector<T>*>, Result>
214+
AllreduceV(Context const* ctx, std::vector<T>* data, Fn redop) {
215+
return AllreduceV(ctx, *GlobalCommGroup(), data, redop);
216+
}
78217
} // namespace xgboost::collective

0 commit comments

Comments
 (0)