Skip to content

Commit 65ed372

Browse files
seanraschclaude
andcommitted
fix: turbo4 SET_ROWS corruption, tail-block truncation, constant coupling (Issue PrismML-Eng#29)
Three bugs from the block-size-32 refactor: 1. kernel_set_rows_turbo hardcoded turbo3 packing for turbo4 — split into separate kernel_set_rows_turbo3 and kernel_set_rows_turbo4 kernels. turbo4 now correctly does 3-bit PolarQuant + QJL residual correction. 2. Integer division in n_groups = nk0 / blocks_per_group silently dropped tail blocks for non-128-aligned head dims (e.g. dk=192). Added ceiling division with tail-group bounds checking in turbo3, and GGML_ASSERT in WHT dispatch to catch non-128-aligned tensors. 3. TURBO_D constant was semantically coupled to QK_TURBO4 — replaced with TURBO_ROT_DIM (= QK_TURBO3_GROUP) and added static_assert that QK_TURBO4 == QK_TURBO3_GROUP to guard against future drift. Closes PrismML-Eng#29 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5811aa5 commit 65ed372

3 files changed

Lines changed: 160 additions & 56 deletions

File tree

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,7 @@ int ggml_metal_op_turbo_wht(ggml_metal_op_t ctx, int idx) {
16591659
memcpy(&direction, op->op_params, sizeof(int));
16601660

16611661
const int64_t n_elements = ggml_nelements(op->src[0]);
1662+
GGML_ASSERT(n_elements % 128 == 0 && "TURBO_WHT requires head_dim to be a multiple of 128");
16621663
const int64_t n_groups = n_elements / 128;
16631664

16641665
auto pipeline = ggml_metal_library_get_pipeline_turbo_wht(lib);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 132 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ constant float turbo_mid_2bit[3] = { -0.086728f, 0.0f, 0.086728f };
452452
constant float turbo_mid_3bit[7] = { -0.154259f, -0.091775f, -0.043589f, 0.0f, 0.043589f, 0.091775f, 0.154259f };
453453

454454
// Quantize 32 elements into one block_turbo3_0 (NO rotation — rotation happens
455-
// at the 128-element group level in kernel_set_rows_turbo)
455+
// at the 128-element group level in kernel_set_rows_turbo3)
456456
void quantize_turbo3_0(device const float * src, device block_turbo3_0 & dst) {
457457
#pragma METAL fp math_mode(safe)
458458
// Compute norm for this 32-element sub-block
@@ -9489,12 +9489,11 @@ kernel void kernel_set_rows_q32(
94899489
}
94909490
}
94919491

9492-
// TurboQuant set_rows kernel — block size 128 (QK_TURBO3/QK_TURBO4)
9493-
// TurboQuant SET_ROWS kernel — processes QK_TURBO3_GROUP (128) elements per iteration,
9492+
// TurboQuant3 SET_ROWS kernel — processes QK_TURBO3_GROUP (128) elements per iteration,
94949493
// writes QK_TURBO3_GROUP/QK_TURBO3 (4) blocks per iteration.
94959494
// The rotation operates on 128 elements, then results are split into 32-element blocks.
9496-
template<typename TI, typename block_q, int QK, void (*quantize_func)(device const float *, device block_q &)>
9497-
kernel void kernel_set_rows_turbo(
9495+
template<typename TI>
9496+
kernel void kernel_set_rows_turbo3(
94989497
constant ggml_metal_kargs_set_rows & args,
94999498
device const void * src0,
95009499
device const void * src1,
@@ -9512,44 +9511,48 @@ kernel void kernel_set_rows_turbo(
95129511
const int32_t i10 = i01;
95139512
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
95149513

9515-
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
9516-
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
9514+
device block_turbo3_0 * dst_row = ( device block_turbo3_0 *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
9515+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
95179516

9518-
// Process in groups of 4 blocks (128 elements) for rotation
9519-
const int blocks_per_group = QK_TURBO3_GROUP / QK; // 128/32 = 4
9520-
const int n_groups = args.nk0 / blocks_per_group;
9517+
// Process in groups of 4 blocks (128 elements) for rotation.
9518+
// Use ceiling division so tail blocks for non-128-aligned head dims are not dropped.
9519+
const int blocks_per_group = QK_TURBO3_GROUP / QK_TURBO3; // 128/32 = 4
9520+
const int n_groups = (args.nk0 + blocks_per_group - 1) / blocks_per_group;
95219521

95229522
for (int grp = tiitg%tptg.x; grp < n_groups; grp += tptg.x) {
95239523
const device float * grp_src = src_row + QK_TURBO3_GROUP * grp;
95249524

9525-
// Normalize and rotate the full 128-element group
9525+
// How many blocks are valid in this group (may be < 4 for tail group)
9526+
const int grp_start_block = grp * blocks_per_group;
9527+
const int grp_blocks = min(blocks_per_group, (int)args.nk0 - grp_start_block);
9528+
const int grp_elems = grp_blocks * QK_TURBO3;
9529+
9530+
// Normalize the valid elements, zero-pad the rest for WHT
95269531
float norm_sq = 0.0f;
9527-
for (int j = 0; j < QK_TURBO3_GROUP; j++) norm_sq += grp_src[j] * grp_src[j];
9532+
for (int j = 0; j < grp_elems; j++) norm_sq += grp_src[j] * grp_src[j];
95289533
float grp_norm = sqrt(norm_sq);
95299534
float inv_norm = grp_norm > 1e-10f ? 1.0f / grp_norm : 0.0f;
95309535

95319536
float x[128];
9532-
for (int j = 0; j < 128; j++) x[j] = grp_src[j] * inv_norm;
9537+
for (int j = 0; j < grp_elems; j++) x[j] = grp_src[j] * inv_norm;
9538+
for (int j = grp_elems; j < 128; j++) x[j] = 0.0f; // zero-pad tail
95339539
turbo_rotate_forward(x, turbo_wht_signs1, turbo_wht_signs2);
95349540

9535-
// Split into 4 blocks of 32 elements each
9536-
// All blocks store the SAME group norm — centroids are in normalized space
9541+
// Split into blocks (may be fewer than 4 for tail group)
95379542
// Norm correction (ported from @spiritbuun's CUDA implementation):
9538-
// Accumulate ||centroid_vector||^2 across all 128 elements, then store
9539-
// grp_norm / ||centroid_vector|| instead of raw grp_norm. This makes
9540-
// dequantized vectors have the exact original L2 norm at zero decode cost.
9543+
// Store grp_norm / ||centroid_vector|| so dequant has exact original L2 norm.
95419544
float recon_norm_sq = 0.0f;
95429545

9543-
for (int b = 0; b < blocks_per_group; b++) {
9544-
device block_q & blk = dst_row[grp * blocks_per_group + b];
9545-
const int off = b * QK;
9546+
for (int b = 0; b < grp_blocks; b++) {
9547+
device block_turbo3_0 & blk = dst_row[grp_start_block + b];
9548+
const int off = b * QK_TURBO3;
95469549

9547-
for (int j = 0; j < QK / 4; j++) blk.qs[j] = 0;
9548-
for (int j = 0; j < QK / 8; j++) blk.signs[j] = 0;
9550+
for (int j = 0; j < QK_TURBO3 / 4; j++) blk.qs[j] = 0;
9551+
for (int j = 0; j < QK_TURBO3 / 8; j++) blk.signs[j] = 0;
95499552

9550-
// Quantize rotated values to 3-bit centroids
9551-
for (int j = 0; j < QK; j++) {
9552-
float rv = x[off + j]; // rotated, normalized value
9553+
// Quantize rotated values to 3-bit centroids (split: 2-bit low in qs, 1-bit high in signs)
9554+
for (int j = 0; j < QK_TURBO3; j++) {
9555+
float rv = x[off + j];
95539556
uint8_t idx;
95549557
if (rv < turbo_mid_3bit[0]) idx = 0;
95559558
else if (rv < turbo_mid_3bit[1]) idx = 1;
@@ -9563,18 +9566,110 @@ kernel void kernel_set_rows_turbo(
95639566
blk.qs[j / 4] |= (idx & 0x3) << ((j % 4) * 2);
95649567
if (idx & 0x4) blk.signs[j / 8] |= (1 << (j % 8));
95659568

9566-
// Accumulate centroid reconstruction norm for norm correction
95679569
float c = turbo_centroids_3bit[idx];
95689570
recon_norm_sq += c * c;
95699571
}
95709572
}
95719573

95729574
// Norm correction: store corrected norm so dequant(x) has exact original L2 norm.
9573-
// Zero decode cost — dequant already multiplies by stored norm.
95749575
float recon_norm = sqrt(recon_norm_sq);
95759576
float corrected_norm = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm;
9576-
for (int b = 0; b < blocks_per_group; b++) {
9577-
dst_row[grp * blocks_per_group + b].norm = half(corrected_norm);
9577+
for (int b = 0; b < grp_blocks; b++) {
9578+
dst_row[grp_start_block + b].norm = half(corrected_norm);
9579+
}
9580+
}
9581+
}
9582+
9583+
// TurboQuant4 SET_ROWS kernel — processes 128 elements per block (QK_TURBO4).
9584+
// Turbo4 = 3-bit PolarQuant + 1-bit QJL residual correction.
9585+
// Unlike turbo3 which splits 128-element groups into 4x32-element blocks,
9586+
// turbo4 uses a single 128-element block with packed 3-bit indices + QJL signs.
9587+
template<typename TI>
9588+
kernel void kernel_set_rows_turbo4(
9589+
constant ggml_metal_kargs_set_rows & args,
9590+
device const void * src0,
9591+
device const void * src1,
9592+
device float * dst,
9593+
uint3 tgpig[[threadgroup_position_in_grid]],
9594+
uint tiitg[[thread_index_in_threadgroup]],
9595+
uint3 tptg [[threads_per_threadgroup]]) {
9596+
const int32_t i03 = tgpig.z;
9597+
const int32_t i02 = tgpig.y;
9598+
const int32_t i12 = i03%args.ne12;
9599+
const int32_t i11 = i02%args.ne11;
9600+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
9601+
if (i01 >= args.ne01) return;
9602+
9603+
const int32_t i10 = i01;
9604+
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
9605+
9606+
device block_turbo4_0 * dst_row = ( device block_turbo4_0 *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
9607+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
9608+
9609+
// Each block is one 128-element group
9610+
const int n_blocks = args.nk0; // nk0 = ne0 / QK_TURBO4, already in block units
9611+
9612+
for (int blk_idx = tiitg%tptg.x; blk_idx < n_blocks; blk_idx += tptg.x) {
9613+
const device float * blk_src = src_row + QK_TURBO4 * blk_idx;
9614+
device block_turbo4_0 & blk = dst_row[blk_idx];
9615+
9616+
// Step 1: Compute norm + normalize
9617+
float norm_sq = 0.0f;
9618+
for (int j = 0; j < QK_TURBO4; j++) norm_sq += blk_src[j] * blk_src[j];
9619+
float grp_norm = sqrt(norm_sq);
9620+
float inv_norm = grp_norm > 1e-10f ? 1.0f / grp_norm : 0.0f;
9621+
blk.norm = half(grp_norm);
9622+
9623+
float x[128];
9624+
for (int j = 0; j < 128; j++) x[j] = blk_src[j] * inv_norm;
9625+
float normalized[128];
9626+
for (int j = 0; j < 128; j++) normalized[j] = x[j];
9627+
9628+
// Step 2: WHT rotate in-place
9629+
turbo_rotate_forward(x, turbo_wht_signs1, turbo_wht_signs2);
9630+
9631+
// Step 3: 3-bit PolarQuant quantization — packed 3-bit indices
9632+
for (int j = 0; j < QK_TURBO4 * 3 / 8; j++) blk.qs[j] = 0;
9633+
for (int j = 0; j < QK_TURBO4 / 8; j++) blk.signs[j] = 0;
9634+
9635+
float recon[128];
9636+
for (int j = 0; j < 128; j++) {
9637+
float val = x[j];
9638+
uint8_t idx;
9639+
if (val < turbo_mid_3bit[0]) idx = 0;
9640+
else if (val < turbo_mid_3bit[1]) idx = 1;
9641+
else if (val < turbo_mid_3bit[2]) idx = 2;
9642+
else if (val < turbo_mid_3bit[3]) idx = 3;
9643+
else if (val < turbo_mid_3bit[4]) idx = 4;
9644+
else if (val < turbo_mid_3bit[5]) idx = 5;
9645+
else if (val < turbo_mid_3bit[6]) idx = 6;
9646+
else idx = 7;
9647+
recon[j] = turbo_centroids_3bit[idx];
9648+
9649+
// Pack 3-bit index (may span byte boundary)
9650+
int bit_offset = j * 3;
9651+
int byte_idx = bit_offset / 8;
9652+
int bit_pos = bit_offset % 8;
9653+
blk.qs[byte_idx] |= (uint8_t)((idx & 0x7) << bit_pos);
9654+
if (bit_pos > 5 && byte_idx + 1 < QK_TURBO4 * 3 / 8) {
9655+
blk.qs[byte_idx + 1] |= (uint8_t)((idx & 0x7) >> (8 - bit_pos));
9656+
}
9657+
}
9658+
9659+
// Step 4: Compute residual and its norm
9660+
float rnorm_sq = 0.0f;
9661+
for (int j = 0; j < 128; j++) {
9662+
x[j] = normalized[j] - recon[j]; // residual in x buffer
9663+
rnorm_sq += x[j] * x[j];
9664+
}
9665+
blk.rnorm = half(sqrt(rnorm_sq));
9666+
9667+
// Step 5: QJL — WHT rotate residual, store sign bits
9668+
turbo_rotate_forward(x, turbo_qjl_wht_signs1, turbo_qjl_wht_signs2);
9669+
for (int i = 0; i < 128; i++) {
9670+
if (x[i] >= 0.0f) {
9671+
blk.signs[i / 8] |= (1 << (i % 8));
9672+
}
95789673
}
95799674
}
95809675
}
@@ -10381,13 +10476,14 @@ template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kerne
1038110476
template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
1038210477
template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
1038310478

10384-
// TurboQuant set_rows instantiations (block size 128)
10385-
typedef decltype(kernel_set_rows_turbo<int64_t, block_turbo3_0, QK_TURBO3, quantize_turbo3_0>) set_rows_turbo_t;
10479+
// TurboQuant set_rows instantiations — separate turbo3 and turbo4 kernels
10480+
typedef decltype(kernel_set_rows_turbo3<int64_t>) set_rows_turbo3_t;
10481+
typedef decltype(kernel_set_rows_turbo4<int64_t>) set_rows_turbo4_t;
1038610482

10387-
template [[host_name("kernel_set_rows_turbo3_i64")]] kernel set_rows_turbo_t kernel_set_rows_turbo<int64_t, block_turbo3_0, QK_TURBO3, quantize_turbo3_0>;
10388-
template [[host_name("kernel_set_rows_turbo3_i32")]] kernel set_rows_turbo_t kernel_set_rows_turbo<int32_t, block_turbo3_0, QK_TURBO3, quantize_turbo3_0>;
10389-
template [[host_name("kernel_set_rows_turbo4_i64")]] kernel set_rows_turbo_t kernel_set_rows_turbo<int64_t, block_turbo4_0, QK_TURBO4, quantize_turbo4_0>;
10390-
template [[host_name("kernel_set_rows_turbo4_i32")]] kernel set_rows_turbo_t kernel_set_rows_turbo<int32_t, block_turbo4_0, QK_TURBO4, quantize_turbo4_0>;
10483+
template [[host_name("kernel_set_rows_turbo3_i64")]] kernel set_rows_turbo3_t kernel_set_rows_turbo3<int64_t>;
10484+
template [[host_name("kernel_set_rows_turbo3_i32")]] kernel set_rows_turbo3_t kernel_set_rows_turbo3<int32_t>;
10485+
template [[host_name("kernel_set_rows_turbo4_i64")]] kernel set_rows_turbo4_t kernel_set_rows_turbo4<int64_t>;
10486+
template [[host_name("kernel_set_rows_turbo4_i32")]] kernel set_rows_turbo4_t kernel_set_rows_turbo4<int32_t>;
1039110487

1039210488
//
1039310489
// matrix-matrix multiplication

ggml/src/ggml-turbo-quant.c

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@
1919

2020
#define TURBO_SEED_ROTATION 42
2121
#define TURBO_SEED_QJL 1042
22-
#define TURBO_D 128 /* rotation group size = head_dim (independent of block size) */
2322
#define TURBO_QJL_CONST 1.2533141373155003f /* sqrt(pi/2) */
2423

24+
/* Rotation group size = QK_TURBO3_GROUP (from ggml-common.h), NOT a separate constant.
25+
* turbo4 block size (QK_TURBO4) happens to equal the rotation group size today,
26+
* but they are semantically different. Assert they match so turbo4 code can safely
27+
* use QK_TURBO4 for both array sizing and loop bounds. */
28+
static_assert(QK_TURBO4 == QK_TURBO3_GROUP,
29+
"turbo4 block size must equal rotation group size (both 128)");
30+
#define TURBO_ROT_DIM QK_TURBO3_GROUP
31+
2532
/* Optimal centroids from paper (scaled by 1/sqrt(d)) */
2633
/* 1-bit: ±sqrt(2/(pi*d)) */
2734
static const float CENTROIDS_1BIT[2] = { -0.070711f, 0.070711f }; /* for d=128 */
@@ -37,8 +44,8 @@ static const float CENTROIDS_3BIT[8] = {
3744

3845
/* ---------- rotation matrix (lazy init) ---------- */
3946

40-
static float turbo_rotation[TURBO_D * TURBO_D];
41-
static float turbo_rotation_t[TURBO_D * TURBO_D]; /* transpose */
47+
static float turbo_rotation[TURBO_ROT_DIM * TURBO_ROT_DIM];
48+
static float turbo_rotation_t[TURBO_ROT_DIM * TURBO_ROT_DIM]; /* transpose */
4249
static int turbo_rotation_initialized = 0;
4350

4451
/* Simple LCG PRNG for deterministic rotation generation */
@@ -61,11 +68,11 @@ static double turbo_prng_normal(void) {
6168
static void turbo_init_rotation(void) {
6269
if (turbo_rotation_initialized) return;
6370

64-
const int d = TURBO_D;
71+
const int d = TURBO_ROT_DIM;
6572

6673
/* Generate random Gaussian matrix */
6774
turbo_prng_seed(TURBO_SEED_ROTATION);
68-
float G[TURBO_D * TURBO_D];
75+
float G[TURBO_ROT_DIM * TURBO_ROT_DIM];
6976
for (int i = 0; i < d * d; i++) {
7077
G[i] = (float)turbo_prng_normal();
7178
}
@@ -111,14 +118,14 @@ static void turbo_init_rotation(void) {
111118

112119
/* ---------- QJL projection matrix (lazy init, seed-based) ---------- */
113120

114-
static float turbo_qjl_matrix[TURBO_D * TURBO_D];
115-
static float turbo_qjl_matrix_t[TURBO_D * TURBO_D];
121+
static float turbo_qjl_matrix[TURBO_ROT_DIM * TURBO_ROT_DIM];
122+
static float turbo_qjl_matrix_t[TURBO_ROT_DIM * TURBO_ROT_DIM];
116123
static int turbo_qjl_initialized = 0;
117124

118125
static void turbo_init_qjl(void) {
119126
if (turbo_qjl_initialized) return;
120127

121-
const int d = TURBO_D;
128+
const int d = TURBO_ROT_DIM;
122129
turbo_prng_seed(TURBO_SEED_QJL);
123130

124131
for (int i = 0; i < d * d; i++) {
@@ -235,7 +242,7 @@ void quantize_row_turbo4_0_ref(const float * GGML_RESTRICT x, block_turbo4_0 * G
235242
float norm = sqrtf(norm_sq);
236243

237244
/* Normalize */
238-
float normalized[TURBO_D];
245+
float normalized[TURBO_ROT_DIM];
239246
if (norm > 1e-10f) {
240247
const float inv = 1.0f / norm;
241248
for (int i = 0; i < d; i++) normalized[i] = src[i] * inv;
@@ -244,31 +251,31 @@ void quantize_row_turbo4_0_ref(const float * GGML_RESTRICT x, block_turbo4_0 * G
244251
}
245252

246253
/* Step 2: Rotate */
247-
float rotated[TURBO_D];
254+
float rotated[TURBO_ROT_DIM];
248255
matvec(turbo_rotation, normalized, rotated, d);
249256

250257
/* Step 3: 3-bit quantization */
251-
uint8_t indices[TURBO_D];
258+
uint8_t indices[TURBO_ROT_DIM];
252259
for (int i = 0; i < d; i++) {
253260
indices[i] = (uint8_t)nearest_centroid_3bit(rotated[i]);
254261
}
255262

256263
/* Step 4: Residual */
257-
float reconstructed[TURBO_D];
264+
float reconstructed[TURBO_ROT_DIM];
258265
for (int i = 0; i < d; i++) {
259266
reconstructed[i] = CENTROIDS_3BIT[indices[i]];
260267
}
261-
float mse_recon[TURBO_D];
268+
float mse_recon[TURBO_ROT_DIM];
262269
matvec(turbo_rotation_t, reconstructed, mse_recon, d);
263270

264-
float residual[TURBO_D];
271+
float residual[TURBO_ROT_DIM];
265272
for (int i = 0; i < d; i++) {
266273
residual[i] = normalized[i] - mse_recon[i];
267274
}
268275

269276

270277
/* Step 5: QJL */
271-
float projected[TURBO_D];
278+
float projected[TURBO_ROT_DIM];
272279
matvec(turbo_qjl_matrix, residual, projected, d);
273280

274281
/* Pack */
@@ -310,7 +317,7 @@ void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGM
310317
float norm = GGML_FP16_TO_FP32(x[block].norm);
311318

312319
/* Unpack 3-bit indices */
313-
uint8_t indices[TURBO_D];
320+
uint8_t indices[TURBO_ROT_DIM];
314321
for (int i = 0; i < d; i++) {
315322
int bit_offset = i * 3;
316323
int byte_idx = bit_offset / 8;
@@ -323,7 +330,7 @@ void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGM
323330
}
324331

325332
/* Unpack signs */
326-
float signs[TURBO_D];
333+
float signs[TURBO_ROT_DIM];
327334
for (int i = 0; i < d; i++) {
328335
signs[i] = (x[block].signs[i / 8] & (1 << (i % 8))) ? 1.0f : -1.0f;
329336
}
@@ -332,15 +339,15 @@ void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGM
332339
const float qjl_scale = TURBO_QJL_CONST / (float)d * rnorm;
333340

334341
/* PolarQuant dequant */
335-
float rotated_recon[TURBO_D];
342+
float rotated_recon[TURBO_ROT_DIM];
336343
for (int i = 0; i < d; i++) {
337344
rotated_recon[i] = CENTROIDS_3BIT[indices[i]];
338345
}
339-
float mse_recon[TURBO_D];
346+
float mse_recon[TURBO_ROT_DIM];
340347
matvec(turbo_rotation_t, rotated_recon, mse_recon, d);
341348

342349
/* QJL dequant */
343-
float qjl_recon[TURBO_D];
350+
float qjl_recon[TURBO_ROT_DIM];
344351
matvec(turbo_qjl_matrix_t, signs, qjl_recon, d);
345352
for (int i = 0; i < d; i++) {
346353
qjl_recon[i] *= qjl_scale;

0 commit comments

Comments
 (0)