Skip to content

Commit f9bd518

Browse files
authored
vulkan: make FA mask/softcap enables spec constants (ggml-org#19309)
* vulkan: make FA mask/softcap enables spec constants * don't specialize for sinks * bump timeout a little bit
1 parent 7fcf1ef commit f9bd518

6 files changed

Lines changed: 45 additions & 38 deletions

File tree

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ jobs:
468468
export GGML_VK_VISIBLE_DEVICES=0
469469
export GGML_VK_DISABLE_F16=1
470470
# This is using llvmpipe and runs slower than other backends
471-
ctest -L main --verbose --timeout 4200
471+
ctest -L main --verbose --timeout 4800
472472
473473
ubuntu-24-cmake-webgpu:
474474
runs-on: ubuntu-24.04

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -402,19 +402,19 @@ enum FaCodePath {
402402
};
403403

404404
struct vk_fa_pipeline_state {
405-
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt)
406-
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {}
405+
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags)
406+
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {}
407407

408408
uint32_t HSK, HSV;
409409
bool small_rows, small_cache;
410410
FaCodePath path;
411411
bool aligned;
412412
bool f32acc;
413-
bool use_mask_opt;
413+
uint32_t flags;
414414

415415
bool operator<(const vk_fa_pipeline_state &b) const {
416-
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) <
417-
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt);
416+
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) <
417+
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags);
418418
}
419419
};
420420

@@ -3193,7 +3193,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
31933193
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
31943194
};
31953195

3196-
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector<uint32_t> {
3196+
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
31973197
// For large number of rows, 128 invocations seems to work best.
31983198
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
31993199
// can't use 256 for D==80.
@@ -3225,7 +3225,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
32253225
// AMD prefers loading K directly from global memory
32263226
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
32273227

3228-
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt};
3228+
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
32293229
};
32303230

32313231
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
@@ -3237,19 +3237,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
32373237
FaCodePath path = fa.first.path; \
32383238
bool aligned = fa.first.aligned; \
32393239
bool f32acc = fa.first.f32acc; \
3240-
bool use_mask_opt = fa.first.use_mask_opt; \
3240+
uint32_t flags = fa.first.flags; \
32413241
if (path == FAPATH) { \
32423242
if (aligned) { \
32433243
if (f32acc) { \
3244-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3244+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32453245
} else { \
3246-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3246+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32473247
} \
32483248
} else { \
32493249
if (f32acc) { \
3250-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3250+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32513251
} else { \
3252-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3252+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32533253
} \
32543254
} \
32553255
} \
@@ -8595,10 +8595,26 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
85958595

85968596
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
85978597

8598+
float scale = 1.0f;
8599+
float max_bias = 0.0f;
8600+
float logit_softcap = 0.0f;
8601+
8602+
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8603+
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8604+
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
8605+
8606+
if (logit_softcap != 0) {
8607+
scale /= logit_softcap;
8608+
}
8609+
85988610
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
85998611
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
86008612

8601-
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt);
8613+
uint32_t flags = (use_mask_opt ? 1 : 0) |
8614+
(mask != nullptr ? 2 : 0) |
8615+
(logit_softcap != 0 ? 4 : 0);
8616+
8617+
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags);
86028618

86038619
vk_pipeline pipeline = nullptr;
86048620

@@ -8678,18 +8694,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
86788694
}
86798695
}
86808696

8681-
float scale = 1.0f;
8682-
float max_bias = 0.0f;
8683-
float logit_softcap = 0.0f;
8684-
8685-
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8686-
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8687-
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
8688-
8689-
if (logit_softcap != 0) {
8690-
scale /= logit_softcap;
8691-
}
8692-
86938697
const uint32_t n_head_kv = neq2;
86948698
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
86958699
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -8703,7 +8707,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
87038707
vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
87048708
vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
87058709

8706-
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
8710+
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
87078711

87088712
if (use_mask_opt)
87098713
{

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ void main() {
127127
continue;
128128
}
129129
// Only load if the block is not all zeros
130-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
130+
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
131131
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
132132

133133
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
@@ -181,15 +181,15 @@ void main() {
181181
}
182182
}
183183

184-
if (p.logit_softcap != 0.0f) {
184+
if (LOGIT_SOFTCAP) {
185185
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
186186
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
187187
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
188188
}
189189
}
190190
}
191191

192-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
192+
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
193193
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
194194
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
195195
float mvf = masksh[c * cols_per_iter + col_tid][r];

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ layout (constant_id = 5) const uint32_t Clamp = 0;
1010
layout (constant_id = 6) const uint32_t D_split = 16;
1111
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
1212
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
13-
layout (constant_id = 9) const bool USE_MASK_OPT = false;
13+
layout (constant_id = 9) const uint32_t Flags = 0;
14+
15+
const bool USE_MASK_OPT = (Flags & 1) != 0;
16+
const bool MASK_ENABLE = (Flags & 2) != 0;
17+
const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
1418

1519
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
1620
const uint32_t HSK_pad = (HSK + 15) & ~15;
@@ -60,7 +64,6 @@ layout (push_constant) uniform parameter {
6064
} p;
6165

6266
#define SINK_ENABLE_BIT (1<<24)
63-
#define MASK_ENABLE_BIT (1<<16)
6467
#define N_LOG2_MASK 0xFFFF
6568

6669
layout (binding = 4) readonly buffer S {float data_s[];};

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ void main() {
160160
mask_cache[idx] = f16vec4(0);
161161
}
162162

163-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
163+
if (MASK_ENABLE) {
164164

165165
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
166166
mask_opt_idx = j / 16;
@@ -303,7 +303,7 @@ void main() {
303303
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
304304
barrier();
305305

306-
if (p.logit_softcap != 0.0f) {
306+
if (LOGIT_SOFTCAP) {
307307
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
308308
uint32_t c = (idx + tid) / (Br / 4);
309309
uint32_t r = (idx + tid) % (Br / 4);
@@ -314,7 +314,7 @@ void main() {
314314
barrier();
315315
}
316316

317-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
317+
if (MASK_ENABLE) {
318318
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
319319
uint32_t c = (idx + tid) / (Br / 4);
320320
uint32_t r = (idx + tid) % (Br / 4);

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ void main() {
155155
for (uint32_t j = start_j; j < end_j; ++j) {
156156

157157
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
158-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
158+
if (MASK_ENABLE) {
159159

160160
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
161161
mask_opt_idx = j / 16;
@@ -197,14 +197,14 @@ void main() {
197197
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
198198
S = coopMatMulAdd(Qf16, K_T, S);
199199

200-
if (p.logit_softcap != 0.0f) {
200+
if (LOGIT_SOFTCAP) {
201201
[[unroll]]
202202
for (int k = 0; k < S.length(); ++k) {
203203
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
204204
}
205205
}
206206

207-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
207+
if (MASK_ENABLE) {
208208
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
209209
}
210210

0 commit comments

Comments
 (0)