Skip to content

Commit f1748ac

Browse files
committed
Add mixed-type Metal FA kernels for auto-asymmetric K/V
turbo4_1 and turbo3_1 auto-promote K by 1 bit (K=turbo5_1/V=turbo4_1). Previously this fell back to CPU scalar attention (47 t/s). Now with mixed-type Metal flash attention kernels: 73 t/s (+53%). Changes: - ggml-metal.metal: 8 new FA kernel instantiations for mixed K/V (4 batched + 4 vec, for turbo and rq auto-asymmetric pairs) - ggml-metal-device.cpp: pipeline naming includes V type when K!=V - ggml-metal-device.m: allow mixed turbo/rq types in supports_op - ggml-metal-ops.cpp: relax K==V type assertion for turbo types Results (gpt-oss-120b, M3 Ultra): turbo4_1: 47→73 t/s (+53%), correct output turbo3_1: 47→75 t/s (+59%), marginal quality turbo5_1: 76 t/s (unchanged, symmetric) q8_0: 80 t/s (baseline)
1 parent 94a8ba9 commit f1748ac

4 files changed

Lines changed: 49 additions & 12 deletions

File tree

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,11 +1321,19 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
13211321
// do bounds checks for the mask?
13221322
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
13231323

1324-
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
1325-
"flash_attn_ext",
1326-
ggml_type_name(op->src[1]->type),
1327-
dk,
1328-
dv);
1324+
// Support mixed K/V types for turbo auto-asymmetric
1325+
if (op->src[1]->type != op->src[2]->type) {
1326+
snprintf(base, 256, "kernel_%s_%s_v%s_dk%d_dv%d",
1327+
"flash_attn_ext",
1328+
ggml_type_name(op->src[1]->type),
1329+
ggml_type_name(op->src[2]->type),
1330+
dk, dv);
1331+
} else {
1332+
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
1333+
"flash_attn_ext",
1334+
ggml_type_name(op->src[1]->type),
1335+
dk, dv);
1336+
}
13291337

13301338
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
13311339
base,
@@ -1384,11 +1392,18 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v
13841392
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
13851393
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
13861394

1387-
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
1388-
"flash_attn_ext_vec",
1389-
ggml_type_name(op->src[1]->type),
1390-
dk,
1391-
dv);
1395+
if (op->src[1]->type != op->src[2]->type) {
1396+
snprintf(base, 256, "kernel_%s_%s_v%s_dk%d_dv%d",
1397+
"flash_attn_ext_vec",
1398+
ggml_type_name(op->src[1]->type),
1399+
ggml_type_name(op->src[2]->type),
1400+
dk, dv);
1401+
} else {
1402+
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
1403+
"flash_attn_ext_vec",
1404+
ggml_type_name(op->src[1]->type),
1405+
dk, dv);
1406+
}
13921407

13931408
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
13941409
base,

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1157,7 +1157,12 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
11571157
return false;
11581158
}
11591159
if (op->src[1]->type != op->src[2]->type) {
1160-
return false;
1160+
// Allow mixed turbo/rq types (auto-asymmetric K/V)
1161+
const bool k_is_turbo = (op->src[1]->type >= GGML_TYPE_TURBO3_1 && op->src[1]->type <= GGML_TYPE_RQ6_1);
1162+
const bool v_is_turbo = (op->src[2]->type >= GGML_TYPE_TURBO3_1 && op->src[2]->type <= GGML_TYPE_RQ6_1);
1163+
if (!(k_is_turbo && v_is_turbo)) {
1164+
return false;
1165+
}
11611166
}
11621167
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
11631168
case GGML_OP_SSM_CONV:

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2633,7 +2633,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
26332633
GGML_ASSERT(ne00 % 4 == 0);
26342634

26352635
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2636-
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
2636+
// Allow mixed turbo/rq K/V types for auto-asymmetric
2637+
// GGML_ASSERT(op->src[1]->type == op->src[2]->type);
26372638

26382639
//GGML_ASSERT(ggml_are_same_shape (src1, src2));
26392640
GGML_ASSERT(ne11 == ne21);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6702,6 +6702,16 @@ template [[host_name("kernel_flash_attn_ext_rq4_1_dk64_dv64")]] kernel flash_att
67026702
template [[host_name("kernel_flash_attn_ext_rq5_1_dk64_dv64")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_rq5_1, 4, dequantize_rq5_1, block_rq5_1, 4, dequantize_rq5_1, 64, 64>;
67036703
template [[host_name("kernel_flash_attn_ext_rq6_1_dk64_dv64")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_rq6_1, 4, dequantize_rq6_1, block_rq6_1, 4, dequantize_rq6_1, 64, 64>;
67046704

6705+
// Mixed K/V type flash attention kernels (auto-asymmetric: K gets 1 more bit than V)
6706+
// turbo4_1 → K=turbo5_1, V=turbo4_1
6707+
template [[host_name("kernel_flash_attn_ext_turbo5_1_vturbo4_1_dk64_dv64")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_turbo5_1, 4, dequantize_turbo5_1, block_turbo4_1, 4, dequantize_turbo4_1, 64, 64>;
6708+
// turbo3_1 → K=turbo4_1, V=turbo3_1
6709+
template [[host_name("kernel_flash_attn_ext_turbo4_1_vturbo3_1_dk64_dv64")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_turbo4_1, 4, dequantize_turbo4_1, block_turbo3_1, 4, dequantize_turbo3_1, 64, 64>;
6710+
// rq4_1 → K=rq5_1, V=rq4_1
6711+
template [[host_name("kernel_flash_attn_ext_rq5_1_vrq4_1_dk64_dv64")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_rq5_1, 4, dequantize_rq5_1, block_rq4_1, 4, dequantize_rq4_1, 64, 64>;
6712+
// rq3_1 → K=rq4_1, V=rq3_1
6713+
template [[host_name("kernel_flash_attn_ext_rq4_1_vrq3_1_dk64_dv64")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_rq4_1, 4, dequantize_rq4_1, block_rq3_1, 4, dequantize_rq3_1, 64, 64>;
6714+
67056715
#undef FA_TYPES
67066716
#undef FA_TYPES_BF
67076717
#undef FA_TYPES_F32
@@ -7320,6 +7330,12 @@ template [[host_name("kernel_flash_attn_ext_vec_rq4_1_dk64_dv64")]] kernel flash
73207330
template [[host_name("kernel_flash_attn_ext_vec_rq5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_rq5_1, 16, dequantize_rq5_1_t4, block_rq5_1, 16, dequantize_rq5_1_t4, 64, 64, 2>;
73217331
template [[host_name("kernel_flash_attn_ext_vec_rq6_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_rq6_1, 16, dequantize_rq6_1_t4, block_rq6_1, 16, dequantize_rq6_1_t4, 64, 64, 2>;
73227332

7333+
// Mixed K/V vec flash attention kernels (auto-asymmetric)
7334+
template [[host_name("kernel_flash_attn_ext_vec_turbo5_1_vturbo4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_turbo5_1, 16, dequantize_turbo5_1_t4, block_turbo4_1, 16, dequantize_turbo4_1_t4, 64, 64, 2>;
7335+
template [[host_name("kernel_flash_attn_ext_vec_turbo4_1_vturbo3_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_turbo4_1, 16, dequantize_turbo4_1_t4, block_turbo3_1, 16, dequantize_turbo3_1_t4, 64, 64, 2>;
7336+
template [[host_name("kernel_flash_attn_ext_vec_rq5_1_vrq4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_rq5_1, 16, dequantize_rq5_1_t4, block_rq4_1, 16, dequantize_rq4_1_t4, 64, 64, 2>;
7337+
template [[host_name("kernel_flash_attn_ext_vec_rq4_1_vrq3_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_rq4_1, 16, dequantize_rq4_1_t4, block_rq3_1, 16, dequantize_rq3_1_t4, 64, 64, 2>;
7338+
73237339
#undef FA_TYPES
73247340
#undef FA_TYPES_F32
73257341

0 commit comments

Comments
 (0)