@@ -1133,7 +1133,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
11331133 if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta (t_ij->view_src ->buffer )) {
11341134 t_ij->view_src = ggml_backend_meta_buffer_simple_tensor (tensor->view_src , j);
11351135 if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
1136- GGML_ASSERT (ne[split_dim] != 0 && tensor->ne [split_dim] != 0 );
1136+ GGML_ASSERT (tensor->ne [split_dim] != 0 );
11371137 const int split_dim_view_src = ggml_backend_meta_get_split_state (tensor->view_src , /* assume_sync =*/ true ).axis ;
11381138 GGML_ASSERT (split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS);
11391139
@@ -1170,6 +1170,28 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
11701170
11711171 simple_tensors.push_back (t_ij);
11721172 }
1173+
1174+ // If one of the sources has a zero-sized slice, disable the computation:
1175+ for (int i = 0 ; i < GGML_MAX_SRC; i++) {
1176+ if (tensor->src [i] == nullptr || !ggml_backend_buffer_is_meta (tensor->src [i]->buffer )) {
1177+ continue ;
1178+ }
1179+
1180+ const ggml_backend_meta_split_state split_state_src = ggml_backend_meta_get_split_state (tensor->src [i], /* assume_sync =*/ true );
1181+ if (split_state_src.axis < 0 || split_state_src.axis >= GGML_MAX_DIMS) {
1182+ continue ;
1183+ }
1184+ for (size_t j = 0 ; j < n_simple_bufs; j++) {
1185+ int64_t ne_sum = 0 ;
1186+ for (size_t s = 0 ; s < split_state_src.n_segments ; s++) {
1187+ ne_sum += split_state_src.ne [s*n_simple_bufs + j];
1188+ }
1189+ if (ne_sum == 0 ) {
1190+ simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
1191+ }
1192+ }
1193+ }
1194+
11731195 buf_ctx->simple_tensors [tensor] = simple_tensors;
11741196
11751197 return GGML_STATUS_SUCCESS;
@@ -1442,17 +1464,20 @@ struct ggml_backend_meta_context {
14421464 struct backend_config {
14431465 ggml_backend_t backend;
14441466
1445- std::vector<cgraph_config> cgraphs;
1446- std::vector<ggml_tensor *> nodes;
1447- ggml_backend_buffer_ptr buf ;
1467+ std::vector<cgraph_config> cgraphs;
1468+ std::vector<ggml_tensor *> nodes;
1469+ std::vector< ggml_backend_buffer_ptr> bufs ;
14481470
1449- backend_config (ggml_backend_t backend) : backend(backend) {}
1471+ backend_config (ggml_backend_t backend, const size_t n_reduce_steps) : backend(backend) {
1472+ bufs.resize (n_reduce_steps);
1473+ }
14501474 };
14511475 std::string name;
14521476 std::vector<backend_config> backend_configs;
14531477 ggml_context_ptr ctx;
14541478 std::vector<ggml_cgraph *> cgraphs_aux;
14551479 std::vector<ggml_tensor *> nodes_aux;
1480+ size_t n_reduce_steps;
14561481 int max_nnodes = 0 ;
14571482 size_t max_tmp_size = 0 ;
14581483 size_t max_subgraphs = 0 ;
@@ -1464,6 +1489,7 @@ struct ggml_backend_meta_context {
14641489
14651490 ggml_backend_meta_context (ggml_backend_dev_t meta_dev, const char * params) {
14661491 const size_t n_devs = ggml_backend_meta_dev_n_devs (meta_dev);
1492+ n_reduce_steps = std::ceil (std::log2 (n_devs));
14671493 name = " Meta(" ;
14681494 std::vector<ggml_backend_t > simple_backends;
14691495 backend_configs.reserve (n_devs);
@@ -1475,7 +1501,7 @@ struct ggml_backend_meta_context {
14751501 }
14761502 name += ggml_backend_dev_name (simple_dev);
14771503 simple_backends.push_back (ggml_backend_dev_init (simple_dev, params));
1478- backend_configs.emplace_back (simple_backends.back ());
1504+ backend_configs.emplace_back (simple_backends.back (), n_reduce_steps );
14791505 }
14801506 name += " )" ;
14811507
@@ -1505,10 +1531,6 @@ struct ggml_backend_meta_context {
15051531 ggml_backend_free (bc.backend );
15061532 }
15071533 }
1508-
1509- size_t n_reduce_steps () const {
1510- return std::ceil (std::log2 (backend_configs.size ()));
1511- }
15121534};
15131535
15141536static const char * ggml_backend_meta_get_name (ggml_backend_t backend) {
@@ -1754,16 +1776,17 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
17541776 if (max_tmp_size > backend_ctx->max_tmp_size ) {
17551777 for (size_t j = 0 ; j < n_backends; j++) {
17561778 auto & bcj = backend_ctx->backend_configs [j];
1757- bcj.buf .reset (ggml_backend_alloc_buffer (bcj.backend , max_tmp_size));
1779+ for (size_t i = 0 ; i < backend_ctx->n_reduce_steps ; i++) {
1780+ bcj.bufs [i].reset (ggml_backend_alloc_buffer (bcj.backend , max_tmp_size));
1781+ }
17581782 }
17591783 backend_ctx->max_tmp_size = max_tmp_size;
17601784 }
17611785
17621786 if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs ) {
17631787 backend_ctx->max_subgraphs = std::max (backend_ctx->max_subgraphs , n_subgraphs);
1764- const size_t n_reduce_steps = backend_ctx->n_reduce_steps ();
1765- const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step
1766- const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step
1788+ const size_t n_nodes_per_device = 3 * backend_ctx->n_reduce_steps ; // tmp + ADD (+zeroing) graph per step and device
1789+ const size_t n_cgraphs_per_device = 2 * backend_ctx->n_reduce_steps ; // ADD ( + zeroing) graph per step and device
17671790 const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs *ggml_graph_overhead_custom (backend_ctx->max_nnodes , cgraph->grads );
17681791 const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs *ggml_graph_overhead_custom (1 , cgraph->grads );
17691792 const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs *ggml_tensor_overhead ();
@@ -1812,11 +1835,6 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
18121835 size_t iga = 0 ; // i graph aux
18131836 size_t ina = 0 ; // i node aux
18141837
1815- // FIXME usage_counts
1816- auto get_cgraph_aux = [&]() -> ggml_cgraph * {
1817- ggml_cgraph * ret = backend_ctx->cgraphs_aux [iga++];
1818- return ret;
1819- };
18201838 auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * {
18211839 ggml_tensor * ret = backend_ctx->nodes_aux [ina++];
18221840 memset (ret, 0 , sizeof (ggml_tensor));
@@ -1828,75 +1846,110 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
18281846 }
18291847 return ret;
18301848 };
1849+ auto set_tmp_data = [&](ggml_tensor * tensor, const size_t j, const size_t i_buf) {
1850+ auto & bcj = backend_ctx->backend_configs [j];
1851+ ggml_backend_buffer_ptr & buf_ptr = bcj.bufs [i_buf];
1852+ if (!buf_ptr || ggml_backend_buffer_get_size (buf_ptr.get ()) < backend_ctx->max_tmp_size ) {
1853+ buf_ptr.reset (ggml_backend_alloc_buffer (bcj.backend , backend_ctx->max_tmp_size ));
1854+ }
1855+ tensor->buffer = buf_ptr.get ();
1856+ tensor->data = ggml_backend_buffer_get_base (buf_ptr.get ());
1857+ };
1858+ // FIXME usage_counts
1859+ auto get_cgraph_aux = [&]() -> ggml_cgraph * {
1860+ ggml_cgraph * ret = backend_ctx->cgraphs_aux [iga++];
1861+ return ret;
1862+ };
18311863
18321864 // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable:
18331865 auto allreduce_fallback = [&](size_t i) -> ggml_status {
18341866 std::vector<ggml_cgraph *> step_cgraphs (n_backends, nullptr );
18351867
1836- for (size_t offset_j = 1 ; offset_j < n_backends; offset_j *= 2 ) {
1868+ // Zero out nodes that were disabled due to having a zero-sized slice:
1869+ for (size_t j = 0 ; j < n_backends; j++) {
1870+ auto & bcj = backend_ctx->backend_configs [j];
1871+ ggml_tensor * node = bcj.cgraphs [i].cgraph_main ->nodes [bcj.cgraphs [i].cgraph_main ->n_nodes - 1 ];
1872+ if (node->flags & GGML_TENSOR_FLAG_COMPUTE) {
1873+ continue ;
1874+ }
1875+ ggml_tensor * node_zero = get_node_aux (node);
1876+ node_zero->op = GGML_OP_SCALE; // FIXME 0.0f * NaN == NaN
1877+ node_zero->src [0 ] = node;
1878+ ggml_set_op_params_f32 (node_zero, 0 , 0 .0f );
1879+ node_zero->data = node->data ;
1880+ node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE;
1881+
1882+ step_cgraphs[j] = get_cgraph_aux ();
1883+ step_cgraphs[j]->nodes [0 ] = node_zero;
1884+ step_cgraphs[j]->n_nodes = 1 ;
1885+ const ggml_status status = ggml_backend_graph_compute_async (bcj.backend , step_cgraphs[j]);
1886+ if (status != GGML_STATUS_SUCCESS) {
1887+ return status;
1888+ }
1889+ }
1890+ std::fill (step_cgraphs.begin (), step_cgraphs.end (), nullptr );
1891+
1892+ auto push_data = [&](const size_t j_src, const size_t j_dst, const size_t i_buf) {
1893+ assert (step_cgraphs[j_dst] == nullptr );
1894+ auto & bcj_src = backend_ctx->backend_configs [j_src];
1895+ auto & bcj_dst = backend_ctx->backend_configs [j_dst];
1896+
1897+ ggml_tensor * node_src = bcj_src.cgraphs [i].cgraph_main ->nodes [bcj_src.cgraphs [i].cgraph_main ->n_nodes - 1 ];
1898+ ggml_tensor * node_dst = bcj_dst.cgraphs [i].cgraph_main ->nodes [bcj_dst.cgraphs [i].cgraph_main ->n_nodes - 1 ];
1899+ GGML_ASSERT (ggml_is_contiguous (node_src));
1900+ GGML_ASSERT (ggml_is_contiguous (node_dst));
1901+
1902+ ggml_tensor * node_tmp = get_node_aux (node_dst);
1903+ set_tmp_data (node_tmp, j_dst, i_buf);
1904+
1905+ ggml_backend_tensor_copy_async (bcj_src.backend , bcj_dst.backend , node_src, node_tmp);
1906+
1907+ ggml_tensor * node_red = get_node_aux (node_dst);
1908+ node_red->view_src = node_dst->view_src == nullptr ? node_dst : node_dst->view_src ;
1909+ node_red->view_offs = node_dst->view_offs ;
1910+ node_red->op = GGML_OP_ADD;
1911+ node_red->src [0 ] = node_dst;
1912+ node_red->src [1 ] = node_tmp;
1913+ node_red->flags |= GGML_TENSOR_FLAG_COMPUTE;
1914+ ggml_backend_view_init (node_red);
1915+
1916+ ggml_cgraph * cgraph_aux = get_cgraph_aux ();
1917+ cgraph_aux->nodes [0 ] = node_red;
1918+ cgraph_aux->n_nodes = 1 ;
1919+ step_cgraphs[j_dst] = cgraph_aux;
1920+ };
1921+
1922+ size_t offset_j = n_backends/2 ;
1923+ while ((offset_j & (offset_j - 1 )) != 0 ) {
1924+ offset_j--;
1925+ }
1926+ const size_t offset_j_max = offset_j;
1927+ size_t i_buf = 0 ;
1928+
1929+ // If n_backends is not a power of 2, fold in the excess prior to butterfly reduction:
1930+ for (size_t j_src = 2 *offset_j_max; j_src < n_backends; j_src++) {
1931+ const size_t j_dst = j_src - 2 *offset_j_max;
1932+ push_data (j_src, j_dst, i_buf);
1933+ const ggml_status status = ggml_backend_graph_compute_async (backend_ctx->backend_configs [j_dst].backend , step_cgraphs[j_dst]);
1934+ if (status != GGML_STATUS_SUCCESS) {
1935+ return status;
1936+ }
1937+ i_buf = 1 ;
1938+ }
1939+
1940+ // Butterfly reduction:
1941+ for (; offset_j >= 1 ; offset_j /= 2 ) {
18371942 std::fill (step_cgraphs.begin (), step_cgraphs.end (), nullptr );
18381943
1839- for (size_t j = 0 ; j < n_backends ; j++) {
1944+ for (size_t j = 0 ; j < 2 *offset_j_max ; j++) {
18401945 const size_t j_other = j ^ offset_j;
1841- if (j_other > j ) {
1946+ if (j_other >= n_backends ) {
18421947 continue ;
18431948 }
1844-
1845- auto & bcj1 = backend_ctx->backend_configs [j];
1846- auto & bcj2 = backend_ctx->backend_configs [j_other];
1847-
1848- ggml_tensor * node1 = bcj1.cgraphs [i].cgraph_main ->nodes [bcj1.cgraphs [i].cgraph_main ->n_nodes - 1 ];
1849- ggml_tensor * node2 = bcj2.cgraphs [i].cgraph_main ->nodes [bcj2.cgraphs [i].cgraph_main ->n_nodes - 1 ];
1850- GGML_ASSERT (ggml_is_contiguous (node1));
1851- GGML_ASSERT (ggml_is_contiguous (node2));
1852-
1853- // Tmp tensors to receive P2P copies
1854- ggml_tensor * node_tmp_1 = get_node_aux (node1);
1855- node_tmp_1->buffer = bcj1.buf .get ();
1856- node_tmp_1->data = ggml_backend_buffer_get_base (bcj1.buf .get ());
1857-
1858- ggml_tensor * node_tmp_2 = get_node_aux (node2);
1859- node_tmp_2->buffer = bcj2.buf .get ();
1860- node_tmp_2->data = ggml_backend_buffer_get_base (bcj2.buf .get ());
1861-
1862- // 2 P2P copies: exchange full buffers
1863- ggml_backend_tensor_copy_async (bcj1.backend , bcj2.backend , node1, node_tmp_2);
1864- ggml_backend_tensor_copy_async (bcj2.backend , bcj1.backend , node2, node_tmp_1);
1865-
1866- // Local ADD: node1 += tmp1 (in-place via view)
1867- ggml_tensor * node_red_1 = get_node_aux (node1);
1868- node_red_1->view_src = node1->view_src == nullptr ? node1 : node1->view_src ;
1869- node_red_1->view_offs = node1->view_offs ;
1870- node_red_1->op = GGML_OP_ADD;
1871- node_red_1->src [0 ] = node1;
1872- node_red_1->src [1 ] = node_tmp_1;
1873- node_red_1->flags |= GGML_TENSOR_FLAG_COMPUTE;
1874- ggml_backend_view_init (node_red_1);
1875-
1876- // Local ADD: node2 += tmp2 (in-place via view)
1877- ggml_tensor * node_red_2 = get_node_aux (node2);
1878- node_red_2->view_src = node2->view_src == nullptr ? node2 : node2->view_src ;
1879- node_red_2->view_offs = node2->view_offs ;
1880- node_red_2->op = GGML_OP_ADD;
1881- node_red_2->src [0 ] = node2;
1882- node_red_2->src [1 ] = node_tmp_2;
1883- node_red_2->flags |= GGML_TENSOR_FLAG_COMPUTE;
1884- ggml_backend_view_init (node_red_2);
1885-
1886- // Build 1-node cgraphs for the ADD ops
1887- ggml_cgraph * cgraph_aux_1 = get_cgraph_aux ();
1888- cgraph_aux_1->nodes [0 ] = node_red_1;
1889- cgraph_aux_1->n_nodes = 1 ;
1890- step_cgraphs[j] = cgraph_aux_1;
1891-
1892- ggml_cgraph * cgraph_aux_2 = get_cgraph_aux ();
1893- cgraph_aux_2->nodes [0 ] = node_red_2;
1894- cgraph_aux_2->n_nodes = 1 ;
1895- step_cgraphs[j_other] = cgraph_aux_2;
1949+ push_data (j, j_other, i_buf);
18961950 }
18971951
1898- // Execute local ADDs for this step
1899- for (size_t j = 0 ; j < n_backends; j++) {
1952+ for (size_t j = 0 ; j < 2 *offset_j_max; j++) {
19001953 if (step_cgraphs[j] == nullptr ) {
19011954 continue ;
19021955 }
@@ -1906,7 +1959,20 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
19061959 return status;
19071960 }
19081961 }
1962+ i_buf++;
19091963 }
1964+ assert (i_buf == backend_ctx->n_reduce_steps );
1965+
1966+ // If n_backends is not a power of 2, copy back the reduced tensors to the excess:
1967+ for (size_t j = 2 *offset_j_max; j < n_backends; j++) {
1968+ auto & bcj_src = backend_ctx->backend_configs [j - 2 *offset_j_max];
1969+ auto & bcj_dst = backend_ctx->backend_configs [j];
1970+
1971+ ggml_tensor * node_src = bcj_src.cgraphs [i].cgraph_main ->nodes [bcj_src.cgraphs [i].cgraph_main ->n_nodes - 1 ];
1972+ ggml_tensor * node_dst = bcj_dst.cgraphs [i].cgraph_main ->nodes [bcj_dst.cgraphs [i].cgraph_main ->n_nodes - 1 ];
1973+ ggml_backend_tensor_copy_async (bcj_src.backend , bcj_dst.backend , node_src, node_dst);
1974+ }
1975+
19101976 return GGML_STATUS_SUCCESS;
19111977 };
19121978
0 commit comments