@@ -338,14 +338,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
338338 }
339339 }
340340
341- #ifdef GGML_USE_NCCL
342- int dev_ids[GGML_CUDA_MAX_DEVICES];
343- for (int id = 0 ; id < info.device_count ; ++id) {
344- dev_ids[id] = id;
345- }
346- NCCL_CHECK (ncclCommInitAll (info.comms , info.device_count , dev_ids));
347- #endif // GGML_USE_NCCL
348-
349341 return info;
350342}
351343
@@ -1125,29 +1117,94 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte
11251117 /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
11261118};
11271119
1128- bool ggml_backend_cuda_allreduce_tensor (ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends) {
1120+ #ifdef GGML_USE_NCCL
1121+ struct ggml_backend_cuda_comm_context {
1122+ std::vector<ggml_backend_t > backends;
1123+ std::vector<ncclComm_t> comms;
1124+
1125+ ~ggml_backend_cuda_comm_context () {
1126+ for (ncclComm_t comm : comms) {
1127+ NCCL_CHECK (ncclCommDestroy (comm));
1128+ }
1129+ }
1130+ };
1131+ #endif // GGML_USE_NCCL
1132+
1133+ static void ggml_backend_cuda_comm_free (void * comm_ctx_v) {
1134+ #ifdef GGML_USE_NCCL
1135+ if (comm_ctx_v == nullptr ) {
1136+ return ;
1137+ }
1138+ ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v;
1139+ delete comm_ctx;
1140+ #else
1141+ GGML_UNUSED (comm_ctx_v);
1142+ #endif // GGML_USE_NCCL
1143+ }
1144+
1145+ static void * ggml_backend_cuda_comm_init (ggml_backend_t * backends, size_t n_backends) {
1146+ #ifdef GGML_USE_NCCL
1147+ for (size_t i = 0 ; i < n_backends; i++) {
1148+ if (!ggml_backend_is_cuda (backends[i])) {
1149+ return nullptr ;
1150+ }
1151+ }
1152+ ggml_backend_cuda_comm_context * ret = new ggml_backend_cuda_comm_context;
1153+ std::vector<int > dev_ids;
1154+ ret->backends .reserve (n_backends);
1155+ dev_ids.reserve (n_backends);
1156+ for (size_t i = 0 ; i < n_backends; i++) {
1157+ ret->backends .push_back (backends[i]);
1158+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context ;
1159+ dev_ids.push_back (cuda_ctx->device );
1160+ }
1161+
1162+ ret->comms .resize (n_backends);
1163+ NCCL_CHECK (ncclCommInitAll (ret->comms .data (), n_backends, dev_ids.data ()));
1164+ return ret;
1165+ #else
1166+ // If NCCL is installed it is used by default for optimal performance.
1167+ // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package.
1168+ // RCCL is disabled by default, users are explicitly opting in.
1169+ // Therefore print no warning for RCCL.
1170+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1171+ static bool warning_printed = false ;
1172+ if (!warning_printed) {
1173+ GGML_LOG_WARN (" %s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n " , __func__);
1174+ warning_printed = true ;
1175+ }
1176+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1177+ GGML_UNUSED_VARS (backends, n_backends);
1178+ return nullptr ;
1179+ #endif // GGML_USE_NCCL
1180+ }
1181+
1182+ static bool ggml_backend_cuda_comm_allreduce_tensor (void * comm_ctx_v, struct ggml_tensor ** tensors) {
11291183#ifdef GGML_USE_NCCL
11301184 const int64_t ne = ggml_nelements (tensors[0 ]);
11311185 // FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0
11321186 // This then causes a crash in this function
11331187 if (ne == 0 ) {
11341188 return true ;
11351189 }
1190+
1191+ GGML_ASSERT (comm_ctx_v != nullptr );
1192+ ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v;
1193+ const size_t n_backends = comm_ctx->backends .size ();
1194+
11361195 for (size_t i = 0 ; i < n_backends; ++i) {
11371196 GGML_ASSERT (tensors[i] != nullptr );
11381197 GGML_ASSERT (ggml_nelements (tensors[i]) == ne);
11391198 GGML_ASSERT (ggml_is_contiguously_allocated (tensors[i]));
11401199 }
11411200
1142- const ggml_cuda_device_info info = ggml_cuda_info ();
1143-
11441201 // For small tensors, simply reduce them as FP32.
11451202 // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0.
11461203 if ((n_backends <= 2 && ne < 32768 ) || (n_backends == 3 && ne < 131072 ) || (n_backends >= 4 && ne < 262144 )) {
11471204 NCCL_CHECK (ncclGroupStart ());
11481205 for (size_t i = 0 ; i < n_backends; ++i) {
1149- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context ;
1150- NCCL_CHECK (ncclAllReduce (tensors[i]->data , tensors[i]->data , ne, ncclFloat, ncclSum, info. comms [cuda_ctx-> device ], cuda_ctx->stream ()));
1206+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx-> backends [i]->context ;
1207+ NCCL_CHECK (ncclAllReduce (tensors[i]->data , tensors[i]->data , ne, ncclFloat, ncclSum, comm_ctx-> comms [i ], cuda_ctx->stream ()));
11511208 }
11521209 NCCL_CHECK (ncclGroupEnd ());
11531210
@@ -1160,44 +1217,33 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t
11601217
11611218 ggml_cuda_pool_alloc<nv_bfloat16> tmp[GGML_CUDA_MAX_DEVICES];
11621219 for (size_t i = 0 ; i < n_backends; ++i) {
1163- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context ;
1220+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx-> backends [i]->context ;
11641221 tmp[i].pool = &cuda_ctx->pool ();
11651222 tmp[i].alloc (ne);
11661223
1167- ggml_cuda_set_device (i );
1224+ ggml_cuda_set_device (cuda_ctx-> device );
11681225 to_bf16 (tensors[i]->data , tmp[i].get (), ne, cuda_ctx->stream ());
11691226 CUDA_CHECK (cudaGetLastError ());
11701227 }
11711228
11721229 NCCL_CHECK (ncclGroupStart ());
11731230 for (size_t i = 0 ; i < n_backends; ++i) {
1174- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context ;
1175- NCCL_CHECK (ncclAllReduce (tmp[i].get (), tmp[i].get (), ne, ncclBfloat16, ncclSum, info. comms [cuda_ctx-> device ], cuda_ctx->stream ()));
1231+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx-> backends [i]->context ;
1232+ NCCL_CHECK (ncclAllReduce (tmp[i].get (), tmp[i].get (), ne, ncclBfloat16, ncclSum, comm_ctx-> comms [i ], cuda_ctx->stream ()));
11761233 }
11771234 NCCL_CHECK (ncclGroupEnd ());
11781235
11791236 for (size_t i = 0 ; i < n_backends; ++i) {
1180- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context ;
1237+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx-> backends [i]->context ;
11811238
1182- ggml_cuda_set_device (i );
1239+ ggml_cuda_set_device (cuda_ctx-> device );
11831240 to_fp32 (tmp[i].get (), (float *) tensors[i]->data , ne, cuda_ctx->stream ());
11841241 CUDA_CHECK (cudaGetLastError ());
11851242 }
11861243
11871244 return true ;
11881245#else
1189- // If NCCL is installed it is used by default for optimal performance.
1190- // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package.
1191- // RCCL is disabled by default, users are explicitly opting in.
1192- // Therefore print no warning for RCCL.
1193- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1194- static bool warning_printed = false ;
1195- if (!warning_printed) {
1196- GGML_LOG_WARN (" %s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n " , __func__);
1197- warning_printed = true ;
1198- }
1199- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1200- GGML_UNUSED_VARS (backends, tensors, n_backends);
1246+ GGML_UNUSED_VARS (comm_ctx_v, tensors);
12011247 return false ;
12021248#endif // GGML_USE_NCCL
12031249}
@@ -5220,8 +5266,14 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
52205266
52215267static void * ggml_backend_cuda_reg_get_proc_address (ggml_backend_reg_t reg, const char * name) {
52225268 GGML_UNUSED (reg);
5223- if (strcmp (name, " ggml_backend_allreduce_tensor" ) == 0 ) {
5224- return (void *)ggml_backend_cuda_allreduce_tensor;
5269+ if (strcmp (name, " ggml_backend_comm_init" ) == 0 ) {
5270+ return (void *)ggml_backend_cuda_comm_init;
5271+ }
5272+ if (strcmp (name, " ggml_backend_comm_free" ) == 0 ) {
5273+ return (void *)ggml_backend_cuda_comm_free;
5274+ }
5275+ if (strcmp (name, " ggml_backend_comm_allreduce_tensor" ) == 0 ) {
5276+ return (void *)ggml_backend_cuda_comm_allreduce_tensor;
52255277 }
52265278 if (strcmp (name, " ggml_backend_split_buffer_type" ) == 0 ) {
52275279 return (void *)ggml_backend_cuda_split_buffer_type;
0 commit comments