Skip to content

Commit 9e7cbb4

Browse files
TheTomclaude
authored andcommitted
vulkan: add SET_ROWS support for turbo2_0 and turbo4_0 (ggml-org#50)
Mirror of @apollosenvy's turbo3_0 Vulkan SET_ROWS port (PR ggml-org#33 + ggml-org#87) to the other two turbo types. Reported by @dpblnt in ggml-org#50 with a clean matrix on RX 9060 XT showing turbo3 V works on Vulkan but turbo2/turbo4 V abort with: pre-allocated tensor (cache_v_l*) in a buffer (Vulkan0) that cannot run the operation (SET_ROWS) at llama_context::sched_reserve() time, before any compute runs. Mechanical port across 4 files: - vulkan-shaders/types.glsl: block_turbo2_0 + block_turbo4_0 struct declarations matching the C side (ggml-common.h). - vulkan-shaders/copy_to_quant.comp: SET_ROWS quantize main() blocks for turbo2 (4 centroids, 2-bit pack, no signs byte) and turbo4 (16 centroids, 4-bit nibble pack, no signs byte). WHT setup and reduction structure identical to turbo3 (QK = 128 across all three). Centroid + midpoint tables ported from CENTROIDS_2BIT and CENTROIDS_4BIT in ggml-turbo-quant.c. - vulkan-shaders/vulkan-shaders-gen.cpp: turbo2_0 and turbo4_0 added to the set_rows iteration list at line ~789. - ggml-vulkan.cpp: SET_ROWS pipeline registrations + supports_op switch + dispatch element-count all extended with TURBO2_0 and TURBO4_0 cases. ## Verified on llvmpipe Vulkan (CPU software, AMD MI300X cloud droplet) Patched ggml-vulkan.cpp temporarily during repro to allow llvmpipe (normally filtered out as eCpu); patch reverted before commit. The SET_ROWS abort is a backend-capability check at graph build time so it fires regardless of GPU vs CPU Vulkan backend. | ctk / ctv | tg16 (t/s) | status | |-------------------|-----------:|---------------| | q4_0 / q4_0 | 17.68 | baseline | | q4_0 / turbo3 | 5.91 | already worked| | q4_0 / turbo4 | 6.14 | was aborting | | q4_0 / turbo2 | 5.65 | was aborting | llvmpipe perf numbers are not meaningful (CPU-emulated Vulkan); they are reported here only to confirm the abort is gone and the kernels run end-to-end without divergence. ## Needs GPU validation Cannot validate GPU shader correctness on the droplet (MI300X SR-IOV VF does not expose itself to RADV/amdvlk on cloud). Specifically: - Subgroup shuffle / ballot behavior on real GPU subgroup sizes - Shader compilation under non-llvmpipe Vulkan drivers - PPL / quality on the actual quantization math @dpblnt @apollosenvy if either of you has cycles, would appreciate a quick rebuild on RDNA Vulkan (gfx1100/gfx1200) to confirm: 1. The SET_ROWS abort that triggered ggml-org#50 is gone 2. Output coherence on turbo4 V (not garbage tokens) 3. PPL stays in the expected ballpark vs the CUDA / Metal implementations of the same quants Closes ggml-org#50. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 89be49e commit 9e7cbb4

4 files changed

Lines changed: 278 additions & 3 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4566,7 +4566,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
45664566
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
45674567
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
45684568
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4569+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO2_0], "set_rows_turbo2_0" #itype, set_rows_turbo2_0 ## itype ## _len, set_rows_turbo2_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
45694570
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO3_0], "set_rows_turbo3_0" #itype, set_rows_turbo3_0 ## itype ## _len, set_rows_turbo3_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4571+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO4_0], "set_rows_turbo4_0" #itype, set_rows_turbo4_0 ## itype ## _len, set_rows_turbo4_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
45704572
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TQ4_1S], "set_rows_tq4_1s" #itype, set_rows_tq4_1s ## itype ## _len, set_rows_tq4_1s ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
45714573

45724574
SET_ROWS(_i32)
@@ -10360,7 +10362,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
1036010362
case GGML_OP_SET_ROWS:
1036110363
{
1036210364
uint32_t ne = ggml_nelements(src0);
10363-
if (dst->type == GGML_TYPE_TURBO3_0) {
10365+
if (dst->type == GGML_TYPE_TURBO2_0 ||
10366+
dst->type == GGML_TYPE_TURBO3_0 ||
10367+
dst->type == GGML_TYPE_TURBO4_0) {
1036410368
ne = ne / 128;
1036510369
} else if (dst->type == GGML_TYPE_TQ4_1S) {
1036610370
ne = ne / 32;
@@ -15830,7 +15834,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1583015834
case GGML_TYPE_Q5_1:
1583115835
case GGML_TYPE_Q8_0:
1583215836
case GGML_TYPE_IQ4_NL:
15837+
case GGML_TYPE_TURBO2_0:
1583315838
case GGML_TYPE_TURBO3_0:
15839+
case GGML_TYPE_TURBO4_0:
1583415840
case GGML_TYPE_TQ4_1S:
1583515841
return true;
1583615842
default:

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 240 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#extension GL_KHR_shader_subgroup_shuffle : enable
66
#include "types.glsl"
77

8-
#if defined(SET_ROWS) && defined(DATA_A_TURBO3_0)
8+
#if defined(SET_ROWS) && (defined(DATA_A_TURBO2_0) || defined(DATA_A_TURBO3_0) || defined(DATA_A_TURBO4_0))
99
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
1010
const uint BLOCK_SIZE = 128;
1111
#elif defined(SET_ROWS) && QUANT_K == 1
@@ -469,6 +469,245 @@ void main() {
469469
data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm / rn) : gnrm);
470470
}
471471
}
472+
#elif defined(SET_ROWS) && defined(DATA_A_TURBO2_0)
473+
// Mirror of the TURBO3_0 block above, adapted for turbo2 (4 centroids,
474+
// 2-bit pack, no signs byte). WHT tables and reduction structure are
475+
// identical (QK = 128 for both).
476+
const float TS1_T2[128] = float[128](
477+
-1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1,
478+
1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, -1,
479+
-1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1,
480+
1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1,
481+
-1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, 1,
482+
1, -1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, 1, -1,
483+
-1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1,
484+
1, -1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, 1
485+
);
486+
const float TS2_T2[128] = float[128](
487+
1, 1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1, -1,
488+
1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, 1, 1,
489+
1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1,
490+
1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1,
491+
1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, 1,
492+
-1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1,
493+
1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1,
494+
-1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1
495+
);
496+
const float TINV_T2 = 0.08838834764831845; // 1 / sqrt(128)
497+
// Lloyd-Max centroids for N(0, 1/128), 4 levels (matches CENTROIDS_2BIT in C ref)
498+
const float TC2[4] = float[4](-0.133462, -0.039994, 0.039994, 0.133462);
499+
// Midpoints between adjacent centroids
500+
const float TM2[3] = float[3](-0.086728, 0.0, 0.086728);
501+
502+
shared float wht_t2[128];
503+
shared float sg_acc_t2[16];
504+
shared float gnrm_t2;
505+
506+
void main() {
507+
const uint t = gl_LocalInvocationID.x;
508+
const uint g = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
509+
const uint gpr = p.ne00 / 128;
510+
511+
if (gpr == 0) return;
512+
if (g >= p.ne / 128) return;
513+
514+
uint tmp = g;
515+
const uint ig = tmp % gpr; tmp /= gpr;
516+
const uint i01 = tmp % p.ne01; tmp /= p.ne01;
517+
const uint i02 = tmp % p.ne12;
518+
const uint i03 = tmp / p.ne12;
519+
520+
const uint sb = src0_idx(ig * 128, i01, i02, i03) + get_aoffset();
521+
const uint i1 = data_i[src1_idx(i01, fastmod(i02, p.ne11), fastmod(i03, p.ne12), 0) + get_boffset()] DATA_I_SWIZZLE;
522+
const uint db = dst_idx(ig, i1, i02, i03) + get_doffset();
523+
524+
wht_t2[t] = data_s[sb + t];
525+
barrier();
526+
527+
float v2 = wht_t2[t] * wht_t2[t];
528+
v2 = subgroupAdd(v2);
529+
if (gl_SubgroupInvocationID == 0) sg_acc_t2[gl_SubgroupID] = v2;
530+
barrier();
531+
if (t == 0) {
532+
float total = 0.0;
533+
for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t2[w];
534+
gnrm_t2 = sqrt(total);
535+
}
536+
barrier();
537+
538+
wht_t2[t] *= (gnrm_t2 > 1e-10) ? (1.0 / gnrm_t2) : 0.0;
539+
barrier();
540+
541+
wht_t2[t] *= TS1_T2[t];
542+
barrier();
543+
544+
[[unroll]] for (uint h = 1; h < 128; h *= 2) {
545+
if ((t % (2 * h)) < h) {
546+
float a = wht_t2[t];
547+
float b = wht_t2[t + h];
548+
wht_t2[t] = a + b;
549+
wht_t2[t + h] = a - b;
550+
}
551+
barrier();
552+
}
553+
554+
float rv = wht_t2[t] * TINV_T2 * TS2_T2[t];
555+
556+
// Quantize to nearest of 4 centroids (2-bit index, no signs byte)
557+
uint idx = rv < TM2[0] ? 0u : rv < TM2[1] ? 1u : rv < TM2[2] ? 2u : 3u;
558+
559+
// Pack qs: 4 elements per byte (full 2-bit each, no high bit)
560+
uint sg_lane = gl_SubgroupInvocationID;
561+
uint qs_byte = 0u;
562+
[[unroll]] for (uint k = 0; k < 4; k++) {
563+
uint contrib = subgroupShuffle(idx & 0x3u, (sg_lane & ~3u) + k);
564+
qs_byte |= contrib << (k * 2u);
565+
}
566+
if (sg_lane % 4u == 0u) {
567+
data_q[db].qs[t / 4u] = uint8_t(qs_byte);
568+
}
569+
570+
// Reconstruction norm via subgroup reduction
571+
float rc = TC2[idx] * TC2[idx];
572+
rc = subgroupAdd(rc);
573+
if (sg_lane == 0u) sg_acc_t2[gl_SubgroupID] = rc;
574+
barrier();
575+
if (t == 0u) {
576+
float total = 0.0;
577+
for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t2[w];
578+
float rn = sqrt(total);
579+
data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm_t2 / rn) : gnrm_t2);
580+
}
581+
}
582+
583+
#elif defined(SET_ROWS) && defined(DATA_A_TURBO4_0)
584+
// Mirror of the TURBO3_0 block above, adapted for turbo4 (16 centroids,
585+
// 4-bit nibble pack, no signs byte). WHT tables and reduction structure
586+
// are identical (QK = 128 for both). The block struct keeps a reserved
587+
// rnorm field for ABI parity with the legacy 3-bit + QJL layout.
588+
const float TS1_T4[128] = float[128](
589+
-1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1,
590+
1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, -1,
591+
-1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1,
592+
1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1,
593+
-1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, 1,
594+
1, -1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, 1, -1,
595+
-1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1,
596+
1, -1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, 1
597+
);
598+
const float TS2_T4[128] = float[128](
599+
1, 1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1, -1,
600+
1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, 1, 1,
601+
1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1,
602+
1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1,
603+
1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, 1,
604+
-1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1,
605+
1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1,
606+
-1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1
607+
);
608+
const float TINV_T4 = 0.08838834764831845; // 1 / sqrt(128)
609+
// Lloyd-Max centroids for N(0, 1/128), 16 levels (matches CENTROIDS_4BIT in C ref)
610+
const float TC4[16] = float[16](
611+
-0.173926, -0.117195, -0.089527, -0.068756,
612+
-0.051262, -0.035597, -0.020989, -0.006938,
613+
0.006938, 0.020989, 0.035597, 0.051262,
614+
0.068756, 0.089527, 0.117195, 0.173926
615+
);
616+
// 15 midpoints between adjacent centroids
617+
const float TM4[15] = float[15](
618+
-0.145561, -0.103361, -0.079142, -0.060009,
619+
-0.043430, -0.028293, -0.013964, 0.0,
620+
0.013964, 0.028293, 0.043430, 0.060009,
621+
0.079142, 0.103361, 0.145561
622+
);
623+
624+
shared float wht_t4[128];
625+
shared float sg_acc_t4[16];
626+
shared float gnrm_t4;
627+
628+
void main() {
629+
const uint t = gl_LocalInvocationID.x;
630+
const uint g = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
631+
const uint gpr = p.ne00 / 128;
632+
633+
if (gpr == 0) return;
634+
if (g >= p.ne / 128) return;
635+
636+
uint tmp = g;
637+
const uint ig = tmp % gpr; tmp /= gpr;
638+
const uint i01 = tmp % p.ne01; tmp /= p.ne01;
639+
const uint i02 = tmp % p.ne12;
640+
const uint i03 = tmp / p.ne12;
641+
642+
const uint sb = src0_idx(ig * 128, i01, i02, i03) + get_aoffset();
643+
const uint i1 = data_i[src1_idx(i01, fastmod(i02, p.ne11), fastmod(i03, p.ne12), 0) + get_boffset()] DATA_I_SWIZZLE;
644+
const uint db = dst_idx(ig, i1, i02, i03) + get_doffset();
645+
646+
wht_t4[t] = data_s[sb + t];
647+
barrier();
648+
649+
float v2 = wht_t4[t] * wht_t4[t];
650+
v2 = subgroupAdd(v2);
651+
if (gl_SubgroupInvocationID == 0) sg_acc_t4[gl_SubgroupID] = v2;
652+
barrier();
653+
if (t == 0) {
654+
float total = 0.0;
655+
for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t4[w];
656+
gnrm_t4 = sqrt(total);
657+
}
658+
barrier();
659+
660+
wht_t4[t] *= (gnrm_t4 > 1e-10) ? (1.0 / gnrm_t4) : 0.0;
661+
barrier();
662+
663+
wht_t4[t] *= TS1_T4[t];
664+
barrier();
665+
666+
[[unroll]] for (uint h = 1; h < 128; h *= 2) {
667+
if ((t % (2 * h)) < h) {
668+
float a = wht_t4[t];
669+
float b = wht_t4[t + h];
670+
wht_t4[t] = a + b;
671+
wht_t4[t + h] = a - b;
672+
}
673+
barrier();
674+
}
675+
676+
float rv = wht_t4[t] * TINV_T4 * TS2_T4[t];
677+
678+
// Quantize to nearest of 16 centroids (4-bit index, no signs byte)
679+
uint idx = 0u;
680+
[[unroll]] for (uint i = 0; i < 15; i++) {
681+
if (rv >= TM4[i]) idx = i + 1u;
682+
}
683+
684+
// Pack qs: 2 elements per byte (4-bit nibble each)
685+
uint sg_lane = gl_SubgroupInvocationID;
686+
uint pair_low = subgroupShuffle(idx & 0xFu, sg_lane & ~1u);
687+
uint pair_high = subgroupShuffle(idx & 0xFu, (sg_lane & ~1u) + 1u);
688+
uint qs_byte = pair_low | (pair_high << 4u);
689+
if (sg_lane % 2u == 0u) {
690+
data_q[db].qs[t / 2u] = uint8_t(qs_byte);
691+
}
692+
693+
// Reset rnorm field (reserved in 4-bit mode)
694+
if (t == 0u) {
695+
data_q[db].rnorm = float16_t(0.0);
696+
}
697+
698+
// Reconstruction norm via subgroup reduction
699+
float rc = TC4[idx] * TC4[idx];
700+
rc = subgroupAdd(rc);
701+
if (sg_lane == 0u) sg_acc_t4[gl_SubgroupID] = rc;
702+
barrier();
703+
if (t == 0u) {
704+
float total = 0.0;
705+
for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t4[w];
706+
float rn = sqrt(total);
707+
data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm_t4 / rn) : gnrm_t4);
708+
}
709+
}
710+
472711
#elif defined(SET_ROWS) && defined(DATA_A_TQ4_1S)
473712

474713
void main() {

ggml/src/ggml-vulkan/vulkan-shaders/types.glsl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,36 @@ struct block_turbo3_0
17471747
#define A_TYPE block_turbo3_0
17481748
#endif
17491749

1750+
#define QUANT_K_TURBO2_0 128
1751+
#define QUANT_R_TURBO2_0 1
1752+
struct block_turbo2_0
1753+
{
1754+
float16_t norm;
1755+
uint8_t qs[32]; // 2-bit centroid indices (4 per byte), 128/4 = 32 bytes
1756+
};
1757+
#if defined(DATA_A_TURBO2_0)
1758+
#define QUANT_K QUANT_K_TURBO2_0
1759+
#define QUANT_R QUANT_R_TURBO2_0
1760+
#define QUANT_AUXF 1
1761+
#define A_TYPE block_turbo2_0
1762+
#endif
1763+
1764+
#define QUANT_K_TURBO4_0 128
1765+
#define QUANT_R_TURBO4_0 1
1766+
struct block_turbo4_0
1767+
{
1768+
float16_t norm;
1769+
float16_t rnorm; // reserved in 4-bit mode (kept for ABI parity with legacy)
1770+
uint8_t qs[64]; // 4-bit centroid indices, nibble-packed (2 per byte), 128/2 = 64 bytes
1771+
};
1772+
#if defined(DATA_A_TURBO4_0)
1773+
#define QUANT_K QUANT_K_TURBO4_0
1774+
#define QUANT_R QUANT_R_TURBO4_0
1775+
#define QUANT_AUXF 1
1776+
#define A_TYPE block_turbo4_0
1777+
#endif
1778+
1779+
17501780
#define QUANT_K_TQ4_1S 32
17511781
#define QUANT_R_TQ4_1S 1
17521782

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ void process_shaders() {
783783
// tq4_1s copy-from-quant only; copy-to-quant requires WHT forward (handled in SET_ROWS path)
784784
string_to_spv("cpy_tq4_1s_f32", "copy_from_quant.comp", {{"DATA_A_TQ4_1S", "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
785785

786-
for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo3_0", "tq4_1s"}) {
786+
for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo2_0", "turbo3_0", "turbo4_0", "tq4_1s"}) {
787787
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
788788
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
789789
}

0 commit comments

Comments
 (0)