@@ -624,8 +624,6 @@ struct vk_device_struct {
624624 // floor(log2(maxComputeWorkGroupInvocations))
625625 uint32_t max_workgroup_size_log2 {};
626626
627- bool flash_attention_fp16;
628-
629627 bool coopmat_support;
630628 bool coopmat_acc_f32_support {};
631629 bool coopmat_acc_f16_support {};
@@ -2978,11 +2976,15 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
29782976 }
29792977}
29802978
2981- static vk_fa_pipeline_state get_fa_pipeline_state(const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
2979+ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
29822980 bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
2981+ const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
2982+ (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
2983+
29832984 uint32_t flags = (use_mask_opt ? 1 : 0) |
29842985 (use_mask ? 2 : 0) |
2985- (use_logit_softcap ? 4 : 0);
2986+ (use_logit_softcap ? 4 : 0) |
2987+ (old_amd_windows ? 8 : 0);
29862988
29872989 const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
29882990
@@ -3384,7 +3386,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
33843386 } \
33853387 }
33863388
3387- if (device->flash_attention_fp16 ) {
3389+ if (device->fp16 ) {
33883390 CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
33893391 CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
33903392 CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
@@ -5423,10 +5425,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
54235425 device->mmvq_mode = 1;
54245426 }
54255427
5426- // Driver issues with older AMD GPUs on Windows, see https://github.com/ggml-org/llama.cpp/pull/19625#issuecomment-3940840613
5427- const bool is_amd_proprietary_gcn = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture == AMD_GCN && device->driver_id == vk::DriverId::eAmdProprietary;
5428- device->flash_attention_fp16 = device->fp16 && !is_amd_proprietary_gcn;
5429-
54305428 return device;
54315429 }
54325430
@@ -8567,7 +8565,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
85678565 const uint32_t Br = params.block_rows;
85688566 const uint32_t Bc = params.block_cols;
85698567
8570- const uint32_t float_type_size = device->flash_attention_fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
8568+ const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
85718569
85728570 // tmpsh is overestimated slightly
85738571 const uint32_t tmpsh = wg_size * sizeof(float);
@@ -8690,7 +8688,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
86908688 uint32_t workgroups_y = (uint32_t)neq2;
86918689 uint32_t workgroups_z = (uint32_t)neq3;
86928690
8693- const bool f32acc = !ctx->device->flash_attention_fp16 || dst->op_params[3] == GGML_PREC_F32;
8691+ const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
86948692
86958693 // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
86968694 // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
@@ -8745,7 +8743,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
87458743
87468744 // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
87478745 bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
8748- vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(tuning_params, HSK, HSV, aligned, f32acc,
8746+ vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
87498747 mask != nullptr, use_mask_opt, logit_softcap != 0);
87508748
87518749 vk_pipeline pipeline = nullptr;
0 commit comments