Skip to content

Commit 6effcec

Browse files
TP: round up granularity to 128 (ggml-org#24180)
* TP: round up granularity to 128 * remove assert
1 parent 86591c7 commit 6effcec

1 file changed

Lines changed: 16 additions & 8 deletions

File tree

src/llama-model.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -553,10 +553,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
553553
};
554554

555555
auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector<std::pair<int64_t, uint32_t>> & segments) -> std::vector<int64_t> {
556+
// for better performance it may make sense to round up blck_size to a higher power of 2 so that more efficient kernels can be used
556557
if (hparams.is_recr(il)) {
557558
// linear attention
558-
const int64_t head_dim = hparams.ssm_d_state;
559-
const int64_t granularity_qkv = std::lcm(blck_size, head_dim);
559+
const int64_t head_dim = hparams.ssm_d_state;
560+
const int64_t blck_size_perf = std::lcm(blck_size, 128);
561+
const int64_t granularity_qkv = std::lcm(blck_size_perf, head_dim);
560562
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) ||
561563
std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
562564
return std::vector<int64_t>(segments.size(), granularity_qkv);
@@ -578,17 +580,24 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
578580
// regular attention
579581
const uint32_t n_gqa = hparams.n_gqa(il);
580582
const uint32_t n_embd_q = n_gqa * hparams.n_embd_head_k(il);
583+
584+
// to handle head sizes like 80, only increase granularity while it doesn't cause underutilization
585+
int64_t blck_size_perf = blck_size;
586+
while (blck_size_perf < 128 && blck_size_perf*ud->n_devices < n_embd_q) {
587+
blck_size_perf *= 2;
588+
}
589+
581590
if (std::regex_match(tensor_name, pattern_attn_sinks)) {
582591
GGML_ASSERT(segments.size() == 1);
583-
return {std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa};
592+
return {std::lcm(n_embd_q, blck_size_perf)/n_embd_q * n_gqa};
584593
}
585594

586-
const int64_t granularity_q = std::lcm(n_embd_q, blck_size);
595+
const int64_t granularity_q = std::lcm(n_embd_q, blck_size_perf);
587596
if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_q_bias)) {
588597
GGML_ASSERT(segments.size() == 1);
589598
// some models have Q gate tensors, for those cases the granularity needs to be doubled:
590599
if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) {
591-
return {std::lcm(2*n_embd_q, blck_size)};
600+
return {std::lcm(2*n_embd_q, blck_size_perf)};
592601
}
593602
return {granularity_q};
594603
}
@@ -613,8 +622,9 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
613622
// FFN
614623
if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) ||
615624
std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) {
625+
const int64_t blck_size_perf = std::lcm(blck_size, 128);
616626
GGML_ASSERT(segments.size() == 1);
617-
return {blck_size};
627+
return {blck_size_perf};
618628
}
619629

620630
// everything else
@@ -627,7 +637,6 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
627637
tensor_config tc = get_tensor_config();
628638
split_state.axis = tc.axis;
629639
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
630-
const int64_t ne_full = tensor->ne[split_state.axis];
631640
const int64_t blck_size = ggml_blck_size(tc.tensor_axis_0->type);
632641
const float * tensor_split = ud->model->tensor_split();
633642
std::vector<float> tensor_split_scan;
@@ -644,7 +653,6 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
644653
const int64_t ne_s = segments[is].first;
645654
const uint32_t nr_s = segments[is].second;
646655
const int64_t g_s = granularity[is];
647-
GGML_ASSERT(ne_full % g_s == 0);
648656
int64_t low = 0;
649657
size_t j = 0;
650658
for (; j < ud->n_devices - 1; j++) {

0 commit comments

Comments
 (0)