Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions plugin/federated/federated_coll.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2026, XGBoost contributors
*/
#include "federated_coll.h"

Expand Down Expand Up @@ -61,7 +61,8 @@ Coll *FederatedColl::MakeCUDAVar() {
}
#endif

[[nodiscard]] Result FederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
[[nodiscard]] Result FederatedColl::Allreduce(Context const * /*ctx*/, Comm const &comm,
common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
Expand All @@ -87,12 +88,13 @@ Coll *FederatedColl::MakeCUDAVar() {
return Success();
}

[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) {
[[nodiscard]] Result FederatedColl::Broadcast(Context const * /*ctx*/, Comm const &comm,
common::Span<std::int8_t> data, std::int32_t root) {
return BroadcastImpl(comm, &this->sequence_number_, data, root);
}

[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
[[nodiscard]] Result FederatedColl::Allgather(Context const * /*ctx*/, Comm const &comm,
common::Span<std::int8_t> data) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
Expand Down Expand Up @@ -120,7 +122,7 @@ Coll *FederatedColl::MakeCUDAVar() {
return Success();
}

[[nodiscard]] Result FederatedColl::AllgatherV(Comm const &comm,
[[nodiscard]] Result FederatedColl::AllgatherV(Context const * /*ctx*/, Comm const &comm,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const>,
common::Span<std::int64_t>,
Expand Down
28 changes: 17 additions & 11 deletions plugin/federated/federated_coll.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2026, XGBoost Contributors
*/
#include <cstdint> // for int8_t, int32_t
#include <memory> // for dynamic_pointer_cast
Expand All @@ -18,7 +18,8 @@ Coll *FederatedColl::MakeCUDAVar() {
return new CUDAFederatedColl{std::dynamic_pointer_cast<FederatedColl>(this->shared_from_this())};
}

[[nodiscard]] Result CUDAFederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
[[nodiscard]] Result CUDAFederatedColl::Allreduce(Context const *ctx, Comm const &comm,
common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
Expand All @@ -29,14 +30,15 @@ Coll *FederatedColl::MakeCUDAVar() {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Allreduce(comm, common::Span{h_data.data(), h_data.size()}, type, op);
return p_impl_->Allreduce(ctx, comm, common::Span{h_data.data(), h_data.size()}, type, op);
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
};
}

[[nodiscard]] Result CUDAFederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
[[nodiscard]] Result CUDAFederatedColl::Broadcast(Context const *ctx, Comm const &comm,
common::Span<std::int8_t> data,
std::int32_t root) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
Expand All @@ -46,14 +48,15 @@ Coll *FederatedColl::MakeCUDAVar() {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Broadcast(comm, common::Span{h_data.data(), h_data.size()}, root);
return p_impl_->Broadcast(ctx, comm, common::Span{h_data.data(), h_data.size()}, root);
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
};
}

[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
[[nodiscard]] Result CUDAFederatedColl::Allgather(Context const *ctx, Comm const &comm,
common::Span<std::int8_t> data) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
Expand All @@ -62,16 +65,19 @@ Coll *FederatedColl::MakeCUDAVar() {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()});
return p_impl_->Allgather(ctx, comm, common::Span{h_data.data(), h_data.size()});
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
};
}

[[nodiscard]] Result CUDAFederatedColl::AllgatherV(
Comm const &comm, common::Span<std::int8_t const> data, common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments, common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
[[nodiscard]] Result CUDAFederatedColl::AllgatherV(Context const *ctx, Comm const &comm,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv,
AllgatherVAlgo algo) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);

Expand All @@ -82,7 +88,7 @@ Coll *FederatedColl::MakeCUDAVar() {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return this->p_impl_->AllgatherV(comm, h_data, sizes, recv_segments, h_recv, algo);
return this->p_impl_->AllgatherV(ctx, comm, h_data, sizes, recv_segments, h_recv, algo);
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(recv.data(), h_recv.data(), h_recv.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
Expand Down
17 changes: 10 additions & 7 deletions plugin/federated/federated_coll.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023-2024, XGBoost contributors
* Copyright 2023-2026, XGBoost contributors
*/
#include "../../src/collective/comm.h" // for Comm, Coll
#include "federated_coll.h" // for FederatedColl
Expand All @@ -12,12 +12,15 @@ class CUDAFederatedColl : public Coll {

public:
explicit CUDAFederatedColl(std::shared_ptr<FederatedColl> pimpl) : p_impl_{std::move(pimpl)} {}
[[nodiscard]] Result Allreduce(Comm const &comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
[[nodiscard]] Result Allreduce(Context const *ctx, Comm const &comm,
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type type,
Op op) override;
[[nodiscard]] Result Broadcast(Context const *ctx, Comm const &comm,
common::Span<std::int8_t> data, std::int32_t root) override;
[[nodiscard]] Result Allgather(Context const *ctx, Comm const &,
common::Span<std::int8_t> data) override;
[[nodiscard]] Result AllgatherV(Context const *ctx, Comm const &comm,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
Expand Down
18 changes: 10 additions & 8 deletions plugin/federated/federated_coll.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
/**
* Copyright 2023-2024, XGBoost contributors
* Copyright 2023-2026, XGBoost contributors
*/
#pragma once
#include "../../src/collective/coll.h" // for Coll
#include "../../src/collective/comm.h" // for Comm
#include "../../src/collective/coll.h" // for Coll
#include "../../src/collective/comm.h" // for Comm

namespace xgboost::collective {
class FederatedColl : public Coll {
Expand All @@ -13,12 +13,14 @@ class FederatedColl : public Coll {
public:
Coll *MakeCUDAVar() override;

[[nodiscard]] Result Allreduce(Comm const &, common::Span<std::int8_t> data,
[[nodiscard]] Result Allreduce(Context const *ctx, Comm const &, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
[[nodiscard]] Result Broadcast(Context const *ctx, Comm const &comm,
common::Span<std::int8_t> data, std::int32_t root) override;
[[nodiscard]] Result Allgather(Context const *ctx, Comm const &,
common::Span<std::int8_t> data) override;
[[nodiscard]] Result AllgatherV(Context const *ctx, Comm const &comm,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
Expand Down
8 changes: 4 additions & 4 deletions src/collective/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ template <typename T>

auto const& cctx = comm.Ctx(ctx, data.Device());
auto backend = comm.Backend(data.Device());
return backend->Allgather(cctx, erased);
return backend->Allgather(ctx, cctx, erased);
}

/**
Expand Down Expand Up @@ -144,7 +144,7 @@ template <typename T>
sizes[comm.Rank()] = data.Values().size_bytes();
auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()});
auto rc =
comm.Backend(DeviceOrd::CPU())->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
comm.Backend(DeviceOrd::CPU())->Allgather(ctx, comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
if (!rc.OK()) {
return rc;
}
Expand All @@ -161,8 +161,8 @@ template <typename T>
auto erased = common::EraseType(data.Values());

return backend->AllgatherV(
comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments,
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(),
ctx, comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()},
s_segments, data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(),
data.Device().IsCUDA() ? AllgatherVAlgo::kBcast : AllgatherVAlgo::kRing);
}

Expand Down
10 changes: 5 additions & 5 deletions src/collective/allreduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ template <typename T, std::int32_t kDim>
auto type = ToDType<T>::kType;

auto backend = comm.Backend(data.Device());
return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op);
return backend->Allreduce(ctx, comm.Ctx(ctx, data.Device()), erased, type, op);
}

template <typename T, std::int32_t kDim>
Expand Down Expand Up @@ -366,8 +366,8 @@ template <typename T, typename Fn>
std::enable_if_t<std::is_invocable_v<Fn, dh::device_vector<T> const&, dh::device_vector<T> const&,
dh::device_vector<T>*, cudaStream_t>,
Result>
AllreduceV(Comm const& comm, dh::device_vector<T>* data, AllreduceVScratch<T>* scratch,
Fn&& redop) {
AllreduceV(Context const* ctx, Comm const& comm, dh::device_vector<T>* data,
AllreduceVScratch<T>* scratch, Fn&& redop) {
if (!comm.IsDistributed() || comm.World() == 1) {
return Success();
}
Expand All @@ -377,7 +377,7 @@ AllreduceV(Comm const& comm, dh::device_vector<T>* data, AllreduceVScratch<T>* s
return Fail("Distributed GPU AllreduceV requires NCCL support.");
}

return gpu_impl::AllreduceV(*nccl, data, scratch, std::forward<Fn>(redop));
return gpu_impl::AllreduceV(ctx, *nccl, data, scratch, std::forward<Fn>(redop));
}

template <typename T, typename Fn>
Expand All @@ -396,7 +396,7 @@ AllreduceV(Context const* ctx, CommGroup const& comm, dh::device_vector<T>* data
auto const& cctx = comm.Ctx(ctx, ctx->Device());
auto nccl = dynamic_cast<NCCLComm const*>(&cctx);
if (nccl != nullptr) {
return gpu_impl::AllreduceV(*nccl, data, scratch, std::forward<Fn>(redop));
return gpu_impl::AllreduceV(ctx, *nccl, data, scratch, std::forward<Fn>(redop));
}
return gpu_detail::AllreduceVHostFallback(ctx, comm, data, scratch, std::forward<Fn>(redop));
}
Expand Down
Loading
Loading