Skip to content

Commit 7c07307

Browse files
Sébastien LoiselSébastien Loisel
authored andcommitted
Add sum for MatrixMPI with MPI.Allreduce
This was the root cause of MPI desync - the barrier function was using sum() on MatrixMPI but there was no MPI-aware implementation, causing each rank to get different values.
1 parent 131be3e commit 7c07307

1 file changed

Lines changed: 30 additions & 0 deletions

File tree

src/dense.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,36 @@ Base.size(A::MatrixMPI, d::Integer) = size(A)[d]
13601360
Base.eltype(::MatrixMPI{T}) where T = T
13611361
Base.eltype(::Type{MatrixMPI{T}}) where T = T
13621362

1363+
"""
1364+
Base.sum(A::MatrixMPI{T}; dims=nothing) where T
1365+
1366+
Compute the sum of all elements in the distributed matrix.
1367+
If dims is specified, sum along that dimension.
1368+
"""
1369+
function Base.sum(A::MatrixMPI{T}; dims=nothing) where T
1370+
comm = MPI.COMM_WORLD
1371+
1372+
if dims === nothing
1373+
# Sum all elements
1374+
local_sum = sum(A.A; init=zero(T))
1375+
return MPI.Allreduce(local_sum, MPI.SUM, comm)
1376+
elseif dims == 1
1377+
# Sum along rows (each rank has some rows, result is a row vector)
1378+
# Each rank computes column sums of its local rows
1379+
local_colsums = sum(A.A, dims=1)
1380+
# Reduce across all ranks
1381+
global_colsums = MPI.Allreduce(local_colsums, MPI.SUM, comm)
1382+
return global_colsums
1383+
elseif dims == 2
1384+
# Sum along columns (result is distributed column vector)
1385+
local_rowsums = sum(A.A, dims=2)
1386+
# Result inherits row partition from A
1387+
return MatrixMPI{T,typeof(A.A)}(A.structural_hash, A.row_partition, A.col_partition, local_rowsums)
1388+
else
1389+
error("dims must be nothing, 1, or 2")
1390+
end
1391+
end
1392+
13631393
"""
13641394
LinearAlgebra.norm(A::MatrixMPI{T}, p::Real=2) where T
13651395

0 commit comments

Comments
 (0)