77
88#include < algorithm>
99#include < array>
10+ #include < cctype>
1011#include < cinttypes>
1112#include < cstdint>
1213#include < cstdlib>
@@ -1061,7 +1062,82 @@ static bool llama_model_quant_wht_type_supported(ggml_type type) {
10611062 type == GGML_TYPE_Q4_K ||
10621063 type == GGML_TYPE_Q5_K ||
10631064 type == GGML_TYPE_Q6_K ||
1064- type == GGML_TYPE_Q8_0;
1065+ type == GGML_TYPE_Q8_0 ||
1066+ type == GGML_TYPE_IQ1_S ||
1067+ type == GGML_TYPE_IQ1_M ||
1068+ type == GGML_TYPE_IQ2_XXS ||
1069+ type == GGML_TYPE_IQ2_XS ||
1070+ type == GGML_TYPE_IQ2_S ||
1071+ type == GGML_TYPE_IQ3_XXS ||
1072+ type == GGML_TYPE_IQ3_S ||
1073+ type == GGML_TYPE_IQ4_NL ||
1074+ type == GGML_TYPE_IQ4_XS;
1075+ }
1076+
1077+ static const char * llama_model_quant_wht_type_name (ggml_type type) {
1078+ switch (type) {
1079+ case GGML_TYPE_Q2_K: return " Q2_K" ;
1080+ case GGML_TYPE_Q3_K: return " Q3_K" ;
1081+ case GGML_TYPE_Q4_K: return " Q4_K" ;
1082+ case GGML_TYPE_Q5_K: return " Q5_K" ;
1083+ case GGML_TYPE_Q6_K: return " Q6_K" ;
1084+ case GGML_TYPE_Q8_0: return " Q8_0" ;
1085+ case GGML_TYPE_IQ1_S: return " IQ1_S" ;
1086+ case GGML_TYPE_IQ1_M: return " IQ1_M" ;
1087+ case GGML_TYPE_IQ2_XXS: return " IQ2_XXS" ;
1088+ case GGML_TYPE_IQ2_XS: return " IQ2_XS" ;
1089+ case GGML_TYPE_IQ2_S: return " IQ2_S" ;
1090+ case GGML_TYPE_IQ3_XXS: return " IQ3_XXS" ;
1091+ case GGML_TYPE_IQ3_S: return " IQ3_S" ;
1092+ case GGML_TYPE_IQ4_NL: return " IQ4_NL" ;
1093+ case GGML_TYPE_IQ4_XS: return " IQ4_XS" ;
1094+ default : return nullptr ;
1095+ }
1096+ }
1097+
1098+ static std::string llama_model_quant_wht_normalize_type_token (std::string token) {
1099+ token.erase (std::remove_if (token.begin (), token.end (), [](unsigned char c) { return std::isspace (c) != 0 ; }), token.end ());
1100+ std::transform (token.begin (), token.end (), token.begin (), [](unsigned char c) { return (char ) std::toupper (c); });
1101+ return token;
1102+ }
1103+
1104+ static ggml_type llama_model_quant_wht_parse_type_token (const std::string & token) {
1105+ const std::string name = llama_model_quant_wht_normalize_type_token (token);
1106+ for (int i = 0 ; i < GGML_TYPE_COUNT; ++i) {
1107+ const ggml_type type = (ggml_type) i;
1108+ const char * type_name = llama_model_quant_wht_type_name (type);
1109+ if (type_name != nullptr && name == type_name) {
1110+ return type;
1111+ }
1112+ }
1113+ return GGML_TYPE_COUNT;
1114+ }
1115+
1116+ static bool llama_model_quant_wht_skip_list_has (const std::string & skip_types, ggml_type type) {
1117+ size_t start = 0 ;
1118+ while (start <= skip_types.size ()) {
1119+ const size_t end = skip_types.find (' ,' , start);
1120+ const std::string token = skip_types.substr (start, end == std::string::npos ? std::string::npos : end - start);
1121+ if (!llama_model_quant_wht_normalize_type_token (token).empty ()) {
1122+ const ggml_type parsed = llama_model_quant_wht_parse_type_token (token);
1123+ if (parsed == GGML_TYPE_COUNT || !llama_model_quant_wht_type_supported (parsed)) {
1124+ throw std::runtime_error (format (" unsupported general.quant_wht.skip_types entry: %s" , token.c_str ()));
1125+ }
1126+ if (parsed == type) {
1127+ return true ;
1128+ }
1129+ }
1130+ if (end == std::string::npos) {
1131+ break ;
1132+ }
1133+ start = end + 1 ;
1134+ }
1135+ return false ;
1136+ }
1137+
1138+ static bool llama_model_quant_wht_type_enabled (ggml_type type, const std::string & skip_types) {
1139+ return llama_model_quant_wht_type_supported (type) &&
1140+ !llama_model_quant_wht_skip_list_has (skip_types, type);
10651141}
10661142
10671143static bool llama_model_quant_wht_backend_supported (ggml_backend_dev_t dev) {
@@ -1174,12 +1250,12 @@ struct ggml_tensor * llama_model_loader::create_tensor(
11741250 const bool quant_wht_tensor =
11751251 hparams.quant_wht_enabled &&
11761252 (op == GGML_OP_MUL_MAT || op == GGML_OP_MUL_MAT_ID) &&
1177- llama_model_quant_wht_type_supported (t_meta->type ) &&
1253+ llama_model_quant_wht_type_enabled (t_meta->type , hparams. quant_wht_skip_types ) &&
11781254 llama_model_quant_wht_name_supported (tn);
11791255
11801256 if (hparams.quant_wht_enabled &&
11811257 (op == GGML_OP_MUL_MAT || op == GGML_OP_MUL_MAT_ID) &&
1182- llama_model_quant_wht_type_supported (t_meta->type ) &&
1258+ llama_model_quant_wht_type_enabled (t_meta->type , hparams. quant_wht_skip_types ) &&
11831259 llama_model_quant_wht_name_supported (tn) &&
11841260 t_meta->ne [0 ] % 256 != 0 ) {
11851261 throw std::runtime_error (format (" general.quant_wht tensor %s has unsupported reduction dimension %" PRId64,
@@ -1191,8 +1267,9 @@ struct ggml_tensor * llama_model_loader::create_tensor(
11911267 if (getenv (" GGML_CUDA_LOG_QUANT_WHT" ) != nullptr ) {
11921268 static int n_logged = 0 ;
11931269 if (n_logged < 8 ) {
1194- LLAMA_LOG_INFO (" %s: quant_wht tensor flagged: %s type=%s dim=%" PRId64 " \n " ,
1195- __func__, tn.str ().c_str (), ggml_type_name (t_meta->type ), t_meta->ne [0 ]);
1270+ LLAMA_LOG_INFO (" %s: quant_wht tensor flagged: %s type=%s dim=%" PRId64 " skip_types=%s\n " ,
1271+ __func__, tn.str ().c_str (), ggml_type_name (t_meta->type ), t_meta->ne [0 ],
1272+ hparams.quant_wht_skip_types [0 ] == ' \0 ' ? " <none>" : hparams.quant_wht_skip_types );
11961273 ++n_logged;
11971274 }
11981275 }
0 commit comments