Skip to content

Commit 1db37c6

Browse files
committed
Merge branch 'cuda-aware' into 'master'
Enable the CUDA-Aware-MPI/NCCL for ChASE-GPU See merge request SLai/ChASE!35
2 parents 9bf636a + 9cde423 commit 1db37c6

14 files changed

Lines changed: 1518 additions & 171 deletions

ChASE-MPI/CMakeLists.txt

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Specify the minimum version for CMake
44
cmake_minimum_required( VERSION 3.8 )
55

6-
# project( ChASE-MPI LANGUAGES C CXX CUDA )
6+
option(ENABLE_CUDA_AWARE "Enable CUDA aware MPI for collective communications" OFF)
77

88
add_library( chase_seq INTERFACE )
99
target_link_libraries(chase_seq INTERFACE chase_algorithm)
@@ -39,6 +39,10 @@ find_package( MPI REQUIRED )
3939

4040
find_package( SCALAPACK )
4141

42+
if(ENABLE_CUDA_AWARE)
43+
find_package( NCCL )
44+
endif()
45+
4246
target_include_directories( chase_seq INTERFACE
4347
${MPI_CXX_INCLUDE_PATH}
4448
)
@@ -160,7 +164,24 @@ if(CMAKE_CUDA_COMPILER)
160164
target_link_libraries( chase_seq INTERFACE
161165
${CUDA_nvToolsExt_LIBRARY}
162166
)
163-
167+
168+
if(ENABLE_CUDA_AWARE)
169+
target_compile_definitions( chase_cuda INTERFACE
170+
"-DCUDA_AWARE"
171+
)
172+
if(NCCL_FOUND)
173+
target_link_libraries(chase_cuda INTERFACE
174+
${NCCL_LIBRARIES}
175+
)
176+
target_include_directories( chase_cuda INTERFACE
177+
${NCCL_INCLUDE_DIRS}
178+
)
179+
target_compile_definitions( chase_cuda INTERFACE
180+
"-DHAS_NCCL"
181+
)
182+
endif()
183+
endif()
184+
164185
if(ENABLE_NSIGHT)
165186
target_compile_definitions(chase_cuda INTERFACE USE_NSIGHT)
166187
target_compile_definitions(chase_seq INTERFACE USE_NSIGHT)

ChASE-MPI/blas_cuda_wrapper.hpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,58 @@
1313
#include <cuda_runtime.h>
1414
#include <cusolverDn.h>
1515

16+
cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t transa,
17+
int m, int n,
18+
const float* alpha, const float* A, int lda,
19+
const float* x, int incx, const float* beta, float* y,
20+
int incy)
21+
{
22+
return cublasSgemv(handle, transa, m, n, alpha, A, lda, x, incx,
23+
beta, y, incy);
24+
}
25+
26+
cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t transa,
27+
int m, int n,
28+
const double* alpha, const double* A, int lda,
29+
const double* x, int incx, const double* beta,
30+
double* y, int incy)
31+
{
32+
return cublasDgemv(handle, transa, m, n, alpha, A, lda, x, incx,
33+
beta, y, incy);
34+
}
35+
36+
cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t transa,
37+
int m, int n,
38+
const std::complex<float>* alpha,
39+
const std::complex<float>* A, int lda,
40+
const std::complex<float>* x, int incx,
41+
const std::complex<float>* beta,
42+
std::complex<float>* y, int incy)
43+
{
44+
return cublasCgemv(handle, transa, m, n,
45+
reinterpret_cast<const cuComplex*>(alpha),
46+
reinterpret_cast<const cuComplex*>(A), lda,
47+
reinterpret_cast<const cuComplex*>(x), incx,
48+
reinterpret_cast<const cuComplex*>(beta),
49+
reinterpret_cast<cuComplex*>(y), incy);
50+
}
51+
52+
cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t transa,
53+
int m, int n,
54+
const std::complex<double>* alpha,
55+
const std::complex<double>* A, int lda,
56+
const std::complex<double>* x, int incx,
57+
const std::complex<double>* beta,
58+
std::complex<double>* y, int incy)
59+
{
60+
return cublasZgemv(handle, transa, m, n,
61+
reinterpret_cast<const cuDoubleComplex*>(alpha),
62+
reinterpret_cast<const cuDoubleComplex*>(A), lda,
63+
reinterpret_cast<const cuDoubleComplex*>(x), incx,
64+
reinterpret_cast<const cuDoubleComplex*>(beta),
65+
reinterpret_cast<cuDoubleComplex*>(y), incy);
66+
}
67+
1668
cublasStatus_t cublasTgemm(cublasHandle_t handle, cublasOperation_t transa,
1769
cublasOperation_t transb, int m, int n, int k,
1870
const float* alpha, const float* A, int lda,

ChASE-MPI/blas_templates.inc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ std::size_t t_heevd(int matrix_layout, char jobz, char uplo, std::size_t n,
10361036
}
10371037

10381038
// Overload of ?gemv functions
1039-
/*
1039+
10401040
template <>
10411041
void
10421042
t_gemv(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE trans,
@@ -1059,7 +1059,7 @@ t_gemv(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE trans,
10591059
BlasInt incx_ = incx;
10601060
BlasInt incy_ = incy;
10611061

1062-
FC_GLOBAL(cgemv, CGEMV)
1062+
FC_GLOBAL(dgemv, DGEMV)
10631063
(&TA, &m_, &n_, alpha, a, &lda_, x, &incx_, beta, y, &incy_);
10641064
}
10651065
template <>
@@ -1138,7 +1138,7 @@ void t_gemv(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE trans,
11381138
FC_GLOBAL(cgemv, CGEMV)
11391139
(&TA, &m_, &n_, alpha, a, &lda_, x, &incx_, beta, y, &incy_);
11401140
}
1141-
*/
1141+
11421142
// Overload of ?stemr functions
11431143

11441144
template <>

ChASE-MPI/chase_mpi_properties.hpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "algorithm/types.hpp"
1616
#include "chase_mpi_matrices.hpp"
17+
#include "mpi_wrapper.hpp"
1718

1819
namespace chase
1920
{
@@ -436,6 +437,41 @@ class ChaseMpiProperties
436437
#else
437438
V_.reset(new T[N_ * max_block_]());
438439
#endif
440+
441+
comm_2 row_comm_dup;
442+
comm_2 col_comm_dup;
443+
#if defined(HAS_NCCL)
444+
ncclUniqueId nccl_id, nccl_ids[nprocs_];
445+
ncclGetUniqueId(&nccl_id);
446+
MPI_Allgather(&nccl_id, sizeof(ncclUniqueId), MPI_UINT8_T,
447+
&nccl_ids[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm);
448+
449+
450+
for(auto i = 0; i < dims_[0]; i++){
451+
if(coord_[0] == i){
452+
ncclCommInitRank(&row_comm_dup, dims_[1], nccl_ids[i], coord_[1]);
453+
}
454+
}
455+
456+
//col_comm
457+
ncclUniqueId nccl_id_2, nccl_ids_2[nprocs_];
458+
ncclGetUniqueId(&nccl_id_2);
459+
MPI_Allgather(&nccl_id_2, sizeof(ncclUniqueId), MPI_UINT8_T,
460+
&nccl_ids_2[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm);
461+
462+
463+
for(auto i = 0; i < dims_[1]; i++){
464+
if(coord_[1] == i){
465+
ncclCommInitRank(&col_comm_dup, dims_[0], nccl_ids_2[i * dims_[0]], coord_[0]);
466+
}
467+
}
468+
#else
469+
MPI_Comm_dup(row_comm_, &row_comm_dup);
470+
MPI_Comm_dup(col_comm_, &col_comm_dup);
471+
#endif
472+
mpi_wrapper_.add(row_comm_, row_comm_dup);
473+
mpi_wrapper_.add(col_comm_, col_comm_dup);
474+
439475
#ifdef USE_NSIGHT
440476
nvtxRangePop();
441477
#endif
@@ -649,6 +685,41 @@ class ChaseMpiProperties
649685
#else
650686
V_.reset(new T[N_ * max_block_]());
651687
#endif
688+
689+
comm_2 row_comm_dup;
690+
comm_2 col_comm_dup;
691+
#if defined(HAS_NCCL)
692+
ncclUniqueId nccl_id, nccl_ids[nprocs_];
693+
ncclGetUniqueId(&nccl_id);
694+
MPI_Allgather(&nccl_id, sizeof(ncclUniqueId), MPI_UINT8_T,
695+
&nccl_ids[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm);
696+
697+
698+
for(auto i = 0; i < dims_[0]; i++){
699+
if(coord_[0] == i){
700+
ncclCommInitRank(&row_comm_dup, dims_[1], nccl_ids[i], coord_[1]);
701+
}
702+
}
703+
704+
//col_comm
705+
ncclUniqueId nccl_id_2, nccl_ids_2[nprocs_];
706+
ncclGetUniqueId(&nccl_id_2);
707+
MPI_Allgather(&nccl_id_2, sizeof(ncclUniqueId), MPI_UINT8_T,
708+
&nccl_ids_2[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm);
709+
710+
711+
for(auto i = 0; i < dims_[1]; i++){
712+
if(coord_[1] == i){
713+
ncclCommInitRank(&col_comm_dup, dims_[0], nccl_ids_2[i * dims_[0]], coord_[0]);
714+
}
715+
}
716+
#else
717+
MPI_Comm_dup(row_comm_, &row_comm_dup);
718+
MPI_Comm_dup(col_comm_, &col_comm_dup);
719+
#endif
720+
mpi_wrapper_.add(row_comm_, row_comm_dup);
721+
mpi_wrapper_.add(col_comm_, col_comm_dup);
722+
652723
#ifdef USE_NSIGHT
653724
nvtxRangePop();
654725
#endif
@@ -847,11 +918,47 @@ class ChaseMpiProperties
847918
#else
848919
V_.reset(new T[N_ * max_block_]());
849920
#endif
921+
922+
comm_2 row_comm_dup;
923+
comm_2 col_comm_dup;
924+
#if defined(HAS_NCCL)
925+
ncclUniqueId nccl_id, nccl_ids[nprocs_];
926+
ncclGetUniqueId(&nccl_id);
927+
MPI_Allgather(&nccl_id, sizeof(ncclUniqueId), MPI_UINT8_T,
928+
&nccl_ids[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm);
929+
930+
931+
for(auto i = 0; i < dims_[0]; i++){
932+
if(coord_[0] == i){
933+
ncclCommInitRank(&row_comm_dup, dims_[1], nccl_ids[i], coord_[1]);
934+
}
935+
}
936+
937+
//col_comm
938+
ncclUniqueId nccl_id_2, nccl_ids_2[nprocs_];
939+
ncclGetUniqueId(&nccl_id_2);
940+
MPI_Allgather(&nccl_id_2, sizeof(ncclUniqueId), MPI_UINT8_T,
941+
&nccl_ids_2[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm);
942+
943+
944+
for(auto i = 0; i < dims_[1]; i++){
945+
if(coord_[1] == i){
946+
ncclCommInitRank(&col_comm_dup, dims_[0], nccl_ids_2[i * dims_[0]], coord_[0]);
947+
}
948+
}
949+
#else
950+
MPI_Comm_dup(row_comm_, &row_comm_dup);
951+
MPI_Comm_dup(col_comm_, &col_comm_dup);
952+
#endif
953+
mpi_wrapper_.add(row_comm_, row_comm_dup);
954+
mpi_wrapper_.add(col_comm_, col_comm_dup);
955+
850956
#ifdef USE_NSIGHT
851957
nvtxRangePop();
852958
#endif
853959
}
854960

961+
Comm_t get_mpi_wrapper() { return mpi_wrapper_;}
855962
#if defined(HAS_SCALAPACK)
856963
int get_colcomm_ctxt() { return colcomm_ctxt_; }
857964

@@ -1600,6 +1707,8 @@ class ChaseMpiProperties
16001707
//! It is allocated only when no ScaLAPACK is detected.
16011708
std::unique_ptr<T[]> V_;
16021709
#endif
1710+
1711+
Comm_t mpi_wrapper_;
16031712
};
16041713
} // namespace mpi
16051714
} // namespace chase

ChASE-MPI/chase_mpidla_interface.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,16 @@ class ChaseMpiDLAInterface
343343

344344
virtual void B2C(T* B, std::size_t off1, T* C, std::size_t off2, std::size_t block) = 0;
345345

346+
virtual void getMpiWorkSpace(T **C, T **B, T **A, T **C2, T **B2, T **vv, Base<T> **rsd, T **w) = 0;
347+
virtual void getMpiCollectiveBackend(int *allreduce_backend, int *bcast_backend) = 0;
348+
virtual bool isCudaAware() = 0;
349+
virtual void lacpy(char uplo, std::size_t m, std::size_t n,
350+
T* a, std::size_t lda, T* b, std::size_t ldb) = 0;
351+
virtual void shiftMatrixForQR(T *A, std::size_t n, T shift) = 0;
352+
virtual void retrieveC(T **C, std::size_t locked, std::size_t block, bool copy) = 0;
353+
virtual void retrieveB(T **B, std::size_t locked, std::size_t block, bool copy) = 0;
354+
virtual void retrieveResid(Base<T> **rsd, std::size_t locked, std::size_t block) = 0;
355+
virtual void putC(T *C, std::size_t locked, std::size_t block) = 0;
346356
};
347357
} // namespace mpi
348358
} // namespace chase

0 commit comments

Comments
 (0)