@@ -320,13 +320,37 @@ function Base.cat(vs::VectorMPI{T}...; dims) where T
320320 end
321321end
322322
323+ """
324+ _vcat_target_partition(output_partition::Vector{Int}, offset::Int, vec_len::Int) -> Vector{Int}
325+
326+ Compute the target partition for repartitioning a vector for use in vcat.
327+
328+ The vector starts at position (offset + 1) in the output. Each rank needs elements
329+ from this vector that fall within its output range.
330+ """
331+ function _vcat_target_partition (output_partition:: Vector{Int} , offset:: Int , vec_len:: Int )
332+ nranks = length (output_partition) - 1
333+ target = Vector {Int} (undef, nranks + 1 )
334+
335+ for r in 0 : (nranks- 1 )
336+ # Rank r owns output indices [output_partition[r+1], output_partition[r+2]-1]
337+ # From this vector (at offset), it needs indices starting at:
338+ # max(1, output_partition[r+1] - offset)
339+ target[r+ 1 ] = clamp (output_partition[r+ 1 ] - offset, 1 , vec_len + 1 )
340+ end
341+ target[nranks+ 1 ] = vec_len + 1
342+
343+ return target
344+ end
345+
323346"""
324347 _vcat_vectors(vs::VectorMPI{T}...) where T
325348
326349Vertically concatenate VectorMPI vectors.
327350
328- This is a distributed implementation that only gathers the vector elements each rank
329- needs for its local output, rather than gathering all data to all ranks.
351+ Uses `repartition` to redistribute each input vector's elements to the ranks that
352+ need them for the output. This provides plan caching and a fast path when partitions
353+ already align (no communication needed).
330354"""
331355function _vcat_vectors (vs:: VectorMPI{T} ...) where T
332356 length (vs) == 1 && return copy (vs[1 ])
@@ -341,14 +365,7 @@ function _vcat_vectors(vs::VectorMPI{T}...) where T
341365 offsets = [0 ; cumsum (lengths[1 : end - 1 ])]
342366
343367 # Step 2: Compute output partition
344- elements_per_rank = div (total_length, nranks)
345- elem_remainder = mod (total_length, nranks)
346- output_partition = Vector {Int} (undef, nranks + 1 )
347- output_partition[1 ] = 1
348- for r in 1 : nranks
349- extra = r <= elem_remainder ? 1 : 0
350- output_partition[r+ 1 ] = output_partition[r] + elements_per_rank + extra
351- end
368+ output_partition = uniform_partition (total_length, nranks)
352369
353370 my_out_start = output_partition[rank+ 1 ]
354371 my_out_end = output_partition[rank+ 2 ] - 1
@@ -357,33 +374,31 @@ function _vcat_vectors(vs::VectorMPI{T}...) where T
357374 # Step 3: Allocate local output vector
358375 local_v = Vector {T} (undef, local_len)
359376
360- # Step 4: For each input vector, gather elements (all ranks must participate)
377+ # Step 4: For each input vector, repartition and copy elements
361378 for (vec_idx, v) in enumerate (vs)
362- vec_start = offsets[vec_idx] + 1
363- vec_end = offsets[vec_idx] + lengths[vec_idx]
379+ vec_len = lengths[vec_idx]
380+ offset = offsets[vec_idx]
381+ vec_start = offset + 1
382+ vec_end = offset + vec_len
364383
365- # Determine indices we need from this vector
366- # NOTE: ALL ranks must call _gather_specific_elements for EVERY vector to
367- # participate in MPI collectives. Pass empty array if no overlap.
384+ # Compute target partition for this vector and repartition
385+ target = _vcat_target_partition (output_partition, offset, vec_len)
386+ v_repart = repartition (v, target)
387+ my_v_start = v_repart. partition[rank+ 1 ]
388+
389+ # Check if this vector contributes to my output range
368390 has_overlap = ! (vec_end < my_out_start || vec_start > my_out_end)
369391
370392 if has_overlap
371- first_in_vec = max (1 , my_out_start - offsets[vec_idx])
372- last_in_vec = min (lengths[vec_idx], my_out_end - offsets[vec_idx])
373- indices_needed = collect (first_in_vec: last_in_vec)
374- else
375- indices_needed = Int[]
376- end
393+ # Copy elements from repartitioned vector to output
394+ first_in_vec = max (1 , my_out_start - offset)
395+ last_in_vec = min (vec_len, my_out_end - offset)
377396
378- # Gather these elements (all ranks must call this!)
379- gathered = _gather_specific_elements (v, indices_needed)
380-
381- # Place into local output (only if we have overlap)
382- if has_overlap
383- for (i, idx_in_vec) in enumerate (indices_needed)
384- global_out_idx = offsets[vec_idx] + idx_in_vec
397+ for idx_in_vec in first_in_vec: last_in_vec
398+ global_out_idx = offset + idx_in_vec
385399 local_out_idx = global_out_idx - my_out_start + 1
386- local_v[local_out_idx] = gathered[i]
400+ local_v_idx = idx_in_vec - my_v_start + 1
401+ local_v[local_out_idx] = v_repart. v[local_v_idx]
387402 end
388403 end
389404 end
0 commit comments