Skip to content

Commit 10cb187

Browse files
TheTomclaude
andcommitted
feat: symmetric turbo3 K support in TurboFlash + research conclusions
Added turbo3 K dequant path to TurboFlash kernel via function constant FC_turbo_flash_p1_k_is_turbo3. Symmetric turbo3/turbo3 now dispatches through TurboFlash instead of baseline FA. Result: symmetric TurboFlash is neutral vs baseline FA (-0.7%). This confirms the 56->145 tok/s gap to Eric's MLX-Swift is 100% framework overhead (dispatch count, graph evaluation), not kernel-level. Best config remains asymmetric q8_0-K/turbo3-V with TurboFlash V4 + simd_shuffle WHT: 56.82 tok/s, 93% of q8_0, +1.5% over baseline. Co-Authored-By: tturney@psyguard.ai Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b0b8dde commit 10cb187

2 files changed

Lines changed: 55 additions & 19 deletions

File tree

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2709,8 +2709,8 @@ static bool ggml_metal_op_flash_attn_ext_use_turbo_flash(const ggml_tensor * op)
27092709
// Only for turbo3 V cache
27102710
if (type_v != GGML_TYPE_TURBO3_0) return false;
27112711

2712-
// Only for q8_0 K (asymmetric) — the primary target config
2713-
if (type_k != GGML_TYPE_Q8_0) return false;
2712+
// Only for q8_0 or turbo3 K — asymmetric or symmetric turbo
2713+
if (type_k != GGML_TYPE_Q8_0 && type_k != GGML_TYPE_TURBO3_0) return false;
27142714

27152715
// Only for supported head dims (64, 96, 128) and power-of-2 aligned to 32
27162716
if (ne00 % 32 != 0) return false;
@@ -2947,7 +2947,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
29472947

29482948
// ==================== TurboFlash two-pass dispatch ====================
29492949
// Intercept before the normal VEC/non-VEC path when conditions are met:
2950-
// - V is turbo3, K is q8_0
2950+
// - V is turbo3, K is q8_0 or turbo3
29512951
// - Single-token decode (ne01 == 1)
29522952
// - Supported head dimensions (64, 96, 128)
29532953
fprintf(stderr, "TURBOFLASH: pre-check ne01=%d type_k=%d type_v=%d ne00=%d TURBO3=%d\n",
@@ -3013,19 +3013,23 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
30133013
};
30143014

30153015
// Pipeline name: kernel_turbo_flash_p1_dk{dk}_dv{dv}
3016+
const ggml_type type_k = op->src[1]->type;
3017+
const bool k_is_turbo3 = (type_k == GGML_TYPE_TURBO3_0);
3018+
30163019
char p1_base[128];
30173020
char p1_name[256];
30183021
snprintf(p1_base, 128, "kernel_turbo_flash_p1_dk%d_dv%d", dk, dv);
3019-
snprintf(p1_name, 256, "%s_mask=%d_dk=%d_dv=%d",
3020-
p1_base, has_mask ? 1 : 0, dk, dv);
3022+
snprintf(p1_name, 256, "%s_mask=%d_dk=%d_dv=%d_kt3=%d",
3023+
p1_base, has_mask ? 1 : 0, dk, dv, k_is_turbo3 ? 1 : 0);
30213024

3022-
// The kernel uses FC_turbo_flash_p1_has_mask as a function constant
3025+
// The kernel uses FC_turbo_flash_p1_has_mask and FC_turbo_flash_p1_k_is_turbo3 as function constants
30233026
ggml_metal_pipeline_with_params res_p1 = ggml_metal_library_get_pipeline(lib, p1_name);
30243027
if (!res_p1.pipeline) {
30253028
ggml_metal_cv_t cv = ggml_metal_cv_init();
3026-
ggml_metal_cv_set_int32(cv, dk, FC_TURBO_FLASH_P1 + 0);
3027-
ggml_metal_cv_set_int32(cv, dv, FC_TURBO_FLASH_P1 + 1);
3028-
ggml_metal_cv_set_bool(cv, has_mask, FC_TURBO_FLASH_P1 + 2);
3029+
ggml_metal_cv_set_int32(cv, dk, FC_TURBO_FLASH_P1 + 0);
3030+
ggml_metal_cv_set_int32(cv, dv, FC_TURBO_FLASH_P1 + 1);
3031+
ggml_metal_cv_set_bool(cv, has_mask, FC_TURBO_FLASH_P1 + 2);
3032+
ggml_metal_cv_set_bool(cv, k_is_turbo3, FC_TURBO_FLASH_P1 + 3);
30293033

30303034
fprintf(stderr, "TURBOFLASH: compiling P1 pipeline base='%s' has_mask=%d\n", p1_base, has_mask);
30313035
res_p1 = ggml_metal_library_compile_pipeline(lib, p1_base, p1_name, cv);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8587,6 +8587,7 @@ kernel void kernel_flash_attn_ext_vec_reduce(
85878587
constant int32_t FC_turbo_flash_p1_dk [[function_constant(FC_TURBO_FLASH_P1 + 0)]];
85888588
constant int32_t FC_turbo_flash_p1_dv [[function_constant(FC_TURBO_FLASH_P1 + 1)]];
85898589
constant bool FC_turbo_flash_p1_has_mask [[function_constant(FC_TURBO_FLASH_P1 + 2)]];
8590+
constant bool FC_turbo_flash_p1_k_is_turbo3 [[function_constant(FC_TURBO_FLASH_P1 + 3)]];
85908591

85918592
// Function constants for Pass 2
85928593
constant int32_t FC_turbo_flash_p2_dv [[function_constant(FC_TURBO_FLASH_P2 + 0)]];
@@ -8666,6 +8667,14 @@ kernel void kernel_turbo_flash_p1(
86668667
v_cb[i] = float(turbo_centroids_3bit_h[i]);
86678668
}
86688669

8670+
// K codebook — same centroids, only loaded when K is turbo3
8671+
float k_cb[8];
8672+
if (FC_turbo_flash_p1_k_is_turbo3) {
8673+
for (int i = 0; i < 8; i++) {
8674+
k_cb[i] = float(turbo_centroids_3bit_h[i]);
8675+
}
8676+
}
8677+
86698678
// ====== Online softmax state — all in registers ======
86708679
float m_state = -INFINITY;
86718680
float l_state = 0.0f;
@@ -8697,19 +8706,42 @@ kernel void kernel_turbo_flash_p1(
86978706
}
86988707

86998708
// --- Dequant K and compute Q·K score ---
8700-
// K is q8_0: 32 elements per block, DK/32 blocks per row.
87018709
// Each lane computes partial dot for its interleaved dims, then simd_sum.
8702-
device const block_q8_0 * k_row = (device const block_q8_0 *)(k_base + t * args.nb11);
8703-
87048710
float dot_partial = 0.0f;
8705-
for (short i = 0; i < DK_PER_LANE; i++) {
8706-
const int d = (int)lane + i * 32;
8707-
if (d >= DK) break;
87088711

8709-
// Which q8_0 block and offset within it
8710-
const int qb = d / 32; // block index
8711-
const int qj = d % 32; // element within block
8712-
dot_partial += q_vals[i] * (float)k_row[qb].qs[qj] * (float)k_row[qb].d;
8712+
if (FC_turbo_flash_p1_k_is_turbo3) {
8713+
// K is turbo3_0: same struct as V — norm, qs[], signs[]
8714+
device const block_turbo3_0 * k_row = (device const block_turbo3_0 *)(k_base + t * args.nb11);
8715+
const float k_norm = float(k_row[0].norm);
8716+
8717+
for (short i = 0; i < DK_PER_LANE; i++) {
8718+
const int d = (int)lane + i * 32;
8719+
if (d >= DK) break;
8720+
8721+
const int qs_byte = d / 4;
8722+
const int qs_shift = (d % 4) * 2;
8723+
const uint8_t q_idx = (k_row[0].qs[qs_byte] >> qs_shift) & 0x03;
8724+
8725+
const int sign_byte = d / 8;
8726+
const int sign_bit = d % 8;
8727+
const uint8_t s_bit = (k_row[0].signs[sign_byte] >> sign_bit) & 1;
8728+
8729+
const uint8_t centroid_idx = q_idx | (s_bit << 2);
8730+
dot_partial += q_vals[i] * k_cb[centroid_idx] * k_norm;
8731+
}
8732+
} else {
8733+
// K is q8_0: 32 elements per block, DK/32 blocks per row.
8734+
device const block_q8_0 * k_row = (device const block_q8_0 *)(k_base + t * args.nb11);
8735+
8736+
for (short i = 0; i < DK_PER_LANE; i++) {
8737+
const int d = (int)lane + i * 32;
8738+
if (d >= DK) break;
8739+
8740+
// Which q8_0 block and offset within it
8741+
const int qb = d / 32; // block index
8742+
const int qj = d % 32; // element within block
8743+
dot_partial += q_vals[i] * (float)k_row[qb].qs[qj] * (float)k_row[qb].d;
8744+
}
87138745
}
87148746
float score = simd_sum(dot_partial) * args.scale + mask_val;
87158747

0 commit comments

Comments
 (0)