From d08d77f728be729b0a4729f34aba6d795eb551fc Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 17 Apr 2026 22:24:28 +0800 Subject: [PATCH 1/9] Fix AllreduceV with CUDA stream. We use async NCCL to implement timeout. However, when NCCL is in async mode, it uses a thread pool, and cannot work with per-thread CUDA stream, which resolves to the wrong per-thread stream in its pool. The existing allreduce implementation works because we use a custom CUDA stream in the NCCL coll wrapper. This PR moves that stream into NCCLComm, to expose it to the allreduce V implementation. --- src/collective/allreduce.h | 7 ++- src/collective/allreduce_v.cuh | 97 +++++++++++++++++++++------------- src/collective/coll.cu | 48 ++++++----------- src/collective/coll.cuh | 4 +- src/collective/comm.cu | 8 +-- src/collective/comm.cuh | 10 ++-- 6 files changed, 94 insertions(+), 80 deletions(-) diff --git a/src/collective/allreduce.h b/src/collective/allreduce.h index 88852aef030c..2dd6d095e8b6 100644 --- a/src/collective/allreduce.h +++ b/src/collective/allreduce.h @@ -377,7 +377,9 @@ AllreduceV(Comm const& comm, dh::device_vector* data, AllreduceVScratch* s return Fail("Distributed GPU AllreduceV requires NCCL support."); } - return gpu_impl::AllreduceV(*nccl, data, scratch, std::forward(redop)); + // No separate user stream available; let the NCCL stream drive everything. + // Cross-stream events inside `gpu_impl::AllreduceV` degenerate to same-stream no-ops. + return gpu_impl::AllreduceV(nccl->Stream(), *nccl, data, scratch, std::forward(redop)); } template @@ -396,7 +398,8 @@ AllreduceV(Context const* ctx, CommGroup const& comm, dh::device_vector* data auto const& cctx = comm.Ctx(ctx, ctx->Device()); auto nccl = dynamic_cast(&cctx); if (nccl != nullptr) { - return gpu_impl::AllreduceV(*nccl, data, scratch, std::forward(redop)); + return gpu_impl::AllreduceV(ctx->CUDACtx()->Stream(), *nccl, data, scratch, + std::forward(redop)); } return gpu_detail::AllreduceVHostFallback(ctx, comm, data, scratch, std::forward(redop)); } diff --git a/src/collective/allreduce_v.cuh b/src/collective/allreduce_v.cuh index a7b53ab3ed55..10da4425f3bd 100644 --- a/src/collective/allreduce_v.cuh +++ b/src/collective/allreduce_v.cuh @@ -11,6 +11,7 @@ #include #include "../common/device_helpers.cuh" // for device_vector +#include "../common/utils.h" // for MakeCleanup #include "comm_group.h" // for GlobalCommGroup #include "comm.cuh" // for NCCLComm, GetCUDAResult #include "topo.h" // for binomial tree helpers @@ -31,13 +32,33 @@ struct AllreduceVScratch { } }; +// Bracket a block of NCCL work with CUDA events so that `nccl_stream` sees any prior +// writes on `user_stream`, and `user_stream`'s subsequent reads see the NCCL work. Events +// are recorded on the caller's thread, so magic stream handles (cudaStreamPerThread) are +// resolved on the correct thread. +template +std::enable_if_t, Result>, Result> BracketNccl( + curt::StreamRef user_stream, curt::StreamRef nccl_stream, Fn&& fn) { + curt::Event before; + before.Record(user_stream); + nccl_stream.Wait(before); + + auto after = common::MakeCleanup([&] { + curt::Event ev; + ev.Record(nccl_stream); + user_stream.Wait(ev); + }); + + return std::forward(fn)(); +} + template std::enable_if_t< std::is_invocable_v const&, dh::device_vector const&, dh::device_vector*, cudaStream_t>, Result> -AllreduceV(NCCLComm const& nccl, dh::device_vector* data, AllreduceVScratch* scratch, - Fn&& redop) { +AllreduceV(curt::StreamRef user_stream, NCCLComm const& nccl, dh::device_vector* data, + AllreduceVScratch* scratch, Fn&& redop) { static_assert(std::is_standard_layout_v && std::is_trivially_copyable_v, "AllreduceV requires trivially-copyable payload elements."); CHECK(data); @@ -50,7 +71,8 @@ AllreduceV(NCCLComm const& nccl, dh::device_vector* data, AllreduceVScratchsize.resize(1); } - auto stream = cudaStream_t{nccl.Stream()}; + auto nccl_stream = nccl.Stream(); + auto stream = cudaStream_t{nccl_stream}; // Nonblocking NCCL communicators can keep returning `ncclInProgress` after the p2p launch. // Wait for communicator progress here so the next tree edge doesn't race the previous one. auto wait_p2p = [&] { return BusyWait(nccl.Stub(), nccl.Handle(), nccl.Timeout()); }; @@ -61,11 +83,7 @@ AllreduceV(NCCLComm const& nccl, dh::device_vector* data, AllreduceVScratchBlock(); + return wait_p2p(); }; auto recv_all = [&](std::int32_t peer, std::int8_t* ptr, std::size_t n_bytes) { @@ -74,11 +92,7 @@ AllreduceV(NCCLComm const& nccl, dh::device_vector* data, AllreduceVScratchBlock(); + return wait_p2p(); }; auto send_size = [&](std::int32_t peer, std::int64_t n) { @@ -106,31 +120,37 @@ AllreduceV(NCCLComm const& nccl, dh::device_vector* data, AllreduceVScratchsize` + // copies are already stream-ordered on the NCCL stream. auto send_vec = [&](std::int32_t peer, dh::device_vector const& payload) { - auto rc = send_size(peer, static_cast(payload.size())); - if (!rc.OK() || payload.empty()) { - return rc; - } - - auto count = payload.size() * sizeof(T); - return send_all(peer, reinterpret_cast(payload.data().get()), count); + return BracketNccl(user_stream, nccl_stream, [&]() -> Result { + auto rc = send_size(peer, static_cast(payload.size())); + if (!rc.OK() || payload.empty()) { + return rc; + } + auto count = payload.size() * sizeof(T); + return send_all(peer, reinterpret_cast(payload.data().get()), count); + }); }; auto recv_vec = [&](std::int32_t peer, dh::device_vector* payload) { CHECK(payload); - std::int64_t n = 0; - auto rc = recv_size(peer, &n); - if (!rc.OK()) { - return rc; - } - CHECK_GE(n, 0); - payload->resize(static_cast(n)); - if (n == 0) { - return Success(); - } - - auto count = static_cast(n) * sizeof(T); - return recv_all(peer, reinterpret_cast(payload->data().get()), count); + return BracketNccl(user_stream, nccl_stream, [&]() -> Result { + std::int64_t n = 0; + auto rc = recv_size(peer, &n); + if (!rc.OK()) { + return rc; + } + CHECK_GE(n, 0); + payload->resize(static_cast(n)); + if (n == 0) { + return Success(); + } + auto count = static_cast(n) * sizeof(T); + return recv_all(peer, reinterpret_cast(payload->data().get()), count); + }); }; auto rank = nccl.Rank(); @@ -156,7 +176,9 @@ AllreduceV(NCCLComm const& nccl, dh::device_vector* data, AllreduceVScratchpayload, &scratch->next, stream); + // `recv_vec`'s BracketNccl already made `user_stream` wait for the NCCL kernel, so + // `redop` may run freely on `user_stream`. + redop(*data, scratch->payload, &scratch->next, cudaStream_t{user_stream}); std::swap(*data, scratch->next); } } @@ -166,9 +188,10 @@ AllreduceV(NCCLComm const& nccl, dh::device_vector* data, AllreduceVScratchBackend(DeviceOrd::CUDA(0)); auto broadcast = [&](void* ptr, std::size_t n_bytes) { - return coll->Broadcast(nccl, common::Span{reinterpret_cast(ptr), - n_bytes}, - kRoot); + return BracketNccl(user_stream, nccl_stream, [&] { + return coll->Broadcast( + nccl, common::Span{reinterpret_cast(ptr), n_bytes}, kRoot); + }); }; if (rank == kRoot) { n = static_cast(data->size()); diff --git a/src/collective/coll.cu b/src/collective/coll.cu index d327e03ba29b..3bc0bd41db40 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -91,17 +91,8 @@ struct Chan { template > [[nodiscard]] std::enable_if_t, Result> AsyncLaunch( - common::ThreadPool* pool, NCCLComm const* nccl, std::shared_ptr stub, - curt::StreamRef stream, Fn&& fn) { - curt::Event e0; - e0.Record(nccl->Stream()); - stream.Wait(e0); - - auto cleanup = common::MakeCleanup([&] { - curt::Event e1; - e1.Record(stream); - nccl->Stream().Wait(e1); - }); + common::ThreadPool* pool, NCCLComm const* nccl, std::shared_ptr stub, Fn&& fn) { + auto stream = nccl->Stream(); Chan chan; @@ -195,14 +186,13 @@ void RunBitwiseAllreduce(curt::StreamRef stream, common::Span out_b } [[nodiscard]] Result BitwiseAllReduce(common::ThreadPool* pool, NCCLComm const* pcomm, - common::Span data, Op op, - curt::StreamRef stream) { + common::Span data, Op op) { dh::device_vector buffer(data.size() * pcomm->World()); auto* device_buffer = buffer.data().get(); auto stub = pcomm->Stub(); // First gather data from all the workers. - auto rc = AsyncLaunch(pool, pcomm, stub, stream, [&](curt::StreamRef s) { + auto rc = AsyncLaunch(pool, pcomm, stub, [&](curt::StreamRef s) { return stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, pcomm->Handle(), s); }); if (!rc.OK()) { @@ -259,16 +249,15 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { return Success() << [&] { if (IsBitwiseOp(op)) { - return BitwiseAllReduce(&this->pool_, nccl, data, op, this->stream_.View()); + return BitwiseAllReduce(&this->pool_, nccl, data, op); } else { return DispatchDType(type, [&](auto t) { using T = decltype(t); auto rdata = common::RestoreType(data); - return AsyncLaunch( - &this->pool_, nccl, stub, this->stream_.View(), [&](curt::StreamRef s) { - return stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type), - GetNCCLRedOp(op), nccl->Handle(), s); - }); + return AsyncLaunch(&this->pool_, nccl, stub, [&](curt::StreamRef s) { + return stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type), + GetNCCLRedOp(op), nccl->Handle(), s); + }); }); } } << [&] { @@ -286,11 +275,10 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { auto stub = nccl->Stub(); return Success() << [&] { - return AsyncLaunch(&this->pool_, nccl, stub, this->stream_.View(), - [data, nccl, root, stub](curt::StreamRef s) { - return stub->Broadcast(data.data(), data.data(), data.size_bytes(), - ncclInt8, root, nccl->Handle(), s); - }); + return AsyncLaunch(&this->pool_, nccl, stub, [data, nccl, root, stub](curt::StreamRef s) { + return stub->Broadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root, + nccl->Handle(), s); + }); } << [&] { return nccl->Block(); }; @@ -307,11 +295,9 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { auto send = data.subspan(comm.Rank() * size, size); return Success() << [&] { - return AsyncLaunch(&this->pool_, nccl, stub, this->stream_.View(), - [send, data, size, nccl, stub](curt::StreamRef s) { - return stub->Allgather(send.data(), data.data(), size, ncclInt8, - nccl->Handle(), s); - }); + return AsyncLaunch(&this->pool_, nccl, stub, [send, data, size, nccl, stub](curt::StreamRef s) { + return stub->Allgather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), s); + }); } << [&] { return nccl->Block(); }; @@ -381,7 +367,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, curt::StreamRef s, }; } case AllgatherVAlgo::kBcast: { - return AsyncLaunch(&this->pool_, nccl, stub, this->stream_.View(), [&](curt::StreamRef s) { + return AsyncLaunch(&this->pool_, nccl, stub, [&](curt::StreamRef s) { return cuda_impl::BroadcastAllgatherV(nccl, s, data, sizes, recv); }); } diff --git a/src/collective/coll.cuh b/src/collective/coll.cuh index 084f89402866..649cf152eec2 100644 --- a/src/collective/coll.cuh +++ b/src/collective/coll.cuh @@ -1,11 +1,10 @@ /** - * Copyright 2023-2025, XGBoost Contributors + * Copyright 2023-2026, XGBoost Contributors */ #pragma once #include // for int8_t, int64_t -#include "../common/cuda_stream.h" // for Stream #include "../common/threadpool.h" // for ThreadPool #include "../data/array_interface.h" // for ArrayInterfaceHandler #include "coll.h" // for Coll @@ -15,7 +14,6 @@ namespace xgboost::collective { class NCCLColl : public Coll { common::ThreadPool pool_; - curt::Stream stream_; public: NCCLColl(); diff --git a/src/collective/comm.cu b/src/collective/comm.cu index 627f9cb27935..e40f3f4aed78 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -7,7 +7,6 @@ #include // for uint64_t, int8_t #include // for memcpy #include // for shared_ptr -#include // for stringstream #include // for vector #include "../common/cuda_context.cuh" // for CUDAContext @@ -47,7 +46,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p StringView nccl_path) : Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(), root.TaskID()}, - stream_{ctx->CUDACtx()->Stream()} { + stream_{} { this->world_ = root.World(); this->rank_ = root.Rank(); this->domain_ = root.Domain(); @@ -97,11 +96,14 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p // Keep point-to-point channel launches on the communicator stream so helper-local staging // work and the NCCL send/recv edges share one ordering domain. this->channels_.emplace_back( - std::make_shared(root, r, nccl_comm_, stub_, stream_)); + std::make_shared(root, r, nccl_comm_, stub_, stream_.View())); } } NCCLComm::~NCCLComm() { + // Drain any kernels NCCL's non-blocking async thread may have queued on our stream + // before `stream_` destructs the underlying cudaStream_t. + (void)stream_.Sync(); if (nccl_comm_) { auto rc = Success() << [this] { return this->stub_->CommFinalize(this->nccl_comm_); diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh index 4ccdc8b6a480..4501c2a4ee78 100644 --- a/src/collective/comm.cuh +++ b/src/collective/comm.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023-2025, XGBoost Contributors + * Copyright 2023-2026, XGBoost Contributors */ #pragma once @@ -30,10 +30,12 @@ inline Result GetCUDAResult(cudaError rc) { #if defined(XGBOOST_USE_NCCL) class NCCLComm : public Comm { private: - ncclComm_t nccl_comm_{nullptr}; + // stream_ is declared first so it is destroyed LAST, after stub_/nccl_comm_ and after + // the base class's channels_. + curt::Stream stream_; std::shared_ptr stub_; + ncclComm_t nccl_comm_{nullptr}; ncclUniqueId nccl_unique_id_{}; - curt::StreamRef stream_; std::string nccl_path_; public: @@ -48,7 +50,7 @@ class NCCLComm : public Comm { } ~NCCLComm() override; [[nodiscard]] bool IsFederated() const override { return false; } - [[nodiscard]] curt::StreamRef Stream() const { return stream_; } + [[nodiscard]] curt::StreamRef Stream() const { return stream_.View(); } [[nodiscard]] Result Block() const override { auto rc = this->Stream().Sync(false); return GetCUDAResult(rc); From af87a85778795807c0a97065c6572dbb30eea719 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 17 Apr 2026 22:27:45 +0800 Subject: [PATCH 2/9] lint. --- src/collective/coll.cu | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/collective/coll.cu b/src/collective/coll.cu index 3bc0bd41db40..5a53ccfe39bb 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -2,17 +2,17 @@ * Copyright 2023-2025, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) -#include // for chrono, chrono_literals -#include // for size_t -#include // for int8_t, int64_t -#include // for bit_and, bit_or, bit_xor -#include // for future, future_status -#include // for shared_ptr -#include // for mutex, unique_lock -#include // for string -#include // for this_thread -#include // for invoke_result_t, is_same_v, enable_if_t -#include // for move +#include // for chrono, chrono_literals +#include // for size_t +#include // for int8_t, int64_t +#include // for bit_and, bit_or, bit_xor +#include // for future, future_status +#include // for shared_ptr +#include // for mutex, unique_lock +#include // for string +#include // for this_thread +#include // for invoke_result_t, is_same_v, enable_if_t +#include // for move #include "../common/cuda_stream.h" // for StreamRef, Event #include "../common/device_helpers.cuh" // for device_vector From ec6e67ed8fbb01087c2d16083529424579d203ad Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 18 Apr 2026 01:04:29 +0800 Subject: [PATCH 3/9] thread through the context. --- plugin/federated/federated_coll.cc | 14 ++++++---- plugin/federated/federated_coll.cu | 24 +++++++++------- plugin/federated/federated_coll.cuh | 17 +++++++----- plugin/federated/federated_coll.h | 14 ++++++---- src/collective/allgather.h | 8 +++--- src/collective/allreduce.h | 5 ++-- src/collective/allreduce_v.cuh | 30 ++++---------------- src/collective/broadcast.h | 4 +-- src/collective/coll.cc | 15 ++++++---- src/collective/coll.cu | 38 ++++++++++++++++---------- src/collective/coll.cuh | 13 +++++---- src/collective/coll.h | 4 ++- tests/cpp/collective/test_allgather.cu | 8 ++++-- tests/cpp/collective/test_allreduce.cu | 8 +++--- 14 files changed, 107 insertions(+), 95 deletions(-) diff --git a/plugin/federated/federated_coll.cc b/plugin/federated/federated_coll.cc index b62abdada5a5..bdd90d00ba5f 100644 --- a/plugin/federated/federated_coll.cc +++ b/plugin/federated/federated_coll.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost contributors + * Copyright 2023-2026, XGBoost contributors */ #include "federated_coll.h" @@ -61,7 +61,8 @@ Coll *FederatedColl::MakeCUDAVar() { } #endif -[[nodiscard]] Result FederatedColl::Allreduce(Comm const &comm, common::Span data, +[[nodiscard]] Result FederatedColl::Allreduce(Context const * /*ctx*/, Comm const &comm, + common::Span data, ArrayInterfaceHandler::Type type, Op op) { using namespace federated; // NOLINT auto fed = dynamic_cast(&comm); @@ -87,12 +88,13 @@ Coll *FederatedColl::MakeCUDAVar() { return Success(); } -[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span data, - std::int32_t root) { +[[nodiscard]] Result FederatedColl::Broadcast(Context const * /*ctx*/, Comm const &comm, + common::Span data, std::int32_t root) { return BroadcastImpl(comm, &this->sequence_number_, data, root); } -[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span data) { +[[nodiscard]] Result FederatedColl::Allgather(Context const * /*ctx*/, Comm const &comm, + common::Span data) { using namespace federated; // NOLINT auto fed = dynamic_cast(&comm); CHECK(fed); @@ -120,7 +122,7 @@ Coll *FederatedColl::MakeCUDAVar() { return Success(); } -[[nodiscard]] Result FederatedColl::AllgatherV(Comm const &comm, +[[nodiscard]] Result FederatedColl::AllgatherV(Context const * /*ctx*/, Comm const &comm, common::Span data, common::Span, common::Span, diff --git a/plugin/federated/federated_coll.cu b/plugin/federated/federated_coll.cu index 3f604c50d2d2..979254dbbadc 100644 --- a/plugin/federated/federated_coll.cu +++ b/plugin/federated/federated_coll.cu @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2026, XGBoost Contributors */ #include // for int8_t, int32_t #include // for dynamic_pointer_cast @@ -18,7 +18,8 @@ Coll *FederatedColl::MakeCUDAVar() { return new CUDAFederatedColl{std::dynamic_pointer_cast(this->shared_from_this())}; } -[[nodiscard]] Result CUDAFederatedColl::Allreduce(Comm const &comm, common::Span data, +[[nodiscard]] Result CUDAFederatedColl::Allreduce(Context const *ctx, Comm const &comm, + common::Span data, ArrayInterfaceHandler::Type type, Op op) { auto cufed = dynamic_cast(&comm); CHECK(cufed); @@ -29,14 +30,15 @@ Coll *FederatedColl::MakeCUDAVar() { return GetCUDAResult( cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); } << [&] { - return p_impl_->Allreduce(comm, common::Span{h_data.data(), h_data.size()}, type, op); + return p_impl_->Allreduce(ctx, comm, common::Span{h_data.data(), h_data.size()}, type, op); } << [&] { return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(), cudaMemcpyHostToDevice, cufed->Stream())); }; } -[[nodiscard]] Result CUDAFederatedColl::Broadcast(Comm const &comm, common::Span data, +[[nodiscard]] Result CUDAFederatedColl::Broadcast(Context const *ctx, Comm const &comm, + common::Span data, std::int32_t root) { auto cufed = dynamic_cast(&comm); CHECK(cufed); @@ -46,14 +48,15 @@ Coll *FederatedColl::MakeCUDAVar() { return GetCUDAResult( cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); } << [&] { - return p_impl_->Broadcast(comm, common::Span{h_data.data(), h_data.size()}, root); + return p_impl_->Broadcast(ctx, comm, common::Span{h_data.data(), h_data.size()}, root); } << [&] { return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(), cudaMemcpyHostToDevice, cufed->Stream())); }; } -[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span data) { +[[nodiscard]] Result CUDAFederatedColl::Allgather(Context const *ctx, Comm const &comm, + common::Span data) { auto cufed = dynamic_cast(&comm); CHECK(cufed); std::vector h_data(data.size()); @@ -62,7 +65,7 @@ Coll *FederatedColl::MakeCUDAVar() { return GetCUDAResult( cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); } << [&] { - return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}); + return p_impl_->Allgather(ctx, comm, common::Span{h_data.data(), h_data.size()}); } << [&] { return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(), cudaMemcpyHostToDevice, cufed->Stream())); @@ -70,8 +73,9 @@ Coll *FederatedColl::MakeCUDAVar() { } [[nodiscard]] Result CUDAFederatedColl::AllgatherV( - Comm const &comm, common::Span data, common::Span sizes, - common::Span recv_segments, common::Span recv, AllgatherVAlgo algo) { + Context const *ctx, Comm const &comm, common::Span data, + common::Span sizes, common::Span recv_segments, + common::Span recv, AllgatherVAlgo algo) { auto cufed = dynamic_cast(&comm); CHECK(cufed); @@ -82,7 +86,7 @@ Coll *FederatedColl::MakeCUDAVar() { return GetCUDAResult( cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); } << [&] { - return this->p_impl_->AllgatherV(comm, h_data, sizes, recv_segments, h_recv, algo); + return this->p_impl_->AllgatherV(ctx, comm, h_data, sizes, recv_segments, h_recv, algo); } << [&] { return GetCUDAResult(cudaMemcpyAsync(recv.data(), h_recv.data(), h_recv.size(), cudaMemcpyHostToDevice, cufed->Stream())); diff --git a/plugin/federated/federated_coll.cuh b/plugin/federated/federated_coll.cuh index 6a690a33d889..6ae83feaa5e0 100644 --- a/plugin/federated/federated_coll.cuh +++ b/plugin/federated/federated_coll.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023-2024, XGBoost contributors + * Copyright 2023-2026, XGBoost contributors */ #include "../../src/collective/comm.h" // for Comm, Coll #include "federated_coll.h" // for FederatedColl @@ -12,12 +12,15 @@ class CUDAFederatedColl : public Coll { public: explicit CUDAFederatedColl(std::shared_ptr pimpl) : p_impl_{std::move(pimpl)} {} - [[nodiscard]] Result Allreduce(Comm const &comm, common::Span data, - ArrayInterfaceHandler::Type type, Op op) override; - [[nodiscard]] Result Broadcast(Comm const &comm, common::Span data, - std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const &, common::Span data) override; - [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data, + [[nodiscard]] Result Allreduce(Context const *ctx, Comm const &comm, + common::Span data, ArrayInterfaceHandler::Type type, + Op op) override; + [[nodiscard]] Result Broadcast(Context const *ctx, Comm const &comm, + common::Span data, std::int32_t root) override; + [[nodiscard]] Result Allgather(Context const *ctx, Comm const &, + common::Span data) override; + [[nodiscard]] Result AllgatherV(Context const *ctx, Comm const &comm, + common::Span data, common::Span sizes, common::Span recv_segments, common::Span recv, AllgatherVAlgo algo) override; diff --git a/plugin/federated/federated_coll.h b/plugin/federated/federated_coll.h index 12443a3e1b5a..219d35bf1985 100644 --- a/plugin/federated/federated_coll.h +++ b/plugin/federated/federated_coll.h @@ -1,5 +1,5 @@ /** - * Copyright 2023-2024, XGBoost contributors + * Copyright 2023-2026, XGBoost contributors */ #pragma once #include "../../src/collective/coll.h" // for Coll @@ -13,12 +13,14 @@ class FederatedColl : public Coll { public: Coll *MakeCUDAVar() override; - [[nodiscard]] Result Allreduce(Comm const &, common::Span data, + [[nodiscard]] Result Allreduce(Context const *ctx, Comm const &, common::Span data, ArrayInterfaceHandler::Type type, Op op) override; - [[nodiscard]] Result Broadcast(Comm const &comm, common::Span data, - std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const &, common::Span data) override; - [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data, + [[nodiscard]] Result Broadcast(Context const *ctx, Comm const &comm, + common::Span data, std::int32_t root) override; + [[nodiscard]] Result Allgather(Context const *ctx, Comm const &, + common::Span data) override; + [[nodiscard]] Result AllgatherV(Context const *ctx, Comm const &comm, + common::Span data, common::Span sizes, common::Span recv_segments, common::Span recv, AllgatherVAlgo algo) override; diff --git a/src/collective/allgather.h b/src/collective/allgather.h index 28cd488fb571..bd01e19904a4 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -115,7 +115,7 @@ template auto const& cctx = comm.Ctx(ctx, data.Device()); auto backend = comm.Backend(data.Device()); - return backend->Allgather(cctx, erased); + return backend->Allgather(ctx, cctx, erased); } /** @@ -144,7 +144,7 @@ template sizes[comm.Rank()] = data.Values().size_bytes(); auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()}); auto rc = - comm.Backend(DeviceOrd::CPU())->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes); + comm.Backend(DeviceOrd::CPU())->Allgather(ctx, comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes); if (!rc.OK()) { return rc; } @@ -161,8 +161,8 @@ template auto erased = common::EraseType(data.Values()); return backend->AllgatherV( - comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments, - data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), + ctx, comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, + s_segments, data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), data.Device().IsCUDA() ? AllgatherVAlgo::kBcast : AllgatherVAlgo::kRing); } diff --git a/src/collective/allreduce.h b/src/collective/allreduce.h index 2dd6d095e8b6..a9edf789ed6f 100644 --- a/src/collective/allreduce.h +++ b/src/collective/allreduce.h @@ -55,7 +55,7 @@ template auto type = ToDType::kType; auto backend = comm.Backend(data.Device()); - return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op); + return backend->Allreduce(ctx, comm.Ctx(ctx, data.Device()), erased, type, op); } template @@ -398,8 +398,7 @@ AllreduceV(Context const* ctx, CommGroup const& comm, dh::device_vector* data auto const& cctx = comm.Ctx(ctx, ctx->Device()); auto nccl = dynamic_cast(&cctx); if (nccl != nullptr) { - return gpu_impl::AllreduceV(ctx->CUDACtx()->Stream(), *nccl, data, scratch, - std::forward(redop)); + return gpu_impl::AllreduceV(ctx, *nccl, data, scratch, std::forward(redop)); } return gpu_detail::AllreduceVHostFallback(ctx, comm, data, scratch, std::forward(redop)); } diff --git a/src/collective/allreduce_v.cuh b/src/collective/allreduce_v.cuh index 10da4425f3bd..41f1fe3a6654 100644 --- a/src/collective/allreduce_v.cuh +++ b/src/collective/allreduce_v.cuh @@ -11,9 +11,8 @@ #include #include "../common/device_helpers.cuh" // for device_vector -#include "../common/utils.h" // for MakeCleanup +#include "comm.cuh" // for NCCLComm, BracketNccl #include "comm_group.h" // for GlobalCommGroup -#include "comm.cuh" // for NCCLComm, GetCUDAResult #include "topo.h" // for binomial tree helpers #include "xgboost/collective/result.h" #include "xgboost/logging.h" @@ -32,35 +31,16 @@ struct AllreduceVScratch { } }; -// Bracket a block of NCCL work with CUDA events so that `nccl_stream` sees any prior -// writes on `user_stream`, and `user_stream`'s subsequent reads see the NCCL work. Events -// are recorded on the caller's thread, so magic stream handles (cudaStreamPerThread) are -// resolved on the correct thread. -template -std::enable_if_t, Result>, Result> BracketNccl( - curt::StreamRef user_stream, curt::StreamRef nccl_stream, Fn&& fn) { - curt::Event before; - before.Record(user_stream); - nccl_stream.Wait(before); - - auto after = common::MakeCleanup([&] { - curt::Event ev; - ev.Record(nccl_stream); - user_stream.Wait(ev); - }); - - return std::forward(fn)(); -} - template std::enable_if_t< std::is_invocable_v const&, dh::device_vector const&, dh::device_vector*, cudaStream_t>, Result> -AllreduceV(curt::StreamRef user_stream, NCCLComm const& nccl, dh::device_vector* data, +AllreduceV(Context const* ctx, NCCLComm const& nccl, dh::device_vector* data, AllreduceVScratch* scratch, Fn&& redop) { static_assert(std::is_standard_layout_v && std::is_trivially_copyable_v, "AllreduceV requires trivially-copyable payload elements."); + CHECK(ctx); CHECK(data); CHECK(scratch); @@ -71,6 +51,7 @@ AllreduceV(curt::StreamRef user_stream, NCCLComm const& nccl, dh::device_vector< scratch->size.resize(1); } + auto user_stream = ctx->CUDACtx()->Stream(); auto nccl_stream = nccl.Stream(); auto stream = cudaStream_t{nccl_stream}; // Nonblocking NCCL communicators can keep returning `ncclInProgress` after the p2p launch. @@ -190,7 +171,8 @@ AllreduceV(curt::StreamRef user_stream, NCCLComm const& nccl, dh::device_vector< auto broadcast = [&](void* ptr, std::size_t n_bytes) { return BracketNccl(user_stream, nccl_stream, [&] { return coll->Broadcast( - nccl, common::Span{reinterpret_cast(ptr), n_bytes}, kRoot); + ctx, nccl, + common::Span{reinterpret_cast(ptr), n_bytes}, kRoot); }); }; if (rank == kRoot) { diff --git a/src/collective/broadcast.h b/src/collective/broadcast.h index 61cab8cdd8f6..7fc98c999cc7 100644 --- a/src/collective/broadcast.h +++ b/src/collective/broadcast.h @@ -1,5 +1,5 @@ /** - * Copyright 2023-2024, XGBoost Contributors + * Copyright 2023-2026, XGBoost Contributors */ #pragma once #include // for int32_t, int8_t @@ -37,7 +37,7 @@ template CHECK(data.Contiguous()); auto erased = common::EraseType(data.Values()); auto backend = comm.Backend(data.Device()); - return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root); + return backend->Broadcast(ctx, comm.Ctx(ctx, data.Device()), erased, root); } template diff --git a/src/collective/coll.cc b/src/collective/coll.cc index b720d09b7eb9..c539bafb20e4 100644 --- a/src/collective/coll.cc +++ b/src/collective/coll.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023-2024, XGBoost Contributors + * Copyright 2023-2026, XGBoost Contributors */ #include "coll.h" @@ -31,7 +31,8 @@ bool constexpr IsFloatingPointV() { #endif // defined(XGBOOST_USE_CUDA) } -[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span data, +[[nodiscard]] Result Coll::Allreduce(Context const* /*ctx*/, Comm const& comm, + common::Span data, ArrayInterfaceHandler::Type type, Op op) { namespace coll = ::xgboost::collective; @@ -103,16 +104,18 @@ bool constexpr IsFloatingPointV() { return std::move(rc) << [&] { return comm.Block(); }; } -[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span data, - std::int32_t root) { +[[nodiscard]] Result Coll::Broadcast(Context const* /*ctx*/, Comm const& comm, + common::Span data, std::int32_t root) { return cpu_impl::Broadcast(comm, data, root); } -[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span data) { +[[nodiscard]] Result Coll::Allgather(Context const* /*ctx*/, Comm const& comm, + common::Span data) { return RingAllgather(comm, data); } -[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span data, +[[nodiscard]] Result Coll::AllgatherV(Context const* /*ctx*/, Comm const& comm, + common::Span data, common::Span sizes, common::Span recv_segments, common::Span recv, AllgatherVAlgo algo) { diff --git a/src/collective/coll.cu b/src/collective/coll.cu index 5a53ccfe39bb..306d20485770 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -238,8 +238,10 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { } } // namespace -[[nodiscard]] Result NCCLColl::Allreduce(Comm const& comm, common::Span data, +[[nodiscard]] Result NCCLColl::Allreduce(Context const* ctx, Comm const& comm, + common::Span data, ArrayInterfaceHandler::Type type, Op op) { + CHECK(ctx); if (!comm.IsDistributed()) { return Success(); } @@ -249,12 +251,12 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { return Success() << [&] { if (IsBitwiseOp(op)) { - return BitwiseAllReduce(&this->pool_, nccl, data, op); + return BitwiseAllReduce(ctx, &this->pool_, nccl, data, op); } else { return DispatchDType(type, [&](auto t) { using T = decltype(t); auto rdata = common::RestoreType(data); - return AsyncLaunch(&this->pool_, nccl, stub, [&](curt::StreamRef s) { + return AsyncLaunch(ctx, &this->pool_, nccl, stub, [&](curt::StreamRef s) { return stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type), GetNCCLRedOp(op), nccl->Handle(), s); }); @@ -265,8 +267,9 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { }; } -[[nodiscard]] Result NCCLColl::Broadcast(Comm const& comm, common::Span data, - std::int32_t root) { +[[nodiscard]] Result NCCLColl::Broadcast(Context const* ctx, Comm const& comm, + common::Span data, std::int32_t root) { + CHECK(ctx); if (!comm.IsDistributed()) { return Success(); } @@ -275,16 +278,19 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { auto stub = nccl->Stub(); return Success() << [&] { - return AsyncLaunch(&this->pool_, nccl, stub, [data, nccl, root, stub](curt::StreamRef s) { - return stub->Broadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root, - nccl->Handle(), s); - }); + return AsyncLaunch(ctx, &this->pool_, nccl, stub, + [data, nccl, root, stub](curt::StreamRef s) { + return stub->Broadcast(data.data(), data.data(), data.size_bytes(), + ncclInt8, root, nccl->Handle(), s); + }); } << [&] { return nccl->Block(); }; } -[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span data) { +[[nodiscard]] Result NCCLColl::Allgather(Context const* ctx, Comm const& comm, + common::Span data) { + CHECK(ctx); if (!comm.IsDistributed()) { return Success(); } @@ -295,9 +301,11 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { auto send = data.subspan(comm.Rank() * size, size); return Success() << [&] { - return AsyncLaunch(&this->pool_, nccl, stub, [send, data, size, nccl, stub](curt::StreamRef s) { - return stub->Allgather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), s); - }); + return AsyncLaunch(ctx, &this->pool_, nccl, stub, + [send, data, size, nccl, stub](curt::StreamRef s) { + return stub->Allgather(send.data(), data.data(), size, ncclInt8, + nccl->Handle(), s); + }); } << [&] { return nccl->Block(); }; @@ -333,10 +341,12 @@ Result BroadcastAllgatherV(NCCLComm const* comm, curt::StreamRef s, } } // namespace cuda_impl -[[nodiscard]] Result NCCLColl::AllgatherV(Comm const& comm, common::Span data, +[[nodiscard]] Result NCCLColl::AllgatherV(Context const* ctx, Comm const& comm, + common::Span data, common::Span sizes, common::Span recv_segments, common::Span recv, AllgatherVAlgo algo) { + CHECK(ctx); auto nccl = dynamic_cast(&comm); CHECK(nccl); if (!comm.IsDistributed()) { diff --git a/src/collective/coll.cuh b/src/collective/coll.cuh index 649cf152eec2..f5645df70f7b 100644 --- a/src/collective/coll.cuh +++ b/src/collective/coll.cuh @@ -19,12 +19,15 @@ class NCCLColl : public Coll { NCCLColl(); ~NCCLColl() override; - [[nodiscard]] Result Allreduce(Comm const& comm, common::Span data, + [[nodiscard]] Result Allreduce(Context const* ctx, Comm const& comm, + common::Span data, ArrayInterfaceHandler::Type type, Op op) override; - [[nodiscard]] Result Broadcast(Comm const& comm, common::Span data, - std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const& comm, common::Span data) override; - [[nodiscard]] Result AllgatherV(Comm const& comm, common::Span data, + [[nodiscard]] Result Broadcast(Context const* ctx, Comm const& comm, + common::Span data, std::int32_t root) override; + [[nodiscard]] Result Allgather(Context const* ctx, Comm const& comm, + common::Span data) override; + [[nodiscard]] Result AllgatherV(Context const* ctx, Comm const& comm, + common::Span data, common::Span sizes, common::Span recv_segments, common::Span recv, AllgatherVAlgo algo) override; diff --git a/src/collective/coll.h b/src/collective/coll.h index 96fe35229510..18e172a89094 100644 --- a/src/collective/coll.h +++ b/src/collective/coll.h @@ -8,6 +8,7 @@ #include "../data/array_interface.h" // for ArrayInterfaceHandler #include "comm.h" // for Comm #include "xgboost/collective/result.h" // for Result +#include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span namespace xgboost::collective { @@ -60,7 +61,8 @@ class Coll : public std::enable_shared_from_this { * doesn't use the buffer. * @param [out] recv pre-allocated buffer for output. */ - [[nodiscard]] virtual Result AllgatherV(Comm const& comm, common::Span data, + [[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm, + common::Span data, common::Span sizes, common::Span recv_segments, common::Span recv, AllgatherVAlgo algo); diff --git a/tests/cpp/collective/test_allgather.cu b/tests/cpp/collective/test_allgather.cu index d0c34cdc3843..ebd33028033b 100644 --- a/tests/cpp/collective/test_allgather.cu +++ b/tests/cpp/collective/test_allgather.cu @@ -1,5 +1,5 @@ /** - * Copyright 2023-2024, XGBoost Contributors + * Copyright 2023-2026, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) #include @@ -40,7 +40,8 @@ class Worker : public NCCLWorkerForTest { auto s_result = common::EraseType(dh::ToSpan(result)); std::vector recv_seg(nccl_comm_->World() + 1, 0); - rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, + rc = nccl_coll_->AllgatherV(&ctx_, *nccl_comm_, s_data, + common::Span{sizes.data(), sizes.size()}, common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); SafeColl(rc); @@ -65,7 +66,8 @@ class Worker : public NCCLWorkerForTest { auto s_result = common::EraseType(dh::ToSpan(result)); std::vector recv_seg(nccl_comm_->World() + 1, 0); - rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, + rc = nccl_coll_->AllgatherV(&ctx_, *nccl_comm_, s_data, + common::Span{sizes.data(), sizes.size()}, common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); SafeColl(rc); // check segment size diff --git a/tests/cpp/collective/test_allreduce.cu b/tests/cpp/collective/test_allreduce.cu index 4b750acdcdca..650dc88e3c49 100644 --- a/tests/cpp/collective/test_allreduce.cu +++ b/tests/cpp/collective/test_allreduce.cu @@ -1,5 +1,5 @@ /** - * Copyright 2023-2024, XGBoost Contributors + * Copyright 2023-2026, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) #include @@ -41,7 +41,7 @@ class Worker : public NCCLWorkerForTest { void BitOr() { dh::device_vector data(comm_.World(), 0); data[comm_.Rank()] = ~std::uint32_t{0}; - auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), + auto rc = nccl_coll_->Allreduce(&ctx_, *nccl_comm_, common::EraseType(dh::ToSpan(data)), ArrayInterfaceHandler::kU4, Op::kBitwiseOR); SafeColl(rc); thrust::host_vector h_data(data.size()); @@ -53,7 +53,7 @@ class Worker : public NCCLWorkerForTest { void Acc() { dh::device_vector data(314, 1.5); - auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), + auto rc = nccl_coll_->Allreduce(&ctx_, *nccl_comm_, common::EraseType(dh::ToSpan(data)), ArrayInterfaceHandler::kF8, Op::kSum); SafeColl(rc); for (std::size_t i = 0; i < data.size(); ++i) { @@ -64,7 +64,7 @@ class Worker : public NCCLWorkerForTest { Result NoCheck() { dh::device_vector data(314, 1.5); - auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), + auto rc = nccl_coll_->Allreduce(&ctx_, *nccl_comm_, common::EraseType(dh::ToSpan(data)), ArrayInterfaceHandler::kF8, Op::kSum); return rc; } From 13ad8a74829fb8f4cdfd92aaf160f51324e8c6cd Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 18 Apr 2026 01:08:19 +0800 Subject: [PATCH 4/9] more. --- src/collective/coll.cu | 103 ++++++++++-------- src/collective/coll.h | 12 +- src/collective/comm.cu | 17 ++- src/collective/comm.cuh | 34 +++++- tests/cpp/collective/test_allgather.cc | 6 +- tests/cpp/collective/test_allreduce.cc | 4 +- .../plugin/federated/test_federated_coll.cc | 18 ++- .../plugin/federated/test_federated_coll.cu | 15 ++- 8 files changed, 129 insertions(+), 80 deletions(-) diff --git a/src/collective/coll.cu b/src/collective/coll.cu index 306d20485770..e4c8ca0e3c39 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -185,38 +185,43 @@ void RunBitwiseAllreduce(curt::StreamRef stream, common::Span out_b }); } -[[nodiscard]] Result BitwiseAllReduce(common::ThreadPool* pool, NCCLComm const* pcomm, - common::Span data, Op op) { +[[nodiscard]] Result BitwiseAllReduce(Context const* ctx, common::ThreadPool* pool, + NCCLComm const* pcomm, common::Span data, + Op op) { dh::device_vector buffer(data.size() * pcomm->World()); auto* device_buffer = buffer.data().get(); auto stub = pcomm->Stub(); - // First gather data from all the workers. - auto rc = AsyncLaunch(pool, pcomm, stub, [&](curt::StreamRef s) { - return stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, pcomm->Handle(), s); + // Outer bracket so the post-allgather reduce kernel (run on the NCCL + // stream) is synchronised back to the caller's stream. + return BracketNccl(ctx->CUDACtx()->Stream(), pcomm->Stream(), [&]() -> Result { + auto rc = AsyncLaunch(ctx, pool, pcomm, stub, [&](curt::StreamRef s) { + return stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, pcomm->Handle(), + s); + }); + if (!rc.OK()) { + return rc; + } + // Reduce on the NCCL stream (ordered after the allgather kernel queued + // by `AsyncLaunch`). + switch (op) { + case Op::kBitwiseAND: + RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, std::bit_and{}, pcomm->World(), + data.size()); + break; + case Op::kBitwiseOR: + RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, std::bit_or{}, pcomm->World(), + data.size()); + break; + case Op::kBitwiseXOR: + RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, std::bit_xor{}, pcomm->World(), + data.size()); + break; + default: + LOG(FATAL) << "Not a bitwise reduce operation."; + } + return Success(); }); - if (!rc.OK()) { - return rc; - } - - // Then reduce locally. - switch (op) { - case Op::kBitwiseAND: - RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, std::bit_and{}, pcomm->World(), - data.size()); - break; - case Op::kBitwiseOR: - RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, std::bit_or{}, pcomm->World(), - data.size()); - break; - case Op::kBitwiseXOR: - RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, std::bit_xor{}, pcomm->World(), - data.size()); - break; - default: - LOG(FATAL) << "Not a bitwise reduce operation."; - } - return Success(); } ncclRedOp_t GetNCCLRedOp(Op const& op) { @@ -356,28 +361,32 @@ Result BroadcastAllgatherV(NCCLComm const* comm, curt::StreamRef s, switch (algo) { case AllgatherVAlgo::kRing: { - return Success() << [&] { - return stub->GroupStart(); - } << [&] { - // get worker offset - detail::AllgatherVOffset(sizes, recv_segments); - // copy data - auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes()); - if (current.data() != data.data()) { - dh::safe_cuda(cudaMemcpyAsync(current.data(), data.data(), current.size_bytes(), - cudaMemcpyDeviceToDevice, nccl->Stream())); - } - return detail::RingAllgatherV(comm, sizes, recv_segments, recv); - } << [&] { - return stub->GroupEnd(); - } << [&] { - return nccl->Block(); - } << [&] { - return BusyWait(stub, nccl->Handle(), nccl->Timeout()); - }; + // kRing talks to `NCCLChannel` directly without `AsyncLaunch`; bracket + // with the caller's stream explicitly. + return BracketNccl(ctx->CUDACtx()->Stream(), nccl->Stream(), [&] { + return Success() << [&] { + return stub->GroupStart(); + } << [&] { + // get worker offset + detail::AllgatherVOffset(sizes, recv_segments); + // copy data + auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes()); + if (current.data() != data.data()) { + dh::safe_cuda(cudaMemcpyAsync(current.data(), data.data(), current.size_bytes(), + cudaMemcpyDeviceToDevice, nccl->Stream())); + } + return detail::RingAllgatherV(comm, sizes, recv_segments, recv); + } << [&] { + return stub->GroupEnd(); + } << [&] { + return nccl->Block(); + } << [&] { + return BusyWait(stub, nccl->Handle(), nccl->Timeout()); + }; + }); } case AllgatherVAlgo::kBcast: { - return AsyncLaunch(&this->pool_, nccl, stub, [&](curt::StreamRef s) { + return AsyncLaunch(ctx, &this->pool_, nccl, stub, [&](curt::StreamRef s) { return cuda_impl::BroadcastAllgatherV(nccl, s, data, sizes, recv); }); } diff --git a/src/collective/coll.h b/src/collective/coll.h index 18e172a89094..22f31edd1f02 100644 --- a/src/collective/coll.h +++ b/src/collective/coll.h @@ -35,22 +35,24 @@ class Coll : public std::enable_shared_from_this { * @param [in] op Reduce operation. For custom operation, user needs to reach down to * the CPU implementation. */ - [[nodiscard]] virtual Result Allreduce(Comm const& comm, common::Span data, + [[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm, + common::Span data, ArrayInterfaceHandler::Type type, Op op); /** * @brief Broadcast * * @param [in,out] data Data buffer for input and output. - * @param [in] root Root rank for broadcast. + * @param [in] root Root rank for broadcast. */ - [[nodiscard]] virtual Result Broadcast(Comm const& comm, common::Span data, - std::int32_t root); + [[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm, + common::Span data, std::int32_t root); /** * @brief Allgather * * @param [in,out] data Data buffer for input and output. */ - [[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span data); + [[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm, + common::Span data); /** * @brief Allgather with variable length. * diff --git a/src/collective/comm.cu b/src/collective/comm.cu index e40f3f4aed78..94d29f333cf7 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -9,7 +9,6 @@ #include // for shared_ptr #include // for vector -#include "../common/cuda_context.cuh" // for CUDAContext #include "../common/cuda_rt_utils.h" // for SetDevice, GetUuid, PrintUuid #include "../common/type.h" // for EraseType #include "comm.cuh" // for NCCLComm @@ -20,8 +19,8 @@ namespace xgboost::collective { namespace { -Result GetUniqueId(Comm const& comm, std::shared_ptr stub, std::shared_ptr coll, - ncclUniqueId* pid) { +Result GetUniqueId(Context const* ctx, Comm const& comm, std::shared_ptr stub, + std::shared_ptr coll, ncclUniqueId* pid) { static const int kRootRank = 0; ncclUniqueId id; if (comm.Rank() == kRootRank) { @@ -29,7 +28,8 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr stub, std::shared SafeColl(rc); } auto rc = coll->Broadcast( - comm, common::Span{reinterpret_cast(&id), sizeof(ncclUniqueId)}, kRootRank); + ctx, comm, common::Span{reinterpret_cast(&id), sizeof(ncclUniqueId)}, + kRootRank); if (!rc.OK()) { return rc; } @@ -46,7 +46,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p StringView nccl_path) : Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(), root.TaskID()}, - stream_{} { + stream_{ctx->Ordinal()} { this->world_ = root.World(); this->rank_ = root.Rank(); this->domain_ = root.Domain(); @@ -62,7 +62,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p auto s_this_uuid = s_uuid.subspan(root.Rank() * curt::kUuidLength, curt::kUuidLength); curt::GetUuid(s_this_uuid, ctx->Ordinal()); - auto rc = pimpl->Allgather(root, common::EraseType(s_uuid)); + auto rc = pimpl->Allgather(ctx, root, common::EraseType(s_uuid)); SafeColl(rc); std::vector> converted(root.World()); @@ -81,7 +81,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p << "device is not supported. " << curt::PrintUuid(s_this_uuid) << "\n"; rc = std::move(rc) << [&] { - return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_); + return GetUniqueId(ctx, root, this->stub_, pimpl, &nccl_unique_id_); } << [&] { ncclConfig_t config = NCCL_CONFIG_INITIALIZER; config.blocking = 0; @@ -101,8 +101,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p } NCCLComm::~NCCLComm() { - // Drain any kernels NCCL's non-blocking async thread may have queued on our stream - // before `stream_` destructs the underlying cudaStream_t. + // Drain pending NCCL kernels before `stream_` is destroyed. (void)stream_.Sync(); if (nccl_comm_) { auto rc = Success() << [this] { diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh index 4501c2a4ee78..12a4bc47be00 100644 --- a/src/collective/comm.cuh +++ b/src/collective/comm.cuh @@ -7,11 +7,13 @@ #include "nccl.h" #endif // XGBOOST_USE_NCCL -#include // for int32_t -#include // for shared_ptr -#include // for move +#include // for int32_t +#include // for shared_ptr +#include // for enable_if_t, invoke_result_t, is_same_v +#include // for move, forward -#include "../common/cuda_stream.h" // for StreamRef +#include "../common/cuda_stream.h" // for StreamRef, Stream, Event +#include "../common/utils.h" // for MakeCleanup #include "coll.h" #include "comm.h" #include "nccl_stub.h" // for NcclStub @@ -27,11 +29,31 @@ inline Result GetCUDAResult(cudaError rc) { return Fail(msg); } +#if defined(XGBOOST_USE_NCCL) +// Cross-stream bracket for a block of NCCL work: `nccl_stream` waits for +// prior `user_stream` work on entry, `user_stream` waits for the NCCL work +// on exit. Events are recorded on the calling thread. +template +[[nodiscard]] std::enable_if_t, Result>, Result> +BracketNccl(curt::StreamRef user_stream, curt::StreamRef nccl_stream, Fn&& fn) { + curt::Event before; + before.Record(user_stream); + nccl_stream.Wait(before); + + auto after = common::MakeCleanup([&] { + curt::Event ev; + ev.Record(nccl_stream); + user_stream.Wait(ev); + }); + + return std::forward(fn)(); +} +#endif // defined(XGBOOST_USE_NCCL) + #if defined(XGBOOST_USE_NCCL) class NCCLComm : public Comm { private: - // stream_ is declared first so it is destroyed LAST, after stub_/nccl_comm_ and after - // the base class's channels_. + // Declared first so it outlives stub_/nccl_comm_ and the base class's channels_. curt::Stream stream_; std::shared_ptr stub_; ncclComm_t nccl_comm_{nullptr}; diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index 7764a2adcc07..c1c6f690d855 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -113,7 +113,8 @@ class Worker : public WorkerForTest { auto s_recv = common::Span{recv.data(), recv.size()}; - rc = pcoll->AllgatherV(comm_, common::EraseType(s_data), + Context ctx; + rc = pcoll->AllgatherV(&ctx, comm_, common::EraseType(s_data), common::Span{sizes.data(), sizes.size()}, common::Span{recv_segments.data(), recv_segments.size()}, common::EraseType(s_recv), AllgatherVAlgo::kBcast); @@ -126,7 +127,8 @@ class Worker : public WorkerForTest { auto current = s_recv.subspan(recv_segments[comm_.Rank()], recv_segments[comm_.Rank() + 1] - recv_segments[comm_.Rank()]); std::copy_n(data.data(), data.size(), current.data()); - rc = pcoll->AllgatherV(comm_, common::EraseType(current), + Context ctx; + rc = pcoll->AllgatherV(&ctx, comm_, common::EraseType(current), common::Span{sizes.data(), sizes.size()}, common::Span{recv_segments.data(), recv_segments.size()}, common::EraseType(s_recv), algo); diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index c3744d557274..d35da490da5a 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -73,7 +73,9 @@ class AllreduceWorker : public WorkerForTest { std::vector data(comm_.World(), 0); data[comm_.Rank()] = ~std::uint32_t{0}; auto pcoll = std::shared_ptr{new Coll{}}; - auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}), + Context ctx; + auto rc = pcoll->Allreduce(&ctx, comm_, + common::EraseType(common::Span{data.data(), data.size()}), ArrayInterfaceHandler::kU4, Op::kBitwiseOR); SafeColl(rc); for (auto v : data) { diff --git a/tests/cpp/plugin/federated/test_federated_coll.cc b/tests/cpp/plugin/federated/test_federated_coll.cc index 6c5c74f4f087..1c6d83b6bbd6 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cc +++ b/tests/cpp/plugin/federated/test_federated_coll.cc @@ -26,7 +26,9 @@ TEST_F(FederatedCollTest, Allreduce) { [=](auto i) { return i * n_workers; }); auto coll = std::make_shared(); - auto rc = coll->Allreduce(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), + Context ctx; + auto rc = coll->Allreduce(&ctx, *comm, + common::EraseType(common::Span{buffer.data(), buffer.size()}), ArrayInterfaceHandler::kI4, Op::kSum); SafeColl(rc); for (auto i = 0; i < 5; i++) { @@ -39,14 +41,17 @@ TEST_F(FederatedCollTest, Broadcast) { std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u); TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t) { FederatedColl coll{}; + Context ctx; auto rc = Success(); if (comm->Rank() == 0) { std::string buffer{"hello"}; - rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0); + rc = coll.Broadcast(&ctx, *comm, + common::EraseType(common::Span{buffer.data(), buffer.size()}), 0); ASSERT_EQ(buffer, "hello"); } else { std::string buffer{" "}; - rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0); + rc = coll.Broadcast(&ctx, *comm, + common::EraseType(common::Span{buffer.data(), buffer.size()}), 0); ASSERT_EQ(buffer, "hello"); } SafeColl(rc); @@ -60,7 +65,9 @@ TEST_F(FederatedCollTest, Allgather) { std::vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); - auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()})); + Context ctx; + auto rc = coll.Allgather(&ctx, *comm, + common::EraseType(common::Span{buffer.data(), buffer.size()})); SafeColl(rc); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); @@ -80,8 +87,9 @@ TEST_F(FederatedCollTest, AllgatherV) { static_cast(inputs[1].size())}; r.resize(sizes[0] + sizes[1]); + Context ctx; auto rc = coll.AllgatherV( - *comm, + &ctx, *comm, common::EraseType(common::Span{inputs[comm->Rank()].data(), inputs[comm->Rank()].size()}), common::Span{sizes.data(), sizes.size()}, recv_segments, common::EraseType(common::Span{r.data(), r.size()}), AllgatherVAlgo::kRing); diff --git a/tests/cpp/plugin/federated/test_federated_coll.cu b/tests/cpp/plugin/federated/test_federated_coll.cu index 67bf0ebc66e2..b1dbc624f4b0 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cu +++ b/tests/cpp/plugin/federated/test_federated_coll.cu @@ -39,7 +39,8 @@ void TestAllreduce(std::shared_ptr comm, std::int32_t rank, std:: thrust::transform(buffer.cbegin(), buffer.cend(), expected.begin(), [=] XGBOOST_DEVICE(std::int32_t i) { return i * n_workers; }); - auto rc = w.coll->Allreduce(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), + Context ctx; + auto rc = w.coll->Allreduce(&ctx, *w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), ArrayInterfaceHandler::kI4, Op::kSum); SafeColl(rc); for (auto i = 0; i < 5; i++) { @@ -53,14 +54,15 @@ void TestBroadcast(std::shared_ptr comm, std::int32_t rank) { auto rc = Success(); std::vector expect{0, 1, 2, 3}; + Context ctx; if (comm->Rank() == 0) { dh::device_vector buffer{expect}; - rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0); + rc = w.coll->Broadcast(&ctx, *w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0); std::vector expect{0, 1, 2, 3}; ASSERT_EQ(buffer, expect); } else { dh::device_vector buffer(std::vector{4, 5, 6, 7}); - rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0); + rc = w.coll->Broadcast(&ctx, *w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0); ASSERT_EQ(buffer, expect); } SafeColl(rc); @@ -71,7 +73,8 @@ void TestAllgather(std::shared_ptr comm, std::int32_t rank, std:: dh::device_vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); - auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer))); + Context ctx; + auto rc = w.coll->Allgather(&ctx, *w.nccl_comm, common::EraseType(dh::ToSpan(buffer))); SafeColl(rc); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); @@ -89,7 +92,9 @@ void TestAllgatherV(std::shared_ptr comm, std::int32_t rank) { static_cast(inputs[1].size())}; r.resize(sizes[0] + sizes[1]); - auto rc = w.coll->AllgatherV(*w.nccl_comm, common::EraseType(dh::ToSpan(inputs[comm->Rank()])), + Context ctx; + auto rc = w.coll->AllgatherV(&ctx, *w.nccl_comm, + common::EraseType(dh::ToSpan(inputs[comm->Rank()])), common::Span{sizes.data(), sizes.size()}, recv_segments, common::EraseType(dh::ToSpan(r)), AllgatherVAlgo::kRing); SafeColl(rc); From 488241edf83eb33b39dabec2aa887e6693c7358f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 18 Apr 2026 01:14:40 +0800 Subject: [PATCH 5/9] more. --- src/collective/allreduce.h | 8 ++---- src/collective/coll.cu | 36 +++++++++++++++--------- src/common/cuda_stream.h | 4 +++ tests/cpp/collective/test_allreduce_v.cu | 2 +- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/collective/allreduce.h b/src/collective/allreduce.h index a9edf789ed6f..f38a26d92dc6 100644 --- a/src/collective/allreduce.h +++ b/src/collective/allreduce.h @@ -366,8 +366,8 @@ template std::enable_if_t const&, dh::device_vector const&, dh::device_vector*, cudaStream_t>, Result> -AllreduceV(Comm const& comm, dh::device_vector* data, AllreduceVScratch* scratch, - Fn&& redop) { +AllreduceV(Context const* ctx, Comm const& comm, dh::device_vector* data, + AllreduceVScratch* scratch, Fn&& redop) { if (!comm.IsDistributed() || comm.World() == 1) { return Success(); } @@ -377,9 +377,7 @@ AllreduceV(Comm const& comm, dh::device_vector* data, AllreduceVScratch* s return Fail("Distributed GPU AllreduceV requires NCCL support."); } - // No separate user stream available; let the NCCL stream drive everything. - // Cross-stream events inside `gpu_impl::AllreduceV` degenerate to same-stream no-ops. - return gpu_impl::AllreduceV(nccl->Stream(), *nccl, data, scratch, std::forward(redop)); + return gpu_impl::AllreduceV(ctx, *nccl, data, scratch, std::forward(redop)); } template diff --git a/src/collective/coll.cu b/src/collective/coll.cu index e4c8ca0e3c39..64b55239795c 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -14,6 +14,7 @@ #include // for invoke_result_t, is_same_v, enable_if_t #include // for move +#include "../common/cuda_context.cuh" // for CUDAContext #include "../common/cuda_stream.h" // for StreamRef, Event #include "../common/device_helpers.cuh" // for device_vector #include "../common/threadpool.h" // for ThreadPool @@ -91,8 +92,20 @@ struct Chan { template > [[nodiscard]] std::enable_if_t, Result> AsyncLaunch( - common::ThreadPool* pool, NCCLComm const* nccl, std::shared_ptr stub, Fn&& fn) { + Context const* ctx, common::ThreadPool* pool, NCCLComm const* nccl, + std::shared_ptr stub, Fn&& fn) { auto stream = nccl->Stream(); + auto user_stream = ctx->CUDACtx()->Stream(); + + curt::Event before; + before.Record(user_stream); + stream.Wait(before); + + auto user_after = common::MakeCleanup([&] { + curt::Event ev; + ev.Record(stream); + user_stream.Wait(ev); + }); Chan chan; @@ -196,8 +209,7 @@ void RunBitwiseAllreduce(curt::StreamRef stream, common::Span out_b // stream) is synchronised back to the caller's stream. return BracketNccl(ctx->CUDACtx()->Stream(), pcomm->Stream(), [&]() -> Result { auto rc = AsyncLaunch(ctx, pool, pcomm, stub, [&](curt::StreamRef s) { - return stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, pcomm->Handle(), - s); + return stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, pcomm->Handle(), s); }); if (!rc.OK()) { return rc; @@ -283,11 +295,10 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { auto stub = nccl->Stub(); return Success() << [&] { - return AsyncLaunch(ctx, &this->pool_, nccl, stub, - [data, nccl, root, stub](curt::StreamRef s) { - return stub->Broadcast(data.data(), data.data(), data.size_bytes(), - ncclInt8, root, nccl->Handle(), s); - }); + return AsyncLaunch(ctx, &this->pool_, nccl, stub, [data, nccl, root, stub](curt::StreamRef s) { + return stub->Broadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root, + nccl->Handle(), s); + }); } << [&] { return nccl->Block(); }; @@ -306,11 +317,10 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { auto send = data.subspan(comm.Rank() * size, size); return Success() << [&] { - return AsyncLaunch(ctx, &this->pool_, nccl, stub, - [send, data, size, nccl, stub](curt::StreamRef s) { - return stub->Allgather(send.data(), data.data(), size, ncclInt8, - nccl->Handle(), s); - }); + return AsyncLaunch( + ctx, &this->pool_, nccl, stub, [send, data, size, nccl, stub](curt::StreamRef s) { + return stub->Allgather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), s); + }); } << [&] { return nccl->Block(); }; diff --git a/src/common/cuda_stream.h b/src/common/cuda_stream.h index 3714b8cfa5e9..6b345b0b782a 100644 --- a/src/common/cuda_stream.h +++ b/src/common/cuda_stream.h @@ -90,6 +90,10 @@ class Stream { public: Stream() { dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); } + explicit Stream(std::int32_t device) { + dh::safe_cuda(cudaSetDevice(device)); + dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); + } ~Stream() { dh::safe_cuda(cudaStreamDestroy(stream_)); } [[nodiscard]] StreamRef View() const { return StreamRef{stream_}; } diff --git a/tests/cpp/collective/test_allreduce_v.cu b/tests/cpp/collective/test_allreduce_v.cu index 7538be8ce206..dea61ff2cf9e 100644 --- a/tests/cpp/collective/test_allreduce_v.cu +++ b/tests/cpp/collective/test_allreduce_v.cu @@ -129,7 +129,7 @@ class AllreduceVWorker : public NCCLWorkerForTest { template void TreeAllreduceV(dh::device_vector* data, collective::AllreduceVScratch* scratch) { auto rc = - collective::AllreduceV(*nccl_comm_, data, scratch, + collective::AllreduceV(&ctx_, *nccl_comm_, data, scratch, [&](dh::device_vector const& lhs, dh::device_vector const& rhs, dh::device_vector* out, cudaStream_t stream) { this->DeviceSumToMaxLen(lhs, rhs, out, stream); From 5a7f56c1239339373a4e86a6ad11ee68809bb0f3 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 18 Apr 2026 01:14:58 +0800 Subject: [PATCH 6/9] lint. --- plugin/federated/federated_coll.cu | 10 ++++++---- plugin/federated/federated_coll.h | 4 ++-- src/collective/coll.cc | 4 +++- src/collective/comm.cu | 6 +++--- tests/cpp/collective/test_allgather.cc | 2 +- tests/cpp/collective/test_allreduce.cc | 6 +++--- tests/cpp/plugin/federated/test_federated_coll.cc | 10 +++++----- tests/cpp/plugin/federated/test_federated_coll.cu | 8 ++++---- 8 files changed, 27 insertions(+), 23 deletions(-) diff --git a/plugin/federated/federated_coll.cu b/plugin/federated/federated_coll.cu index 979254dbbadc..f7ae45e5971f 100644 --- a/plugin/federated/federated_coll.cu +++ b/plugin/federated/federated_coll.cu @@ -72,10 +72,12 @@ Coll *FederatedColl::MakeCUDAVar() { }; } -[[nodiscard]] Result CUDAFederatedColl::AllgatherV( - Context const *ctx, Comm const &comm, common::Span data, - common::Span sizes, common::Span recv_segments, - common::Span recv, AllgatherVAlgo algo) { +[[nodiscard]] Result CUDAFederatedColl::AllgatherV(Context const *ctx, Comm const &comm, + common::Span data, + common::Span sizes, + common::Span recv_segments, + common::Span recv, + AllgatherVAlgo algo) { auto cufed = dynamic_cast(&comm); CHECK(cufed); diff --git a/plugin/federated/federated_coll.h b/plugin/federated/federated_coll.h index 219d35bf1985..3f28b2b32351 100644 --- a/plugin/federated/federated_coll.h +++ b/plugin/federated/federated_coll.h @@ -2,8 +2,8 @@ * Copyright 2023-2026, XGBoost contributors */ #pragma once -#include "../../src/collective/coll.h" // for Coll -#include "../../src/collective/comm.h" // for Comm +#include "../../src/collective/coll.h" // for Coll +#include "../../src/collective/comm.h" // for Comm namespace xgboost::collective { class FederatedColl : public Coll { diff --git a/src/collective/coll.cc b/src/collective/coll.cc index c539bafb20e4..2ac5edb1bfc4 100644 --- a/src/collective/coll.cc +++ b/src/collective/coll.cc @@ -101,7 +101,9 @@ bool constexpr IsFloatingPointV() { return Fail("Invalid op."); }); - return std::move(rc) << [&] { return comm.Block(); }; + return std::move(rc) << [&] { + return comm.Block(); + }; } [[nodiscard]] Result Coll::Broadcast(Context const* /*ctx*/, Comm const& comm, diff --git a/src/collective/comm.cu b/src/collective/comm.cu index 94d29f333cf7..c45069f03a6e 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -27,9 +27,9 @@ Result GetUniqueId(Context const* ctx, Comm const& comm, std::shared_ptrGetUniqueId(&id); SafeColl(rc); } - auto rc = coll->Broadcast( - ctx, comm, common::Span{reinterpret_cast(&id), sizeof(ncclUniqueId)}, - kRootRank); + auto rc = coll->Broadcast(ctx, comm, + common::Span{reinterpret_cast(&id), sizeof(ncclUniqueId)}, + kRootRank); if (!rc.OK()) { return rc; } diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index c1c6f690d855..51257edd6fb4 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -122,7 +122,7 @@ class Worker : public WorkerForTest { CheckV(s_recv); // Test inplace - auto test_inplace = [&] (AllgatherVAlgo algo) { + auto test_inplace = [&](AllgatherVAlgo algo) { std::fill_n(s_recv.data(), s_recv.size(), 0); auto current = s_recv.subspan(recv_segments[comm_.Rank()], recv_segments[comm_.Rank() + 1] - recv_segments[comm_.Rank()]); diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index d35da490da5a..1528c9ad2901 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -74,9 +74,9 @@ class AllreduceWorker : public WorkerForTest { data[comm_.Rank()] = ~std::uint32_t{0}; auto pcoll = std::shared_ptr{new Coll{}}; Context ctx; - auto rc = pcoll->Allreduce(&ctx, comm_, - common::EraseType(common::Span{data.data(), data.size()}), - ArrayInterfaceHandler::kU4, Op::kBitwiseOR); + auto rc = + pcoll->Allreduce(&ctx, comm_, common::EraseType(common::Span{data.data(), data.size()}), + ArrayInterfaceHandler::kU4, Op::kBitwiseOR); SafeColl(rc); for (auto v : data) { ASSERT_EQ(v, ~std::uint32_t{0}); diff --git a/tests/cpp/plugin/federated/test_federated_coll.cc b/tests/cpp/plugin/federated/test_federated_coll.cc index 1c6d83b6bbd6..ccaa554003e3 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cc +++ b/tests/cpp/plugin/federated/test_federated_coll.cc @@ -27,9 +27,9 @@ TEST_F(FederatedCollTest, Allreduce) { auto coll = std::make_shared(); Context ctx; - auto rc = coll->Allreduce(&ctx, *comm, - common::EraseType(common::Span{buffer.data(), buffer.size()}), - ArrayInterfaceHandler::kI4, Op::kSum); + auto rc = + coll->Allreduce(&ctx, *comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), + ArrayInterfaceHandler::kI4, Op::kSum); SafeColl(rc); for (auto i = 0; i < 5; i++) { ASSERT_EQ(buffer[i], expected[i]); @@ -66,8 +66,8 @@ TEST_F(FederatedCollTest, Allgather) { std::vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); Context ctx; - auto rc = coll.Allgather(&ctx, *comm, - common::EraseType(common::Span{buffer.data(), buffer.size()})); + auto rc = + coll.Allgather(&ctx, *comm, common::EraseType(common::Span{buffer.data(), buffer.size()})); SafeColl(rc); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); diff --git a/tests/cpp/plugin/federated/test_federated_coll.cu b/tests/cpp/plugin/federated/test_federated_coll.cu index b1dbc624f4b0..de22dc5bc3a8 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cu +++ b/tests/cpp/plugin/federated/test_federated_coll.cu @@ -93,10 +93,10 @@ void TestAllgatherV(std::shared_ptr comm, std::int32_t rank) { r.resize(sizes[0] + sizes[1]); Context ctx; - auto rc = w.coll->AllgatherV(&ctx, *w.nccl_comm, - common::EraseType(dh::ToSpan(inputs[comm->Rank()])), - common::Span{sizes.data(), sizes.size()}, recv_segments, - common::EraseType(dh::ToSpan(r)), AllgatherVAlgo::kRing); + auto rc = + w.coll->AllgatherV(&ctx, *w.nccl_comm, common::EraseType(dh::ToSpan(inputs[comm->Rank()])), + common::Span{sizes.data(), sizes.size()}, recv_segments, + common::EraseType(dh::ToSpan(r)), AllgatherVAlgo::kRing); SafeColl(rc); ASSERT_EQ(r[0], 1); From c1def2a33fd74bdeb0139c6f9dc21ba1c3e3c4df Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 18 Apr 2026 06:55:28 +0800 Subject: [PATCH 7/9] fixes. --- src/collective/comm.cu | 4 ++++ src/collective/comm.cuh | 5 +---- src/common/cuda_stream.h | 1 + tests/cpp/plugin/federated/test_federated_coll.cu | 8 ++++---- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/collective/comm.cu b/src/collective/comm.cu index c45069f03a6e..abb280c1173b 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -54,6 +54,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p return; } + CHECK(ctx && ctx->IsCUDA()); curt::SetDevice(ctx->Ordinal()); stub_ = std::make_shared(nccl_path); @@ -103,6 +104,9 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p NCCLComm::~NCCLComm() { // Drain pending NCCL kernels before `stream_` is destroyed. (void)stream_.Sync(); + // Release channels while `stream_` is still alive. + this->channels_.clear(); + if (nccl_comm_) { auto rc = Success() << [this] { return this->stub_->CommFinalize(this->nccl_comm_); diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh index 12a4bc47be00..9b317fc5b9a6 100644 --- a/src/collective/comm.cuh +++ b/src/collective/comm.cuh @@ -30,9 +30,6 @@ inline Result GetCUDAResult(cudaError rc) { } #if defined(XGBOOST_USE_NCCL) -// Cross-stream bracket for a block of NCCL work: `nccl_stream` waits for -// prior `user_stream` work on entry, `user_stream` waits for the NCCL work -// on exit. Events are recorded on the calling thread. template [[nodiscard]] std::enable_if_t, Result>, Result> BracketNccl(curt::StreamRef user_stream, curt::StreamRef nccl_stream, Fn&& fn) { @@ -53,7 +50,7 @@ BracketNccl(curt::StreamRef user_stream, curt::StreamRef nccl_stream, Fn&& fn) { #if defined(XGBOOST_USE_NCCL) class NCCLComm : public Comm { private: - // Declared first so it outlives stub_/nccl_comm_ and the base class's channels_. + // Declared first so among this class's own members it is destroyed last curt::Stream stream_; std::shared_ptr stub_; ncclComm_t nccl_comm_{nullptr}; diff --git a/src/common/cuda_stream.h b/src/common/cuda_stream.h index 6b345b0b782a..55373bad491c 100644 --- a/src/common/cuda_stream.h +++ b/src/common/cuda_stream.h @@ -91,6 +91,7 @@ class Stream { public: Stream() { dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); } explicit Stream(std::int32_t device) { + CHECK_GE(device, 0); dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); } diff --git a/tests/cpp/plugin/federated/test_federated_coll.cu b/tests/cpp/plugin/federated/test_federated_coll.cu index de22dc5bc3a8..29bdd6732dff 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cu +++ b/tests/cpp/plugin/federated/test_federated_coll.cu @@ -39,7 +39,7 @@ void TestAllreduce(std::shared_ptr comm, std::int32_t rank, std:: thrust::transform(buffer.cbegin(), buffer.cend(), expected.begin(), [=] XGBOOST_DEVICE(std::int32_t i) { return i * n_workers; }); - Context ctx; + auto ctx = MakeCUDACtx(rank); auto rc = w.coll->Allreduce(&ctx, *w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), ArrayInterfaceHandler::kI4, Op::kSum); SafeColl(rc); @@ -54,7 +54,7 @@ void TestBroadcast(std::shared_ptr comm, std::int32_t rank) { auto rc = Success(); std::vector expect{0, 1, 2, 3}; - Context ctx; + auto ctx = MakeCUDACtx(rank); if (comm->Rank() == 0) { dh::device_vector buffer{expect}; rc = w.coll->Broadcast(&ctx, *w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0); @@ -73,7 +73,7 @@ void TestAllgather(std::shared_ptr comm, std::int32_t rank, std:: dh::device_vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); - Context ctx; + auto ctx = MakeCUDACtx(rank); auto rc = w.coll->Allgather(&ctx, *w.nccl_comm, common::EraseType(dh::ToSpan(buffer))); SafeColl(rc); for (auto i = 0; i < n_workers; i++) { @@ -92,7 +92,7 @@ void TestAllgatherV(std::shared_ptr comm, std::int32_t rank) { static_cast(inputs[1].size())}; r.resize(sizes[0] + sizes[1]); - Context ctx; + auto ctx = MakeCUDACtx(rank); auto rc = w.coll->AllgatherV(&ctx, *w.nccl_comm, common::EraseType(dh::ToSpan(inputs[comm->Rank()])), common::Span{sizes.data(), sizes.size()}, recv_segments, From 8b3e109399e42e1679732875804136e1e693a305 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 18 Apr 2026 18:40:26 +0800 Subject: [PATCH 8/9] Restore device. --- src/collective/comm.cu | 1 - src/common/cuda_stream.cc | 16 ++++++++++++++++ src/common/cuda_stream.h | 6 +----- 3 files changed, 17 insertions(+), 6 deletions(-) create mode 100644 src/common/cuda_stream.cc diff --git a/src/collective/comm.cu b/src/collective/comm.cu index abb280c1173b..62554236c25f 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -54,7 +54,6 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p return; } - CHECK(ctx && ctx->IsCUDA()); curt::SetDevice(ctx->Ordinal()); stub_ = std::make_shared(nccl_path); diff --git a/src/common/cuda_stream.cc b/src/common/cuda_stream.cc new file mode 100644 index 000000000000..18efe1eec388 --- /dev/null +++ b/src/common/cuda_stream.cc @@ -0,0 +1,16 @@ +/** + * Copyright 2024-2026, XGBoost contributors + */ +#include "cuda_stream.h" + +#include "cuda_rt_utils.h" // for CurrentDevice +#include "utils.h" + +namespace xgboost::curt { +Stream::Stream(std::int32_t device) { + std::int32_t cur = CurrentDevice(); + auto guard = common::MakeCleanup([=] { SetDevice(cur); }); + SetDevice(device); + dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); +} +} // namespace xgboost::curt diff --git a/src/common/cuda_stream.h b/src/common/cuda_stream.h index 55373bad491c..f142918a2a99 100644 --- a/src/common/cuda_stream.h +++ b/src/common/cuda_stream.h @@ -90,11 +90,7 @@ class Stream { public: Stream() { dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); } - explicit Stream(std::int32_t device) { - CHECK_GE(device, 0); - dh::safe_cuda(cudaSetDevice(device)); - dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); - } + explicit Stream(std::int32_t device); ~Stream() { dh::safe_cuda(cudaStreamDestroy(stream_)); } [[nodiscard]] StreamRef View() const { return StreamRef{stream_}; } From da4bb6a4dc696340ceba8de49b9f322c6f42c2d2 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 18 Apr 2026 18:47:04 +0800 Subject: [PATCH 9/9] cpu build. --- src/common/cuda_stream.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/common/cuda_stream.cc b/src/common/cuda_stream.cc index 18efe1eec388..2fd9d7e2ba89 100644 --- a/src/common/cuda_stream.cc +++ b/src/common/cuda_stream.cc @@ -1,6 +1,7 @@ /** * Copyright 2024-2026, XGBoost contributors */ +#if defined(XGBOOST_USE_CUDA) #include "cuda_stream.h" #include "cuda_rt_utils.h" // for CurrentDevice @@ -14,3 +15,4 @@ Stream::Stream(std::int32_t device) { dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); } } // namespace xgboost::curt +#endif // defined(XGBOOST_USE_CUDA)