Skip to content

Commit 2e4a5be

Browse files
JohannesGaesslerArberSephirotheca
authored andcommitted
CUDA: manage NCCL communicators in context (ggml-org#21891)
* CUDA: manage NCCL communicators in context * add check that all backends are CUDA * remove unused vector, limit init to > 1 GPUs * fix warnings * fix cuda device, cache allreduce
1 parent 38a6ffb commit 2e4a5be

4 files changed

Lines changed: 119 additions & 47 deletions

File tree

ggml/include/ggml-backend.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,11 @@ extern "C" {
202202

203203
// Common functions that may be obtained using ggml_backend_reg_get_proc_address
204204

205-
// AllReduce operation for tensor parallelism (meta backend)
206-
typedef bool (*ggml_backend_allreduce_tensor_t)(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends);
205+
// Context management and operations for faster communication between backends, used for tensor parallelism (meta backend)
206+
typedef void * (*ggml_backend_comm_init_t)(ggml_backend_t * backends, size_t n_backends);
207+
typedef void (*ggml_backend_comm_free_t)(void * comm_ctx);
208+
typedef bool (*ggml_backend_comm_allreduce_tensor_t)(void * comm_ctx, struct ggml_tensor ** tensors);
209+
207210
// Split buffer type for tensor parallelism (old)
208211
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);
209212
// Set the number of threads for the backend

ggml/src/ggml-backend-meta.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,22 +1419,48 @@ struct ggml_backend_meta_context {
14191419
size_t max_tmp_size = 0;
14201420
size_t max_subgraphs = 0;
14211421

1422+
void * comm_ctx = nullptr;
1423+
ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr;
1424+
14221425
ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) {
14231426
const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev);
14241427
name = "Meta(";
1428+
std::vector<ggml_backend_t> simple_backends;
14251429
backend_configs.reserve(n_devs);
1430+
simple_backends.reserve(n_devs);
14261431
for (size_t i = 0; i < n_devs; i++) {
14271432
ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i);
14281433
if (i > 0) {
14291434
name += ",";
14301435
}
14311436
name += ggml_backend_dev_name(simple_dev);
1432-
backend_configs.emplace_back(ggml_backend_dev_init(simple_dev, params));
1437+
simple_backends.push_back(ggml_backend_dev_init(simple_dev, params));
1438+
backend_configs.emplace_back(simple_backends.back());
14331439
}
14341440
name += ")";
1441+
1442+
if (n_devs > 1) {
1443+
ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address(
1444+
ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init");
1445+
if (comm_init != nullptr) {
1446+
comm_ctx = comm_init(simple_backends.data(), simple_backends.size());
1447+
}
1448+
}
1449+
if (comm_ctx != nullptr) {
1450+
comm_allreduce = (ggml_backend_comm_allreduce_tensor_t)
1451+
ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg(
1452+
ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor");
1453+
GGML_ASSERT(comm_allreduce != nullptr);
1454+
}
14351455
}
14361456

14371457
~ggml_backend_meta_context() {
1458+
if (comm_ctx != nullptr) {
1459+
ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address(
1460+
ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free");
1461+
GGML_ASSERT(comm_free != nullptr);
1462+
comm_free(comm_ctx);
1463+
}
14381464
for (auto & bc : backend_configs) {
14391465
ggml_backend_free(bc.backend);
14401466
}
@@ -1845,20 +1871,15 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
18451871

18461872
if (n_backends > 1 && i < n_subgraphs - 1) {
18471873
bool backend_allreduce_success = false;
1848-
ggml_backend_allreduce_tensor_t allreduce_tensor = (ggml_backend_allreduce_tensor_t) ggml_backend_reg_get_proc_address(
1849-
ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_ctx->backend_configs[0].backend)), "ggml_backend_allreduce_tensor");
1850-
if (allreduce_tensor) {
1851-
std::vector<ggml_backend_t> backends;
1852-
backends.reserve(n_backends);
1874+
if (backend_ctx->comm_ctx) {
18531875
std::vector<ggml_tensor *> nodes;
18541876
nodes.reserve(n_backends);
18551877
for (size_t j = 0; j < n_backends; j++) {
18561878
auto & bcj = backend_ctx->backend_configs[j];
1857-
backends.push_back(bcj.backend);
18581879
ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main;
18591880
nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]);
18601881
}
1861-
backend_allreduce_success = allreduce_tensor(backends.data(), nodes.data(), n_backends);
1882+
backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data());
18621883
}
18631884

18641885
if (!backend_allreduce_success) {

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,10 +1092,6 @@ struct ggml_cuda_device_info {
10921092
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
10931093

10941094
std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
1095-
1096-
#ifdef GGML_USE_NCCL
1097-
ncclComm_t comms[GGML_CUDA_MAX_DEVICES];
1098-
#endif // GGML_USE_NCCL
10991095
};
11001096

11011097
const ggml_cuda_device_info & ggml_cuda_info();

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 85 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
338338
}
339339
}
340340

341-
#ifdef GGML_USE_NCCL
342-
int dev_ids[GGML_CUDA_MAX_DEVICES];
343-
for (int id = 0; id < info.device_count; ++id) {
344-
dev_ids[id] = id;
345-
}
346-
NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids));
347-
#endif // GGML_USE_NCCL
348-
349341
return info;
350342
}
351343

@@ -1125,29 +1117,94 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte
11251117
/* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
11261118
};
11271119

1128-
bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends) {
1120+
#ifdef GGML_USE_NCCL
1121+
struct ggml_backend_cuda_comm_context {
1122+
std::vector<ggml_backend_t> backends;
1123+
std::vector<ncclComm_t> comms;
1124+
1125+
~ggml_backend_cuda_comm_context() {
1126+
for (ncclComm_t comm : comms) {
1127+
NCCL_CHECK(ncclCommDestroy(comm));
1128+
}
1129+
}
1130+
};
1131+
#endif // GGML_USE_NCCL
1132+
1133+
static void ggml_backend_cuda_comm_free(void * comm_ctx_v) {
1134+
#ifdef GGML_USE_NCCL
1135+
if (comm_ctx_v == nullptr) {
1136+
return;
1137+
}
1138+
ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v;
1139+
delete comm_ctx;
1140+
#else
1141+
GGML_UNUSED(comm_ctx_v);
1142+
#endif // GGML_USE_NCCL
1143+
}
1144+
1145+
static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) {
1146+
#ifdef GGML_USE_NCCL
1147+
for (size_t i = 0; i < n_backends; i++) {
1148+
if (!ggml_backend_is_cuda(backends[i])) {
1149+
return nullptr;
1150+
}
1151+
}
1152+
ggml_backend_cuda_comm_context * ret = new ggml_backend_cuda_comm_context;
1153+
std::vector<int> dev_ids;
1154+
ret->backends.reserve(n_backends);
1155+
dev_ids.reserve(n_backends);
1156+
for (size_t i = 0; i < n_backends; i++) {
1157+
ret->backends.push_back(backends[i]);
1158+
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
1159+
dev_ids.push_back(cuda_ctx->device);
1160+
}
1161+
1162+
ret->comms.resize(n_backends);
1163+
NCCL_CHECK(ncclCommInitAll(ret->comms.data(), n_backends, dev_ids.data()));
1164+
return ret;
1165+
#else
1166+
// If NCCL is installed it is used by default for optimal performance.
1167+
// However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package.
1168+
// RCCL is disabled by default, users are explicitly opting in.
1169+
// Therefore print no warning for RCCL.
1170+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1171+
static bool warning_printed = false;
1172+
if (!warning_printed) {
1173+
GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__);
1174+
warning_printed = true;
1175+
}
1176+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1177+
GGML_UNUSED_VARS(backends, n_backends);
1178+
return nullptr;
1179+
#endif // GGML_USE_NCCL
1180+
}
1181+
1182+
static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) {
11291183
#ifdef GGML_USE_NCCL
11301184
const int64_t ne = ggml_nelements(tensors[0]);
11311185
// FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0
11321186
// This then causes a crash in this function
11331187
if (ne == 0) {
11341188
return true;
11351189
}
1190+
1191+
GGML_ASSERT(comm_ctx_v != nullptr);
1192+
ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v;
1193+
const size_t n_backends = comm_ctx->backends.size();
1194+
11361195
for (size_t i = 0; i < n_backends; ++i) {
11371196
GGML_ASSERT(tensors[i] != nullptr);
11381197
GGML_ASSERT(ggml_nelements(tensors[i]) == ne);
11391198
GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i]));
11401199
}
11411200

1142-
const ggml_cuda_device_info info = ggml_cuda_info();
1143-
11441201
// For small tensors, simply reduce them as FP32.
11451202
// The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0.
11461203
if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) {
11471204
NCCL_CHECK(ncclGroupStart());
11481205
for (size_t i = 0; i < n_backends; ++i) {
1149-
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
1150-
NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream()));
1206+
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1207+
NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comm_ctx->comms[i], cuda_ctx->stream()));
11511208
}
11521209
NCCL_CHECK(ncclGroupEnd());
11531210

@@ -1160,44 +1217,33 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t
11601217

11611218
ggml_cuda_pool_alloc<nv_bfloat16> tmp[GGML_CUDA_MAX_DEVICES];
11621219
for (size_t i = 0; i < n_backends; ++i) {
1163-
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
1220+
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
11641221
tmp[i].pool = &cuda_ctx->pool();
11651222
tmp[i].alloc(ne);
11661223

1167-
ggml_cuda_set_device(i);
1224+
ggml_cuda_set_device(cuda_ctx->device);
11681225
to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream());
11691226
CUDA_CHECK(cudaGetLastError());
11701227
}
11711228

11721229
NCCL_CHECK(ncclGroupStart());
11731230
for (size_t i = 0; i < n_backends; ++i) {
1174-
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
1175-
NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream()));
1231+
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1232+
NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, comm_ctx->comms[i], cuda_ctx->stream()));
11761233
}
11771234
NCCL_CHECK(ncclGroupEnd());
11781235

11791236
for (size_t i = 0; i < n_backends; ++i) {
1180-
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
1237+
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
11811238

1182-
ggml_cuda_set_device(i);
1239+
ggml_cuda_set_device(cuda_ctx->device);
11831240
to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream());
11841241
CUDA_CHECK(cudaGetLastError());
11851242
}
11861243

11871244
return true;
11881245
#else
1189-
// If NCCL is installed it is used by default for optimal performance.
1190-
// However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package.
1191-
// RCCL is disabled by default, users are explicitly opting in.
1192-
// Therefore print no warning for RCCL.
1193-
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1194-
static bool warning_printed = false;
1195-
if (!warning_printed) {
1196-
GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__);
1197-
warning_printed = true;
1198-
}
1199-
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1200-
GGML_UNUSED_VARS(backends, tensors, n_backends);
1246+
GGML_UNUSED_VARS(comm_ctx_v, tensors);
12011247
return false;
12021248
#endif // GGML_USE_NCCL
12031249
}
@@ -5220,8 +5266,14 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
52205266

52215267
static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
52225268
GGML_UNUSED(reg);
5223-
if (strcmp(name, "ggml_backend_allreduce_tensor") == 0) {
5224-
return (void *)ggml_backend_cuda_allreduce_tensor;
5269+
if (strcmp(name, "ggml_backend_comm_init") == 0) {
5270+
return (void *)ggml_backend_cuda_comm_init;
5271+
}
5272+
if (strcmp(name, "ggml_backend_comm_free") == 0) {
5273+
return (void *)ggml_backend_cuda_comm_free;
5274+
}
5275+
if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) {
5276+
return (void *)ggml_backend_cuda_comm_allreduce_tensor;
52255277
}
52265278
if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
52275279
return (void *)ggml_backend_cuda_split_buffer_type;

0 commit comments

Comments
 (0)