Skip to content

Commit 06938ac

Browse files
authored
tests : add support for qwen3 SSM archs (ggml-org#24031)
* tests : add support for qwen3 SSM archs * arch : add LLM_KV_ATTENTION_RECURRENT_LAYERS * cont : naming + TODOs
1 parent d545a2a commit 06938ac

25 files changed

Lines changed: 109 additions & 83 deletions

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
247247
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
248248
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
249249
{ LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" },
250+
{ LLM_KV_ATTENTION_RECURRENT_LAYERS, "%s.attention.recurrent_layers" },
250251

251252
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
252253
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ enum llm_kv {
251251
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
252252
LLM_KV_ATTENTION_INDEXER_TOP_K,
253253
LLM_KV_ATTENTION_SHARED_KV_LAYERS,
254+
LLM_KV_ATTENTION_RECURRENT_LAYERS,
254255

255256
LLM_KV_ROPE_DIMENSION_COUNT,
256257
LLM_KV_ROPE_DIMENSION_COUNT_SWA,

src/llama-hparams.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,31 @@
88
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
99
if (dense_first) {
1010
for (uint32_t il = 0; il < n_layer; ++il) {
11-
swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0);
11+
is_swa_impl[il] = n_pattern == 0 || (il % n_pattern != 0);
1212
}
1313
} else {
1414
for (uint32_t il = 0; il < n_layer; ++il) {
15-
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
15+
is_swa_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
1616
}
1717
}
1818
}
1919

20+
// TODO: implement
21+
//void llama_hparams::set_recr_pattern(uint32_t n_pattern, bool dense_first) {
22+
// if (dense_first) {
23+
// for (uint32_t il = 0; il < n_layer; ++il) {
24+
// is_recr_impl[il] = n_pattern == 0 || (il % n_pattern != 0);
25+
// }
26+
// } else {
27+
// for (uint32_t il = 0; il < n_layer; ++il) {
28+
// is_recr_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
29+
// }
30+
// }
31+
//}
32+
2033
bool llama_hparams::is_swa_any() const {
2134
for (uint32_t il = 0; il < n_layer; ++il) {
22-
if (swa_layers[il]) {
35+
if (is_swa_impl[il]) {
2336
return true;
2437
}
2538
}
@@ -193,9 +206,9 @@ uint32_t llama_hparams::n_embd_s() const {
193206
return ssm_d_state * ssm_d_inner;
194207
}
195208

196-
bool llama_hparams::is_recurrent(uint32_t il) const {
209+
bool llama_hparams::is_recr(uint32_t il) const {
197210
if (il < n_layer) {
198-
return recurrent_layer_arr[il];
211+
return is_recr_impl[il];
199212
}
200213

201214
GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer);
@@ -207,7 +220,7 @@ uint32_t llama_hparams::n_pos_per_embd() const {
207220

208221
bool llama_hparams::is_swa(uint32_t il) const {
209222
if (il < n_layer) {
210-
return swa_layers[il];
223+
return is_swa_impl[il];
211224
}
212225

213226
GGML_ABORT("fatal error");

src/llama-hparams.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ struct llama_hparams_convnext {
3737
};
3838

3939
struct llama_hparams {
40+
// note: use the `_impl` suffix to avoid name conflict between members and getters
41+
// for example: n_embd_out() vs n_embd_out_impl
42+
4043
bool vocab_only;
4144
bool no_alloc;
4245
bool rope_finetuned;
@@ -46,7 +49,7 @@ struct llama_hparams {
4649
uint32_t n_ctx_train; // context size the model was trained on
4750
uint32_t n_embd;
4851
uint32_t n_layer;
49-
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
52+
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
5053
uint32_t n_expert = 0;
5154
uint32_t n_expert_used = 0;
5255
uint32_t n_rel_attn_bkts = 0;
@@ -137,11 +140,15 @@ struct llama_hparams {
137140
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
138141
// the size of the sliding window (0 - no SWA)
139142
uint32_t n_swa = 0;
140-
// if swa_layers[il] == 1, then layer il is SWA
141-
// if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA)
143+
144+
// if is_swa_impl[il] == 1, then layer il is SWA
145+
// if is_swa_impl[il] == 0, then layer il is dense (i.e. non-SWA)
142146
// by default, all layers are dense
143147
// note: using uint32_t type for compatibility reason
144-
std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers;
148+
std::array<uint32_t, LLAMA_MAX_LAYERS> is_swa_impl;
149+
150+
// for hybrid state space models
151+
std::array<uint32_t, LLAMA_MAX_LAYERS> is_recr_impl;
145152

146153
// for State Space Models
147154
uint32_t ssm_d_conv = 0;
@@ -153,9 +160,6 @@ struct llama_hparams {
153160
// for Kimi Linear KDA
154161
uint32_t n_embd_head_kda = 0;
155162

156-
// for hybrid state space models
157-
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
158-
159163
bool ssm_dt_b_c_rms = false;
160164

161165
float f_clamp_kqv = 0.0f;
@@ -266,6 +270,14 @@ struct llama_hparams {
266270
// return true if one of the layers is SWA
267271
bool is_swa_any() const;
268272

273+
bool is_swa(uint32_t il) const;
274+
275+
// TODO: implement
276+
//void set_recr_pattern(uint32_t n_pattern, bool dense_first = false);
277+
278+
// whether or not the given layer is recurrent (for hybrid models)
279+
bool is_recr(uint32_t il) const;
280+
269281
uint32_t n_head(uint32_t il = 0) const;
270282

271283
uint32_t n_head_kv(uint32_t il = 0) const;
@@ -307,13 +319,8 @@ struct llama_hparams {
307319
// dimension of the recurrent state embeddings
308320
uint32_t n_embd_s() const;
309321

310-
// whether or not the given layer is recurrent (for hybrid models)
311-
bool is_recurrent(uint32_t il) const;
312-
313322
uint32_t n_pos_per_embd() const;
314323

315-
bool is_swa(uint32_t il) const;
316-
317324
// note: currently only support if either all or none of the layers are MLA
318325
bool is_mla() const;
319326

src/llama-memory-hybrid-iswa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
4444
n_ubatch,
4545
n_pad,
4646
filter_attn == nullptr ?
47-
[&](int32_t il) { return !hparams.is_recurrent(il); }
47+
[&](int32_t il) { return !hparams.is_recr(il); }
4848
: filter_attn,
4949
nullptr
5050
)),
@@ -57,7 +57,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
5757
n_seq_max,
5858
n_rs_seq,
5959
filter_recr == nullptr ?
60-
[&](int32_t il) { return hparams.is_recurrent(il); }
60+
[&](int32_t il) { return hparams.is_recr(il); }
6161
: filter_recr
6262
)) {}
6363

src/llama-memory-hybrid.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ llama_memory_hybrid::llama_memory_hybrid(
4545
n_swa,
4646
swa_type,
4747
filter_attn == nullptr ?
48-
[&](int32_t il) { return !hparams.is_recurrent(il); }
48+
[&](int32_t il) { return !hparams.is_recr(il); }
4949
: filter_attn,
5050
nullptr
5151
)),
@@ -58,7 +58,7 @@ llama_memory_hybrid::llama_memory_hybrid(
5858
n_seq_max,
5959
n_rs_seq,
6060
filter_recr == nullptr ?
61-
[&](int32_t il) { return hparams.is_recurrent(il); }
61+
[&](int32_t il) { return hparams.is_recr(il); }
6262
: filter_recr
6363
)) {}
6464

src/llama-model-loader.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ namespace GGUFMeta {
146146
const enum gguf_type arr_type = gguf_get_arr_type(ctx, k);
147147
return ArrayInfo {
148148
arr_type,
149-
size_t(gguf_get_arr_n(ctx, k)),
149+
gguf_get_arr_n(ctx, k),
150150
arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k),
151151
};
152152
}
@@ -445,7 +445,7 @@ namespace GGUFMeta {
445445
}
446446

447447
if (n > N_MAX) {
448-
throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str()));
448+
throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", n, (uint32_t) N_MAX, key.c_str()));
449449
}
450450

451451
if (gguf_get_kv_type(metadata, kid) == GGUF_TYPE_ARRAY) {
@@ -502,9 +502,9 @@ namespace GGUFMeta {
502502
}
503503

504504
// TODO: this is not very clever - figure out something better
505-
template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
505+
template bool llama_model_loader::get_key_or_arr<std::array<int, 4>> (enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
506506
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
507-
template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required);
507+
template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required);
508508

509509

510510
llama_model_loader::llama_model_loader(

src/llama-model-saver.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414

1515
bool llama_model_saver_supports_arch(llm_arch arch) {
1616
switch (arch) {
17-
case LLM_ARCH_QWEN3NEXT:
18-
case LLM_ARCH_QWEN35:
19-
case LLM_ARCH_QWEN35MOE:
2017
case LLM_ARCH_PLAMO3:
2118
case LLM_ARCH_GEMMA3:
2219
case LLM_ARCH_GEMMA3N:
@@ -107,6 +104,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, c
107104
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values);
108105
} else if (std::is_same<typename Container::value_type, uint32_t>::value) {
109106
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values);
107+
} else if (std::is_same<typename Container::value_type, bool>::value) {
108+
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_BOOL, value.data(), n_values);
110109
} else if (std::is_same<typename Container::value_type, int32_t>::value) {
111110
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values);
112111
} else if (std::is_same<typename Container::value_type, float>::value) {
@@ -245,7 +244,7 @@ void llama_model_saver::add_kv_from_model() {
245244
add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
246245
add_kv(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count);
247246
add_kv(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
248-
// add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???);
247+
// add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); // saved as LLM_KV_ATTENTION_RECURRENT_LAYERS instead
249248

250249
add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true);
251250
add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true);
@@ -279,6 +278,7 @@ void llama_model_saver::add_kv_from_model() {
279278
add_kv(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head);
280279
add_kv(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size);
281280
add_kv(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k);
281+
add_kv(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, true);
282282

283283
const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
284284

src/llama-model.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -373,10 +373,10 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
373373
// count only the same type of previous layers to avoid this
374374
auto get_il_eff = [&](const size_t il){
375375
size_t ret = 0;
376-
const bool il_is_recurrent = hparams.is_recurrent(il);
377-
const bool il_is_swa = hparams.is_swa(il);
376+
const bool il_is_recr = hparams.is_recr(il);
377+
const bool il_is_swa = hparams.is_swa(il);
378378
for (size_t il_prev = 0; il_prev < il; il_prev++) {
379-
ret += hparams.is_recurrent(il_prev) == il_is_recurrent && hparams.is_swa(il_prev) == il_is_swa;
379+
ret += hparams.is_recr(il_prev) == il_is_recr && hparams.is_swa(il_prev) == il_is_swa;
380380
}
381381
return ret;
382382
};
@@ -553,7 +553,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
553553
};
554554

555555
auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector<std::pair<int64_t, uint32_t>> & segments) -> std::vector<int64_t> {
556-
if (hparams.is_recurrent(il)) {
556+
if (hparams.is_recr(il)) {
557557
// linear attention
558558
const int64_t head_dim = hparams.ssm_d_state;
559559
const int64_t granularity_qkv = std::lcm(blck_size, head_dim);
@@ -1076,18 +1076,16 @@ void llama_model_base::load_hparams(llama_model_loader & ml) {
10761076
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
10771077
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
10781078
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
1079-
std::fill(
1080-
hparams.recurrent_layer_arr.begin(),
1081-
hparams.recurrent_layer_arr.end(),
1082-
llm_arch_is_recurrent(ml.get_arch()));
10831079

10841080
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
1085-
std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0);
1081+
std::fill(hparams.is_swa_impl.begin(), hparams.is_swa_impl.end(), 0);
1082+
std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), llm_arch_is_recurrent(ml.get_arch()) ? 1 : 0);
10861083

10871084
std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f);
10881085
std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f);
1089-
std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f);
1090-
std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f);
1086+
std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f);
1087+
std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f);
1088+
10911089
std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f);
10921090
std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f);
10931091

@@ -2040,18 +2038,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
20402038
filter_recr = [&](int32_t) { return true; };
20412039
} else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) {
20422040
filter_attn = [&](int32_t il) {
2043-
return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
2041+
return !hparams.is_recr(il) && hparams.n_ff(il) == 0;
20442042
};
20452043
filter_recr = [&](int32_t il) {
2046-
return hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
2044+
return hparams.is_recr(il) && hparams.n_ff(il) == 0;
20472045
};
20482046
} else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) {
20492047
const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
20502048
filter_attn = [&, n_main](int32_t il) {
2051-
return (uint32_t)il < n_main && !hparams.is_recurrent(il);
2049+
return (uint32_t)il < n_main && !hparams.is_recr(il);
20522050
};
20532051
filter_recr = [&, n_main](int32_t il) {
2054-
return (uint32_t)il < n_main && hparams.is_recurrent(il);
2052+
return (uint32_t)il < n_main && hparams.is_recr(il);
20552053
};
20562054
}
20572055

src/models/falcon-h1.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ void llama_model_falcon_h1::load_arch_hparams(llama_model_loader & ml) {
1111
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
1212
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
1313

14-
std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true);
14+
std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), true);
1515

1616
switch (hparams.n_layer) {
1717
case 36:

0 commit comments

Comments
 (0)