11/* *
2- * Copyright 2023, XGBoost Contributors
2+ * Copyright 2023-2026 , XGBoost Contributors
33 */
44#include < cstdint> // for int8_t, int32_t
55#include < memory> // for dynamic_pointer_cast
@@ -18,7 +18,8 @@ Coll *FederatedColl::MakeCUDAVar() {
1818 return new CUDAFederatedColl{std::dynamic_pointer_cast<FederatedColl>(this ->shared_from_this ())};
1919}
2020
21- [[nodiscard]] Result CUDAFederatedColl::Allreduce (Comm const &comm, common::Span<std::int8_t > data,
21+ [[nodiscard]] Result CUDAFederatedColl::Allreduce (Context const *ctx, Comm const &comm,
22+ common::Span<std::int8_t > data,
2223 ArrayInterfaceHandler::Type type, Op op) {
2324 auto cufed = dynamic_cast <CUDAFederatedComm const *>(&comm);
2425 CHECK (cufed);
@@ -29,14 +30,15 @@ Coll *FederatedColl::MakeCUDAVar() {
2930 return GetCUDAResult (
3031 cudaMemcpy (h_data.data (), data.data (), data.size (), cudaMemcpyDeviceToHost));
3132 } << [&] {
32- return p_impl_->Allreduce (comm, common::Span{h_data.data (), h_data.size ()}, type, op);
33+ return p_impl_->Allreduce (ctx, comm, common::Span{h_data.data (), h_data.size ()}, type, op);
3334 } << [&] {
3435 return GetCUDAResult (cudaMemcpyAsync (data.data (), h_data.data (), data.size (),
3536 cudaMemcpyHostToDevice, cufed->Stream ()));
3637 };
3738}
3839
39- [[nodiscard]] Result CUDAFederatedColl::Broadcast (Comm const &comm, common::Span<std::int8_t > data,
40+ [[nodiscard]] Result CUDAFederatedColl::Broadcast (Context const *ctx, Comm const &comm,
41+ common::Span<std::int8_t > data,
4042 std::int32_t root) {
4143 auto cufed = dynamic_cast <CUDAFederatedComm const *>(&comm);
4244 CHECK (cufed);
@@ -46,14 +48,15 @@ Coll *FederatedColl::MakeCUDAVar() {
4648 return GetCUDAResult (
4749 cudaMemcpy (h_data.data (), data.data (), data.size (), cudaMemcpyDeviceToHost));
4850 } << [&] {
49- return p_impl_->Broadcast (comm, common::Span{h_data.data (), h_data.size ()}, root);
51+ return p_impl_->Broadcast (ctx, comm, common::Span{h_data.data (), h_data.size ()}, root);
5052 } << [&] {
5153 return GetCUDAResult (cudaMemcpyAsync (data.data (), h_data.data (), data.size (),
5254 cudaMemcpyHostToDevice, cufed->Stream ()));
5355 };
5456}
5557
56- [[nodiscard]] Result CUDAFederatedColl::Allgather (Comm const &comm, common::Span<std::int8_t > data) {
58+ [[nodiscard]] Result CUDAFederatedColl::Allgather (Context const *ctx, Comm const &comm,
59+ common::Span<std::int8_t > data) {
5760 auto cufed = dynamic_cast <CUDAFederatedComm const *>(&comm);
5861 CHECK (cufed);
5962 std::vector<std::int8_t > h_data (data.size ());
@@ -62,16 +65,19 @@ Coll *FederatedColl::MakeCUDAVar() {
6265 return GetCUDAResult (
6366 cudaMemcpy (h_data.data (), data.data (), data.size (), cudaMemcpyDeviceToHost));
6467 } << [&] {
65- return p_impl_->Allgather (comm, common::Span{h_data.data (), h_data.size ()});
68+ return p_impl_->Allgather (ctx, comm, common::Span{h_data.data (), h_data.size ()});
6669 } << [&] {
6770 return GetCUDAResult (cudaMemcpyAsync (data.data (), h_data.data (), data.size (),
6871 cudaMemcpyHostToDevice, cufed->Stream ()));
6972 };
7073}
7174
72- [[nodiscard]] Result CUDAFederatedColl::AllgatherV (
73- Comm const &comm, common::Span<std::int8_t const > data, common::Span<std::int64_t const > sizes,
74- common::Span<std::int64_t > recv_segments, common::Span<std::int8_t > recv, AllgatherVAlgo algo) {
75+ [[nodiscard]] Result CUDAFederatedColl::AllgatherV (Context const *ctx, Comm const &comm,
76+ common::Span<std::int8_t const > data,
77+ common::Span<std::int64_t const > sizes,
78+ common::Span<std::int64_t > recv_segments,
79+ common::Span<std::int8_t > recv,
80+ AllgatherVAlgo algo) {
7581 auto cufed = dynamic_cast <CUDAFederatedComm const *>(&comm);
7682 CHECK (cufed);
7783
@@ -82,7 +88,7 @@ Coll *FederatedColl::MakeCUDAVar() {
8288 return GetCUDAResult (
8389 cudaMemcpy (h_data.data (), data.data (), data.size (), cudaMemcpyDeviceToHost));
8490 } << [&] {
85- return this ->p_impl_ ->AllgatherV (comm, h_data, sizes, recv_segments, h_recv, algo);
91+ return this ->p_impl_ ->AllgatherV (ctx, comm, h_data, sizes, recv_segments, h_recv, algo);
8692 } << [&] {
8793 return GetCUDAResult (cudaMemcpyAsync (recv.data (), h_recv.data (), h_recv.size (),
8894 cudaMemcpyHostToDevice, cufed->Stream ()));
0 commit comments