@@ -1456,6 +1456,8 @@ struct ggml_backend_meta_context {
14561456 int max_nnodes = 0 ;
14571457 size_t max_tmp_size = 0 ;
14581458 size_t max_subgraphs = 0 ;
1459+ size_t n_subgraphs = 0 ;
1460+ uint64_t uid = 0 ;
14591461
14601462 void * comm_ctx = nullptr ;
14611463 ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr ;
@@ -1616,6 +1618,9 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
16161618 const size_t n_backends = ggml_backend_meta_n_backends (backend);
16171619 ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context ;
16181620
1621+ // If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend.
1622+ const bool needs_rebuild = (cgraph->uid == 0 ) || (cgraph->uid != backend_ctx->uid );
1623+
16191624 bool max_nnodes_raised = false ;
16201625 if (cgraph->n_nodes > backend_ctx->max_nnodes ) {
16211626 for (size_t j = 0 ; j < n_backends; j++) {
@@ -1625,173 +1630,181 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
16251630 }
16261631 backend_ctx->max_nnodes = cgraph->n_nodes ;
16271632 max_nnodes_raised = true ;
1633+ assert (needs_rebuild);
16281634 }
1629- for (size_t j = 0 ; j < n_backends; j++) {
1630- auto & bcj = backend_ctx->backend_configs [j];
1631-
1632- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
1633- ggml_tensor * node = cgraph->nodes [i];
1634- if (node->view_src != nullptr && node->view_src ->op == GGML_OP_NONE && ggml_backend_buffer_is_host (node->view_src ->buffer )) {
1635- // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes.
1636- // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash.
1637- bcj.nodes [i] = node;
1638- continue ;
1635+
1636+ if (needs_rebuild) {
1637+ size_t n_subgraphs = 0 ;
1638+ size_t max_tmp_size = 0 ;
1639+
1640+ for (size_t j = 0 ; j < n_backends; j++) {
1641+ auto & bcj = backend_ctx->backend_configs [j];
1642+
1643+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
1644+ ggml_tensor * node = cgraph->nodes [i];
1645+ if (node->view_src != nullptr && node->view_src ->op == GGML_OP_NONE && ggml_backend_buffer_is_host (node->view_src ->buffer )) {
1646+ // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes.
1647+ // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash.
1648+ bcj.nodes [i] = node;
1649+ continue ;
1650+ }
1651+ bcj.nodes [i] = ggml_backend_meta_buffer_simple_tensor (node, j);
1652+ GGML_ASSERT (bcj.nodes [i]);
16391653 }
1640- bcj.nodes [i] = ggml_backend_meta_buffer_simple_tensor (node, j);
1641- GGML_ASSERT (bcj.nodes [i]);
16421654 }
1643- }
16441655
1645- size_t n_subgraphs = 0 ;
1646- size_t max_tmp_size = 0 ;
1647- {
1648- // For MoE models it may make sense to delay the AllReduce in order to reduce I/O:
1649- auto get_i_delayed = [&](const int i) -> int {
1650- int id = i; // i_delayed
1651- int idr = i; // i_delayed return, last safe return value
1652-
1653- ggml_tensor * node = cgraph->nodes [id];
1654- int32_t n_used = ggml_node_get_use_count (cgraph, id);
1655- if (id + 1 >= cgraph->n_nodes ) {
1656- return idr;
1657- }
1658- {
1659- ggml_tensor * next = cgraph->nodes [id+1 ];
1660- if (next->op == GGML_OP_ADD_ID && next->src [0 ] == node &&
1661- ggml_backend_meta_get_split_state (next->src [1 ], false ).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL &&
1662- ggml_backend_meta_get_split_state (next->src [2 ], false ).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED ) {
1663- node = next;
1656+ {
1657+ // For MoE models it may make sense to delay the AllReduce in order to reduce I/O:
1658+ auto get_i_delayed = [&](const int i) -> int {
1659+ int id = i; // i_delayed
1660+ int idr = i; // i_delayed return, last safe return value
1661+
1662+ ggml_tensor * node = cgraph->nodes [id];
1663+ int32_t n_used = ggml_node_get_use_count (cgraph, id);
1664+ if (id + 1 >= cgraph->n_nodes ) {
1665+ return idr;
1666+ }
1667+ {
1668+ ggml_tensor * next = cgraph->nodes [id+1 ];
1669+ if (next->op == GGML_OP_ADD_ID && next->src [0 ] == node &&
1670+ ggml_backend_meta_get_split_state (next->src [1 ], false ).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL &&
1671+ ggml_backend_meta_get_split_state (next->src [2 ], false ).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED ) {
1672+ node = next;
1673+ id++;
1674+ idr = id;
1675+ n_used = ggml_node_get_use_count (cgraph, id);
1676+ }
1677+ }
1678+ if (id + 1 >= cgraph->n_nodes ) {
1679+ return idr;
1680+ }
1681+ {
1682+ ggml_tensor * next = cgraph->nodes [id+1 ];
1683+ if (next->op == GGML_OP_MUL && next->src [0 ] == node &&
1684+ ggml_backend_meta_get_split_state (next->src [1 ], false ).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED ) {
1685+ node = next;
1686+ id++;
1687+ idr = id;
1688+ n_used = ggml_node_get_use_count (cgraph, id);
1689+ }
1690+ }
1691+
1692+ if (n_used != node->ne [1 ] || id + 2 *n_used-1 >= cgraph->n_nodes ) {
1693+ return idr;
1694+ }
1695+ for (int32_t k = 0 ; k < n_used; k++) {
1696+ ggml_tensor * next = cgraph->nodes [id+1 ];
1697+ if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb [1 ] ||
1698+ next->ne [0 ] != node->ne [0 ] || next->ne [1 ] != node->ne [2 ] || next->nb [1 ] != node->nb [2 ] ||
1699+ ggml_node_get_use_count (cgraph, id+1 ) != 1 ) {
1700+ return idr;
1701+ }
16641702 id++;
1665- idr = id;
1666- n_used = ggml_node_get_use_count (cgraph, id);
16671703 }
1668- }
1669- if (id + 1 >= cgraph->n_nodes ) {
1670- return idr;
1671- }
1672- {
1673- ggml_tensor * next = cgraph->nodes [id+1 ];
1674- if (next->op == GGML_OP_MUL && next->src [0 ] == node &&
1675- ggml_backend_meta_get_split_state (next->src [1 ], false ).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED ) {
1676- node = next;
1704+ {
1705+ ggml_tensor * next = cgraph->nodes [id+1 ];
1706+ if (next->op != GGML_OP_ADD || next->src [0 ] != cgraph->nodes [id - (n_used-1 )] ||
1707+ next->src [1 ] != cgraph->nodes [id - (n_used-2 )] || ggml_node_get_use_count (cgraph, id+1 ) != 1 ) {
1708+ return idr;
1709+ }
16771710 id++;
1678- idr = id;
1679- n_used = ggml_node_get_use_count (cgraph, id);
16801711 }
1681- }
1682-
1683- if (n_used != node->ne [1 ] || id + 2 *n_used-1 >= cgraph->n_nodes ) {
1712+ for (int32_t k = 0 ; k < n_used - 2 ; k++) {
1713+ ggml_tensor * next = cgraph->nodes [id+1 ];
1714+ if (next->op != GGML_OP_ADD || next->src [0 ] != cgraph->nodes [id] ||
1715+ next->src [1 ] != cgraph->nodes [id - (n_used-2 )] || ggml_node_get_use_count (cgraph, id+1 ) != 1 ) {
1716+ return idr;
1717+ }
1718+ id++;
1719+ }
1720+ idr = id;
16841721 return idr;
1685- }
1686- for ( int32_t k = 0 ; k < n_used; k++) {
1687- ggml_tensor * next = cgraph-> nodes [id+ 1 ] ;
1688- if (next-> op != GGML_OP_VIEW || next-> view_src != node || next-> view_offs != k*node-> nb [ 1 ] ||
1689- next-> ne [ 0 ] != node-> ne [ 0 ] || next-> ne [ 1 ] != node-> ne [ 2 ] || next-> nb [ 1 ] != node-> nb [ 2 ] ||
1690- ggml_node_get_use_count (cgraph, id+ 1 ) != 1 ) {
1691- return idr ;
1722+ };
1723+
1724+ int i_start = 0 ;
1725+ for ( int i = 0 ; i < cgraph-> n_nodes ; i++) {
1726+ ggml_tensor * node = cgraph-> nodes [i];
1727+ if (node-> view_src != nullptr && node-> view_src -> op == GGML_OP_NONE && ggml_backend_buffer_is_host (node-> view_src -> buffer ) ) {
1728+ continue ;
16921729 }
1693- id++;
1694- }
1695- {
1696- ggml_tensor * next = cgraph->nodes [id+1 ];
1697- if (next->op != GGML_OP_ADD || next->src [0 ] != cgraph->nodes [id - (n_used-1 )] ||
1698- next->src [1 ] != cgraph->nodes [id - (n_used-2 )] || ggml_node_get_use_count (cgraph, id+1 ) != 1 ) {
1699- return idr;
1730+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state (node, /* assume_sync =*/ false );
1731+ if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL ) {
1732+ max_tmp_size = std::max (max_tmp_size, ggml_nbytes (node));
17001733 }
1701- id++;
1702- }
1703- for (int32_t k = 0 ; k < n_used - 2 ; k++) {
1704- ggml_tensor * next = cgraph->nodes [id+1 ];
1705- if (next->op != GGML_OP_ADD || next->src [0 ] != cgraph->nodes [id] ||
1706- next->src [1 ] != cgraph->nodes [id - (n_used-2 )] || ggml_node_get_use_count (cgraph, id+1 ) != 1 ) {
1707- return idr;
1734+ const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL ;
1735+ if (!new_subgraph) {
1736+ continue ;
17081737 }
1709- id++;
1710- }
1711- idr = id;
1712- return idr;
1713- };
1714-
1715- int i_start = 0 ;
1716- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
1717- ggml_tensor * node = cgraph->nodes [i];
1718- if (node->view_src != nullptr && node->view_src ->op == GGML_OP_NONE && ggml_backend_buffer_is_host (node->view_src ->buffer )) {
1719- continue ;
1720- }
1721- const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state (node, /* assume_sync =*/ false );
1722- if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL ) {
1723- max_tmp_size = std::max (max_tmp_size, ggml_nbytes (node));
1724- }
1725- const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL ;
1726- if (!new_subgraph) {
1727- continue ;
1738+
1739+ i = get_i_delayed (i);
1740+
1741+ for (size_t j = 0 ; j < n_backends; j++) {
1742+ auto & bcj = backend_ctx->backend_configs [j];
1743+ bcj.cgraphs [n_subgraphs].offset = i_start;
1744+ }
1745+ n_subgraphs++;
1746+ i_start = i + 1 ;
17281747 }
1748+ GGML_ASSERT (i_start == cgraph->n_nodes );
1749+ }
17291750
1730- i = get_i_delayed (i);
1751+ backend_ctx->uid = cgraph->uid ;
1752+ backend_ctx->n_subgraphs = n_subgraphs;
17311753
1754+ if (max_tmp_size > backend_ctx->max_tmp_size ) {
17321755 for (size_t j = 0 ; j < n_backends; j++) {
17331756 auto & bcj = backend_ctx->backend_configs [j];
1734- bcj.cgraphs [n_subgraphs].offset = i_start;
1757+ bcj.buf .reset (ggml_backend_alloc_buffer (bcj.backend , max_tmp_size));
1758+ }
1759+ backend_ctx->max_tmp_size = max_tmp_size;
1760+ }
1761+
1762+ if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs ) {
1763+ backend_ctx->max_subgraphs = std::max (backend_ctx->max_subgraphs , n_subgraphs);
1764+ const size_t n_reduce_steps = backend_ctx->n_reduce_steps ();
1765+ const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step
1766+ const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step
1767+ const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs *ggml_graph_overhead_custom (backend_ctx->max_nnodes , cgraph->grads );
1768+ const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs *ggml_graph_overhead_custom (1 , cgraph->grads );
1769+ const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs *ggml_tensor_overhead ();
1770+ ggml_init_params params = {
1771+ /* .mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux),
1772+ /* .mem_buffer =*/ nullptr ,
1773+ /* .no_alloc =*/ true ,
1774+ };
1775+ backend_ctx->ctx .reset (ggml_init (params));
1776+ for (size_t j = 0 ; j < n_backends; j++) {
1777+ auto & bcj = backend_ctx->backend_configs [j];
1778+ for (size_t i = 0 ; i < n_subgraphs; i++) {
1779+ bcj.cgraphs [i].cgraph_main = ggml_new_graph_custom (backend_ctx->ctx .get (), cgraph->n_nodes , /* grads =*/ false );
1780+ }
1781+ }
1782+ backend_ctx->cgraphs_aux .resize (n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs );
1783+ for (size_t k = 0 ; k < backend_ctx->cgraphs_aux .size (); k++) {
1784+ backend_ctx->cgraphs_aux [k] = ggml_new_graph_custom (backend_ctx->ctx .get (), 1 , cgraph->grads );
1785+ }
1786+ backend_ctx->nodes_aux .resize (n_backends*n_nodes_per_device*backend_ctx->max_subgraphs );
1787+ for (size_t k = 0 ; k < backend_ctx->nodes_aux .size (); k++) {
1788+ backend_ctx->nodes_aux [k] = ggml_new_tensor_1d (backend_ctx->ctx .get (), GGML_TYPE_F32 , 1 );
17351789 }
1736- n_subgraphs++;
1737- i_start = i + 1 ;
17381790 }
1739- GGML_ASSERT (i_start == cgraph->n_nodes );
1740- }
17411791
1742- if (max_tmp_size > backend_ctx->max_tmp_size ) {
17431792 for (size_t j = 0 ; j < n_backends; j++) {
17441793 auto & bcj = backend_ctx->backend_configs [j];
1745- bcj.buf .reset (ggml_backend_alloc_buffer (bcj.backend , max_tmp_size));
1746- }
1747- backend_ctx->max_tmp_size = max_tmp_size;
1748- }
1749-
1750-
1751- if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs ) {
1752- backend_ctx->max_subgraphs = std::max (backend_ctx->max_subgraphs , n_subgraphs);
1753- const size_t n_reduce_steps = backend_ctx->n_reduce_steps ();
1754- const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step
1755- const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step
1756- const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs *ggml_graph_overhead_custom (backend_ctx->max_nnodes , cgraph->grads );
1757- const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs *ggml_graph_overhead_custom (1 , cgraph->grads );
1758- const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs *ggml_tensor_overhead ();
1759- ggml_init_params params = {
1760- /* .mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux),
1761- /* .mem_buffer =*/ nullptr ,
1762- /* .no_alloc =*/ true ,
1763- };
1764- backend_ctx->ctx .reset (ggml_init (params));
1765- for (size_t j = 0 ; j < n_backends; j++) {
1766- auto & bcj = backend_ctx->backend_configs [j];
1767- for (size_t i = 0 ; i < n_subgraphs; i++) {
1768- bcj.cgraphs [i].cgraph_main = ggml_new_graph_custom (backend_ctx->ctx .get (), cgraph->n_nodes , /* grads =*/ false );
1769- }
1770- }
1771- backend_ctx->cgraphs_aux .resize (n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs );
1772- for (size_t k = 0 ; k < backend_ctx->cgraphs_aux .size (); k++) {
1773- backend_ctx->cgraphs_aux [k] = ggml_new_graph_custom (backend_ctx->ctx .get (), 1 , cgraph->grads );
1774- }
1775- backend_ctx->nodes_aux .resize (n_backends*n_nodes_per_device*backend_ctx->max_subgraphs );
1776- for (size_t k = 0 ; k < backend_ctx->nodes_aux .size (); k++) {
1777- backend_ctx->nodes_aux [k] = ggml_new_tensor_1d (backend_ctx->ctx .get (), GGML_TYPE_F32 , 1 );
1778- }
1779- }
1780-
1781- for (size_t j = 0 ; j < n_backends; j++) {
1782- auto & bcj = backend_ctx->backend_configs [j];
1783- for (size_t i_graph = 0 ; i_graph < n_subgraphs; i_graph++) {
1784- ggml_cgraph * cgraph_ij = bcj.cgraphs [i_graph].cgraph_main ;
1785- const size_t i_node_start = bcj.cgraphs [i_graph].offset ;
1786- const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs [i_graph + 1 ].offset : cgraph->n_nodes ;
1787- cgraph_ij->n_nodes = i_node_stop - i_node_start;
1788- ggml_hash_set_reset (&cgraph_ij->visited_hash_set );
1789- for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) {
1790- ggml_tensor * node_ij = bcj.nodes [i_node];
1791- cgraph_ij->nodes [i_node - i_node_start] = node_ij;
1792- const size_t hash_pos_orig = ggml_hash_find (&cgraph->visited_hash_set , cgraph->nodes [i_node]);
1793- const size_t hash_pos_ij = ggml_hash_insert (&cgraph_ij->visited_hash_set , node_ij);
1794- cgraph_ij->use_counts [hash_pos_ij] = cgraph->use_counts [hash_pos_orig];
1794+ for (size_t i_graph = 0 ; i_graph < n_subgraphs; i_graph++) {
1795+ ggml_cgraph * cgraph_ij = bcj.cgraphs [i_graph].cgraph_main ;
1796+ const size_t i_node_start = bcj.cgraphs [i_graph].offset ;
1797+ const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs [i_graph + 1 ].offset : cgraph->n_nodes ;
1798+ cgraph_ij->n_nodes = i_node_stop - i_node_start;
1799+ ggml_hash_set_reset (&cgraph_ij->visited_hash_set );
1800+ for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) {
1801+ ggml_tensor * node_ij = bcj.nodes [i_node];
1802+ cgraph_ij->nodes [i_node - i_node_start] = node_ij;
1803+ const size_t hash_pos_orig = ggml_hash_find (&cgraph->visited_hash_set , cgraph->nodes [i_node]);
1804+ const size_t hash_pos_ij = ggml_hash_insert (&cgraph_ij->visited_hash_set , node_ij);
1805+ cgraph_ij->use_counts [hash_pos_ij] = cgraph->use_counts [hash_pos_orig];
1806+ }
1807+ cgraph_ij->uid = ggml_graph_next_uid ();
17951808 }
17961809 }
17971810 }
@@ -1898,7 +1911,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
18981911 };
18991912
19001913
1901- for (size_t i = 0 ; i < n_subgraphs; i++) {
1914+ for (size_t i = 0 ; i < backend_ctx-> n_subgraphs ; i++) {
19021915 for (size_t j = 0 ; j < n_backends; j++) {
19031916 auto & bcj = backend_ctx->backend_configs [j];
19041917 const ggml_status status = ggml_backend_graph_compute_async (bcj.backend , bcj.cgraphs [i].cgraph_main );
@@ -1907,7 +1920,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
19071920 }
19081921 }
19091922
1910- if (n_backends > 1 && i < n_subgraphs - 1 ) {
1923+ if (n_backends > 1 && i < backend_ctx-> n_subgraphs - 1 ) {
19111924 bool backend_allreduce_success = false ;
19121925 if (backend_ctx->comm_ctx ) {
19131926 std::vector<ggml_tensor *> nodes;
0 commit comments