Skip to content

Commit dd55dff

Browse files
perf: fix cuda-aware mpi in v3 (#4977)
This pull request updates the MPI CUDA-awareness detection and handling logic in the `Border` autograd function, simplifying how CUDA support is determined and removing some legacy checks. The changes ensure that CUDA-aware MPI support is queried more directly, and some unnecessary device synchronization calls are removed. * The logic for checking CUDA-aware MPI support has been simplified: version checks and redundant branches have been removed, and the code now directly queries `MPIX_Query_cuda_support()` unless `NO_CUDA_AWARE` is defined. [[1]](diffhunk://#diff-7b7590fd4222d9c50f1dd7dde5ce7ed4b27695fbe591b536787db7575c35e32cL102-L112) [[2]](diffhunk://#diff-7b7590fd4222d9c50f1dd7dde5ce7ed4b27695fbe591b536787db7575c35e32cL227-L237) * Removed explicit `gpuDeviceSynchronize()` calls from both the forward and backward paths, relying on PyTorch's internal synchronization mechanisms instead. [[1]](diffhunk://#diff-7b7590fd4222d9c50f1dd7dde5ce7ed4b27695fbe591b536787db7575c35e32cL196-L198) [[2]](diffhunk://#diff-7b7590fd4222d9c50f1dd7dde5ce7ed4b27695fbe591b536787db7575c35e32cL332-L334) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - Performance - Reduced explicit GPU synchronization, potentially improving throughput during distributed forward/backward operations. - Compatibility - Safer default when CUDA-aware MPI isn’t present: automatically falls back to CPU-based transfers unless support is detected, improving stability across varied clusters. - Reliability - Simplified CUDA-aware detection reduces edge-case misconfigurations in mixed MPI environments. - No API Changes - Public interfaces remain unchanged; existing workflows continue to work. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 34df2b4 commit dd55dff

1 file changed

Lines changed: 6 additions & 29 deletions

File tree

source/op/pt/comm.cc

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class Border : public torch::autograd::Function<Border> {
8686
#ifdef USE_MPI
8787
int mpi_init = 0;
8888
MPI_Initialized(&mpi_init);
89-
int cuda_aware = 1;
89+
int cuda_aware = 0;
9090
int me = 0;
9191
MPI_Comm world;
9292
int world_size = 0;
@@ -99,17 +99,9 @@ class Border : public torch::autograd::Function<Border> {
9999
MPI_Request request;
100100
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
101101
if (world_size >= 1) {
102-
int version, subversion;
103-
MPI_Get_version(&version, &subversion);
104-
if (version >= 4) {
105-
#ifdef NO_CUDA_AWARE
106-
cuda_aware = 0;
107-
#else
108-
cuda_aware = MPIX_Query_cuda_support();
102+
#ifndef NO_CUDA_AWARE
103+
cuda_aware = MPIX_Query_cuda_support();
109104
#endif
110-
} else {
111-
cuda_aware = 0;
112-
}
113105
if (cuda_aware == 0) {
114106
recv_g1_tensor = torch::empty_like(g1).to(torch::kCPU);
115107
recv_g1_tensor.copy_(g1);
@@ -193,10 +185,6 @@ class Border : public torch::autograd::Function<Border> {
193185
static torch::autograd::variable_list backward_t(
194186
torch::autograd::AutogradContext* ctx,
195187
torch::autograd::variable_list grad_output) {
196-
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
197-
gpuDeviceSynchronize();
198-
#endif
199-
200188
torch::autograd::variable_list saved_variables = ctx->get_saved_variables();
201189
torch::Tensor sendlist_tensor = saved_variables[0];
202190
torch::Tensor sendproc_tensor = saved_variables[1];
@@ -212,7 +200,7 @@ class Border : public torch::autograd::Function<Border> {
212200
int mpi_init = 0;
213201
MPI_Initialized(&mpi_init);
214202
int world_size = 0;
215-
int cuda_aware = 1;
203+
int cuda_aware = 0;
216204
int me = 0;
217205
MPI_Comm world;
218206
if (mpi_init) {
@@ -224,17 +212,9 @@ class Border : public torch::autograd::Function<Border> {
224212
MPI_Request request;
225213
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
226214
if (world_size >= 1) {
227-
int version, subversion;
228-
MPI_Get_version(&version, &subversion);
229-
if (version >= 4) {
230-
#ifdef NO_CUDA_AWARE
231-
cuda_aware = 0;
232-
#else
233-
cuda_aware = MPIX_Query_cuda_support();
215+
#ifndef NO_CUDA_AWARE
216+
cuda_aware = MPIX_Query_cuda_support();
234217
#endif
235-
} else {
236-
cuda_aware = 0;
237-
}
238218
if (cuda_aware == 0) {
239219
d_local_g1_tensor = torch::empty_like(grad_output[0]).to(torch::kCPU);
240220
d_local_g1_tensor.copy_(grad_output[0]);
@@ -329,9 +309,6 @@ class Border : public torch::autograd::Function<Border> {
329309
recv_g1_tensor.slice(0, 0, nrecv));
330310
}
331311
}
332-
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
333-
gpuDeviceSynchronize();
334-
#endif
335312
#ifdef USE_MPI
336313
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
337314
if (cuda_aware == 0) {

0 commit comments

Comments
 (0)