@@ -9275,12 +9275,18 @@ void quantize_turbo3_1(device const float * src, device block_turbo3_1 & dst) {
92759275 float norm = sqrt (sum2 + 1e-12f );
92769276 dst.norm = half (norm);
92779277 float inv_norm = 1 .0f / norm;
9278+ float recon_sq = 0 .0f ;
92789279 for (int i = 0 ; i < 16 ; i++) dst.qs [i] = 0 ;
92799280 for (int i = 0 ; i < 64 ; i++) {
92809281 float val = src[i] * inv_norm;
92819282 int idx = turbo_nearest_centroid_m<4 >(val, TURBO_CENTROIDS_2BIT_M );
9283+ recon_sq += TURBO_CENTROIDS_2BIT_M [idx] * TURBO_CENTROIDS_2BIT_M [idx];
92829284 turbo_pack_bits (dst.qs , i * 2 , 2 , idx);
92839285 }
9286+ float recon_norm = sqrt (recon_sq);
9287+ if (recon_norm > 1e-10f ) {
9288+ dst.norm = half (norm / recon_norm);
9289+ }
92849290}
92859291
92869292void quantize_turbo4_1 (device const float * src, device block_turbo4_1 & dst) {
@@ -9289,12 +9295,18 @@ void quantize_turbo4_1(device const float * src, device block_turbo4_1 & dst) {
92899295 float norm = sqrt (sum2 + 1e-12f );
92909296 dst.norm = half (norm);
92919297 float inv_norm = 1 .0f / norm;
9298+ float recon_sq = 0 .0f ;
92929299 for (int i = 0 ; i < 24 ; i++) dst.qs [i] = 0 ;
92939300 for (int i = 0 ; i < 64 ; i++) {
92949301 float val = src[i] * inv_norm;
92959302 int idx = turbo_nearest_centroid_m<8 >(val, TURBO_CENTROIDS_3BIT_M );
9303+ recon_sq += TURBO_CENTROIDS_3BIT_M [idx] * TURBO_CENTROIDS_3BIT_M [idx];
92969304 turbo_pack_bits (dst.qs , i * 3 , 3 , idx);
92979305 }
9306+ float recon_norm = sqrt (recon_sq);
9307+ if (recon_norm > 1e-10f ) {
9308+ dst.norm = half (norm / recon_norm);
9309+ }
92989310}
92999311
93009312void quantize_turbo5_1 (device const float * src, device block_turbo5_1 & dst) {
@@ -9303,12 +9315,18 @@ void quantize_turbo5_1(device const float * src, device block_turbo5_1 & dst) {
93039315 float norm = sqrt (sum2 + 1e-12f );
93049316 dst.norm = half (norm);
93059317 float inv_norm = 1 .0f / norm;
9318+ float recon_sq = 0 .0f ;
93069319 for (int i = 0 ; i < 32 ; i++) dst.qs [i] = 0 ;
93079320 for (int i = 0 ; i < 64 ; i++) {
93089321 float val = src[i] * inv_norm;
93099322 int idx = turbo_nearest_centroid_m<16 >(val, TURBO_CENTROIDS_4BIT_M );
9323+ recon_sq += TURBO_CENTROIDS_4BIT_M [idx] * TURBO_CENTROIDS_4BIT_M [idx];
93109324 turbo_pack_bits (dst.qs , i * 4 , 4 , idx);
93119325 }
9326+ float recon_norm = sqrt (recon_sq);
9327+ if (recon_norm > 1e-10f ) {
9328+ dst.norm = half (norm / recon_norm);
9329+ }
93129330}
93139331
93149332void quantize_turbo6_1 (device const float * src, device block_turbo6_1 & dst) {
@@ -9317,12 +9335,18 @@ void quantize_turbo6_1(device const float * src, device block_turbo6_1 & dst) {
93179335 float norm = sqrt (sum2 + 1e-12f );
93189336 dst.norm = half (norm);
93199337 float inv_norm = 1 .0f / norm;
9338+ float recon_sq = 0 .0f ;
93209339 for (int i = 0 ; i < 40 ; i++) dst.qs [i] = 0 ;
93219340 for (int i = 0 ; i < 64 ; i++) {
93229341 float val = src[i] * inv_norm;
93239342 int idx = turbo_nearest_centroid_m<32 >(val, TURBO_CENTROIDS_5BIT_M );
9343+ recon_sq += TURBO_CENTROIDS_5BIT_M [idx] * TURBO_CENTROIDS_5BIT_M [idx];
93249344 turbo_pack_bits (dst.qs , i * 5 , 5 , idx);
93259345 }
9346+ float recon_norm = sqrt (recon_sq);
9347+ if (recon_norm > 1e-10f ) {
9348+ dst.norm = half (norm / recon_norm);
9349+ }
93269350}
93279351
93289352// RotorQuant GPU quantize functions (with Clifford rotor rotation matching CPU path)
@@ -9332,6 +9356,7 @@ void quantize_rq3_1(device const float * src, device block_rq3_1 & dst) {
93329356 float norm = sqrt (sum2 + 1e-12f );
93339357 dst.norm = half (norm);
93349358 float inv_norm = 1 .0f / norm;
9359+ float recon_sq = 0 .0f ;
93359360 float u[64 ];
93369361 for (int i = 0 ; i < 64 ; i++) u[i] = src[i] * inv_norm;
93379362 // Apply forward rotor per group of 3
@@ -9345,8 +9370,13 @@ void quantize_rq3_1(device const float * src, device block_rq3_1 & dst) {
93459370 for (int i = 0 ; i < 16 ; i++) dst.qs [i] = 0 ;
93469371 for (int i = 0 ; i < 64 ; i++) {
93479372 int idx = turbo_nearest_centroid_m<4 >(rotated[i], TURBO_CENTROIDS_2BIT_M );
9373+ recon_sq += TURBO_CENTROIDS_2BIT_M [idx] * TURBO_CENTROIDS_2BIT_M [idx];
93489374 turbo_pack_bits (dst.qs , i * 2 , 2 , idx);
93499375 }
9376+ float recon_norm = sqrt (recon_sq);
9377+ if (recon_norm > 1e-10f ) {
9378+ dst.norm = half (norm / recon_norm);
9379+ }
93509380}
93519381
93529382void quantize_rq4_1 (device const float * src, device block_rq4_1 & dst) {
@@ -9355,6 +9385,7 @@ void quantize_rq4_1(device const float * src, device block_rq4_1 & dst) {
93559385 float norm = sqrt (sum2 + 1e-12f );
93569386 dst.norm = half (norm);
93579387 float inv_norm = 1 .0f / norm;
9388+ float recon_sq = 0 .0f ;
93589389 float u[64 ];
93599390 for (int i = 0 ; i < 64 ; i++) u[i] = src[i] * inv_norm;
93609391 float rotated[64 ];
@@ -9367,8 +9398,13 @@ void quantize_rq4_1(device const float * src, device block_rq4_1 & dst) {
93679398 for (int i = 0 ; i < 24 ; i++) dst.qs [i] = 0 ;
93689399 for (int i = 0 ; i < 64 ; i++) {
93699400 int idx = turbo_nearest_centroid_m<8 >(rotated[i], TURBO_CENTROIDS_3BIT_M );
9401+ recon_sq += TURBO_CENTROIDS_3BIT_M [idx] * TURBO_CENTROIDS_3BIT_M [idx];
93709402 turbo_pack_bits (dst.qs , i * 3 , 3 , idx);
93719403 }
9404+ float recon_norm = sqrt (recon_sq);
9405+ if (recon_norm > 1e-10f ) {
9406+ dst.norm = half (norm / recon_norm);
9407+ }
93729408}
93739409
93749410void quantize_rq5_1 (device const float * src, device block_rq5_1 & dst) {
@@ -9377,6 +9413,7 @@ void quantize_rq5_1(device const float * src, device block_rq5_1 & dst) {
93779413 float norm = sqrt (sum2 + 1e-12f );
93789414 dst.norm = half (norm);
93799415 float inv_norm = 1 .0f / norm;
9416+ float recon_sq = 0 .0f ;
93809417 float u[64 ];
93819418 for (int i = 0 ; i < 64 ; i++) u[i] = src[i] * inv_norm;
93829419 float rotated[64 ];
@@ -9389,8 +9426,13 @@ void quantize_rq5_1(device const float * src, device block_rq5_1 & dst) {
93899426 for (int i = 0 ; i < 32 ; i++) dst.qs [i] = 0 ;
93909427 for (int i = 0 ; i < 64 ; i++) {
93919428 int idx = turbo_nearest_centroid_m<16 >(rotated[i], TURBO_CENTROIDS_4BIT_M );
9429+ recon_sq += TURBO_CENTROIDS_4BIT_M [idx] * TURBO_CENTROIDS_4BIT_M [idx];
93929430 turbo_pack_bits (dst.qs , i * 4 , 4 , idx);
93939431 }
9432+ float recon_norm = sqrt (recon_sq);
9433+ if (recon_norm > 1e-10f ) {
9434+ dst.norm = half (norm / recon_norm);
9435+ }
93949436}
93959437
93969438void quantize_rq6_1 (device const float * src, device block_rq6_1 & dst) {
@@ -9399,6 +9441,7 @@ void quantize_rq6_1(device const float * src, device block_rq6_1 & dst) {
93999441 float norm = sqrt (sum2 + 1e-12f );
94009442 dst.norm = half (norm);
94019443 float inv_norm = 1 .0f / norm;
9444+ float recon_sq = 0 .0f ;
94029445 float u[64 ];
94039446 for (int i = 0 ; i < 64 ; i++) u[i] = src[i] * inv_norm;
94049447 float rotated[64 ];
@@ -9411,8 +9454,13 @@ void quantize_rq6_1(device const float * src, device block_rq6_1 & dst) {
94119454 for (int i = 0 ; i < 40 ; i++) dst.qs [i] = 0 ;
94129455 for (int i = 0 ; i < 64 ; i++) {
94139456 int idx = turbo_nearest_centroid_m<32 >(rotated[i], TURBO_CENTROIDS_5BIT_M );
9457+ recon_sq += TURBO_CENTROIDS_5BIT_M [idx] * TURBO_CENTROIDS_5BIT_M [idx];
94149458 turbo_pack_bits (dst.qs , i * 5 , 5 , idx);
94159459 }
9460+ float recon_norm = sqrt (recon_sq);
9461+ if (recon_norm > 1e-10f ) {
9462+ dst.norm = half (norm / recon_norm);
9463+ }
94169464}
94179465
94189466template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread float4x4 &)>
0 commit comments