-
Notifications
You must be signed in to change notification settings - Fork 17.6k
Programmatic Dependent Launch (PDL) for more performance on newer NVIDIA GPUs (Hopper+) #22522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
9f4ddbc
73d28e4
000f462
b68aee7
101583e
0e7aa04
d8eb8ab
12ddf12
adfd442
7f1342a
f3fe281
c2d9d47
d942a3a
11150f0
71f8f58
8664310
909ec1f
3c584d0
25bbc88
c5044bf
7e76151
8746582
dac466d
23a24c5
ef28cda
5e318bf
f3b8665
0a7d8c3
75cd1b0
338477a
83e3c79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| #include <cstdio> | ||
| #include <string> | ||
| #include <unordered_map> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
||
| #if defined(GGML_USE_HIP) | ||
|
|
@@ -50,6 +51,7 @@ | |
| #define GGML_CUDA_CC_TURING 750 | ||
| #define GGML_CUDA_CC_AMPERE 800 | ||
| #define GGML_CUDA_CC_ADA_LOVELACE 890 | ||
| #define GGML_CUDA_CC_HOPPER 900 | ||
| // While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see | ||
| // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms | ||
| #define GGML_CUDA_CC_BLACKWELL 1200 | ||
|
|
@@ -107,6 +109,14 @@ | |
| # define GGML_CUDA_USE_CUB | ||
| #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 | ||
|
|
||
| #if defined(GGML_CUDA_USE_PDL) | ||
| # define GGML_CUDA_PDL_SYNC() cudaGridDependencySynchronize() | ||
| # define GGML_CUDA_PDL_LC() cudaTriggerProgrammaticLaunchCompletion() | ||
|
Comment on lines
+113
to
+114
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For transpilation we need to add the corresponding aliases for Musa/Hip (or guard this to be CUDA-only for now if these aliases are absent) |
||
| #else | ||
| # define GGML_CUDA_PDL_SYNC() // no-op when PDL disabled on HIP/MUSA/pre-Hopper | ||
| # define GGML_CUDA_PDL_LC() | ||
| #endif | ||
|
|
||
| #ifdef __CUDA_ARCH_LIST__ | ||
| constexpr bool ggml_cuda_has_arch_impl(int) { | ||
| return false; | ||
|
|
@@ -165,6 +175,58 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in | |
|
|
||
| #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) | ||
|
|
||
| struct ggml_cuda_kernel_launch_params { | ||
| dim3 block_nums; | ||
| dim3 block_dims; | ||
| size_t shmem; | ||
| cudaStream_t stream; | ||
|
|
||
| // size_t shmem | ||
| ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, size_t shmem_, cudaStream_t stream_) | ||
| : block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {} | ||
|
|
||
| // int shmem | ||
| ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, cudaStream_t stream_) | ||
| : block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {} | ||
| }; | ||
|
|
||
| #if defined(GGML_CUDA_USE_PDL) | ||
| struct ggml_cuda_pdl_config { | ||
| cudaLaunchAttribute attr; | ||
| cudaLaunchConfig_t cfg; | ||
|
|
||
| ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) { | ||
| attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; | ||
| attr.val.programmaticStreamSerializationAllowed = 1; | ||
|
|
||
| cfg = {}; | ||
| cfg.gridDim = params.block_nums; | ||
| cfg.blockDim = params.block_dims; | ||
| cfg.dynamicSmemBytes = params.shmem; | ||
| cfg.stream = params.stream; | ||
| cfg.attrs = &attr; | ||
| cfg.numAttrs = 1; | ||
| } | ||
|
|
||
| // Delete due to &attr | ||
| ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete; | ||
| ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete; | ||
| ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete; | ||
|
|
||
| }; | ||
| #endif //defined(GGML_CUDA_USE_PDL) | ||
|
|
||
|
|
||
| template<typename Kernel, typename... Args> | ||
| static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) { | ||
| #if defined(GGML_CUDA_USE_PDL) | ||
| auto pdl_cfg = ggml_cuda_pdl_config(launch_params); | ||
| CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... )); | ||
| #else | ||
| kernel<<<launch_params.block_nums, launch_params.block_dims, launch_params.shmem, launch_params.stream>>>(std::forward<Args>(args)... ); | ||
| #endif //defined(GGML_CUDA_USE_PDL) | ||
| } | ||
|
|
||
| #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) | ||
| static const char * cublas_get_error_str(const cublasStatus_t err) { | ||
| return cublasGetStatusString(err); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,7 @@ static __global__ void flash_attn_ext_vec( | |
| const int32_t nb21, const int32_t nb22, const int64_t nb23, | ||
| const int32_t ne31, const int32_t ne32, const int32_t ne33, | ||
| const int32_t nb31, const int32_t nb32, const int64_t nb33) { | ||
| GGML_CUDA_PDL_LC(); // FATTN_VEC try 1; on maxq | ||
| #ifdef FLASH_ATTN_AVAILABLE | ||
|
|
||
| // Skip unused kernel variants for faster compilation: | ||
|
|
@@ -136,6 +137,9 @@ static __global__ void flash_attn_ext_vec( | |
| #endif // V_DOT2_F32_F16_AVAILABLE | ||
| int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; | ||
| float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; | ||
|
|
||
| GGML_CUDA_PDL_SYNC(); | ||
| // GGML_CUDA_PDL_LC(); // FATTN_VEC try 2; on maxq | ||
|
Comment on lines
+141
to
+142
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please clean-up |
||
| if constexpr (Q_q8_1) { | ||
| #pragma unroll | ||
| for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems unused?