diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3a21c725223..cedd7132cfa 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4502,7 +4502,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO2_0], "set_rows_turbo2_0" #itype, set_rows_turbo2_0 ## itype ## _len, set_rows_turbo2_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO3_0], "set_rows_turbo3_0" #itype, set_rows_turbo3_0 ## itype ## _len, set_rows_turbo3_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO4_0], "set_rows_turbo4_0" #itype, set_rows_turbo4_0 ## itype ## _len, set_rows_turbo4_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TQ4_1S], "set_rows_tq4_1s" #itype, set_rows_tq4_1s ## itype ## _len, set_rows_tq4_1s ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); SET_ROWS(_i32) @@ -10258,7 +10260,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SET_ROWS: { uint32_t ne = ggml_nelements(src0); - if (dst->type == GGML_TYPE_TURBO3_0) { + if (dst->type == GGML_TYPE_TURBO2_0 || + dst->type == GGML_TYPE_TURBO3_0 || + dst->type == GGML_TYPE_TURBO4_0) { ne = ne / 128; } else if (dst->type == GGML_TYPE_TQ4_1S) { ne = ne / 32; @@ -15653,7 +15657,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_TURBO2_0: case GGML_TYPE_TURBO3_0: + case GGML_TYPE_TURBO4_0: case GGML_TYPE_TQ4_1S: return true; default: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index f734c75c4a9..480de55fb85 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -5,7 +5,7 @@ #extension GL_KHR_shader_subgroup_shuffle : enable #include "types.glsl" -#if defined(SET_ROWS) && defined(DATA_A_TURBO3_0) +#if defined(SET_ROWS) && (defined(DATA_A_TURBO2_0) || defined(DATA_A_TURBO3_0) || defined(DATA_A_TURBO4_0)) layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; const uint BLOCK_SIZE = 128; #elif defined(SET_ROWS) && QUANT_K == 1 @@ -469,6 +469,245 @@ void main() { data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm / rn) : gnrm); } } +#elif defined(SET_ROWS) && defined(DATA_A_TURBO2_0) +// Mirror of the TURBO3_0 block above, adapted for turbo2 (4 centroids, +// 2-bit pack, no signs byte). WHT tables and reduction structure are +// identical (QK = 128 for both). +const float TS1_T2[128] = float[128]( + -1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, -1, + -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1, + -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, 1, + 1, -1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, 1, -1, + -1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1, + 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, 1 +); +const float TS2_T2[128] = float[128]( + 1, 1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, + 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, 1, + -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, + -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1 +); +const float TINV_T2 = 0.08838834764831845; // 1 / sqrt(128) +// Lloyd-Max centroids for N(0, 1/128), 4 levels (matches CENTROIDS_2BIT in C ref) +const float TC2[4] = float[4](-0.133462, -0.039994, 0.039994, 0.133462); +// Midpoints between adjacent centroids +const float TM2[3] = float[3](-0.086728, 0.0, 0.086728); + +shared float wht_t2[128]; +shared float sg_acc_t2[16]; +shared float gnrm_t2; + +void main() { + const uint t = gl_LocalInvocationID.x; + const uint g = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint gpr = p.ne00 / 128; + + if (gpr == 0) return; + if (g >= p.ne / 128) return; + + uint tmp = g; + const uint ig = tmp % gpr; tmp /= gpr; + const uint i01 = tmp % p.ne01; tmp /= p.ne01; + const uint i02 = tmp % p.ne12; + const uint i03 = tmp / p.ne12; + + const uint sb = src0_idx(ig * 128, i01, i02, i03) + get_aoffset(); + const uint i1 = data_i[src1_idx(i01, fastmod(i02, p.ne11), fastmod(i03, p.ne12), 0) + get_boffset()] DATA_I_SWIZZLE; + const uint db = dst_idx(ig, i1, i02, i03) + get_doffset(); + + wht_t2[t] = data_s[sb + t]; + barrier(); + + float v2 = wht_t2[t] * wht_t2[t]; + v2 = subgroupAdd(v2); + if (gl_SubgroupInvocationID == 0) sg_acc_t2[gl_SubgroupID] = v2; + barrier(); + if (t == 0) { + float total = 0.0; + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t2[w]; + gnrm_t2 = sqrt(total); + } + barrier(); + + wht_t2[t] *= (gnrm_t2 > 1e-10) ? (1.0 / gnrm_t2) : 0.0; + barrier(); + + wht_t2[t] *= TS1_T2[t]; + barrier(); + + [[unroll]] for (uint h = 1; h < 128; h *= 2) { + if ((t % (2 * h)) < h) { + float a = wht_t2[t]; + float b = wht_t2[t + h]; + wht_t2[t] = a + b; + wht_t2[t + h] = a - b; + } + barrier(); + } + + float rv = wht_t2[t] * TINV_T2 * TS2_T2[t]; + + // Quantize to nearest of 4 centroids (2-bit index, no signs byte) + uint idx = rv < TM2[0] ? 0u : rv < TM2[1] ? 1u : rv < TM2[2] ? 2u : 3u; + + // Pack qs: 4 elements per byte (full 2-bit each, no high bit) + uint sg_lane = gl_SubgroupInvocationID; + uint qs_byte = 0u; + [[unroll]] for (uint k = 0; k < 4; k++) { + uint contrib = subgroupShuffle(idx & 0x3u, (sg_lane & ~3u) + k); + qs_byte |= contrib << (k * 2u); + } + if (sg_lane % 4u == 0u) { + data_q[db].qs[t / 4u] = uint8_t(qs_byte); + } + + // Reconstruction norm via subgroup reduction + float rc = TC2[idx] * TC2[idx]; + rc = subgroupAdd(rc); + if (sg_lane == 0u) sg_acc_t2[gl_SubgroupID] = rc; + barrier(); + if (t == 0u) { + float total = 0.0; + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t2[w]; + float rn = sqrt(total); + data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm_t2 / rn) : gnrm_t2); + } +} + +#elif defined(SET_ROWS) && defined(DATA_A_TURBO4_0) +// Mirror of the TURBO3_0 block above, adapted for turbo4 (16 centroids, +// 4-bit nibble pack, no signs byte). WHT tables and reduction structure +// are identical (QK = 128 for both). The block struct keeps a reserved +// rnorm field for ABI parity with the legacy 3-bit + QJL layout. +const float TS1_T4[128] = float[128]( + -1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, -1, + -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1, + -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, 1, + 1, -1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, 1, -1, + -1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1, + 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, 1 +); +const float TS2_T4[128] = float[128]( + 1, 1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, + 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, 1, + -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, + -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1 +); +const float TINV_T4 = 0.08838834764831845; // 1 / sqrt(128) +// Lloyd-Max centroids for N(0, 1/128), 16 levels (matches CENTROIDS_4BIT in C ref) +const float TC4[16] = float[16]( + -0.173926, -0.117195, -0.089527, -0.068756, + -0.051262, -0.035597, -0.020989, -0.006938, + 0.006938, 0.020989, 0.035597, 0.051262, + 0.068756, 0.089527, 0.117195, 0.173926 +); +// 15 midpoints between adjacent centroids +const float TM4[15] = float[15]( + -0.145561, -0.103361, -0.079142, -0.060009, + -0.043430, -0.028293, -0.013964, 0.0, + 0.013964, 0.028293, 0.043430, 0.060009, + 0.079142, 0.103361, 0.145561 +); + +shared float wht_t4[128]; +shared float sg_acc_t4[16]; +shared float gnrm_t4; + +void main() { + const uint t = gl_LocalInvocationID.x; + const uint g = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint gpr = p.ne00 / 128; + + if (gpr == 0) return; + if (g >= p.ne / 128) return; + + uint tmp = g; + const uint ig = tmp % gpr; tmp /= gpr; + const uint i01 = tmp % p.ne01; tmp /= p.ne01; + const uint i02 = tmp % p.ne12; + const uint i03 = tmp / p.ne12; + + const uint sb = src0_idx(ig * 128, i01, i02, i03) + get_aoffset(); + const uint i1 = data_i[src1_idx(i01, fastmod(i02, p.ne11), fastmod(i03, p.ne12), 0) + get_boffset()] DATA_I_SWIZZLE; + const uint db = dst_idx(ig, i1, i02, i03) + get_doffset(); + + wht_t4[t] = data_s[sb + t]; + barrier(); + + float v2 = wht_t4[t] * wht_t4[t]; + v2 = subgroupAdd(v2); + if (gl_SubgroupInvocationID == 0) sg_acc_t4[gl_SubgroupID] = v2; + barrier(); + if (t == 0) { + float total = 0.0; + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t4[w]; + gnrm_t4 = sqrt(total); + } + barrier(); + + wht_t4[t] *= (gnrm_t4 > 1e-10) ? (1.0 / gnrm_t4) : 0.0; + barrier(); + + wht_t4[t] *= TS1_T4[t]; + barrier(); + + [[unroll]] for (uint h = 1; h < 128; h *= 2) { + if ((t % (2 * h)) < h) { + float a = wht_t4[t]; + float b = wht_t4[t + h]; + wht_t4[t] = a + b; + wht_t4[t + h] = a - b; + } + barrier(); + } + + float rv = wht_t4[t] * TINV_T4 * TS2_T4[t]; + + // Quantize to nearest of 16 centroids (4-bit index, no signs byte) + uint idx = 0u; + [[unroll]] for (uint i = 0; i < 15; i++) { + if (rv >= TM4[i]) idx = i + 1u; + } + + // Pack qs: 2 elements per byte (4-bit nibble each) + uint sg_lane = gl_SubgroupInvocationID; + uint pair_low = subgroupShuffle(idx & 0xFu, sg_lane & ~1u); + uint pair_high = subgroupShuffle(idx & 0xFu, (sg_lane & ~1u) + 1u); + uint qs_byte = pair_low | (pair_high << 4u); + if (sg_lane % 2u == 0u) { + data_q[db].qs[t / 2u] = uint8_t(qs_byte); + } + + // Reset rnorm field (reserved in 4-bit mode) + if (t == 0u) { + data_q[db].rnorm = float16_t(0.0); + } + + // Reconstruction norm via subgroup reduction + float rc = TC4[idx] * TC4[idx]; + rc = subgroupAdd(rc); + if (sg_lane == 0u) sg_acc_t4[gl_SubgroupID] = rc; + barrier(); + if (t == 0u) { + float total = 0.0; + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t4[w]; + float rn = sqrt(total); + data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm_t4 / rn) : gnrm_t4); + } +} + #elif defined(SET_ROWS) && defined(DATA_A_TQ4_1S) void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 10f079d2e42..c386d300841 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -1747,6 +1747,36 @@ struct block_turbo3_0 #define A_TYPE block_turbo3_0 #endif +#define QUANT_K_TURBO2_0 128 +#define QUANT_R_TURBO2_0 1 +struct block_turbo2_0 +{ + float16_t norm; + uint8_t qs[32]; // 2-bit centroid indices (4 per byte), 128/4 = 32 bytes +}; +#if defined(DATA_A_TURBO2_0) +#define QUANT_K QUANT_K_TURBO2_0 +#define QUANT_R QUANT_R_TURBO2_0 +#define QUANT_AUXF 1 +#define A_TYPE block_turbo2_0 +#endif + +#define QUANT_K_TURBO4_0 128 +#define QUANT_R_TURBO4_0 1 +struct block_turbo4_0 +{ + float16_t norm; + float16_t rnorm; // reserved in 4-bit mode (kept for ABI parity with legacy) + uint8_t qs[64]; // 4-bit centroid indices, nibble-packed (2 per byte), 128/2 = 64 bytes +}; +#if defined(DATA_A_TURBO4_0) +#define QUANT_K QUANT_K_TURBO4_0 +#define QUANT_R QUANT_R_TURBO4_0 +#define QUANT_AUXF 1 +#define A_TYPE block_turbo4_0 +#endif + + #define QUANT_K_TQ4_1S 32 #define QUANT_R_TQ4_1S 1 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index a7a2fc70f53..e3e7952b93a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -786,7 +786,7 @@ void process_shaders() { // tq4_1s copy-from-quant only; copy-to-quant requires WHT forward (handled in SET_ROWS path) string_to_spv("cpy_tq4_1s_f32", "copy_from_quant.comp", {{"DATA_A_TQ4_1S", "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo3_0", "tq4_1s"}) { + for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo2_0", "turbo3_0", "turbo4_0", "tq4_1s"}) { string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); }