@@ -202,24 +202,37 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
202202 const int64_t n_v_heads = hparams.ssm_dt_rank;
203203 const int64_t key_dim = head_k_dim * n_k_heads;
204204 const int64_t value_dim = head_v_dim * n_v_heads;
205- const int64_t head_ratio = n_v_heads / n_k_heads;
206- if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
207- GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
208- return std::vector<int64_t>(2 + head_ratio, key_dim);
209- }
210- if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
211- return std::vector<int64_t>(head_ratio, key_dim);
212- }
213- if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
214- std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
215- return std::vector<int64_t>(head_ratio, n_k_heads);
216- }
217- if (std::regex_match(tensor_name, pattern_r_cache)) {
218- return std::vector<int64_t>(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1));
219- }
220- if (std::regex_match(tensor_name, pattern_s_cache)) {
221- return std::vector<int64_t>(head_ratio, n_k_heads * head_v_dim * head_v_dim);
205+
206+ // both Qwen 3 Next and Qwen 3.5 support n_v_heads > n_k_heads but the broadcasting pattern is different:
207+ // - Qwen 3 Next: [k0_v0, k0_v1, k1_v2, k1_v3] (this is the default split pattern)
208+ // - Qwen 3.5: [k0_v0, k1_v1, k0_v2, k1_v3] (needs segmenting of V on the scale of K to get the correct pattern)
209+ if (ud->model->arch == LLM_ARCH_QWEN3NEXT) {
210+ if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
211+ GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
212+ return {key_dim, key_dim, value_dim};
213+ }
214+ } else {
215+ const int64_t head_ratio = n_v_heads / n_k_heads;
216+ if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
217+ GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
218+ return std::vector<int64_t>(2 + head_ratio, key_dim);
219+ }
220+ if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
221+ return std::vector<int64_t>(head_ratio, key_dim);
222+ }
223+ if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
224+ std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
225+ return std::vector<int64_t>(head_ratio, n_k_heads);
226+ }
227+ if (std::regex_match(tensor_name, pattern_r_cache)) {
228+ return std::vector<int64_t>(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1));
229+ }
230+ if (std::regex_match(tensor_name, pattern_s_cache)) {
231+ return std::vector<int64_t>(head_ratio, n_k_heads * head_v_dim * head_v_dim);
232+ }
222233 }
234+
235+ // the FFN is the same for Qwen 3 Next and Qwen 3.5:
223236 if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) {
224237 const int64_t n_ff_exp = hparams.n_ff_exp;
225238 GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp);
@@ -249,13 +262,16 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
249262 const int64_t head_dim = hparams.ssm_d_state;
250263 const int64_t granularity_qkv = std::lcm(blck_size, head_dim);
251264 if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) ||
252- std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
265+ std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
253266 return std::vector<int64_t>(segments.size(), granularity_qkv);
254267 }
255- if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
256- std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
268+ if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
269+ std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
257270 return std::vector<int64_t>(segments.size(), granularity_qkv / head_dim);
258271 }
272+ if (std::regex_match(tensor_name, pattern_ssm_beta_alpha)) {
273+ return std::vector<int64_t>(segments.size(), 2 * (granularity_qkv / head_dim));
274+ }
259275 if (std::regex_match(tensor_name, pattern_r_cache)) {
260276 return std::vector<int64_t>(segments.size(), granularity_qkv * (hparams.ssm_d_conv - 1));
261277 }
@@ -300,7 +316,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
300316
301317 // FFN
302318 if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) ||
303- std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) {
319+ std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) {
304320 GGML_ASSERT(segments.size() <= 2);
305321 return std::vector<int64_t>(segments.size(), blck_size);
306322 }
0 commit comments