Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9f4ddbc
Adds initial PDL setup.
aendk Jan 21, 2026
73d28e4
Adds PDL barriers based on simple heuristic: place "sync" before firs…
aendk Jan 21, 2026
000f462
Further optimization pass of the first half of kernels
aendk Jan 21, 2026
b68aee7
Optimized PDL barriers for the second batch of kernels
aendk Jan 22, 2026
101583e
Further refinements after rebase.
aendk Feb 4, 2026
0e7aa04
Moves pdl logic to separate function, removes some whitespace
aendk Feb 5, 2026
d8eb8ab
Strips post-hoc PDL logic
aendk Feb 13, 2026
12ddf12
Adds stream capture PDL setup. Enrolls quantize_q8_1 to leverage pdl to
aendk Feb 13, 2026
adfd442
Enrolls mul_mat_vec_q, rms_norm_f32 and k_bin_bcast (partly) into PDL
aendk Feb 13, 2026
7f1342a
Enrolls mmvf, rope, set-rows and topk kernels for gpt-oss into PDL
aendk Feb 18, 2026
f3fe281
Merge branch 'master' into akieslinger/pdl-cuda
aendk Feb 18, 2026
c2d9d47
Introduce ggml_cuda_kernel_launch, to abstract away cudaLaunchKernelEx,
aendk Feb 18, 2026
d942a3a
Enrolls cpy_scalar_contiguous, k_get_rows_float and rms_norm_f32
aendk Feb 18, 2026
11150f0
Enrolls flash_attn_combine_results
aendk Feb 18, 2026
71f8f58
Fix: Drops needless and broken check of CUDA arch for PDL. PDL either
aendk Feb 19, 2026
8664310
Enrolls flash-attention kernels to pdl
aendk Feb 19, 2026
909ec1f
Fix: inlines ggml_cuda_kernel_launch, and uses perfect forwarding for
aendk Feb 20, 2026
3c584d0
Merge branch 'master' into akieslinger/pdl-cuda
aendk Feb 20, 2026
25bbc88
Perf: Enrolls k_bin_bcast variadic template invocation into PDL, via
aendk Feb 20, 2026
c5044bf
Enrolls all remaining kernels for qwen3-coder-next into PDL
aendk Feb 20, 2026
7e76151
Remove all PDL LC calls to create a baseline
aendk Mar 11, 2026
8746582
Merge branch 'master' into akieslinger/pdl-cuda
aendk Mar 11, 2026
dac466d
Merge branch 'master' into akieslinger/pdl-cuda
aendk Mar 24, 2026
23a24c5
Added LC according to internal guidance and tested kernel performance.
aendk Mar 25, 2026
ef28cda
Enrols missing qwen3-5 kernels passively into PDL.
aendk Apr 2, 2026
5e318bf
Kernel optimizations (LC signals) for qwen3.5
aendk Apr 10, 2026
f3b8665
Enrolls ssm-scan kernels into PDL
aendk Apr 10, 2026
0a7d8c3
Merge branch 'master' into akieslinger/pdl-cuda-lc-experiments
aendk Apr 16, 2026
75cd1b0
Merge branch 'master' into akieslinger/pdl-cuda-lc-experiments
aendk Apr 20, 2026
338477a
Merge branch 'master' into akieslinger-pdl-cuda-merge-test
aendk Apr 29, 2026
83e3c79
Adds GGML_CUDA_PDL command line option to toggle PDL.
aendk Apr 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM"
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
option(GGML_CUDA_PDL "ggml: use Programmatic Dependent Launch (NVIDIA CC >= 9.0)" OFF)
option(GGML_CUDA_NCCL "ggml: use NVIDIA Collective Comm. Library" ON)
set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING
"ggml: cuda link binary compression mode; requires cuda 12.8+")
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ if (CUDAToolkit_FOUND)
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
endif()

if (GGML_CUDA_PDL)
add_compile_definitions(GGML_CUDA_USE_PDL)
endif()

if (GGML_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif()
Expand Down
33 changes: 15 additions & 18 deletions ggml/src/ggml-cuda/binbcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include <cstdint>
#include <utility>

template<typename T, size_t>
using type_for_index = T;

static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
GGML_UNUSED(a);
Expand Down Expand Up @@ -52,6 +55,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const int s12,
const int s13,
src1_ptrs... src1s) {
GGML_CUDA_PDL_LC(); // BINBCAST try 1; 352.28, 352.62, 352.17, 351.96 on maxq
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
Expand All @@ -69,9 +73,11 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;

GGML_CUDA_PDL_SYNC();
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;

// GGML_CUDA_PDL_LC(); // BINBCAST try 2; 352.44 352.42, 352.05 on maxq
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const uint32_t i10 = fastmodulo(i0, ne10);

Expand Down Expand Up @@ -136,6 +142,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;

GGML_CUDA_PDL_SYNC();
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;

Expand Down Expand Up @@ -282,35 +289,24 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);

if constexpr (sizeof...(I) > 0) {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
{
auto launch_params = ggml_cuda_kernel_launch_params((dim3)block_num, block_size, 0, stream);
ggml_cuda_kernel_launch(k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params,
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
{
auto launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(k_bin_bcast<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params,
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00 ,s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
}
}
}
Expand All @@ -333,6 +329,7 @@ static __global__ void k_repeat_back(
}

T sum = 0;
GGML_CUDA_PDL_SYNC();
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
Expand Down
62 changes: 62 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cstdio>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#if defined(GGML_USE_HIP)
Expand All @@ -50,6 +51,7 @@
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
#define GGML_CUDA_CC_HOPPER 900
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unused?

// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
#define GGML_CUDA_CC_BLACKWELL 1200
Expand Down Expand Up @@ -107,6 +109,14 @@
# define GGML_CUDA_USE_CUB
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070

#if defined(GGML_CUDA_USE_PDL)
# define GGML_CUDA_PDL_SYNC() cudaGridDependencySynchronize()
# define GGML_CUDA_PDL_LC() cudaTriggerProgrammaticLaunchCompletion()
Comment on lines +113 to +114
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For transpilation we need to add the corresponding aliases for Musa/Hip (or guard this to be CUDA-only for now if these aliases are absent)

#else
# define GGML_CUDA_PDL_SYNC() // no-op when PDL disabled on HIP/MUSA/pre-Hopper
# define GGML_CUDA_PDL_LC()
#endif

#ifdef __CUDA_ARCH_LIST__
constexpr bool ggml_cuda_has_arch_impl(int) {
return false;
Expand Down Expand Up @@ -165,6 +175,58 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in

#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)

struct ggml_cuda_kernel_launch_params {
dim3 block_nums;
dim3 block_dims;
size_t shmem;
cudaStream_t stream;

// size_t shmem
ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, size_t shmem_, cudaStream_t stream_)
: block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {}

// int shmem
ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, cudaStream_t stream_)
: block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {}
};

#if defined(GGML_CUDA_USE_PDL)
struct ggml_cuda_pdl_config {
cudaLaunchAttribute attr;
cudaLaunchConfig_t cfg;

ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) {
attr.id = cudaLaunchAttributeProgrammaticStreamSerialization;
attr.val.programmaticStreamSerializationAllowed = 1;

cfg = {};
cfg.gridDim = params.block_nums;
cfg.blockDim = params.block_dims;
cfg.dynamicSmemBytes = params.shmem;
cfg.stream = params.stream;
cfg.attrs = &attr;
cfg.numAttrs = 1;
}

// Delete due to &attr
ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete;
ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete;
ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete;

};
#endif //defined(GGML_CUDA_USE_PDL)


template<typename Kernel, typename... Args>
static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) {
#if defined(GGML_CUDA_USE_PDL)
auto pdl_cfg = ggml_cuda_pdl_config(launch_params);
CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... ));
#else
kernel<<<launch_params.block_nums, launch_params.block_dims, launch_params.shmem, launch_params.stream>>>(std::forward<Args>(args)... );
#endif //defined(GGML_CUDA_USE_PDL)
}

#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
static const char * cublas_get_error_str(const cublasStatus_t err) {
return cublasGetStatusString(err);
Expand Down
5 changes: 3 additions & 2 deletions ggml/src/ggml-cuda/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont

const int64_t n = ne0 * ne1 * ne2;

GGML_CUDA_PDL_SYNC();
for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) {
if constexpr (dim == 0) {
const int64_t row = i / ne0;
Expand Down Expand Up @@ -64,8 +65,8 @@ static void concat_f32_cuda(const float * x,
const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;

if (dim == 0) {
concat_f32_cont<0>
<<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
auto launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
return;
}
if (dim == 1) {
Expand Down
20 changes: 14 additions & 6 deletions ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13) {
GGML_CUDA_PDL_LC(); // try 1
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;

if (i >= ne) {
Expand All @@ -36,6 +37,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;

GGML_CUDA_PDL_SYNC();
cpy_1(cx + x_offset, cdst + dst_offset);
}

Expand All @@ -59,6 +61,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
__shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
int cur_tile_buf = 0;

GGML_CUDA_PDL_SYNC();
#pragma unroll
for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {

Expand Down Expand Up @@ -142,6 +145,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;

GGML_CUDA_PDL_SYNC();
cpy_blck(cx + x_offset, cdst + dst_offset);
}

Expand All @@ -168,6 +172,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;

GGML_CUDA_PDL_SYNC();
cpy_blck(cx + x_offset, cdst + dst_offset);
}

Expand All @@ -182,6 +187,7 @@ static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const
const src_t * x = (const src_t *) cx;
dst_t * dst = (dst_t *) cdst;

GGML_CUDA_PDL_SYNC();
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
}

Expand All @@ -192,8 +198,8 @@ cudaStream_t stream) {

const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne);
auto launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar_contiguous<src_t, dst_t>, launch_params, cx, cdst, ne);
}

template<typename src_t, typename dst_t, bool transposed = false>
Expand Down Expand Up @@ -223,13 +229,15 @@ static void ggml_cpy_scalar_cuda(
GGML_ASSERT(grid_z < USHRT_MAX);
dim3 dimGrid(grid_x, grid_y, grid_z);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
auto launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar_transpose<dst_t>, launch_params,
cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
auto launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar<cpy_1_scalar<src_t, dst_t>>, launch_params,
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
}

Expand Down
15 changes: 11 additions & 4 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ static __global__ void flash_attn_mask_to_KV_max(
const int tid = threadIdx.x;
const int sequence = blockIdx.y;
const int jt = blockIdx.x;
GGML_CUDA_PDL_SYNC();

mask += sequence*s33 + jt*ncols1*s31;

Expand Down Expand Up @@ -774,6 +775,7 @@ static __global__ void flash_attn_stream_k_fixup_general(
const int jc = j*ncols2 + c;
const int tid = threadIdx.x;

GGML_CUDA_PDL_SYNC();
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);

const int kbc0 = int64_t(bidx0 + 0)*total_work / gridDim.x;
Expand Down Expand Up @@ -867,6 +869,7 @@ static __global__ void flash_attn_combine_results(
const float2 * __restrict__ VKQ_meta,
float * __restrict__ dst,
const int parallel_blocks) {
GGML_CUDA_PDL_LC(); // FATTN_COMBINE_RESULTS try 1; on maxq
// Dimension 0: threadIdx.x
// Dimension 1: blockIdx.x
// Dimension 2: blockIdx.y
Expand All @@ -890,10 +893,12 @@ static __global__ void flash_attn_combine_results(
__builtin_assume(tid < D);

extern __shared__ float2 meta[];
GGML_CUDA_PDL_SYNC();
for (int i = tid; i < 2*parallel_blocks; i += D) {
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
}

// GGML_CUDA_PDL_LC(); // FATTN_COMBINE_RESULTS try 2; on maxq
__syncthreads();

float kqmax = meta[0].x;
Expand Down Expand Up @@ -1146,7 +1151,9 @@ void launch_fattn(
const uint3 ne01 = init_fastdiv_values(Q->ne[1]);

GGML_ASSERT(block_dim.x % warp_size == 0);
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(

auto launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream);
ggml_cuda_kernel_launch(fattn_kernel, launch_params,
(const char *) Q->data,
K_data,
V_data,
Expand Down Expand Up @@ -1204,9 +1211,9 @@ void launch_fattn(
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);

flash_attn_combine_results<DV>
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
auto launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream);
ggml_cuda_kernel_launch(flash_attn_combine_results<DV>, launch_params,
dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
}
CUDA_CHECK(cudaGetLastError());
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/fattn-tile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,8 @@ static __global__ void flash_attn_tile(
constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.
constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.

GGML_CUDA_PDL_SYNC(); // needs to guard Q, K, V, mask, sinks, KV_max, dst, dst_meta data accesses. Conservatively placed, not optimal

// Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.
// KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.
// KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/fattn-vec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static __global__ void flash_attn_ext_vec(
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
GGML_CUDA_PDL_LC(); // FATTN_VEC try 1; on maxq
#ifdef FLASH_ATTN_AVAILABLE

// Skip unused kernel variants for faster compilation:
Expand Down Expand Up @@ -136,6 +137,9 @@ static __global__ void flash_attn_ext_vec(
#endif // V_DOT2_F32_F16_AVAILABLE
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];

GGML_CUDA_PDL_SYNC();
// GGML_CUDA_PDL_LC(); // FATTN_VEC try 2; on maxq
Comment on lines +141 to +142
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean-up

if constexpr (Q_q8_1) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ static __global__ void flash_attn_ext_f16(
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);

GGML_CUDA_PDL_SYNC(); // needs to guard Q, K, V, mask, sinks, KV_max, dst, dst_meta data accesses. Conservatively placed, not optimal
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Expand Down
Loading