Skip to content

Commit 4573cdc

Browse files
committed
Merge branch 'cuda-aware' into 'master'
restructuring the codes for the allocation of buffer for matrices See merge request SLai/ChASE!36
2 parents 1db37c6 + dab7391 commit 4573cdc

25 files changed

Lines changed: 1614 additions & 2049 deletions

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ project( ChASE LANGUAGES C CXX VERSION 1.3.0 )
55
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
66
# ## algorithm ##
77

8+
set(CMAKE_CXX_STANDARD 14)
9+
810
add_library(chase_algorithm INTERFACE)
911

1012
include(GNUInstallDirs)

ChASE-MPI/CMakeLists.txt

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
# Specify the minimum version for CMake
44
cmake_minimum_required( VERSION 3.8 )
55

6-
option(ENABLE_CUDA_AWARE "Enable CUDA aware MPI for collective communications" OFF)
6+
option(ENABLE_CUDA_AWARE_MPI "Enable CUDA aware MPI for collective communications" ON)
7+
option(ENABLE_NCCL "Enable Nvidia NCCL for collective communications" ON)
78

89
add_library( chase_seq INTERFACE )
910
target_link_libraries(chase_seq INTERFACE chase_algorithm)
@@ -39,10 +40,6 @@ find_package( MPI REQUIRED )
3940

4041
find_package( SCALAPACK )
4142

42-
if(ENABLE_CUDA_AWARE)
43-
find_package( NCCL )
44-
endif()
45-
4643
target_include_directories( chase_seq INTERFACE
4744
${MPI_CXX_INCLUDE_PATH}
4845
)
@@ -164,11 +161,20 @@ if(CMAKE_CUDA_COMPILER)
164161
target_link_libraries( chase_seq INTERFACE
165162
${CUDA_nvToolsExt_LIBRARY}
166163
)
167-
168-
if(ENABLE_CUDA_AWARE)
164+
165+
target_compile_definitions( chase_cuda INTERFACE
166+
"-DHAS_CUDA"
167+
)
168+
169+
if(ENABLE_CUDA_AWARE_MPI)
169170
target_compile_definitions( chase_cuda INTERFACE
170171
"-DCUDA_AWARE"
171172
)
173+
174+
if(ENABLE_NCCL)
175+
find_package( NCCL REQUIRED )
176+
endif()
177+
172178
if(NCCL_FOUND)
173179
target_link_libraries(chase_cuda INTERFACE
174180
${NCCL_LIBRARIES}

ChASE-MPI/blas_cuda_wrapper.hpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,25 @@
1414
#include <cusolverDn.h>
1515

1616
cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t transa,
17-
int m, int n,
18-
const float* alpha, const float* A, int lda,
19-
const float* x, int incx, const float* beta, float* y,
20-
int incy)
17+
int m, int n, const float* alpha, const float* A,
18+
int lda, const float* x, int incx, const float* beta,
19+
float* y, int incy)
2120
{
22-
return cublasSgemv(handle, transa, m, n, alpha, A, lda, x, incx,
23-
beta, y, incy);
21+
return cublasSgemv(handle, transa, m, n, alpha, A, lda, x, incx, beta, y,
22+
incy);
2423
}
2524

2625
cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t transa,
27-
int m, int n,
28-
const double* alpha, const double* A, int lda,
29-
const double* x, int incx, const double* beta,
30-
double* y, int incy)
26+
int m, int n, const double* alpha, const double* A,
27+
int lda, const double* x, int incx,
28+
const double* beta, double* y, int incy)
3129
{
32-
return cublasDgemv(handle, transa, m, n, alpha, A, lda, x, incx,
33-
beta, y, incy);
30+
return cublasDgemv(handle, transa, m, n, alpha, A, lda, x, incx, beta, y,
31+
incy);
3432
}
3533

3634
cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t transa,
37-
int m, int n,
38-
const std::complex<float>* alpha,
35+
int m, int n, const std::complex<float>* alpha,
3936
const std::complex<float>* A, int lda,
4037
const std::complex<float>* x, int incx,
4138
const std::complex<float>* beta,
@@ -50,8 +47,7 @@ cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t transa,
5047
}
5148

5249
cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t transa,
53-
int m, int n,
54-
const std::complex<double>* alpha,
50+
int m, int n, const std::complex<double>* alpha,
5551
const std::complex<double>* A, int lda,
5652
const std::complex<double>* x, int incx,
5753
const std::complex<double>* beta,

ChASE-MPI/blas_fortran.hpp

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -352,28 +352,34 @@ extern "C"
352352
const BlasInt* lda, const dcomplex* b,
353353
const BlasInt* ldb);
354354

355-
void FC_GLOBAL(sgesvd, SGESVD)(const char *jobu, const char *jobvt,
355+
void FC_GLOBAL(sgesvd, SGESVD)(const char* jobu, const char* jobvt,
356+
const BlasInt* m, const BlasInt* n, float* A,
357+
const BlasInt* lda, float* S, float* U,
358+
const BlasInt* ldu, float* Vt,
359+
const BlasInt* ldvt, float* work,
360+
const BlasInt* lwork, float* rwork,
361+
BlasInt* info);
362+
void FC_GLOBAL(dgesvd, DGESVD)(const char* jobu, const char* jobvt,
363+
const BlasInt* m, const BlasInt* n,
364+
double* A, const BlasInt* lda, double* S,
365+
double* U, const BlasInt* ldu, double* Vt,
366+
const BlasInt* ldvt, double* work,
367+
const BlasInt* lwork, double* rwork,
368+
BlasInt* info);
369+
void FC_GLOBAL(cgesvd, CGESVD)(const char* jobu, const char* jobvt,
356370
const BlasInt* m, const BlasInt* n,
357-
float *A, const BlasInt* lda, float *S,
358-
float *U, const BlasInt *ldu, float *Vt,
359-
const BlasInt *ldvt, float *work,
360-
const BlasInt *lwork, float *rwork, BlasInt *info );
361-
void FC_GLOBAL(dgesvd, DGESVD)(const char *jobu, const char *jobvt,
371+
scomplex* A, const BlasInt* lda, float* S,
372+
scomplex* U, const BlasInt* ldu,
373+
scomplex* Vt, const BlasInt* ldvt,
374+
scomplex* work, const BlasInt* lwork,
375+
float* rwork, BlasInt* info);
376+
void FC_GLOBAL(zgesvd, ZGESVD)(const char* jobu, const char* jobvt,
362377
const BlasInt* m, const BlasInt* n,
363-
double *A, const BlasInt* lda, double *S,
364-
double *U, const BlasInt *ldu, double *Vt,
365-
const BlasInt *ldvt, double *work,
366-
const BlasInt *lwork, double *rwork, BlasInt *info );
367-
void FC_GLOBAL(cgesvd, CGESVD)(const char *jobu, const char *jobvt, const BlasInt* m,
368-
const BlasInt* n, scomplex *A, const BlasInt* lda,
369-
float *S, scomplex *U, const BlasInt *ldu, scomplex *Vt,
370-
const BlasInt *ldvt, scomplex *work, const BlasInt *lwork,
371-
float *rwork, BlasInt *info );
372-
void FC_GLOBAL(zgesvd, ZGESVD)(const char *jobu, const char *jobvt, const BlasInt* m,
373-
const BlasInt* n, dcomplex *A, const BlasInt* lda, double *S,
374-
dcomplex *U, const BlasInt *ldu, dcomplex *Vt,
375-
const BlasInt *ldvt, dcomplex *work, const BlasInt *lwork,
376-
double *rwork, BlasInt *info );
378+
dcomplex* A, const BlasInt* lda, double* S,
379+
dcomplex* U, const BlasInt* ldu,
380+
dcomplex* Vt, const BlasInt* ldvt,
381+
dcomplex* work, const BlasInt* lwork,
382+
double* rwork, BlasInt* info);
377383
} // extern "C"
378384
} // namespace mpi
379385
} // namespace chase

ChASE-MPI/blas_templates.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ void t_trsm(const char side, const char uplo, const char trans, const char diag,
108108
const T* a, const std::size_t lda, const T* b,
109109
const std::size_t ldb);
110110

111-
template<typename T>
112-
void t_gesvd(const char jobu, const char jobvt, const std::size_t m, const std::size_t n,
113-
T *A, const std::size_t lda, Base<T> *S, T *U, const std::size_t ldu, T *Vt,
114-
const std::size_t ldvt);
111+
template <typename T>
112+
void t_gesvd(const char jobu, const char jobvt, const std::size_t m,
113+
const std::size_t n, T* A, const std::size_t lda, Base<T>* S, T* U,
114+
const std::size_t ldu, T* Vt, const std::size_t ldvt);
115115
// scalapack
116116
// BLACS
117117
void t_descinit(std::size_t* desc, std::size_t* m, std::size_t* n,

0 commit comments

Comments
 (0)