Skip to content

Commit ad6e8a1

Browse files
Sébastien LoiselSébastien Loisel
authored andcommitted
Add repartition() and simplify alignment operations
Implement repartition(x, p) for VectorMPI, MatrixMPI, and SparseMatrixMPI to redistribute data to a new row partition. Plans are cached by (source_structural_hash, target_partition_hash, element_type) with eager hash computation during plan creation. Refactor existing operations to use repartition: - VectorMPI +/- now uses repartition instead of get_vector_align_plan - dot(x, y) uses repartition for partition alignment - Broadcast uses repartition in _prepare_broadcast_arg - SparseMatrixMPI +/- uses repartition instead of get_addition_plan Remove deprecated code: - _vector_align_plan_cache, get_vector_align_plan, _align_vector - _addition_plan_cache, get_addition_plan - Addition-related fields from MatrixPlan struct This reduces code duplication and provides a unified API for data redistribution across all distributed types.
1 parent 63908f9 commit ad6e8a1

6 files changed

Lines changed: 1275 additions & 245 deletions

File tree

src/LinearAlgebraMPI.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import SparseArrays: nnz, issparse, dropzeros, spdiagm, blockdiag
88
import LinearAlgebra
99
import LinearAlgebra: tr, diag, triu, tril, Transpose, Adjoint, norm, opnorm, mul!, ldlt, BLAS, issymmetric, UniformScaling, dot
1010

11-
export SparseMatrixMPI, MatrixMPI, VectorMPI, clear_plan_cache!, uniform_partition
11+
export SparseMatrixMPI, MatrixMPI, VectorMPI, clear_plan_cache!, uniform_partition, repartition
1212
export SparseMatrixCSR # Type alias for Transpose{SparseMatrixCSC} (CSR storage format)
1313
export # Multithreaded sparse matrix multiplication
1414
export VectorMPI_local, MatrixMPI_local, SparseMatrixMPI_local # Local constructors
@@ -111,15 +111,15 @@ const _plan_cache = Dict{Tuple{Blake3Hash,Blake3Hash,DataType},Any}()
111111
# Cache for memoized VectorPlans (for A * x)
112112
const _vector_plan_cache = Dict{Tuple{Blake3Hash,Blake3Hash,DataType},Any}()
113113

114-
# Cache for memoized Vector Alignment Plans (for u +/- v with different partitions)
115-
const _vector_align_plan_cache = Dict{Tuple{Blake3Hash,Blake3Hash,DataType},Any}()
116-
117114
# Cache for memoized DenseMatrixVectorPlans (for MatrixMPI * VectorMPI)
118115
const _dense_vector_plan_cache = Dict{Tuple{Blake3Hash,Blake3Hash,DataType},Any}()
119116

120117
# Cache for memoized DenseTransposePlans (for transpose(MatrixMPI))
121118
const _dense_transpose_plan_cache = Dict{Tuple{Blake3Hash,DataType},Any}()
122119

120+
# Cache for memoized RepartitionPlans (for repartition)
121+
const _repartition_plan_cache = Dict{Tuple{Blake3Hash,Blake3Hash,DataType},Any}()
122+
123123
"""
124124
clear_plan_cache!()
125125
@@ -128,15 +128,12 @@ Clear all memoized plan caches.
128128
function clear_plan_cache!()
129129
empty!(_plan_cache)
130130
empty!(_vector_plan_cache)
131-
empty!(_vector_align_plan_cache)
132131
empty!(_dense_vector_plan_cache)
133132
empty!(_dense_transpose_plan_cache)
133+
empty!(_repartition_plan_cache)
134134
if isdefined(@__MODULE__, :_dense_transpose_vector_plan_cache)
135135
empty!(_dense_transpose_vector_plan_cache)
136136
end
137-
if isdefined(@__MODULE__, :_addition_plan_cache)
138-
empty!(_addition_plan_cache)
139-
end
140137
end
141138

142139
"""

src/dense.jl

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,3 +1372,257 @@ function Base.mapslices(f, A::MatrixMPI{T}; dims) where T
13721372
error("dims must be 1 or 2")
13731373
end
13741374
end
1375+
1376+
# ============================================================================
1377+
# DenseRepartitionPlan: Repartition a MatrixMPI to a new row partition
1378+
# ============================================================================
1379+
1380+
"""
1381+
DenseRepartitionPlan{T}
1382+
1383+
Communication plan for repartitioning a MatrixMPI to a new row partition.
1384+
1385+
# Fields
1386+
- `send_rank_ids::Vector{Int}`: Ranks we send rows to (0-indexed)
1387+
- `send_row_ranges::Vector{UnitRange{Int}}`: For each rank, range of local rows to send
1388+
- `send_bufs::Vector{Matrix{T}}`: Pre-allocated send buffers
1389+
- `send_reqs::Vector{MPI.Request}`: Pre-allocated send request handles
1390+
- `recv_rank_ids::Vector{Int}`: Ranks we receive rows from (0-indexed)
1391+
- `recv_row_counts::Vector{Int}`: Number of rows to receive from each rank
1392+
- `recv_bufs::Vector{Matrix{T}}`: Pre-allocated receive buffers
1393+
- `recv_reqs::Vector{MPI.Request}`: Pre-allocated receive request handles
1394+
- `recv_dst_ranges::Vector{UnitRange{Int}}`: Destination row ranges in result for each recv
1395+
- `local_src_range::UnitRange{Int}`: Source row range for local copy
1396+
- `local_dst_range::UnitRange{Int}`: Destination row range for local copy
1397+
- `result_row_partition::Vector{Int}`: Target row partition (copy of p)
1398+
- `result_col_partition::Vector{Int}`: Column partition (unchanged from source)
1399+
- `result_structural_hash::Blake3Hash`: Hash of result matrix
1400+
- `result_local_nrows::Int`: Number of rows this rank owns after repartition
1401+
- `ncols::Int`: Number of columns in the matrix
1402+
"""
1403+
mutable struct DenseRepartitionPlan{T}
1404+
send_rank_ids::Vector{Int}
1405+
send_row_ranges::Vector{UnitRange{Int}}
1406+
send_bufs::Vector{Matrix{T}}
1407+
send_reqs::Vector{MPI.Request}
1408+
recv_rank_ids::Vector{Int}
1409+
recv_row_counts::Vector{Int}
1410+
recv_bufs::Vector{Matrix{T}}
1411+
recv_reqs::Vector{MPI.Request}
1412+
recv_dst_ranges::Vector{UnitRange{Int}}
1413+
local_src_range::UnitRange{Int}
1414+
local_dst_range::UnitRange{Int}
1415+
result_row_partition::Vector{Int}
1416+
result_col_partition::Vector{Int}
1417+
result_structural_hash::Blake3Hash
1418+
result_local_nrows::Int
1419+
ncols::Int
1420+
end
1421+
1422+
"""
1423+
DenseRepartitionPlan(A::MatrixMPI{T}, p::Vector{Int}) where T
1424+
1425+
Create a communication plan to repartition `A` to have row partition `p`.
1426+
The col_partition remains unchanged.
1427+
1428+
The plan computes:
1429+
1. Which rows to send to each rank based on partition overlap
1430+
2. Which rows to receive from each rank
1431+
3. Pre-allocates all buffers for allocation-free execution
1432+
4. Computes the result structural hash eagerly
1433+
"""
1434+
function DenseRepartitionPlan(A::MatrixMPI{T}, p::Vector{Int}) where T
1435+
comm = MPI.COMM_WORLD
1436+
rank = MPI.Comm_rank(comm)
1437+
nranks = MPI.Comm_size(comm)
1438+
1439+
# Source partition info
1440+
src_start = A.row_partition[rank+1]
1441+
src_end = A.row_partition[rank+2] - 1
1442+
local_nrows = max(0, src_end - src_start + 1)
1443+
1444+
# Target partition info
1445+
dst_start = p[rank+1]
1446+
dst_end = p[rank+2] - 1
1447+
result_local_nrows = max(0, dst_end - dst_start + 1)
1448+
1449+
ncols = A.col_partition[end] - 1
1450+
1451+
# Step 1: Determine which rows we send to each rank
1452+
send_row_ranges_map = Dict{Int, UnitRange{Int}}()
1453+
for r in 0:(nranks-1)
1454+
r_start = p[r+1]
1455+
r_end = p[r+2] - 1
1456+
if r_end < r_start
1457+
continue # rank r has no rows in target partition
1458+
end
1459+
# Intersection of our rows with rank r's target
1460+
overlap_start = max(src_start, r_start)
1461+
overlap_end = min(src_end, r_end)
1462+
if overlap_start <= overlap_end
1463+
# Convert to local row indices in A.A
1464+
local_start = overlap_start - src_start + 1
1465+
local_end = overlap_end - src_start + 1
1466+
send_row_ranges_map[r] = local_start:local_end
1467+
end
1468+
end
1469+
1470+
# Step 2: Exchange counts via Alltoall
1471+
send_counts = Int32[haskey(send_row_ranges_map, r) ? length(send_row_ranges_map[r]) : 0 for r in 0:(nranks-1)]
1472+
recv_counts_raw = MPI.Alltoall(MPI.UBuffer(send_counts, 1), comm)
1473+
1474+
# Step 3: Build send/recv structures
1475+
send_rank_ids = Int[]
1476+
send_row_ranges = UnitRange{Int}[]
1477+
recv_rank_ids = Int[]
1478+
recv_row_counts = Int[]
1479+
recv_dst_ranges = UnitRange{Int}[]
1480+
1481+
local_src_range = 1:0 # empty range
1482+
local_dst_range = 1:0 # empty range
1483+
1484+
# Handle local copy separately
1485+
if haskey(send_row_ranges_map, rank)
1486+
local_src_range = send_row_ranges_map[rank]
1487+
# Compute destination range: where do these rows go in the result?
1488+
global_start = src_start + local_src_range.start - 1
1489+
local_dst_start = global_start - dst_start + 1
1490+
local_dst_end = local_dst_start + length(local_src_range) - 1
1491+
local_dst_range = local_dst_start:local_dst_end
1492+
end
1493+
1494+
# Build send arrays (excluding local)
1495+
for r in 0:(nranks-1)
1496+
if haskey(send_row_ranges_map, r) && r != rank
1497+
push!(send_rank_ids, r)
1498+
push!(send_row_ranges, send_row_ranges_map[r])
1499+
end
1500+
end
1501+
1502+
# Build recv arrays (excluding local)
1503+
for r in 0:(nranks-1)
1504+
if recv_counts_raw[r+1] > 0 && r != rank
1505+
push!(recv_rank_ids, r)
1506+
push!(recv_row_counts, recv_counts_raw[r+1])
1507+
1508+
# Rows from rank r: their source range is [A.row_partition[r+1], A.row_partition[r+2]-1]
1509+
# intersected with our target range [dst_start, dst_end]
1510+
r_src_start = A.row_partition[r+1]
1511+
r_src_end = A.row_partition[r+2] - 1
1512+
overlap_start = max(r_src_start, dst_start)
1513+
overlap_end = min(r_src_end, dst_end)
1514+
# Destination range in our result
1515+
dst_range_start = overlap_start - dst_start + 1
1516+
dst_range_end = overlap_end - dst_start + 1
1517+
push!(recv_dst_ranges, dst_range_start:dst_range_end)
1518+
end
1519+
end
1520+
1521+
# Pre-allocate buffers
1522+
send_bufs = [Matrix{T}(undef, length(r), ncols) for r in send_row_ranges]
1523+
recv_bufs = [Matrix{T}(undef, c, ncols) for c in recv_row_counts]
1524+
send_reqs = Vector{MPI.Request}(undef, length(send_rank_ids))
1525+
recv_reqs = Vector{MPI.Request}(undef, length(recv_rank_ids))
1526+
1527+
# Compute result structural hash eagerly
1528+
result_local_size = (result_local_nrows, ncols)
1529+
result_structural_hash = compute_dense_structural_hash(p, A.col_partition, result_local_size, comm)
1530+
1531+
return DenseRepartitionPlan{T}(
1532+
send_rank_ids, send_row_ranges, send_bufs, send_reqs,
1533+
recv_rank_ids, recv_row_counts, recv_bufs, recv_reqs, recv_dst_ranges,
1534+
local_src_range, local_dst_range,
1535+
copy(p), copy(A.col_partition), result_structural_hash,
1536+
result_local_nrows, ncols
1537+
)
1538+
end
1539+
1540+
"""
1541+
execute_plan!(plan::DenseRepartitionPlan{T}, A::MatrixMPI{T}) where T
1542+
1543+
Execute a dense repartition plan to redistribute rows from A to a new partition.
1544+
Returns a new MatrixMPI with the target row partition.
1545+
"""
1546+
function execute_plan!(plan::DenseRepartitionPlan{T}, A::MatrixMPI{T}) where T
1547+
comm = MPI.COMM_WORLD
1548+
1549+
# Allocate result
1550+
result_A = Matrix{T}(undef, plan.result_local_nrows, plan.ncols)
1551+
1552+
# Step 1: Local copy
1553+
if !isempty(plan.local_src_range) && !isempty(plan.local_dst_range)
1554+
result_A[plan.local_dst_range, :] = A.A[plan.local_src_range, :]
1555+
end
1556+
1557+
# Step 2: Fill send buffers and send
1558+
@inbounds for i in eachindex(plan.send_rank_ids)
1559+
r = plan.send_rank_ids[i]
1560+
row_range = plan.send_row_ranges[i]
1561+
buf = plan.send_bufs[i]
1562+
buf .= @view A.A[row_range, :]
1563+
plan.send_reqs[i] = MPI.Isend(vec(buf), comm; dest=r, tag=93)
1564+
end
1565+
1566+
# Step 3: Post receives
1567+
@inbounds for i in eachindex(plan.recv_rank_ids)
1568+
plan.recv_reqs[i] = MPI.Irecv!(vec(plan.recv_bufs[i]), comm; source=plan.recv_rank_ids[i], tag=93)
1569+
end
1570+
1571+
MPI.Waitall(plan.recv_reqs)
1572+
1573+
# Step 4: Copy received rows into result
1574+
@inbounds for i in eachindex(plan.recv_rank_ids)
1575+
dst_range = plan.recv_dst_ranges[i]
1576+
buf = plan.recv_bufs[i]
1577+
result_A[dst_range, :] = buf
1578+
end
1579+
1580+
MPI.Waitall(plan.send_reqs)
1581+
1582+
return MatrixMPI{T}(plan.result_structural_hash, plan.result_row_partition, plan.result_col_partition, result_A)
1583+
end
1584+
1585+
"""
1586+
get_repartition_plan(A::MatrixMPI{T}, p::Vector{Int}) where T
1587+
1588+
Get a memoized DenseRepartitionPlan for repartitioning `A` to row partition `p`.
1589+
The plan is cached based on the structural hash of A and the target partition hash.
1590+
"""
1591+
function get_repartition_plan(A::MatrixMPI{T}, p::Vector{Int}) where T
1592+
target_hash = compute_partition_hash(p)
1593+
key = (_ensure_hash(A), target_hash, T)
1594+
if haskey(_repartition_plan_cache, key)
1595+
return _repartition_plan_cache[key]::DenseRepartitionPlan{T}
1596+
end
1597+
plan = DenseRepartitionPlan(A, p)
1598+
_repartition_plan_cache[key] = plan
1599+
return plan
1600+
end
1601+
1602+
"""
1603+
repartition(A::MatrixMPI{T}, p::Vector{Int}) where T
1604+
1605+
Redistribute a MatrixMPI to a new row partition `p`.
1606+
The col_partition remains unchanged.
1607+
1608+
The partition `p` must be a valid partition vector of length `nranks + 1` with
1609+
`p[1] == 1` and `p[end] == size(A, 1) + 1`.
1610+
1611+
Returns a new MatrixMPI with the same data but `row_partition == p`.
1612+
1613+
# Example
1614+
```julia
1615+
A = MatrixMPI(randn(6, 4)) # uniform partition
1616+
new_partition = [1, 2, 4, 5, 7] # 1, 2, 1, 2 rows per rank
1617+
A_repart = repartition(A, new_partition)
1618+
```
1619+
"""
1620+
function repartition(A::MatrixMPI{T}, p::Vector{Int}) where T
1621+
# Fast path: partition unchanged
1622+
if A.row_partition == p
1623+
return A
1624+
end
1625+
1626+
plan = get_repartition_plan(A, p)
1627+
return execute_plan!(plan, A)
1628+
end

0 commit comments

Comments
 (0)