Skip to content

Commit dbe94fd

Browse files
Sébastien LoiselSébastien Loisel
authored andcommitted
Fix GPU array type preservation in cat/blockdiag
When input sparse matrices have GPU-backed nzval (e.g., MtlVector), cat and blockdiag now correctly return GPU-backed results instead of always returning CPU arrays. The fix captures the AV type parameter from inputs and converts the result back to GPU after MPI communication (which requires CPU staging buffers).
1 parent 82ac0f3 commit dbe94fd

1 file changed

Lines changed: 28 additions & 4 deletions

File tree

src/blocks.jl

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ cat(A, B, C, D; dims=(2,2)) # 2×2 block matrix [A B; C D]
2727
This is a distributed implementation that only gathers the rows each rank needs
2828
for its local output, rather than gathering all data to all ranks.
2929
"""
30-
function Base.cat(As::SparseMatrixMPI{T}...; dims) where T
30+
function Base.cat(As::SparseMatrixMPI{T,Ti,AV}...; dims) where {T,Ti,AV}
3131
isempty(As) && error("cat requires at least one matrix")
3232
length(As) == 1 && return copy(As[1])
3333

@@ -142,7 +142,19 @@ function Base.cat(As::SparseMatrixMPI{T}...; dims) where T
142142
SparseMatrixCSC(total_cols, local_nrows, ones(Int, local_nrows + 1), Int[], T[]) :
143143
sparse(local_J, local_I, local_V, total_cols, local_nrows)
144144

145-
return SparseMatrixMPI_local(transpose(AT_local); comm=comm)
145+
result = SparseMatrixMPI_local(transpose(AT_local); comm=comm)
146+
147+
# Convert to GPU if inputs were GPU (GPU→CPU for MPI, then CPU→GPU for result)
148+
if AV !== Vector{T}
149+
nzval_target = copyto!(similar(As[1].nzval, length(result.nzval)), result.nzval)
150+
rowptr_target = _to_target_backend(result.rowptr, AV)
151+
colval_target = _to_target_backend(result.colval, AV)
152+
return SparseMatrixMPI{T,Ti,AV}(
153+
result.structural_hash, result.row_partition, result.col_partition, result.col_indices,
154+
result.rowptr, result.colval, nzval_target, result.nrows_local, result.ncols_compressed,
155+
nothing, result.cached_symmetric, rowptr_target, colval_target)
156+
end
157+
return result
146158
end
147159

148160
# ============================================================================
@@ -452,7 +464,7 @@ Returns a SparseMatrixMPI.
452464
This is a distributed implementation that only gathers the rows each rank needs
453465
for its local output, rather than gathering all data to all ranks.
454466
"""
455-
function blockdiag(As::SparseMatrixMPI{T}...) where T
467+
function blockdiag(As::SparseMatrixMPI{T,Ti,AV}...) where {T,Ti,AV}
456468
isempty(As) && error("blockdiag requires at least one matrix")
457469
length(As) == 1 && return copy(As[1])
458470

@@ -526,5 +538,17 @@ function blockdiag(As::SparseMatrixMPI{T}...) where T
526538
SparseMatrixCSC(total_cols, local_nrows, ones(Int, local_nrows + 1), Int[], T[]) :
527539
sparse(local_J, local_I, local_V, total_cols, local_nrows)
528540

529-
return SparseMatrixMPI_local(transpose(AT_local); comm=comm)
541+
result = SparseMatrixMPI_local(transpose(AT_local); comm=comm)
542+
543+
# Convert to GPU if inputs were GPU (GPU→CPU for MPI, then CPU→GPU for result)
544+
if AV !== Vector{T}
545+
nzval_target = copyto!(similar(As[1].nzval, length(result.nzval)), result.nzval)
546+
rowptr_target = _to_target_backend(result.rowptr, AV)
547+
colval_target = _to_target_backend(result.colval, AV)
548+
return SparseMatrixMPI{T,Ti,AV}(
549+
result.structural_hash, result.row_partition, result.col_partition, result.col_indices,
550+
result.rowptr, result.colval, nzval_target, result.nrows_local, result.ncols_compressed,
551+
nothing, result.cached_symmetric, rowptr_target, colval_target)
552+
end
553+
return result
530554
end

0 commit comments

Comments
 (0)