Skip to content

Commit abb1270

Browse files
authored
Fix AllreduceV with CUDA stream. (#12171)
We use async NCCL to implement timeout. However, when NCCL is in async mode, it uses a thread pool, and cannot work with per-thread CUDA stream, which resolves to the wrong per-thread stream in its pool. The existing allreduce implementation works because we use a custom CUDA stream in the NCCL coll wrapper. This PR moves that stream into NCCLComm, to expose it to the allreduce V implementation.
1 parent 5e7b49c commit abb1270

23 files changed

Lines changed: 343 additions & 237 deletions

plugin/federated/federated_coll.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023, XGBoost contributors
2+
* Copyright 2023-2026, XGBoost contributors
33
*/
44
#include "federated_coll.h"
55

@@ -61,7 +61,8 @@ Coll *FederatedColl::MakeCUDAVar() {
6161
}
6262
#endif
6363

64-
[[nodiscard]] Result FederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
64+
[[nodiscard]] Result FederatedColl::Allreduce(Context const * /*ctx*/, Comm const &comm,
65+
common::Span<std::int8_t> data,
6566
ArrayInterfaceHandler::Type type, Op op) {
6667
using namespace federated; // NOLINT
6768
auto fed = dynamic_cast<FederatedComm const *>(&comm);
@@ -87,12 +88,13 @@ Coll *FederatedColl::MakeCUDAVar() {
8788
return Success();
8889
}
8990

90-
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
91-
std::int32_t root) {
91+
[[nodiscard]] Result FederatedColl::Broadcast(Context const * /*ctx*/, Comm const &comm,
92+
common::Span<std::int8_t> data, std::int32_t root) {
9293
return BroadcastImpl(comm, &this->sequence_number_, data, root);
9394
}
9495

95-
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
96+
[[nodiscard]] Result FederatedColl::Allgather(Context const * /*ctx*/, Comm const &comm,
97+
common::Span<std::int8_t> data) {
9698
using namespace federated; // NOLINT
9799
auto fed = dynamic_cast<FederatedComm const *>(&comm);
98100
CHECK(fed);
@@ -120,7 +122,7 @@ Coll *FederatedColl::MakeCUDAVar() {
120122
return Success();
121123
}
122124

123-
[[nodiscard]] Result FederatedColl::AllgatherV(Comm const &comm,
125+
[[nodiscard]] Result FederatedColl::AllgatherV(Context const * /*ctx*/, Comm const &comm,
124126
common::Span<std::int8_t const> data,
125127
common::Span<std::int64_t const>,
126128
common::Span<std::int64_t>,

plugin/federated/federated_coll.cu

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023, XGBoost Contributors
2+
* Copyright 2023-2026, XGBoost Contributors
33
*/
44
#include <cstdint> // for int8_t, int32_t
55
#include <memory> // for dynamic_pointer_cast
@@ -18,7 +18,8 @@ Coll *FederatedColl::MakeCUDAVar() {
1818
return new CUDAFederatedColl{std::dynamic_pointer_cast<FederatedColl>(this->shared_from_this())};
1919
}
2020

21-
[[nodiscard]] Result CUDAFederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
21+
[[nodiscard]] Result CUDAFederatedColl::Allreduce(Context const *ctx, Comm const &comm,
22+
common::Span<std::int8_t> data,
2223
ArrayInterfaceHandler::Type type, Op op) {
2324
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
2425
CHECK(cufed);
@@ -29,14 +30,15 @@ Coll *FederatedColl::MakeCUDAVar() {
2930
return GetCUDAResult(
3031
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
3132
} << [&] {
32-
return p_impl_->Allreduce(comm, common::Span{h_data.data(), h_data.size()}, type, op);
33+
return p_impl_->Allreduce(ctx, comm, common::Span{h_data.data(), h_data.size()}, type, op);
3334
} << [&] {
3435
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
3536
cudaMemcpyHostToDevice, cufed->Stream()));
3637
};
3738
}
3839

39-
[[nodiscard]] Result CUDAFederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
40+
[[nodiscard]] Result CUDAFederatedColl::Broadcast(Context const *ctx, Comm const &comm,
41+
common::Span<std::int8_t> data,
4042
std::int32_t root) {
4143
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
4244
CHECK(cufed);
@@ -46,14 +48,15 @@ Coll *FederatedColl::MakeCUDAVar() {
4648
return GetCUDAResult(
4749
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
4850
} << [&] {
49-
return p_impl_->Broadcast(comm, common::Span{h_data.data(), h_data.size()}, root);
51+
return p_impl_->Broadcast(ctx, comm, common::Span{h_data.data(), h_data.size()}, root);
5052
} << [&] {
5153
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
5254
cudaMemcpyHostToDevice, cufed->Stream()));
5355
};
5456
}
5557

56-
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
58+
[[nodiscard]] Result CUDAFederatedColl::Allgather(Context const *ctx, Comm const &comm,
59+
common::Span<std::int8_t> data) {
5760
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
5861
CHECK(cufed);
5962
std::vector<std::int8_t> h_data(data.size());
@@ -62,16 +65,19 @@ Coll *FederatedColl::MakeCUDAVar() {
6265
return GetCUDAResult(
6366
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
6467
} << [&] {
65-
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()});
68+
return p_impl_->Allgather(ctx, comm, common::Span{h_data.data(), h_data.size()});
6669
} << [&] {
6770
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
6871
cudaMemcpyHostToDevice, cufed->Stream()));
6972
};
7073
}
7174

72-
[[nodiscard]] Result CUDAFederatedColl::AllgatherV(
73-
Comm const &comm, common::Span<std::int8_t const> data, common::Span<std::int64_t const> sizes,
74-
common::Span<std::int64_t> recv_segments, common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
75+
[[nodiscard]] Result CUDAFederatedColl::AllgatherV(Context const *ctx, Comm const &comm,
76+
common::Span<std::int8_t const> data,
77+
common::Span<std::int64_t const> sizes,
78+
common::Span<std::int64_t> recv_segments,
79+
common::Span<std::int8_t> recv,
80+
AllgatherVAlgo algo) {
7581
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
7682
CHECK(cufed);
7783

@@ -82,7 +88,7 @@ Coll *FederatedColl::MakeCUDAVar() {
8288
return GetCUDAResult(
8389
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
8490
} << [&] {
85-
return this->p_impl_->AllgatherV(comm, h_data, sizes, recv_segments, h_recv, algo);
91+
return this->p_impl_->AllgatherV(ctx, comm, h_data, sizes, recv_segments, h_recv, algo);
8692
} << [&] {
8793
return GetCUDAResult(cudaMemcpyAsync(recv.data(), h_recv.data(), h_recv.size(),
8894
cudaMemcpyHostToDevice, cufed->Stream()));

plugin/federated/federated_coll.cuh

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost contributors
2+
* Copyright 2023-2026, XGBoost contributors
33
*/
44
#include "../../src/collective/comm.h" // for Comm, Coll
55
#include "federated_coll.h" // for FederatedColl
@@ -12,12 +12,15 @@ class CUDAFederatedColl : public Coll {
1212

1313
public:
1414
explicit CUDAFederatedColl(std::shared_ptr<FederatedColl> pimpl) : p_impl_{std::move(pimpl)} {}
15-
[[nodiscard]] Result Allreduce(Comm const &comm, common::Span<std::int8_t> data,
16-
ArrayInterfaceHandler::Type type, Op op) override;
17-
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
18-
std::int32_t root) override;
19-
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
20-
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
15+
[[nodiscard]] Result Allreduce(Context const *ctx, Comm const &comm,
16+
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type type,
17+
Op op) override;
18+
[[nodiscard]] Result Broadcast(Context const *ctx, Comm const &comm,
19+
common::Span<std::int8_t> data, std::int32_t root) override;
20+
[[nodiscard]] Result Allgather(Context const *ctx, Comm const &,
21+
common::Span<std::int8_t> data) override;
22+
[[nodiscard]] Result AllgatherV(Context const *ctx, Comm const &comm,
23+
common::Span<std::int8_t const> data,
2124
common::Span<std::int64_t const> sizes,
2225
common::Span<std::int64_t> recv_segments,
2326
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;

plugin/federated/federated_coll.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
/**
2-
* Copyright 2023-2024, XGBoost contributors
2+
* Copyright 2023-2026, XGBoost contributors
33
*/
44
#pragma once
5-
#include "../../src/collective/coll.h" // for Coll
6-
#include "../../src/collective/comm.h" // for Comm
5+
#include "../../src/collective/coll.h" // for Coll
6+
#include "../../src/collective/comm.h" // for Comm
77

88
namespace xgboost::collective {
99
class FederatedColl : public Coll {
@@ -13,12 +13,14 @@ class FederatedColl : public Coll {
1313
public:
1414
Coll *MakeCUDAVar() override;
1515

16-
[[nodiscard]] Result Allreduce(Comm const &, common::Span<std::int8_t> data,
16+
[[nodiscard]] Result Allreduce(Context const *ctx, Comm const &, common::Span<std::int8_t> data,
1717
ArrayInterfaceHandler::Type type, Op op) override;
18-
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
19-
std::int32_t root) override;
20-
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
21-
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
18+
[[nodiscard]] Result Broadcast(Context const *ctx, Comm const &comm,
19+
common::Span<std::int8_t> data, std::int32_t root) override;
20+
[[nodiscard]] Result Allgather(Context const *ctx, Comm const &,
21+
common::Span<std::int8_t> data) override;
22+
[[nodiscard]] Result AllgatherV(Context const *ctx, Comm const &comm,
23+
common::Span<std::int8_t const> data,
2224
common::Span<std::int64_t const> sizes,
2325
common::Span<std::int64_t> recv_segments,
2426
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;

src/collective/allgather.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ template <typename T>
115115

116116
auto const& cctx = comm.Ctx(ctx, data.Device());
117117
auto backend = comm.Backend(data.Device());
118-
return backend->Allgather(cctx, erased);
118+
return backend->Allgather(ctx, cctx, erased);
119119
}
120120

121121
/**
@@ -144,7 +144,7 @@ template <typename T>
144144
sizes[comm.Rank()] = data.Values().size_bytes();
145145
auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()});
146146
auto rc =
147-
comm.Backend(DeviceOrd::CPU())->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
147+
comm.Backend(DeviceOrd::CPU())->Allgather(ctx, comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
148148
if (!rc.OK()) {
149149
return rc;
150150
}
@@ -161,8 +161,8 @@ template <typename T>
161161
auto erased = common::EraseType(data.Values());
162162

163163
return backend->AllgatherV(
164-
comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments,
165-
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(),
164+
ctx, comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()},
165+
s_segments, data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(),
166166
data.Device().IsCUDA() ? AllgatherVAlgo::kBcast : AllgatherVAlgo::kRing);
167167
}
168168

src/collective/allreduce.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ template <typename T, std::int32_t kDim>
5555
auto type = ToDType<T>::kType;
5656

5757
auto backend = comm.Backend(data.Device());
58-
return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op);
58+
return backend->Allreduce(ctx, comm.Ctx(ctx, data.Device()), erased, type, op);
5959
}
6060

6161
template <typename T, std::int32_t kDim>
@@ -366,8 +366,8 @@ template <typename T, typename Fn>
366366
std::enable_if_t<std::is_invocable_v<Fn, dh::device_vector<T> const&, dh::device_vector<T> const&,
367367
dh::device_vector<T>*, cudaStream_t>,
368368
Result>
369-
AllreduceV(Comm const& comm, dh::device_vector<T>* data, AllreduceVScratch<T>* scratch,
370-
Fn&& redop) {
369+
AllreduceV(Context const* ctx, Comm const& comm, dh::device_vector<T>* data,
370+
AllreduceVScratch<T>* scratch, Fn&& redop) {
371371
if (!comm.IsDistributed() || comm.World() == 1) {
372372
return Success();
373373
}
@@ -377,7 +377,7 @@ AllreduceV(Comm const& comm, dh::device_vector<T>* data, AllreduceVScratch<T>* s
377377
return Fail("Distributed GPU AllreduceV requires NCCL support.");
378378
}
379379

380-
return gpu_impl::AllreduceV(*nccl, data, scratch, std::forward<Fn>(redop));
380+
return gpu_impl::AllreduceV(ctx, *nccl, data, scratch, std::forward<Fn>(redop));
381381
}
382382

383383
template <typename T, typename Fn>
@@ -396,7 +396,7 @@ AllreduceV(Context const* ctx, CommGroup const& comm, dh::device_vector<T>* data
396396
auto const& cctx = comm.Ctx(ctx, ctx->Device());
397397
auto nccl = dynamic_cast<NCCLComm const*>(&cctx);
398398
if (nccl != nullptr) {
399-
return gpu_impl::AllreduceV(*nccl, data, scratch, std::forward<Fn>(redop));
399+
return gpu_impl::AllreduceV(ctx, *nccl, data, scratch, std::forward<Fn>(redop));
400400
}
401401
return gpu_detail::AllreduceVHostFallback(ctx, comm, data, scratch, std::forward<Fn>(redop));
402402
}

0 commit comments

Comments
 (0)