@@ -1372,3 +1372,257 @@ function Base.mapslices(f, A::MatrixMPI{T}; dims) where T
13721372 error (" dims must be 1 or 2" )
13731373 end
13741374end
1375+
1376+ # ============================================================================
1377+ # DenseRepartitionPlan: Repartition a MatrixMPI to a new row partition
1378+ # ============================================================================
1379+
1380+ """
1381+ DenseRepartitionPlan{T}
1382+
1383+ Communication plan for repartitioning a MatrixMPI to a new row partition.
1384+
1385+ # Fields
1386+ - `send_rank_ids::Vector{Int}`: Ranks we send rows to (0-indexed)
1387+ - `send_row_ranges::Vector{UnitRange{Int}}`: For each rank, range of local rows to send
1388+ - `send_bufs::Vector{Matrix{T}}`: Pre-allocated send buffers
1389+ - `send_reqs::Vector{MPI.Request}`: Pre-allocated send request handles
1390+ - `recv_rank_ids::Vector{Int}`: Ranks we receive rows from (0-indexed)
1391+ - `recv_row_counts::Vector{Int}`: Number of rows to receive from each rank
1392+ - `recv_bufs::Vector{Matrix{T}}`: Pre-allocated receive buffers
1393+ - `recv_reqs::Vector{MPI.Request}`: Pre-allocated receive request handles
1394+ - `recv_dst_ranges::Vector{UnitRange{Int}}`: Destination row ranges in result for each recv
1395+ - `local_src_range::UnitRange{Int}`: Source row range for local copy
1396+ - `local_dst_range::UnitRange{Int}`: Destination row range for local copy
1397+ - `result_row_partition::Vector{Int}`: Target row partition (copy of p)
1398+ - `result_col_partition::Vector{Int}`: Column partition (unchanged from source)
1399+ - `result_structural_hash::Blake3Hash`: Hash of result matrix
1400+ - `result_local_nrows::Int`: Number of rows this rank owns after repartition
1401+ - `ncols::Int`: Number of columns in the matrix
1402+ """
1403+ mutable struct DenseRepartitionPlan{T}
1404+ send_rank_ids:: Vector{Int}
1405+ send_row_ranges:: Vector{UnitRange{Int}}
1406+ send_bufs:: Vector{Matrix{T}}
1407+ send_reqs:: Vector{MPI.Request}
1408+ recv_rank_ids:: Vector{Int}
1409+ recv_row_counts:: Vector{Int}
1410+ recv_bufs:: Vector{Matrix{T}}
1411+ recv_reqs:: Vector{MPI.Request}
1412+ recv_dst_ranges:: Vector{UnitRange{Int}}
1413+ local_src_range:: UnitRange{Int}
1414+ local_dst_range:: UnitRange{Int}
1415+ result_row_partition:: Vector{Int}
1416+ result_col_partition:: Vector{Int}
1417+ result_structural_hash:: Blake3Hash
1418+ result_local_nrows:: Int
1419+ ncols:: Int
1420+ end
1421+
1422+ """
1423+ DenseRepartitionPlan(A::MatrixMPI{T}, p::Vector{Int}) where T
1424+
1425+ Create a communication plan to repartition `A` to have row partition `p`.
1426+ The col_partition remains unchanged.
1427+
1428+ The plan computes:
1429+ 1. Which rows to send to each rank based on partition overlap
1430+ 2. Which rows to receive from each rank
1431+ 3. Pre-allocates all buffers for allocation-free execution
1432+ 4. Computes the result structural hash eagerly
1433+ """
1434+ function DenseRepartitionPlan (A:: MatrixMPI{T} , p:: Vector{Int} ) where T
1435+ comm = MPI. COMM_WORLD
1436+ rank = MPI. Comm_rank (comm)
1437+ nranks = MPI. Comm_size (comm)
1438+
1439+ # Source partition info
1440+ src_start = A. row_partition[rank+ 1 ]
1441+ src_end = A. row_partition[rank+ 2 ] - 1
1442+ local_nrows = max (0 , src_end - src_start + 1 )
1443+
1444+ # Target partition info
1445+ dst_start = p[rank+ 1 ]
1446+ dst_end = p[rank+ 2 ] - 1
1447+ result_local_nrows = max (0 , dst_end - dst_start + 1 )
1448+
1449+ ncols = A. col_partition[end ] - 1
1450+
1451+ # Step 1: Determine which rows we send to each rank
1452+ send_row_ranges_map = Dict {Int, UnitRange{Int}} ()
1453+ for r in 0 : (nranks- 1 )
1454+ r_start = p[r+ 1 ]
1455+ r_end = p[r+ 2 ] - 1
1456+ if r_end < r_start
1457+ continue # rank r has no rows in target partition
1458+ end
1459+ # Intersection of our rows with rank r's target
1460+ overlap_start = max (src_start, r_start)
1461+ overlap_end = min (src_end, r_end)
1462+ if overlap_start <= overlap_end
1463+ # Convert to local row indices in A.A
1464+ local_start = overlap_start - src_start + 1
1465+ local_end = overlap_end - src_start + 1
1466+ send_row_ranges_map[r] = local_start: local_end
1467+ end
1468+ end
1469+
1470+ # Step 2: Exchange counts via Alltoall
1471+ send_counts = Int32[haskey (send_row_ranges_map, r) ? length (send_row_ranges_map[r]) : 0 for r in 0 : (nranks- 1 )]
1472+ recv_counts_raw = MPI. Alltoall (MPI. UBuffer (send_counts, 1 ), comm)
1473+
1474+ # Step 3: Build send/recv structures
1475+ send_rank_ids = Int[]
1476+ send_row_ranges = UnitRange{Int}[]
1477+ recv_rank_ids = Int[]
1478+ recv_row_counts = Int[]
1479+ recv_dst_ranges = UnitRange{Int}[]
1480+
1481+ local_src_range = 1 : 0 # empty range
1482+ local_dst_range = 1 : 0 # empty range
1483+
1484+ # Handle local copy separately
1485+ if haskey (send_row_ranges_map, rank)
1486+ local_src_range = send_row_ranges_map[rank]
1487+ # Compute destination range: where do these rows go in the result?
1488+ global_start = src_start + local_src_range. start - 1
1489+ local_dst_start = global_start - dst_start + 1
1490+ local_dst_end = local_dst_start + length (local_src_range) - 1
1491+ local_dst_range = local_dst_start: local_dst_end
1492+ end
1493+
1494+ # Build send arrays (excluding local)
1495+ for r in 0 : (nranks- 1 )
1496+ if haskey (send_row_ranges_map, r) && r != rank
1497+ push! (send_rank_ids, r)
1498+ push! (send_row_ranges, send_row_ranges_map[r])
1499+ end
1500+ end
1501+
1502+ # Build recv arrays (excluding local)
1503+ for r in 0 : (nranks- 1 )
1504+ if recv_counts_raw[r+ 1 ] > 0 && r != rank
1505+ push! (recv_rank_ids, r)
1506+ push! (recv_row_counts, recv_counts_raw[r+ 1 ])
1507+
1508+ # Rows from rank r: their source range is [A.row_partition[r+1], A.row_partition[r+2]-1]
1509+ # intersected with our target range [dst_start, dst_end]
1510+ r_src_start = A. row_partition[r+ 1 ]
1511+ r_src_end = A. row_partition[r+ 2 ] - 1
1512+ overlap_start = max (r_src_start, dst_start)
1513+ overlap_end = min (r_src_end, dst_end)
1514+ # Destination range in our result
1515+ dst_range_start = overlap_start - dst_start + 1
1516+ dst_range_end = overlap_end - dst_start + 1
1517+ push! (recv_dst_ranges, dst_range_start: dst_range_end)
1518+ end
1519+ end
1520+
1521+ # Pre-allocate buffers
1522+ send_bufs = [Matrix {T} (undef, length (r), ncols) for r in send_row_ranges]
1523+ recv_bufs = [Matrix {T} (undef, c, ncols) for c in recv_row_counts]
1524+ send_reqs = Vector {MPI.Request} (undef, length (send_rank_ids))
1525+ recv_reqs = Vector {MPI.Request} (undef, length (recv_rank_ids))
1526+
1527+ # Compute result structural hash eagerly
1528+ result_local_size = (result_local_nrows, ncols)
1529+ result_structural_hash = compute_dense_structural_hash (p, A. col_partition, result_local_size, comm)
1530+
1531+ return DenseRepartitionPlan {T} (
1532+ send_rank_ids, send_row_ranges, send_bufs, send_reqs,
1533+ recv_rank_ids, recv_row_counts, recv_bufs, recv_reqs, recv_dst_ranges,
1534+ local_src_range, local_dst_range,
1535+ copy (p), copy (A. col_partition), result_structural_hash,
1536+ result_local_nrows, ncols
1537+ )
1538+ end
1539+
1540+ """
1541+ execute_plan!(plan::DenseRepartitionPlan{T}, A::MatrixMPI{T}) where T
1542+
1543+ Execute a dense repartition plan to redistribute rows from A to a new partition.
1544+ Returns a new MatrixMPI with the target row partition.
1545+ """
1546+ function execute_plan! (plan:: DenseRepartitionPlan{T} , A:: MatrixMPI{T} ) where T
1547+ comm = MPI. COMM_WORLD
1548+
1549+ # Allocate result
1550+ result_A = Matrix {T} (undef, plan. result_local_nrows, plan. ncols)
1551+
1552+ # Step 1: Local copy
1553+ if ! isempty (plan. local_src_range) && ! isempty (plan. local_dst_range)
1554+ result_A[plan. local_dst_range, :] = A. A[plan. local_src_range, :]
1555+ end
1556+
1557+ # Step 2: Fill send buffers and send
1558+ @inbounds for i in eachindex (plan. send_rank_ids)
1559+ r = plan. send_rank_ids[i]
1560+ row_range = plan. send_row_ranges[i]
1561+ buf = plan. send_bufs[i]
1562+ buf .= @view A. A[row_range, :]
1563+ plan. send_reqs[i] = MPI. Isend (vec (buf), comm; dest= r, tag= 93 )
1564+ end
1565+
1566+ # Step 3: Post receives
1567+ @inbounds for i in eachindex (plan. recv_rank_ids)
1568+ plan. recv_reqs[i] = MPI. Irecv! (vec (plan. recv_bufs[i]), comm; source= plan. recv_rank_ids[i], tag= 93 )
1569+ end
1570+
1571+ MPI. Waitall (plan. recv_reqs)
1572+
1573+ # Step 4: Copy received rows into result
1574+ @inbounds for i in eachindex (plan. recv_rank_ids)
1575+ dst_range = plan. recv_dst_ranges[i]
1576+ buf = plan. recv_bufs[i]
1577+ result_A[dst_range, :] = buf
1578+ end
1579+
1580+ MPI. Waitall (plan. send_reqs)
1581+
1582+ return MatrixMPI {T} (plan. result_structural_hash, plan. result_row_partition, plan. result_col_partition, result_A)
1583+ end
1584+
1585+ """
1586+ get_repartition_plan(A::MatrixMPI{T}, p::Vector{Int}) where T
1587+
1588+ Get a memoized DenseRepartitionPlan for repartitioning `A` to row partition `p`.
1589+ The plan is cached based on the structural hash of A and the target partition hash.
1590+ """
1591+ function get_repartition_plan (A:: MatrixMPI{T} , p:: Vector{Int} ) where T
1592+ target_hash = compute_partition_hash (p)
1593+ key = (_ensure_hash (A), target_hash, T)
1594+ if haskey (_repartition_plan_cache, key)
1595+ return _repartition_plan_cache[key]:: DenseRepartitionPlan{T}
1596+ end
1597+ plan = DenseRepartitionPlan (A, p)
1598+ _repartition_plan_cache[key] = plan
1599+ return plan
1600+ end
1601+
1602+ """
1603+ repartition(A::MatrixMPI{T}, p::Vector{Int}) where T
1604+
1605+ Redistribute a MatrixMPI to a new row partition `p`.
1606+ The col_partition remains unchanged.
1607+
1608+ The partition `p` must be a valid partition vector of length `nranks + 1` with
1609+ `p[1] == 1` and `p[end] == size(A, 1) + 1`.
1610+
1611+ Returns a new MatrixMPI with the same data but `row_partition == p`.
1612+
1613+ # Example
1614+ ```julia
1615+ A = MatrixMPI(randn(6, 4)) # uniform partition
1616+ new_partition = [1, 2, 4, 5, 7] # 1, 2, 1, 2 rows per rank
1617+ A_repart = repartition(A, new_partition)
1618+ ```
1619+ """
1620+ function repartition (A:: MatrixMPI{T} , p:: Vector{Int} ) where T
1621+ # Fast path: partition unchanged
1622+ if A. row_partition == p
1623+ return A
1624+ end
1625+
1626+ plan = get_repartition_plan (A, p)
1627+ return execute_plan! (plan, A)
1628+ end
0 commit comments