Skip to content

Commit aa73c07

Browse files
Sébastien LoiselSébastien Loisel
authored andcommitted
Refactor spdiagm and vcat to use repartition, fix diag locality
- Refactor spdiagm to use repartition instead of _gather_specific_elements for plan caching and fast path when partitions match - Refactor vcat for VectorMPI to use repartition similarly - Fix diag(A, k) to be purely local with no MPI communication - Remove _gather_specific_elements (no longer needed) - Fix spdiagm(kv...) to return square matrices matching Julia's default
1 parent 954aeec commit aa73c07

3 files changed

Lines changed: 139 additions & 305 deletions

File tree

src/blocks.jl

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -320,13 +320,37 @@ function Base.cat(vs::VectorMPI{T}...; dims) where T
320320
end
321321
end
322322

323+
"""
324+
_vcat_target_partition(output_partition::Vector{Int}, offset::Int, vec_len::Int) -> Vector{Int}
325+
326+
Compute the target partition for repartitioning a vector for use in vcat.
327+
328+
The vector starts at position (offset + 1) in the output. Each rank needs elements
329+
from this vector that fall within its output range.
330+
"""
331+
function _vcat_target_partition(output_partition::Vector{Int}, offset::Int, vec_len::Int)
332+
nranks = length(output_partition) - 1
333+
target = Vector{Int}(undef, nranks + 1)
334+
335+
for r in 0:(nranks-1)
336+
# Rank r owns output indices [output_partition[r+1], output_partition[r+2]-1]
337+
# From this vector (at offset), it needs indices starting at:
338+
# max(1, output_partition[r+1] - offset)
339+
target[r+1] = clamp(output_partition[r+1] - offset, 1, vec_len + 1)
340+
end
341+
target[nranks+1] = vec_len + 1
342+
343+
return target
344+
end
345+
323346
"""
324347
_vcat_vectors(vs::VectorMPI{T}...) where T
325348
326349
Vertically concatenate VectorMPI vectors.
327350
328-
This is a distributed implementation that only gathers the vector elements each rank
329-
needs for its local output, rather than gathering all data to all ranks.
351+
Uses `repartition` to redistribute each input vector's elements to the ranks that
352+
need them for the output. This provides plan caching and a fast path when partitions
353+
already align (no communication needed).
330354
"""
331355
function _vcat_vectors(vs::VectorMPI{T}...) where T
332356
length(vs) == 1 && return copy(vs[1])
@@ -341,14 +365,7 @@ function _vcat_vectors(vs::VectorMPI{T}...) where T
341365
offsets = [0; cumsum(lengths[1:end-1])]
342366

343367
# Step 2: Compute output partition
344-
elements_per_rank = div(total_length, nranks)
345-
elem_remainder = mod(total_length, nranks)
346-
output_partition = Vector{Int}(undef, nranks + 1)
347-
output_partition[1] = 1
348-
for r in 1:nranks
349-
extra = r <= elem_remainder ? 1 : 0
350-
output_partition[r+1] = output_partition[r] + elements_per_rank + extra
351-
end
368+
output_partition = uniform_partition(total_length, nranks)
352369

353370
my_out_start = output_partition[rank+1]
354371
my_out_end = output_partition[rank+2] - 1
@@ -357,33 +374,31 @@ function _vcat_vectors(vs::VectorMPI{T}...) where T
357374
# Step 3: Allocate local output vector
358375
local_v = Vector{T}(undef, local_len)
359376

360-
# Step 4: For each input vector, gather elements (all ranks must participate)
377+
# Step 4: For each input vector, repartition and copy elements
361378
for (vec_idx, v) in enumerate(vs)
362-
vec_start = offsets[vec_idx] + 1
363-
vec_end = offsets[vec_idx] + lengths[vec_idx]
379+
vec_len = lengths[vec_idx]
380+
offset = offsets[vec_idx]
381+
vec_start = offset + 1
382+
vec_end = offset + vec_len
364383

365-
# Determine indices we need from this vector
366-
# NOTE: ALL ranks must call _gather_specific_elements for EVERY vector to
367-
# participate in MPI collectives. Pass empty array if no overlap.
384+
# Compute target partition for this vector and repartition
385+
target = _vcat_target_partition(output_partition, offset, vec_len)
386+
v_repart = repartition(v, target)
387+
my_v_start = v_repart.partition[rank+1]
388+
389+
# Check if this vector contributes to my output range
368390
has_overlap = !(vec_end < my_out_start || vec_start > my_out_end)
369391

370392
if has_overlap
371-
first_in_vec = max(1, my_out_start - offsets[vec_idx])
372-
last_in_vec = min(lengths[vec_idx], my_out_end - offsets[vec_idx])
373-
indices_needed = collect(first_in_vec:last_in_vec)
374-
else
375-
indices_needed = Int[]
376-
end
393+
# Copy elements from repartitioned vector to output
394+
first_in_vec = max(1, my_out_start - offset)
395+
last_in_vec = min(vec_len, my_out_end - offset)
377396

378-
# Gather these elements (all ranks must call this!)
379-
gathered = _gather_specific_elements(v, indices_needed)
380-
381-
# Place into local output (only if we have overlap)
382-
if has_overlap
383-
for (i, idx_in_vec) in enumerate(indices_needed)
384-
global_out_idx = offsets[vec_idx] + idx_in_vec
397+
for idx_in_vec in first_in_vec:last_in_vec
398+
global_out_idx = offset + idx_in_vec
385399
local_out_idx = global_out_idx - my_out_start + 1
386-
local_v[local_out_idx] = gathered[i]
400+
local_v_idx = idx_in_vec - my_v_start + 1
401+
local_v[local_out_idx] = v_repart.v[local_v_idx]
387402
end
388403
end
389404
end

0 commit comments

Comments
 (0)