Skip to content

Commit 8590cbf

Browse files
authored
Merge pull request #62 from Titaniumtown/pr/turboquant-vulkan-work
vulkan: fix and complete turbo3 KV cache support
2 parents 157cb85 + 6a29b58 commit 8590cbf

9 files changed

Lines changed: 488 additions & 93 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,7 @@ struct vk_device_struct {
824824
vk_pipeline pipeline_timestep_embedding_f32;
825825
vk_pipeline pipeline_conv_transpose_1d_f32;
826826
vk_pipeline pipeline_pool2d_f32;
827+
vk_pipeline pipeline_turbo_wht;
827828
vk_pipeline pipeline_rwkv_wkv6_f32;
828829
vk_pipeline pipeline_rwkv_wkv7_f32;
829830
// [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
@@ -3447,11 +3448,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
34473448
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
34483449
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
34493450
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
3451+
CREATE_FA(GGML_TYPE_TURBO3_0, turbo3_0, FA_SCALAR, )
34503452
} else {
34513453
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
34523454
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
34533455
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
34543456
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
3457+
CREATE_FA(GGML_TYPE_TURBO3_0, turbo3_0, FA_SCALAR, _fp32)
34553458
}
34563459
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
34573460
if (device->coopmat1_fa_support) {
@@ -4187,7 +4190,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
41874190
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
41884191
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
41894192
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
4190-
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TURBO3_0], "dequant_turbo3_0", dequant_turbo3_0_len, dequant_turbo3_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
4193+
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TURBO3_0], "dequant_turbo3_0", dequant_turbo3_0_len, dequant_turbo3_0_data, "main", 2, 5 * sizeof(uint32_t), {128, 1, 1}, {}, 1);
4194+
4195+
// TurboQuant WHT
4196+
ggml_vk_create_pipeline(device, device->pipeline_turbo_wht, "turbo_wht", turbo_wht_len, turbo_wht_data, "main", 2, 3 * sizeof(uint32_t), {128, 1, 1}, {}, 1);
41914197

41924198
// get_rows
41934199
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -4307,15 +4313,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
43074313
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
43084314
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
43094315
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4310-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_rte_len, cpy_f32_turbo3_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
43114316
} else {
43124317
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
43134318
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
43144319
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
43154320
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
43164321
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
43174322
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4318-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_len, cpy_f32_turbo3_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
43194323
}
43204324

43214325
#define SET_ROWS(itype, rte) \
@@ -7278,6 +7282,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
72787282
case GGML_TYPE_Q5_1:
72797283
case GGML_TYPE_Q8_0:
72807284
case GGML_TYPE_IQ4_NL:
7285+
case GGML_TYPE_TURBO3_0:
72817286
return ctx->device->pipeline_cpy_quant_f32[src->type];
72827287
default:
72837288
break;
@@ -10063,7 +10068,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
1006310068
case GGML_OP_SET_ROWS:
1006410069
{
1006510070
uint32_t ne = ggml_nelements(src0);
10066-
if (ggml_is_quantized(dst->type)) {
10071+
if (dst->type == GGML_TYPE_TURBO3_0) {
10072+
ne = ne / 128;
10073+
} else if (ggml_is_quantized(dst->type)) {
1006710074
// quants run 32 threads each doing QUANT_K elements
1006810075
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
1006910076
} else {
@@ -10834,6 +10841,32 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
1083410841
});
1083510842
}
1083610843

10844+
static void ggml_vk_turbo_wht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10845+
int direction, group_size;
10846+
memcpy(&direction, dst->op_params + 0, sizeof(int));
10847+
memcpy(&group_size, dst->op_params + sizeof(int), sizeof(int));
10848+
struct { uint32_t ne; uint32_t direction; uint32_t group_size; } pc = {
10849+
(uint32_t)ggml_nelements(src0), (uint32_t)direction, (uint32_t)group_size,
10850+
};
10851+
vk_pipeline pipeline = ctx->device->pipeline_turbo_wht;
10852+
GGML_ASSERT(pipeline != nullptr);
10853+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10854+
vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0, false);
10855+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
10856+
// Spread workgroups across Y/Z to stay within maxComputeWorkGroupCount[0].
10857+
// elements[0] / group_size = wg0; each row of 512 workgroups uses one Y slice.
10858+
const uint32_t n_groups = pc.ne / (uint32_t)group_size;
10859+
std::array<uint32_t, 3> elements;
10860+
if (n_groups > 262144) {
10861+
elements = { 512 * (uint32_t)group_size, 512, CEIL_DIV(n_groups, 262144) };
10862+
} else if (n_groups > 512) {
10863+
elements = { 512 * (uint32_t)group_size, CEIL_DIV(n_groups, 512), 1 };
10864+
} else {
10865+
elements = { pc.ne, 1, 1 };
10866+
}
10867+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, elements);
10868+
}
10869+
1083710870
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1083810871
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
1083910872
}
@@ -13015,6 +13048,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1301513048
case GGML_OP_SET_ROWS:
1301613049
ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node);
1301713050

13051+
break;
13052+
case GGML_OP_TURBO_WHT:
13053+
ggml_vk_turbo_wht(ctx, compute_ctx, src0, node);
13054+
1301813055
break;
1301913056
case GGML_OP_SILU_BACK:
1302013057
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node);
@@ -15338,6 +15375,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1533815375
case GGML_TYPE_F32:
1533915376
case GGML_TYPE_Q4_0:
1534015377
case GGML_TYPE_Q8_0:
15378+
case GGML_TYPE_TURBO3_0:
1534115379
// supported in scalar and coopmat2 paths
1534215380
break;
1534315381
case GGML_TYPE_Q4_1:
@@ -15441,7 +15479,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1544115479
case GGML_TYPE_Q5_1:
1544215480
case GGML_TYPE_Q8_0:
1544315481
case GGML_TYPE_IQ4_NL:
15444-
case GGML_TYPE_TURBO3_0:
1544515482
return true;
1544615483
default:
1544715484
break;
@@ -15710,6 +15747,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1571015747
ggml_is_contiguous(op->src[1]) &&
1571115748
ggml_is_contiguous(op));
1571215749
}
15750+
case GGML_OP_TURBO_WHT:
15751+
return op->src[0]->type == GGML_TYPE_F32 && op->src[0]->ne[0] % 128 == 0;
1571315752
default:
1571415753
return false;
1571515754
}

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 157 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
#version 450
22

3+
#extension GL_KHR_shader_subgroup_arithmetic : enable
4+
#extension GL_KHR_shader_subgroup_ballot : enable
5+
#extension GL_KHR_shader_subgroup_shuffle : enable
36
#include "rte.glsl"
47
#include "types.glsl"
58

6-
#if defined(SET_ROWS) && QUANT_K == 1
9+
#if defined(SET_ROWS) && defined(DATA_A_TURBO3_0)
10+
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
11+
const uint BLOCK_SIZE = 128;
12+
#elif defined(SET_ROWS) && QUANT_K == 1
713
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
814
const uint BLOCK_SIZE = 512;
915
#else
@@ -185,62 +191,67 @@ void quantize(uint dst_idx, uint src_idx)
185191
#endif
186192

187193
#if defined(DATA_A_TURBO3_0)
188-
void quantize(uint dst_idx, uint src_idx)
189-
{
190-
const float centroids[8] = float[8](
191-
-0.190685, -0.117832, -0.065717, -0.021460,
192-
0.021460, 0.065717, 0.117832, 0.190685
193-
);
194-
const float midpoints[7] = float[7](
195-
-0.154259, -0.091775, -0.043589, 0.0, 0.043589, 0.091775, 0.154259
196-
);
197-
198-
// Compute L2 norm
199-
float norm_sq = 0.0;
200-
[[unroll]] for (int j = 0; j < 32; ++j) {
201-
float v = data_s[src_idx + j];
202-
norm_sq += v * v;
203-
}
204-
float norm = sqrt(norm_sq);
205-
float inv_norm = (norm > 1e-10) ? (1.0 / norm) : 0.0;
206-
207-
// Clear output
208-
[[unroll]] for (int j = 0; j < 8; ++j) data_q[dst_idx].qs[j] = uint8_t(0);
209-
[[unroll]] for (int j = 0; j < 4; ++j) data_q[dst_idx].signs[j] = uint8_t(0);
210-
211-
// Accumulate centroid reconstruction norm for correction
212-
float recon_norm_sq = 0.0;
213-
214-
// Quantize each element
215-
[[unroll]] for (int j = 0; j < 32; ++j) {
216-
float val = data_s[src_idx + j] * inv_norm;
217-
218-
// Find nearest centroid via midpoint comparison
219-
uint idx = 0;
220-
if (val < midpoints[0]) idx = 0;
221-
else if (val < midpoints[1]) idx = 1;
222-
else if (val < midpoints[2]) idx = 2;
223-
else if (val < midpoints[3]) idx = 3;
224-
else if (val < midpoints[4]) idx = 4;
225-
else if (val < midpoints[5]) idx = 5;
226-
else if (val < midpoints[6]) idx = 6;
227-
else idx = 7;
228-
229-
recon_norm_sq += centroids[idx] * centroids[idx];
230-
231-
// Pack: low 2 bits to qs, high 1 bit to signs
232-
uint low2 = idx & 0x3;
233-
uint hi1 = (idx >> 2) & 0x1;
234-
data_q[dst_idx].qs[j / 4] |= uint8_t(low2 << ((j % 4) * 2));
235-
data_q[dst_idx].signs[j / 8] |= uint8_t(hi1 << (j % 8));
236-
}
194+
const float TS1[128] = float[128](
195+
-1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1,
196+
1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, -1,
197+
-1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1,
198+
1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1,
199+
-1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, 1,
200+
1, -1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, 1, -1,
201+
-1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1,
202+
1, -1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, 1
203+
);
204+
205+
const float TS2[128] = float[128](
206+
1, 1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1, -1,
207+
1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, 1, 1,
208+
1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1,
209+
1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1,
210+
1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, 1,
211+
-1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1,
212+
1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1,
213+
-1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1
214+
);
215+
216+
const float TINV = 0.08838834764831845; // 1 / sqrt(128)
217+
218+
const float TC[8] = float[8](
219+
-0.190685, -0.117832, -0.065717, -0.021460,
220+
0.021460, 0.065717, 0.117832, 0.190685
221+
);
222+
223+
const float TM[7] = float[7](
224+
-0.154259, -0.091775, -0.043589,
225+
0.0,
226+
0.043589, 0.091775, 0.154259
227+
);
237228

238-
// Norm correction: scale so reconstruction matches original norm
239-
float recon_norm = sqrt(recon_norm_sq);
240-
float corrected_norm = (recon_norm > 1e-10) ? (norm / recon_norm) : norm;
241-
data_q[dst_idx].norm = float16_t(corrected_norm);
229+
#if defined(SET_ROWS)
230+
231+
shared float wht[128];
232+
shared float sg_acc[16];
233+
shared float gnrm;
234+
235+
void quantize_block(uint b, uint o) {
236+
[[unroll]] for (int j = 0; j < 32; ++j) data_q[b].qs[j] = uint8_t(0);
237+
[[unroll]] for (int j = 0; j < 16; ++j) data_q[b].signs[j] = uint8_t(0);
238+
float rs = 0.0;
239+
[[unroll]] for (int j = 0; j < 128; ++j) {
240+
float v = wht[o + j];
241+
uint i = v < TM[0] ? 0 : v < TM[1] ? 1 : v < TM[2] ? 2 : v < TM[3] ? 3 :
242+
v < TM[4] ? 4 : v < TM[5] ? 5 : v < TM[6] ? 6 : 7;
243+
rs += TC[i] * TC[i];
244+
uint low2 = i & 0x3;
245+
uint hi1 = (i >> 2) & 0x1;
246+
data_q[b].qs[j / 4] |= uint8_t(low2 << ((j % 4) * 2));
247+
data_q[b].signs[j / 8] |= uint8_t(hi1 << (j % 8));
248+
}
249+
float rn = sqrt(rs);
250+
data_q[b].norm = float16_t((rn > 1e-10) ? (gnrm / rn) : gnrm);
242251
}
243-
#endif
252+
253+
#endif // defined(SET_ROWS)
254+
#endif // defined(DATA_A_TURBO3_0)
244255

245256
#if defined(DATA_A_IQ4_NL)
246257
uint best_index(float x) {
@@ -304,7 +315,97 @@ void quantize(uint dst_idx, uint src_idx)
304315
}
305316
#endif
306317

307-
#if defined(SET_ROWS)
318+
#if defined(SET_ROWS) && defined(DATA_A_TURBO3_0)
319+
void main() {
320+
const uint t = gl_LocalInvocationID.x;
321+
const uint g = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
322+
const uint gpr = p.ne00 / 128;
323+
324+
if (gpr == 0) return;
325+
if (g >= p.ne / 128) return;
326+
327+
uint tmp = g;
328+
const uint ig = tmp % gpr; tmp /= gpr;
329+
const uint i01 = tmp % p.ne01; tmp /= p.ne01;
330+
const uint i02 = tmp % p.ne12;
331+
const uint i03 = tmp / p.ne12;
332+
333+
const uint sb = src0_idx(ig * 128, i01, i02, i03) + get_aoffset();
334+
const uint i1 = data_i[src1_idx(i01, fastmod(i02, p.ne11), fastmod(i03, p.ne12), 0) + get_boffset()] DATA_I_SWIZZLE;
335+
const uint db = dst_idx(ig, i1, i02, i03) + get_doffset();
336+
337+
// Step 1: load into shared memory
338+
wht[t] = data_s[sb + t];
339+
barrier();
340+
341+
// Step 2: L2 norm via subgroup reduction
342+
float v2 = wht[t] * wht[t];
343+
v2 = subgroupAdd(v2);
344+
if (gl_SubgroupInvocationID == 0) sg_acc[gl_SubgroupID] = v2;
345+
barrier();
346+
if (t == 0) {
347+
float total = 0.0;
348+
for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc[w];
349+
gnrm = sqrt(total);
350+
}
351+
barrier();
352+
353+
// Step 3: normalize, then apply forward WHT: signs1 -> butterfly -> signs2
354+
wht[t] *= (gnrm > 1e-10) ? (1.0 / gnrm) : 0.0;
355+
barrier();
356+
357+
wht[t] *= TS1[t];
358+
barrier();
359+
360+
[[unroll]] for (uint h = 1; h < 128; h *= 2) {
361+
if ((t % (2 * h)) < h) {
362+
float a = wht[t];
363+
float b = wht[t + h];
364+
wht[t] = a + b;
365+
wht[t + h] = a - b;
366+
}
367+
barrier();
368+
}
369+
370+
// Step 5: apply signs2 + scaling
371+
float rv = wht[t] * TINV * TS2[t];
372+
373+
// Step 6: quantize -- all 128 threads participate
374+
uint idx = rv < TM[0] ? 0u : rv < TM[1] ? 1u : rv < TM[2] ? 2u : rv < TM[3] ? 3u :
375+
rv < TM[4] ? 4u : rv < TM[5] ? 5u : rv < TM[6] ? 6u : 7u;
376+
377+
// Pack qs: 4 elements per byte via subgroup shuffle
378+
uint sg_lane = gl_SubgroupInvocationID;
379+
uint my_low2 = idx & 0x3u;
380+
uint qs_byte = 0u;
381+
[[unroll]] for (uint k = 0; k < 4; k++) {
382+
uint contrib = subgroupShuffle(my_low2, (sg_lane & ~3u) + k);
383+
qs_byte |= contrib << (k * 2u);
384+
}
385+
if (sg_lane % 4u == 0u) {
386+
data_q[db].qs[t / 4u] = uint8_t(qs_byte);
387+
}
388+
389+
// Pack signs: 8 elements per byte via subgroup ballot
390+
uvec4 ballot = subgroupBallot(((idx >> 2u) & 1u) != 0u);
391+
if (sg_lane % 8u == 0u) {
392+
uint local_byte = sg_lane / 8u;
393+
data_q[db].signs[t / 8u] = uint8_t((ballot.x >> (local_byte * 8u)) & 0xFFu);
394+
}
395+
396+
// Step 7: reconstruction norm via subgroup reduction
397+
float rc = TC[idx] * TC[idx];
398+
rc = subgroupAdd(rc);
399+
if (sg_lane == 0u) sg_acc[gl_SubgroupID] = rc;
400+
barrier();
401+
if (t == 0u) {
402+
float total = 0.0;
403+
for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc[w];
404+
float rn = sqrt(total);
405+
data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm / rn) : gnrm);
406+
}
407+
}
408+
#elif defined(SET_ROWS)
308409

309410
void main() {
310411
#ifdef NEEDS_INIT_IQ_SHMEM

0 commit comments

Comments
 (0)