Skip to content

Commit 739393b

Browse files
TP: fix delayed AllReduce + zero-sized slices (#22489)
1 parent fc2b005 commit 739393b

1 file changed

Lines changed: 18 additions & 1 deletion

File tree

ggml/src/ggml-backend-meta.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1826,7 +1826,24 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
18261826
continue;
18271827
}
18281828

1829-
i = get_i_delayed(i);
1829+
const int i_delayed = get_i_delayed(i);
1830+
1831+
// If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices.
1832+
// A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has
1833+
// its compute flag disabled and thus gets its data zeroed out.
1834+
// If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled.
1835+
if (i_delayed > i) {
1836+
for (size_t j = 0; j < n_backends; j++) {
1837+
auto & bcj = backend_ctx->backend_configs[j];
1838+
if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
1839+
for (int ii = i + 1; ii <= i_delayed; ii++) {
1840+
bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
1841+
}
1842+
}
1843+
}
1844+
}
1845+
1846+
i = i_delayed;
18301847

18311848
for (size_t j = 0; j < n_backends; j++) {
18321849
auto & bcj = backend_ctx->backend_configs[j];

0 commit comments

Comments
 (0)