Skip to content

Commit 50923d8

Browse files
committed
Fix Metal norm correction parity + add dk64 guards (Codex audit)
Metal quantize functions (quantize_turbo*_1, quantize_rq*_1) now apply the same norm correction as CPU: store original_norm/reconstruction_norm. Previously only CPU had norm correction, causing CPU/GPU mismatch. Added dk64 guards in llama-context.cpp: turbo/rq types now fail init with clear error if n_embd_head_k != 64 or n_embd_head_v != 64. Prevents silent misuse on unsupported head dimensions (e.g. dk128).
1 parent f1748ac commit 50923d8

2 files changed

Lines changed: 74 additions & 0 deletions

File tree

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9275,12 +9275,18 @@ void quantize_turbo3_1(device const float * src, device block_turbo3_1 & dst) {
92759275
float norm = sqrt(sum2 + 1e-12f);
92769276
dst.norm = half(norm);
92779277
float inv_norm = 1.0f / norm;
9278+
float recon_sq = 0.0f;
92789279
for (int i = 0; i < 16; i++) dst.qs[i] = 0;
92799280
for (int i = 0; i < 64; i++) {
92809281
float val = src[i] * inv_norm;
92819282
int idx = turbo_nearest_centroid_m<4>(val, TURBO_CENTROIDS_2BIT_M);
9283+
recon_sq += TURBO_CENTROIDS_2BIT_M[idx] * TURBO_CENTROIDS_2BIT_M[idx];
92829284
turbo_pack_bits(dst.qs, i * 2, 2, idx);
92839285
}
9286+
float recon_norm = sqrt(recon_sq);
9287+
if (recon_norm > 1e-10f) {
9288+
dst.norm = half(norm / recon_norm);
9289+
}
92849290
}
92859291

92869292
void quantize_turbo4_1(device const float * src, device block_turbo4_1 & dst) {
@@ -9289,12 +9295,18 @@ void quantize_turbo4_1(device const float * src, device block_turbo4_1 & dst) {
92899295
float norm = sqrt(sum2 + 1e-12f);
92909296
dst.norm = half(norm);
92919297
float inv_norm = 1.0f / norm;
9298+
float recon_sq = 0.0f;
92929299
for (int i = 0; i < 24; i++) dst.qs[i] = 0;
92939300
for (int i = 0; i < 64; i++) {
92949301
float val = src[i] * inv_norm;
92959302
int idx = turbo_nearest_centroid_m<8>(val, TURBO_CENTROIDS_3BIT_M);
9303+
recon_sq += TURBO_CENTROIDS_3BIT_M[idx] * TURBO_CENTROIDS_3BIT_M[idx];
92969304
turbo_pack_bits(dst.qs, i * 3, 3, idx);
92979305
}
9306+
float recon_norm = sqrt(recon_sq);
9307+
if (recon_norm > 1e-10f) {
9308+
dst.norm = half(norm / recon_norm);
9309+
}
92989310
}
92999311

93009312
void quantize_turbo5_1(device const float * src, device block_turbo5_1 & dst) {
@@ -9303,12 +9315,18 @@ void quantize_turbo5_1(device const float * src, device block_turbo5_1 & dst) {
93039315
float norm = sqrt(sum2 + 1e-12f);
93049316
dst.norm = half(norm);
93059317
float inv_norm = 1.0f / norm;
9318+
float recon_sq = 0.0f;
93069319
for (int i = 0; i < 32; i++) dst.qs[i] = 0;
93079320
for (int i = 0; i < 64; i++) {
93089321
float val = src[i] * inv_norm;
93099322
int idx = turbo_nearest_centroid_m<16>(val, TURBO_CENTROIDS_4BIT_M);
9323+
recon_sq += TURBO_CENTROIDS_4BIT_M[idx] * TURBO_CENTROIDS_4BIT_M[idx];
93109324
turbo_pack_bits(dst.qs, i * 4, 4, idx);
93119325
}
9326+
float recon_norm = sqrt(recon_sq);
9327+
if (recon_norm > 1e-10f) {
9328+
dst.norm = half(norm / recon_norm);
9329+
}
93129330
}
93139331

93149332
void quantize_turbo6_1(device const float * src, device block_turbo6_1 & dst) {
@@ -9317,12 +9335,18 @@ void quantize_turbo6_1(device const float * src, device block_turbo6_1 & dst) {
93179335
float norm = sqrt(sum2 + 1e-12f);
93189336
dst.norm = half(norm);
93199337
float inv_norm = 1.0f / norm;
9338+
float recon_sq = 0.0f;
93209339
for (int i = 0; i < 40; i++) dst.qs[i] = 0;
93219340
for (int i = 0; i < 64; i++) {
93229341
float val = src[i] * inv_norm;
93239342
int idx = turbo_nearest_centroid_m<32>(val, TURBO_CENTROIDS_5BIT_M);
9343+
recon_sq += TURBO_CENTROIDS_5BIT_M[idx] * TURBO_CENTROIDS_5BIT_M[idx];
93249344
turbo_pack_bits(dst.qs, i * 5, 5, idx);
93259345
}
9346+
float recon_norm = sqrt(recon_sq);
9347+
if (recon_norm > 1e-10f) {
9348+
dst.norm = half(norm / recon_norm);
9349+
}
93269350
}
93279351

93289352
// RotorQuant GPU quantize functions (with Clifford rotor rotation matching CPU path)
@@ -9332,6 +9356,7 @@ void quantize_rq3_1(device const float * src, device block_rq3_1 & dst) {
93329356
float norm = sqrt(sum2 + 1e-12f);
93339357
dst.norm = half(norm);
93349358
float inv_norm = 1.0f / norm;
9359+
float recon_sq = 0.0f;
93359360
float u[64];
93369361
for (int i = 0; i < 64; i++) u[i] = src[i] * inv_norm;
93379362
// Apply forward rotor per group of 3
@@ -9345,8 +9370,13 @@ void quantize_rq3_1(device const float * src, device block_rq3_1 & dst) {
93459370
for (int i = 0; i < 16; i++) dst.qs[i] = 0;
93469371
for (int i = 0; i < 64; i++) {
93479372
int idx = turbo_nearest_centroid_m<4>(rotated[i], TURBO_CENTROIDS_2BIT_M);
9373+
recon_sq += TURBO_CENTROIDS_2BIT_M[idx] * TURBO_CENTROIDS_2BIT_M[idx];
93489374
turbo_pack_bits(dst.qs, i * 2, 2, idx);
93499375
}
9376+
float recon_norm = sqrt(recon_sq);
9377+
if (recon_norm > 1e-10f) {
9378+
dst.norm = half(norm / recon_norm);
9379+
}
93509380
}
93519381

93529382
void quantize_rq4_1(device const float * src, device block_rq4_1 & dst) {
@@ -9355,6 +9385,7 @@ void quantize_rq4_1(device const float * src, device block_rq4_1 & dst) {
93559385
float norm = sqrt(sum2 + 1e-12f);
93569386
dst.norm = half(norm);
93579387
float inv_norm = 1.0f / norm;
9388+
float recon_sq = 0.0f;
93589389
float u[64];
93599390
for (int i = 0; i < 64; i++) u[i] = src[i] * inv_norm;
93609391
float rotated[64];
@@ -9367,8 +9398,13 @@ void quantize_rq4_1(device const float * src, device block_rq4_1 & dst) {
93679398
for (int i = 0; i < 24; i++) dst.qs[i] = 0;
93689399
for (int i = 0; i < 64; i++) {
93699400
int idx = turbo_nearest_centroid_m<8>(rotated[i], TURBO_CENTROIDS_3BIT_M);
9401+
recon_sq += TURBO_CENTROIDS_3BIT_M[idx] * TURBO_CENTROIDS_3BIT_M[idx];
93709402
turbo_pack_bits(dst.qs, i * 3, 3, idx);
93719403
}
9404+
float recon_norm = sqrt(recon_sq);
9405+
if (recon_norm > 1e-10f) {
9406+
dst.norm = half(norm / recon_norm);
9407+
}
93729408
}
93739409

93749410
void quantize_rq5_1(device const float * src, device block_rq5_1 & dst) {
@@ -9377,6 +9413,7 @@ void quantize_rq5_1(device const float * src, device block_rq5_1 & dst) {
93779413
float norm = sqrt(sum2 + 1e-12f);
93789414
dst.norm = half(norm);
93799415
float inv_norm = 1.0f / norm;
9416+
float recon_sq = 0.0f;
93809417
float u[64];
93819418
for (int i = 0; i < 64; i++) u[i] = src[i] * inv_norm;
93829419
float rotated[64];
@@ -9389,8 +9426,13 @@ void quantize_rq5_1(device const float * src, device block_rq5_1 & dst) {
93899426
for (int i = 0; i < 32; i++) dst.qs[i] = 0;
93909427
for (int i = 0; i < 64; i++) {
93919428
int idx = turbo_nearest_centroid_m<16>(rotated[i], TURBO_CENTROIDS_4BIT_M);
9429+
recon_sq += TURBO_CENTROIDS_4BIT_M[idx] * TURBO_CENTROIDS_4BIT_M[idx];
93929430
turbo_pack_bits(dst.qs, i * 4, 4, idx);
93939431
}
9432+
float recon_norm = sqrt(recon_sq);
9433+
if (recon_norm > 1e-10f) {
9434+
dst.norm = half(norm / recon_norm);
9435+
}
93949436
}
93959437

93969438
void quantize_rq6_1(device const float * src, device block_rq6_1 & dst) {
@@ -9399,6 +9441,7 @@ void quantize_rq6_1(device const float * src, device block_rq6_1 & dst) {
93999441
float norm = sqrt(sum2 + 1e-12f);
94009442
dst.norm = half(norm);
94019443
float inv_norm = 1.0f / norm;
9444+
float recon_sq = 0.0f;
94029445
float u[64];
94039446
for (int i = 0; i < 64; i++) u[i] = src[i] * inv_norm;
94049447
float rotated[64];
@@ -9411,8 +9454,13 @@ void quantize_rq6_1(device const float * src, device block_rq6_1 & dst) {
94119454
for (int i = 0; i < 40; i++) dst.qs[i] = 0;
94129455
for (int i = 0; i < 64; i++) {
94139456
int idx = turbo_nearest_centroid_m<32>(rotated[i], TURBO_CENTROIDS_5BIT_M);
9457+
recon_sq += TURBO_CENTROIDS_5BIT_M[idx] * TURBO_CENTROIDS_5BIT_M[idx];
94149458
turbo_pack_bits(dst.qs, i * 5, 5, idx);
94159459
}
9460+
float recon_norm = sqrt(recon_sq);
9461+
if (recon_norm > 1e-10f) {
9462+
dst.norm = half(norm / recon_norm);
9463+
}
94169464
}
94179465

94189466
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>

src/llama-context.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2923,6 +2923,22 @@ llama_context_params llama_context_default_params() {
29232923
llama_context * llama_init_from_model(
29242924
llama_model * model,
29252925
llama_context_params params) {
2926+
auto is_turbo_or_rq_type = [](ggml_type type) {
2927+
switch (type) {
2928+
case GGML_TYPE_TURBO3_1:
2929+
case GGML_TYPE_TURBO4_1:
2930+
case GGML_TYPE_TURBO5_1:
2931+
case GGML_TYPE_TURBO6_1:
2932+
case GGML_TYPE_RQ3_1:
2933+
case GGML_TYPE_RQ4_1:
2934+
case GGML_TYPE_RQ5_1:
2935+
case GGML_TYPE_RQ6_1:
2936+
return true;
2937+
default:
2938+
return false;
2939+
}
2940+
};
2941+
29262942
if (!model) {
29272943
LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
29282944
return nullptr;
@@ -2946,6 +2962,11 @@ llama_context * llama_init_from_model(
29462962
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
29472963
const uint32_t blck_size = ggml_blck_size(params.type_k);
29482964
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
2965+
if (is_turbo_or_rq_type(params.type_k) && model->hparams.n_embd_head_k(il) != 64) {
2966+
LLAMA_LOG_ERROR("%s: K cache type %s currently supports only n_embd_head_k=64, got %u at layer %u\n",
2967+
__func__, ggml_type_name(params.type_k), model->hparams.n_embd_head_k(il), il);
2968+
return nullptr;
2969+
}
29492970
if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
29502971
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
29512972
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il));
@@ -2957,6 +2978,11 @@ llama_context * llama_init_from_model(
29572978
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
29582979
const uint32_t blck_size = ggml_blck_size(params.type_v);
29592980
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
2981+
if (is_turbo_or_rq_type(params.type_v) && model->hparams.n_embd_head_v(il) != 64) {
2982+
LLAMA_LOG_ERROR("%s: V cache type %s currently supports only n_embd_head_v=64, got %u at layer %u\n",
2983+
__func__, ggml_type_name(params.type_v), model->hparams.n_embd_head_v(il), il);
2984+
return nullptr;
2985+
}
29602986
if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
29612987
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n",
29622988
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il));

0 commit comments

Comments
 (0)