@@ -27,7 +27,7 @@ cat(A, B, C, D; dims=(2,2)) # 2×2 block matrix [A B; C D]
2727This is a distributed implementation that only gathers the rows each rank needs
2828for its local output, rather than gathering all data to all ranks.
2929"""
30- function Base. cat (As:: SparseMatrixMPI{T} ...; dims) where T
30+ function Base. cat (As:: SparseMatrixMPI{T,Ti,AV } ...; dims) where {T,Ti,AV}
3131 isempty (As) && error (" cat requires at least one matrix" )
3232 length (As) == 1 && return copy (As[1 ])
3333
@@ -142,7 +142,19 @@ function Base.cat(As::SparseMatrixMPI{T}...; dims) where T
142142 SparseMatrixCSC (total_cols, local_nrows, ones (Int, local_nrows + 1 ), Int[], T[]) :
143143 sparse (local_J, local_I, local_V, total_cols, local_nrows)
144144
145- return SparseMatrixMPI_local (transpose (AT_local); comm= comm)
145+ result = SparseMatrixMPI_local (transpose (AT_local); comm= comm)
146+
147+ # Convert to GPU if inputs were GPU (GPU→CPU for MPI, then CPU→GPU for result)
148+ if AV != = Vector{T}
149+ nzval_target = copyto! (similar (As[1 ]. nzval, length (result. nzval)), result. nzval)
150+ rowptr_target = _to_target_backend (result. rowptr, AV)
151+ colval_target = _to_target_backend (result. colval, AV)
152+ return SparseMatrixMPI {T,Ti,AV} (
153+ result. structural_hash, result. row_partition, result. col_partition, result. col_indices,
154+ result. rowptr, result. colval, nzval_target, result. nrows_local, result. ncols_compressed,
155+ nothing , result. cached_symmetric, rowptr_target, colval_target)
156+ end
157+ return result
146158end
147159
148160# ============================================================================
@@ -452,7 +464,7 @@ Returns a SparseMatrixMPI.
452464This is a distributed implementation that only gathers the rows each rank needs
453465for its local output, rather than gathering all data to all ranks.
454466"""
455- function blockdiag (As:: SparseMatrixMPI{T} ...) where T
467+ function blockdiag (As:: SparseMatrixMPI{T,Ti,AV } ...) where {T,Ti,AV}
456468 isempty (As) && error (" blockdiag requires at least one matrix" )
457469 length (As) == 1 && return copy (As[1 ])
458470
@@ -526,5 +538,17 @@ function blockdiag(As::SparseMatrixMPI{T}...) where T
526538 SparseMatrixCSC (total_cols, local_nrows, ones (Int, local_nrows + 1 ), Int[], T[]) :
527539 sparse (local_J, local_I, local_V, total_cols, local_nrows)
528540
529- return SparseMatrixMPI_local (transpose (AT_local); comm= comm)
541+ result = SparseMatrixMPI_local (transpose (AT_local); comm= comm)
542+
543+ # Convert to GPU if inputs were GPU (GPU→CPU for MPI, then CPU→GPU for result)
544+ if AV != = Vector{T}
545+ nzval_target = copyto! (similar (As[1 ]. nzval, length (result. nzval)), result. nzval)
546+ rowptr_target = _to_target_backend (result. rowptr, AV)
547+ colval_target = _to_target_backend (result. colval, AV)
548+ return SparseMatrixMPI {T,Ti,AV} (
549+ result. structural_hash, result. row_partition, result. col_partition, result. col_indices,
550+ result. rowptr, result. colval, nzval_target, result. nrows_local, result. ncols_compressed,
551+ nothing , result. cached_symmetric, rowptr_target, colval_target)
552+ end
553+ return result
530554end
0 commit comments