Skip to content

Commit 865ff06

Browse files
TP: fix Qwen 3 Next data split (ggml-org#21732)
1 parent 2b2cd57 commit 865ff06

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

src/llama-model.cpp

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -202,24 +202,37 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
202202
const int64_t n_v_heads = hparams.ssm_dt_rank;
203203
const int64_t key_dim = head_k_dim * n_k_heads;
204204
const int64_t value_dim = head_v_dim * n_v_heads;
205-
const int64_t head_ratio = n_v_heads / n_k_heads;
206-
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
207-
GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
208-
return std::vector<int64_t>(2 + head_ratio, key_dim);
209-
}
210-
if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
211-
return std::vector<int64_t>(head_ratio, key_dim);
212-
}
213-
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
214-
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
215-
return std::vector<int64_t>(head_ratio, n_k_heads);
216-
}
217-
if (std::regex_match(tensor_name, pattern_r_cache)) {
218-
return std::vector<int64_t>(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1));
219-
}
220-
if (std::regex_match(tensor_name, pattern_s_cache)) {
221-
return std::vector<int64_t>(head_ratio, n_k_heads * head_v_dim * head_v_dim);
205+
206+
// both Qwen 3 Next and Qwen 3.5 support n_v_heads > n_k_heads but the broadcasting pattern is different:
207+
// - Qwen 3 Next: [k0_v0, k0_v1, k1_v2, k1_v3] (this is the default split pattern)
208+
// - Qwen 3.5: [k0_v0, k1_v1, k0_v2, k1_v3] (needs segmenting of V on the scale of K to get the correct pattern)
209+
if (ud->model->arch == LLM_ARCH_QWEN3NEXT) {
210+
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
211+
GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
212+
return {key_dim, key_dim, value_dim};
213+
}
214+
} else {
215+
const int64_t head_ratio = n_v_heads / n_k_heads;
216+
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
217+
GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
218+
return std::vector<int64_t>(2 + head_ratio, key_dim);
219+
}
220+
if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
221+
return std::vector<int64_t>(head_ratio, key_dim);
222+
}
223+
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
224+
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
225+
return std::vector<int64_t>(head_ratio, n_k_heads);
226+
}
227+
if (std::regex_match(tensor_name, pattern_r_cache)) {
228+
return std::vector<int64_t>(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1));
229+
}
230+
if (std::regex_match(tensor_name, pattern_s_cache)) {
231+
return std::vector<int64_t>(head_ratio, n_k_heads * head_v_dim * head_v_dim);
232+
}
222233
}
234+
235+
// the FFN is the same for Qwen 3 Next and Qwen 3.5:
223236
if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) {
224237
const int64_t n_ff_exp = hparams.n_ff_exp;
225238
GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp);
@@ -249,13 +262,16 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
249262
const int64_t head_dim = hparams.ssm_d_state;
250263
const int64_t granularity_qkv = std::lcm(blck_size, head_dim);
251264
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) ||
252-
std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
265+
std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
253266
return std::vector<int64_t>(segments.size(), granularity_qkv);
254267
}
255-
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
256-
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
268+
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
269+
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
257270
return std::vector<int64_t>(segments.size(), granularity_qkv / head_dim);
258271
}
272+
if (std::regex_match(tensor_name, pattern_ssm_beta_alpha)) {
273+
return std::vector<int64_t>(segments.size(), 2 * (granularity_qkv / head_dim));
274+
}
259275
if (std::regex_match(tensor_name, pattern_r_cache)) {
260276
return std::vector<int64_t>(segments.size(), granularity_qkv * (hparams.ssm_d_conv - 1));
261277
}
@@ -300,7 +316,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
300316

301317
// FFN
302318
if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) ||
303-
std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) {
319+
std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) {
304320
GGML_ASSERT(segments.size() <= 2);
305321
return std::vector<int64_t>(segments.size(), blck_size);
306322
}

src/models/qwen3next.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
354354
cb(last_conv_states, "last_conv_states", il);
355355

356356
ggml_tensor * state_update_target =
357-
ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
357+
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
358358
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
359359
cb(state_update_target, "state_update_target", il);
360360

@@ -445,7 +445,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
445445
// Update the recurrent states
446446
ggml_build_forward_expand(gf,
447447
ggml_cpy(ctx0, new_state,
448-
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
448+
ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
449449
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
450450

451451
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]

0 commit comments

Comments
 (0)