Skip to content

Commit f73287f

Browse files
committed
quant-wht: add IQ support and configurable skip list
Extend GGUF-level quant_wht to IQ tensor types and make WHT skipping metadata-driven. --quant-wht now uses the default low-precision skip list, --quant-wht-full rotates all eligible tensors, and --quant-wht-skip-type allows overriding skipped GGML tensor types. Persist general.quant_wht.skip_types in GGUF and make model loading/decode honor that list; missing skip_types keeps full WHT behavior for compatibility.
1 parent 36ef042 commit f73287f

8 files changed

Lines changed: 339 additions & 37 deletions

File tree

ggml/src/ggml-cuda/quantize.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,16 @@ static bool ggml_cuda_quant_wht_type_supported(const ggml_type type) {
5656
type == GGML_TYPE_Q4_K ||
5757
type == GGML_TYPE_Q5_K ||
5858
type == GGML_TYPE_Q6_K ||
59-
type == GGML_TYPE_Q8_0;
59+
type == GGML_TYPE_Q8_0 ||
60+
type == GGML_TYPE_IQ1_S ||
61+
type == GGML_TYPE_IQ1_M ||
62+
type == GGML_TYPE_IQ2_XXS ||
63+
type == GGML_TYPE_IQ2_XS ||
64+
type == GGML_TYPE_IQ2_S ||
65+
type == GGML_TYPE_IQ3_XXS ||
66+
type == GGML_TYPE_IQ3_S ||
67+
type == GGML_TYPE_IQ4_NL ||
68+
type == GGML_TYPE_IQ4_XS;
6069
}
6170

6271
static void ggml_cuda_quant_wht_log_once(const ggml_type type, const char * path) {

include/llama.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,9 @@ extern "C" {
411411
bool pure; // quantize all tensors to the default type
412412
bool keep_split; // quantize to the same number of shards
413413
bool dry_run; // calculate and show the final quantization size without performing quantization
414-
bool quant_wht; // store eligible Q_K tensors in WHT-rotated domain
414+
bool quant_wht; // store eligible Q_K/Q8_0/IQ tensors in WHT-rotated domain
415+
bool quant_wht_full; // rotate every eligible tensor instead of using the skip list
416+
const char * quant_wht_skip_types; // comma-separated GGML tensor types to leave unrotated when quant_wht_full is false
415417
uint32_t quant_wht_dim; // WHT dimension, currently only 256 is supported
416418
const struct llama_model_imatrix_data * imatrix; // pointer to importance matrix data
417419
const struct llama_model_kv_override * kv_overrides; // pointer to kv overrides

src/llama-hparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ struct llama_hparams {
4040
bool use_par_res;
4141
bool swin_norm;
4242
bool quant_wht_enabled = false;
43+
char quant_wht_skip_types[512] = {};
4344

4445
uint32_t n_ctx_train; // context size the model was trained on
4546
uint32_t n_embd;

src/llama-model-loader.cpp

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <algorithm>
99
#include <array>
10+
#include <cctype>
1011
#include <cinttypes>
1112
#include <cstdint>
1213
#include <cstdlib>
@@ -1061,7 +1062,82 @@ static bool llama_model_quant_wht_type_supported(ggml_type type) {
10611062
type == GGML_TYPE_Q4_K ||
10621063
type == GGML_TYPE_Q5_K ||
10631064
type == GGML_TYPE_Q6_K ||
1064-
type == GGML_TYPE_Q8_0;
1065+
type == GGML_TYPE_Q8_0 ||
1066+
type == GGML_TYPE_IQ1_S ||
1067+
type == GGML_TYPE_IQ1_M ||
1068+
type == GGML_TYPE_IQ2_XXS ||
1069+
type == GGML_TYPE_IQ2_XS ||
1070+
type == GGML_TYPE_IQ2_S ||
1071+
type == GGML_TYPE_IQ3_XXS ||
1072+
type == GGML_TYPE_IQ3_S ||
1073+
type == GGML_TYPE_IQ4_NL ||
1074+
type == GGML_TYPE_IQ4_XS;
1075+
}
1076+
1077+
static const char * llama_model_quant_wht_type_name(ggml_type type) {
1078+
switch (type) {
1079+
case GGML_TYPE_Q2_K: return "Q2_K";
1080+
case GGML_TYPE_Q3_K: return "Q3_K";
1081+
case GGML_TYPE_Q4_K: return "Q4_K";
1082+
case GGML_TYPE_Q5_K: return "Q5_K";
1083+
case GGML_TYPE_Q6_K: return "Q6_K";
1084+
case GGML_TYPE_Q8_0: return "Q8_0";
1085+
case GGML_TYPE_IQ1_S: return "IQ1_S";
1086+
case GGML_TYPE_IQ1_M: return "IQ1_M";
1087+
case GGML_TYPE_IQ2_XXS: return "IQ2_XXS";
1088+
case GGML_TYPE_IQ2_XS: return "IQ2_XS";
1089+
case GGML_TYPE_IQ2_S: return "IQ2_S";
1090+
case GGML_TYPE_IQ3_XXS: return "IQ3_XXS";
1091+
case GGML_TYPE_IQ3_S: return "IQ3_S";
1092+
case GGML_TYPE_IQ4_NL: return "IQ4_NL";
1093+
case GGML_TYPE_IQ4_XS: return "IQ4_XS";
1094+
default: return nullptr;
1095+
}
1096+
}
1097+
1098+
static std::string llama_model_quant_wht_normalize_type_token(std::string token) {
1099+
token.erase(std::remove_if(token.begin(), token.end(), [](unsigned char c) { return std::isspace(c) != 0; }), token.end());
1100+
std::transform(token.begin(), token.end(), token.begin(), [](unsigned char c) { return (char) std::toupper(c); });
1101+
return token;
1102+
}
1103+
1104+
static ggml_type llama_model_quant_wht_parse_type_token(const std::string & token) {
1105+
const std::string name = llama_model_quant_wht_normalize_type_token(token);
1106+
for (int i = 0; i < GGML_TYPE_COUNT; ++i) {
1107+
const ggml_type type = (ggml_type) i;
1108+
const char * type_name = llama_model_quant_wht_type_name(type);
1109+
if (type_name != nullptr && name == type_name) {
1110+
return type;
1111+
}
1112+
}
1113+
return GGML_TYPE_COUNT;
1114+
}
1115+
1116+
static bool llama_model_quant_wht_skip_list_has(const std::string & skip_types, ggml_type type) {
1117+
size_t start = 0;
1118+
while (start <= skip_types.size()) {
1119+
const size_t end = skip_types.find(',', start);
1120+
const std::string token = skip_types.substr(start, end == std::string::npos ? std::string::npos : end - start);
1121+
if (!llama_model_quant_wht_normalize_type_token(token).empty()) {
1122+
const ggml_type parsed = llama_model_quant_wht_parse_type_token(token);
1123+
if (parsed == GGML_TYPE_COUNT || !llama_model_quant_wht_type_supported(parsed)) {
1124+
throw std::runtime_error(format("unsupported general.quant_wht.skip_types entry: %s", token.c_str()));
1125+
}
1126+
if (parsed == type) {
1127+
return true;
1128+
}
1129+
}
1130+
if (end == std::string::npos) {
1131+
break;
1132+
}
1133+
start = end + 1;
1134+
}
1135+
return false;
1136+
}
1137+
1138+
static bool llama_model_quant_wht_type_enabled(ggml_type type, const std::string & skip_types) {
1139+
return llama_model_quant_wht_type_supported(type) &&
1140+
!llama_model_quant_wht_skip_list_has(skip_types, type);
10651141
}
10661142

10671143
static bool llama_model_quant_wht_backend_supported(ggml_backend_dev_t dev) {
@@ -1174,12 +1250,12 @@ struct ggml_tensor * llama_model_loader::create_tensor(
11741250
const bool quant_wht_tensor =
11751251
hparams.quant_wht_enabled &&
11761252
(op == GGML_OP_MUL_MAT || op == GGML_OP_MUL_MAT_ID) &&
1177-
llama_model_quant_wht_type_supported(t_meta->type) &&
1253+
llama_model_quant_wht_type_enabled(t_meta->type, hparams.quant_wht_skip_types) &&
11781254
llama_model_quant_wht_name_supported(tn);
11791255

11801256
if (hparams.quant_wht_enabled &&
11811257
(op == GGML_OP_MUL_MAT || op == GGML_OP_MUL_MAT_ID) &&
1182-
llama_model_quant_wht_type_supported(t_meta->type) &&
1258+
llama_model_quant_wht_type_enabled(t_meta->type, hparams.quant_wht_skip_types) &&
11831259
llama_model_quant_wht_name_supported(tn) &&
11841260
t_meta->ne[0] % 256 != 0) {
11851261
throw std::runtime_error(format("general.quant_wht tensor %s has unsupported reduction dimension %" PRId64,
@@ -1191,8 +1267,9 @@ struct ggml_tensor * llama_model_loader::create_tensor(
11911267
if (getenv("GGML_CUDA_LOG_QUANT_WHT") != nullptr) {
11921268
static int n_logged = 0;
11931269
if (n_logged < 8) {
1194-
LLAMA_LOG_INFO("%s: quant_wht tensor flagged: %s type=%s dim=%" PRId64 "\n",
1195-
__func__, tn.str().c_str(), ggml_type_name(t_meta->type), t_meta->ne[0]);
1270+
LLAMA_LOG_INFO("%s: quant_wht tensor flagged: %s type=%s dim=%" PRId64 " skip_types=%s\n",
1271+
__func__, tn.str().c_str(), ggml_type_name(t_meta->type), t_meta->ne[0],
1272+
hparams.quant_wht_skip_types[0] == '\0' ? "<none>" : hparams.quant_wht_skip_types);
11961273
++n_logged;
11971274
}
11981275
}

src/llama-model.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <algorithm>
2323
#include <cassert>
2424
#include <cfloat>
25+
#include <cstdio>
2526
#include <cstdint>
2627
#include <cstring>
2728
#include <cmath>
@@ -731,14 +732,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
731732
if (!ml.get_key("general.quant_wht.version", quant_wht_version, false)) {
732733
throw std::runtime_error("general.quant_wht.enabled=true but general.quant_wht.version is missing");
733734
}
735+
quant_wht_skip_types.clear();
736+
ml.get_key("general.quant_wht.skip_types", quant_wht_skip_types, false);
734737
if (quant_wht_dim != 256 || quant_wht_scheme != "pqk_rht_v1" || quant_wht_version != 1) {
735738
throw std::runtime_error(format("unsupported general.quant_wht metadata: dim=%u scheme=%s version=%u",
736739
quant_wht_dim, quant_wht_scheme.c_str(), quant_wht_version));
737740
}
738741
hparams.quant_wht_dim = quant_wht_dim;
739742
hparams.quant_wht_version = quant_wht_version;
740-
LLAMA_LOG_WARN("%s: WARNING: experimental WHT-rotated Q_K GGUF detected (dim=%u, scheme=%s, version=%u)\n",
741-
__func__, quant_wht_dim, quant_wht_scheme.c_str(), quant_wht_version);
743+
if (quant_wht_skip_types.size() >= sizeof(hparams.quant_wht_skip_types)) {
744+
throw std::runtime_error("general.quant_wht.skip_types is too long");
745+
}
746+
snprintf(hparams.quant_wht_skip_types, sizeof(hparams.quant_wht_skip_types), "%s", quant_wht_skip_types.c_str());
747+
LLAMA_LOG_WARN("%s: WARNING: experimental WHT-rotated Q_K/Q8_0/IQ GGUF detected (dim=%u, scheme=%s, version=%u, skip_types=%s)\n",
748+
__func__, quant_wht_dim, quant_wht_scheme.c_str(), quant_wht_version,
749+
quant_wht_skip_types.empty() ? "<none>" : quant_wht_skip_types.c_str());
742750
}
743751

744752
// everything past this point is not vocab-related
@@ -8185,6 +8193,7 @@ void llama_model::print_info() const {
81858193
if (quant_wht_enabled) {
81868194
LLAMA_LOG_INFO("%s: quant_wht_dim = %u\n", __func__, quant_wht_dim);
81878195
LLAMA_LOG_INFO("%s: quant_wht_scheme = %s\n", __func__, quant_wht_scheme.c_str());
8196+
LLAMA_LOG_INFO("%s: quant_wht_skip_types = %s\n", __func__, quant_wht_skip_types.empty() ? "<none>" : quant_wht_skip_types.c_str());
81888197
}
81898198

81908199
if (!hparams.vocab_only) {

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ struct llama_model {
566566
uint32_t quant_wht_dim = 0;
567567
uint32_t quant_wht_version = 0;
568568
std::string quant_wht_scheme;
569+
std::string quant_wht_skip_types;
569570

570571
// list of devices used in this model
571572
std::vector<llama_device> devices;

0 commit comments

Comments
 (0)