@@ -452,7 +452,7 @@ constant float turbo_mid_2bit[3] = { -0.086728f, 0.0f, 0.086728f };
452452constant float turbo_mid_3bit[7 ] = { -0 .154259f , -0 .091775f , -0 .043589f , 0 .0f , 0 .043589f , 0 .091775f , 0 .154259f };
453453
454454// Quantize 32 elements into one block_turbo3_0 (NO rotation — rotation happens
455- // at the 128-element group level in kernel_set_rows_turbo )
455+ // at the 128-element group level in kernel_set_rows_turbo3 )
456456void quantize_turbo3_0 (device const float * src, device block_turbo3_0 & dst) {
457457#pragma METAL fp math_mode(safe)
458458 // Compute norm for this 32-element sub-block
@@ -9489,12 +9489,11 @@ kernel void kernel_set_rows_q32(
94899489 }
94909490}
94919491
9492- // TurboQuant set_rows kernel — block size 128 (QK_TURBO3/QK_TURBO4)
9493- // TurboQuant SET_ROWS kernel — processes QK_TURBO3_GROUP (128) elements per iteration,
9492+ // TurboQuant3 SET_ROWS kernel — processes QK_TURBO3_GROUP (128) elements per iteration,
94949493// writes QK_TURBO3_GROUP/QK_TURBO3 (4) blocks per iteration.
94959494// The rotation operates on 128 elements, then results are split into 32-element blocks.
9496- template <typename TI , typename block_q, int QK , void (*quantize_func)(device const float *, device block_q &) >
9497- kernel void kernel_set_rows_turbo (
9495+ template <typename TI >
9496+ kernel void kernel_set_rows_turbo3 (
94989497 constant ggml_metal_kargs_set_rows & args,
94999498 device const void * src0,
95009499 device const void * src1,
@@ -9512,44 +9511,48 @@ kernel void kernel_set_rows_turbo(
95129511 const int32_t i10 = i01;
95139512 const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12 ))[0 ];
95149513
9515- device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3 );
9516- const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03 );
9514+ device block_turbo3_0 * dst_row = ( device block_turbo3_0 *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3 );
9515+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03 );
95179516
9518- // Process in groups of 4 blocks (128 elements) for rotation
9519- const int blocks_per_group = QK_TURBO3_GROUP / QK ; // 128/32 = 4
9520- const int n_groups = args.nk0 / blocks_per_group;
9517+ // Process in groups of 4 blocks (128 elements) for rotation.
9518+ // Use ceiling division so tail blocks for non-128-aligned head dims are not dropped.
9519+ const int blocks_per_group = QK_TURBO3_GROUP / QK_TURBO3 ; // 128/32 = 4
9520+ const int n_groups = (args.nk0 + blocks_per_group - 1 ) / blocks_per_group;
95219521
95229522 for (int grp = tiitg%tptg.x ; grp < n_groups; grp += tptg.x ) {
95239523 const device float * grp_src = src_row + QK_TURBO3_GROUP * grp;
95249524
9525- // Normalize and rotate the full 128-element group
9525+ // How many blocks are valid in this group (may be < 4 for tail group)
9526+ const int grp_start_block = grp * blocks_per_group;
9527+ const int grp_blocks = min (blocks_per_group, (int )args.nk0 - grp_start_block);
9528+ const int grp_elems = grp_blocks * QK_TURBO3 ;
9529+
9530+ // Normalize the valid elements, zero-pad the rest for WHT
95269531 float norm_sq = 0 .0f ;
9527- for (int j = 0 ; j < QK_TURBO3_GROUP ; j++) norm_sq += grp_src[j] * grp_src[j];
9532+ for (int j = 0 ; j < grp_elems ; j++) norm_sq += grp_src[j] * grp_src[j];
95289533 float grp_norm = sqrt (norm_sq);
95299534 float inv_norm = grp_norm > 1e-10f ? 1 .0f / grp_norm : 0 .0f ;
95309535
95319536 float x[128 ];
9532- for (int j = 0 ; j < 128 ; j++) x[j] = grp_src[j] * inv_norm;
9537+ for (int j = 0 ; j < grp_elems; j++) x[j] = grp_src[j] * inv_norm;
9538+ for (int j = grp_elems; j < 128 ; j++) x[j] = 0 .0f ; // zero-pad tail
95339539 turbo_rotate_forward (x, turbo_wht_signs1, turbo_wht_signs2);
95349540
9535- // Split into 4 blocks of 32 elements each
9536- // All blocks store the SAME group norm — centroids are in normalized space
9541+ // Split into blocks (may be fewer than 4 for tail group)
95379542 // Norm correction (ported from @spiritbuun's CUDA implementation):
9538- // Accumulate ||centroid_vector||^2 across all 128 elements, then store
9539- // grp_norm / ||centroid_vector|| instead of raw grp_norm. This makes
9540- // dequantized vectors have the exact original L2 norm at zero decode cost.
9543+ // Store grp_norm / ||centroid_vector|| so dequant has exact original L2 norm.
95419544 float recon_norm_sq = 0 .0f ;
95429545
9543- for (int b = 0 ; b < blocks_per_group ; b++) {
9544- device block_q & blk = dst_row[grp * blocks_per_group + b];
9545- const int off = b * QK ;
9546+ for (int b = 0 ; b < grp_blocks ; b++) {
9547+ device block_turbo3_0 & blk = dst_row[grp_start_block + b];
9548+ const int off = b * QK_TURBO3 ;
95469549
9547- for (int j = 0 ; j < QK / 4 ; j++) blk.qs [j] = 0 ;
9548- for (int j = 0 ; j < QK / 8 ; j++) blk.signs [j] = 0 ;
9550+ for (int j = 0 ; j < QK_TURBO3 / 4 ; j++) blk.qs [j] = 0 ;
9551+ for (int j = 0 ; j < QK_TURBO3 / 8 ; j++) blk.signs [j] = 0 ;
95499552
9550- // Quantize rotated values to 3-bit centroids
9551- for (int j = 0 ; j < QK ; j++) {
9552- float rv = x[off + j]; // rotated, normalized value
9553+ // Quantize rotated values to 3-bit centroids (split: 2-bit low in qs, 1-bit high in signs)
9554+ for (int j = 0 ; j < QK_TURBO3 ; j++) {
9555+ float rv = x[off + j];
95539556 uint8_t idx;
95549557 if (rv < turbo_mid_3bit[0 ]) idx = 0 ;
95559558 else if (rv < turbo_mid_3bit[1 ]) idx = 1 ;
@@ -9563,18 +9566,110 @@ kernel void kernel_set_rows_turbo(
95639566 blk.qs [j / 4 ] |= (idx & 0x3 ) << ((j % 4 ) * 2 );
95649567 if (idx & 0x4 ) blk.signs [j / 8 ] |= (1 << (j % 8 ));
95659568
9566- // Accumulate centroid reconstruction norm for norm correction
95679569 float c = turbo_centroids_3bit[idx];
95689570 recon_norm_sq += c * c;
95699571 }
95709572 }
95719573
95729574 // Norm correction: store corrected norm so dequant(x) has exact original L2 norm.
9573- // Zero decode cost — dequant already multiplies by stored norm.
95749575 float recon_norm = sqrt (recon_norm_sq);
95759576 float corrected_norm = (recon_norm > 1e-10f ) ? grp_norm / recon_norm : grp_norm;
9576- for (int b = 0 ; b < blocks_per_group; b++) {
9577- dst_row[grp * blocks_per_group + b].norm = half (corrected_norm);
9577+ for (int b = 0 ; b < grp_blocks; b++) {
9578+ dst_row[grp_start_block + b].norm = half (corrected_norm);
9579+ }
9580+ }
9581+ }
9582+
9583+ // TurboQuant4 SET_ROWS kernel — processes 128 elements per block (QK_TURBO4).
9584+ // Turbo4 = 3-bit PolarQuant + 1-bit QJL residual correction.
9585+ // Unlike turbo3 which splits 128-element groups into 4x32-element blocks,
9586+ // turbo4 uses a single 128-element block with packed 3-bit indices + QJL signs.
9587+ template <typename TI >
9588+ kernel void kernel_set_rows_turbo4 (
9589+ constant ggml_metal_kargs_set_rows & args,
9590+ device const void * src0,
9591+ device const void * src1,
9592+ device float * dst,
9593+ uint3 tgpig[[threadgroup_position_in_grid]] ,
9594+ uint tiitg[[thread_index_in_threadgroup]] ,
9595+ uint3 tptg [[threads_per_threadgroup]] ) {
9596+ const int32_t i03 = tgpig.z ;
9597+ const int32_t i02 = tgpig.y ;
9598+ const int32_t i12 = i03%args.ne12 ;
9599+ const int32_t i11 = i02%args.ne11 ;
9600+ const int32_t i01 = tgpig.x *tptg.y + tiitg/tptg.x ;
9601+ if (i01 >= args.ne01 ) return ;
9602+
9603+ const int32_t i10 = i01;
9604+ const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12 ))[0 ];
9605+
9606+ device block_turbo4_0 * dst_row = ( device block_turbo4_0 *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3 );
9607+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03 );
9608+
9609+ // Each block is one 128-element group
9610+ const int n_blocks = args.nk0 ; // nk0 = ne0 / QK_TURBO4, already in block units
9611+
9612+ for (int blk_idx = tiitg%tptg.x ; blk_idx < n_blocks; blk_idx += tptg.x ) {
9613+ const device float * blk_src = src_row + QK_TURBO4 * blk_idx;
9614+ device block_turbo4_0 & blk = dst_row[blk_idx];
9615+
9616+ // Step 1: Compute norm + normalize
9617+ float norm_sq = 0 .0f ;
9618+ for (int j = 0 ; j < QK_TURBO4 ; j++) norm_sq += blk_src[j] * blk_src[j];
9619+ float grp_norm = sqrt (norm_sq);
9620+ float inv_norm = grp_norm > 1e-10f ? 1 .0f / grp_norm : 0 .0f ;
9621+ blk.norm = half (grp_norm);
9622+
9623+ float x[128 ];
9624+ for (int j = 0 ; j < 128 ; j++) x[j] = blk_src[j] * inv_norm;
9625+ float normalized[128 ];
9626+ for (int j = 0 ; j < 128 ; j++) normalized[j] = x[j];
9627+
9628+ // Step 2: WHT rotate in-place
9629+ turbo_rotate_forward (x, turbo_wht_signs1, turbo_wht_signs2);
9630+
9631+ // Step 3: 3-bit PolarQuant quantization — packed 3-bit indices
9632+ for (int j = 0 ; j < QK_TURBO4 * 3 / 8 ; j++) blk.qs [j] = 0 ;
9633+ for (int j = 0 ; j < QK_TURBO4 / 8 ; j++) blk.signs [j] = 0 ;
9634+
9635+ float recon[128 ];
9636+ for (int j = 0 ; j < 128 ; j++) {
9637+ float val = x[j];
9638+ uint8_t idx;
9639+ if (val < turbo_mid_3bit[0 ]) idx = 0 ;
9640+ else if (val < turbo_mid_3bit[1 ]) idx = 1 ;
9641+ else if (val < turbo_mid_3bit[2 ]) idx = 2 ;
9642+ else if (val < turbo_mid_3bit[3 ]) idx = 3 ;
9643+ else if (val < turbo_mid_3bit[4 ]) idx = 4 ;
9644+ else if (val < turbo_mid_3bit[5 ]) idx = 5 ;
9645+ else if (val < turbo_mid_3bit[6 ]) idx = 6 ;
9646+ else idx = 7 ;
9647+ recon[j] = turbo_centroids_3bit[idx];
9648+
9649+ // Pack 3-bit index (may span byte boundary)
9650+ int bit_offset = j * 3 ;
9651+ int byte_idx = bit_offset / 8 ;
9652+ int bit_pos = bit_offset % 8 ;
9653+ blk.qs [byte_idx] |= (uint8_t )((idx & 0x7 ) << bit_pos);
9654+ if (bit_pos > 5 && byte_idx + 1 < QK_TURBO4 * 3 / 8 ) {
9655+ blk.qs [byte_idx + 1 ] |= (uint8_t )((idx & 0x7 ) >> (8 - bit_pos));
9656+ }
9657+ }
9658+
9659+ // Step 4: Compute residual and its norm
9660+ float rnorm_sq = 0 .0f ;
9661+ for (int j = 0 ; j < 128 ; j++) {
9662+ x[j] = normalized[j] - recon[j]; // residual in x buffer
9663+ rnorm_sq += x[j] * x[j];
9664+ }
9665+ blk.rnorm = half (sqrt (rnorm_sq));
9666+
9667+ // Step 5: QJL — WHT rotate residual, store sign bits
9668+ turbo_rotate_forward (x, turbo_qjl_wht_signs1, turbo_qjl_wht_signs2);
9669+ for (int i = 0 ; i < 128 ; i++) {
9670+ if (x[i] >= 0 .0f ) {
9671+ blk.signs [i / 8 ] |= (1 << (i % 8 ));
9672+ }
95789673 }
95799674 }
95809675}
@@ -10381,13 +10476,14 @@ template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kerne
1038110476template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t , block_iq4_nl, quantize_iq4_nl>;
1038210477template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t , block_iq4_nl, quantize_iq4_nl>;
1038310478
10384- // TurboQuant set_rows instantiations (block size 128)
10385- typedef decltype (kernel_set_rows_turbo<int64_t , block_turbo3_0, QK_TURBO3 , quantize_turbo3_0>) set_rows_turbo_t;
10479+ // TurboQuant set_rows instantiations — separate turbo3 and turbo4 kernels
10480+ typedef decltype (kernel_set_rows_turbo3<int64_t >) set_rows_turbo3_t;
10481+ typedef decltype (kernel_set_rows_turbo4<int64_t >) set_rows_turbo4_t;
1038610482
10387- template [[host_name("kernel_set_rows_turbo3_i64")]] kernel set_rows_turbo_t kernel_set_rows_turbo <int64_t , block_turbo3_0, QK_TURBO3 , quantize_turbo3_0 >;
10388- template [[host_name("kernel_set_rows_turbo3_i32")]] kernel set_rows_turbo_t kernel_set_rows_turbo <int32_t , block_turbo3_0, QK_TURBO3 , quantize_turbo3_0 >;
10389- template [[host_name("kernel_set_rows_turbo4_i64")]] kernel set_rows_turbo_t kernel_set_rows_turbo <int64_t , block_turbo4_0, QK_TURBO4 , quantize_turbo4_0 >;
10390- template [[host_name("kernel_set_rows_turbo4_i32")]] kernel set_rows_turbo_t kernel_set_rows_turbo <int32_t , block_turbo4_0, QK_TURBO4 , quantize_turbo4_0 >;
10483+ template [[host_name("kernel_set_rows_turbo3_i64")]] kernel set_rows_turbo3_t kernel_set_rows_turbo3 <int64_t >;
10484+ template [[host_name("kernel_set_rows_turbo3_i32")]] kernel set_rows_turbo3_t kernel_set_rows_turbo3 <int32_t >;
10485+ template [[host_name("kernel_set_rows_turbo4_i64")]] kernel set_rows_turbo4_t kernel_set_rows_turbo4 <int64_t >;
10486+ template [[host_name("kernel_set_rows_turbo4_i32")]] kernel set_rows_turbo4_t kernel_set_rows_turbo4 <int32_t >;
1039110487
1039210488//
1039310489// matrix-matrix multiplication
0 commit comments