@@ -6269,6 +6269,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_at
62696269template [[host_name(" kernel_flash_attn_ext_f32_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 192 , 128 >;
62706270template [[host_name(" kernel_flash_attn_ext_f32_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 256 , 256 >;
62716271template [[host_name(" kernel_flash_attn_ext_f32_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 320 , 256 >;
6272+ template [[host_name(" kernel_flash_attn_ext_f32_dk512_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 512 , 512 >;
62726273template [[host_name(" kernel_flash_attn_ext_f32_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 576 , 512 >;
62736274
62746275template [[host_name(" kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 32 , 32 >;
@@ -6284,6 +6285,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_at
62846285template [[host_name(" kernel_flash_attn_ext_f16_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 192 , 128 >;
62856286template [[host_name(" kernel_flash_attn_ext_f16_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 256 , 256 >;
62866287template [[host_name(" kernel_flash_attn_ext_f16_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 320 , 256 >;
6288+ template [[host_name(" kernel_flash_attn_ext_f16_dk512_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 512 , 512 >;
62876289template [[host_name(" kernel_flash_attn_ext_f16_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 576 , 512 >;
62886290
62896291#if defined(GGML_METAL_HAS_BF16)
@@ -6300,6 +6302,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_at
63006302template [[host_name(" kernel_flash_attn_ext_bf16_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 128 >;
63016303template [[host_name(" kernel_flash_attn_ext_bf16_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 , 256 >;
63026304template [[host_name(" kernel_flash_attn_ext_bf16_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 320 , 256 >;
6305+ template [[host_name(" kernel_flash_attn_ext_bf16_dk512_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 512 , 512 >;
63036306template [[host_name(" kernel_flash_attn_ext_bf16_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 576 , 512 >;
63046307#endif
63056308
@@ -6316,6 +6319,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_at
63166319template [[host_name(" kernel_flash_attn_ext_q4_0_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 192 , 128 >;
63176320template [[host_name(" kernel_flash_attn_ext_q4_0_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 256 , 256 >;
63186321template [[host_name(" kernel_flash_attn_ext_q4_0_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 320 , 256 >;
6322+ template [[host_name(" kernel_flash_attn_ext_q4_0_dk512_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 512 , 512 >;
63196323template [[host_name(" kernel_flash_attn_ext_q4_0_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 576 , 512 >;
63206324
63216325template [[host_name(" kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 32 , 32 >;
@@ -6331,6 +6335,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_at
63316335template [[host_name(" kernel_flash_attn_ext_q4_1_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 192 , 128 >;
63326336template [[host_name(" kernel_flash_attn_ext_q4_1_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 256 , 256 >;
63336337template [[host_name(" kernel_flash_attn_ext_q4_1_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 320 , 256 >;
6338+ template [[host_name(" kernel_flash_attn_ext_q4_1_dk512_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 512 , 512 >;
63346339template [[host_name(" kernel_flash_attn_ext_q4_1_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 576 , 512 >;
63356340
63366341template [[host_name(" kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 32 , 32 >;
@@ -6346,6 +6351,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_at
63466351template [[host_name(" kernel_flash_attn_ext_q5_0_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 192 , 128 >;
63476352template [[host_name(" kernel_flash_attn_ext_q5_0_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 256 , 256 >;
63486353template [[host_name(" kernel_flash_attn_ext_q5_0_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 320 , 256 >;
6354+ template [[host_name(" kernel_flash_attn_ext_q5_0_dk512_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 512 , 512 >;
63496355template [[host_name(" kernel_flash_attn_ext_q5_0_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 576 , 512 >;
63506356
63516357template [[host_name(" kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 32 , 32 >;
@@ -6361,6 +6367,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_at
63616367template [[host_name(" kernel_flash_attn_ext_q5_1_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 192 , 128 >;
63626368template [[host_name(" kernel_flash_attn_ext_q5_1_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 256 , 256 >;
63636369template [[host_name(" kernel_flash_attn_ext_q5_1_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 320 , 256 >;
6370+ template [[host_name(" kernel_flash_attn_ext_q5_1_dk512_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 512 , 512 >;
63646371template [[host_name(" kernel_flash_attn_ext_q5_1_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 576 , 512 >;
63656372
63666373template [[host_name(" kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 32 , 32 >;
@@ -6376,6 +6383,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_at
63766383template [[host_name(" kernel_flash_attn_ext_q8_0_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 192 , 128 >;
63776384template [[host_name(" kernel_flash_attn_ext_q8_0_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 256 , 256 >;
63786385template [[host_name(" kernel_flash_attn_ext_q8_0_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 320 , 256 >;
6386+ template [[host_name(" kernel_flash_attn_ext_q8_0_dk512_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 512 , 512 >;
63796387template [[host_name(" kernel_flash_attn_ext_q8_0_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 576 , 512 >;
63806388
63816389#undef FA_TYPES
@@ -6957,6 +6965,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flas
69576965template [[host_name(" kernel_flash_attn_ext_vec_q5_1_dk320_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8 , dequantize_q5_1_t4, block_q5_1, 8 , dequantize_q5_1_t4, 320 , 256 , 2 >;
69586966template [[host_name(" kernel_flash_attn_ext_vec_q8_0_dk320_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8 , dequantize_q8_0_t4, block_q8_0, 8 , dequantize_q8_0_t4, 320 , 256 , 2 >;
69596967
6968+ template [[host_name(" kernel_flash_attn_ext_vec_f32_dk512_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1 , dequantize_f32_t4, float4, 1 , dequantize_f32_t4, 512 , 512 , 1 >;
6969+ template [[host_name(" kernel_flash_attn_ext_vec_f16_dk512_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 512 , 512 , 1 >;
6970+ #if defined(GGML_METAL_HAS_BF16)
6971+ template [[host_name(" kernel_flash_attn_ext_vec_bf16_dk512_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1 , dequantize_bf16_t4, bfloat4, 1 , dequantize_bf16_t4, 512 , 512 , 1 >;
6972+ #endif
6973+ template [[host_name(" kernel_flash_attn_ext_vec_q4_0_dk512_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8 , dequantize_q4_0_t4, block_q4_0, 8 , dequantize_q4_0_t4, 512 , 512 , 1 >;
6974+ template [[host_name(" kernel_flash_attn_ext_vec_q4_1_dk512_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8 , dequantize_q4_1_t4, block_q4_1, 8 , dequantize_q4_1_t4, 512 , 512 , 1 >;
6975+ template [[host_name(" kernel_flash_attn_ext_vec_q5_0_dk512_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8 , dequantize_q5_0_t4, block_q5_0, 8 , dequantize_q5_0_t4, 512 , 512 , 1 >;
6976+ template [[host_name(" kernel_flash_attn_ext_vec_q5_1_dk512_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8 , dequantize_q5_1_t4, block_q5_1, 8 , dequantize_q5_1_t4, 512 , 512 , 1 >;
6977+ template [[host_name(" kernel_flash_attn_ext_vec_q8_0_dk512_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8 , dequantize_q8_0_t4, block_q8_0, 8 , dequantize_q8_0_t4, 512 , 512 , 1 >;
6978+
69606979template [[host_name(" kernel_flash_attn_ext_vec_f32_dk576_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1 , dequantize_f32_t4, float4, 1 , dequantize_f32_t4, 576 , 512 , 2 >;
69616980template [[host_name(" kernel_flash_attn_ext_vec_f16_dk576_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 576 , 512 , 2 >;
69626981#if defined(GGML_METAL_HAS_BF16)
0 commit comments