@@ -553,10 +553,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
553553 };
554554
555555 auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector<std::pair<int64_t , uint32_t >> & segments) -> std::vector<int64_t > {
556+ // for better performance it may make sense to round up blck_size to a higher power of 2 so that more efficient kernels can be used
556557 if (hparams.is_recr (il)) {
557558 // linear attention
558- const int64_t head_dim = hparams.ssm_d_state ;
559- const int64_t granularity_qkv = std::lcm (blck_size, head_dim);
559+ const int64_t head_dim = hparams.ssm_d_state ;
560+ const int64_t blck_size_perf = std::lcm (blck_size, 128 );
561+ const int64_t granularity_qkv = std::lcm (blck_size_perf, head_dim);
560562 if (std::regex_match (tensor_name, pattern_qkv_weight) || std::regex_match (tensor_name, pattern_attn_gate_weight) ||
561563 std::regex_match (tensor_name, pattern_ssm_conv1d) || std::regex_match (tensor_name, pattern_ssm_out_weight)) {
562564 return std::vector<int64_t >(segments.size (), granularity_qkv);
@@ -578,17 +580,24 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
578580 // regular attention
579581 const uint32_t n_gqa = hparams.n_gqa (il);
580582 const uint32_t n_embd_q = n_gqa * hparams.n_embd_head_k (il);
583+
584+ // to handle head sizes like 80, only increase granularity while it doesn't cause underutilization
585+ int64_t blck_size_perf = blck_size;
586+ while (blck_size_perf < 128 && blck_size_perf*ud->n_devices < n_embd_q) {
587+ blck_size_perf *= 2 ;
588+ }
589+
581590 if (std::regex_match (tensor_name, pattern_attn_sinks)) {
582591 GGML_ASSERT (segments.size () == 1 );
583- return {std::lcm (n_embd_q, blck_size )/n_embd_q * n_gqa};
592+ return {std::lcm (n_embd_q, blck_size_perf )/n_embd_q * n_gqa};
584593 }
585594
586- const int64_t granularity_q = std::lcm (n_embd_q, blck_size );
595+ const int64_t granularity_q = std::lcm (n_embd_q, blck_size_perf );
587596 if (std::regex_match (tensor_name, pattern_q_weight) || std::regex_match (tensor_name, pattern_q_bias)) {
588597 GGML_ASSERT (segments.size () == 1 );
589598 // some models have Q gate tensors, for those cases the granularity needs to be doubled:
590599 if (ud->model ->arch == LLM_ARCH_QWEN3NEXT || ud->model ->arch == LLM_ARCH_QWEN35 || ud->model ->arch == LLM_ARCH_QWEN35MOE) {
591- return {std::lcm (2 *n_embd_q, blck_size )};
600+ return {std::lcm (2 *n_embd_q, blck_size_perf )};
592601 }
593602 return {granularity_q};
594603 }
@@ -613,8 +622,9 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
613622 // FFN
614623 if (std::regex_match (tensor_name, pattern_ffn_up_gate_weight) || std::regex_match (tensor_name, pattern_ffn_up_gate_bias) ||
615624 std::regex_match (tensor_name, pattern_ffn_gate_up_weight) || std::regex_match (tensor_name, pattern_ffn_down_weight)) {
625+ const int64_t blck_size_perf = std::lcm (blck_size, 128 );
616626 GGML_ASSERT (segments.size () == 1 );
617- return {blck_size };
627+ return {blck_size_perf };
618628 }
619629
620630 // everything else
@@ -627,7 +637,6 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
627637 tensor_config tc = get_tensor_config ();
628638 split_state.axis = tc.axis ;
629639 if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
630- const int64_t ne_full = tensor->ne [split_state.axis ];
631640 const int64_t blck_size = ggml_blck_size (tc.tensor_axis_0 ->type );
632641 const float * tensor_split = ud->model ->tensor_split ();
633642 std::vector<float > tensor_split_scan;
@@ -644,7 +653,6 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
644653 const int64_t ne_s = segments[is].first ;
645654 const uint32_t nr_s = segments[is].second ;
646655 const int64_t g_s = granularity[is];
647- GGML_ASSERT (ne_full % g_s == 0 );
648656 int64_t low = 0 ;
649657 size_t j = 0 ;
650658 for (; j < ud->n_devices - 1 ; j++) {
0 commit comments