Skip to content

Commit 4ca3c32

Browse files
author
Sebastien Loisel
committed
Unify CPU/GPU code paths, remove synchronize calls, reset version
- Remove isa-based CPU/GPU branching in blocks.jl, dense.jl, mumps_factorization.jl; use unified helpers (_convert_array, _ensure_cpu, to_backend) instead - Remove CUDA.synchronize() and Metal.synchronize() from kernel dispatches and cuDSS operations (implicit sync on reads) - Remove comm_barrier() calls from cuDSS collective operations - Add _array_to_device stub to main module for GPU extensions - Update codecov.yml with correct extension filenames - Reset version to 0.1.0 for registry re-registration
1 parent 9110b09 commit 4ca3c32

8 files changed

Lines changed: 37 additions & 139 deletions

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
authors = ["S. Loisel"]
22
name = "HPCLinearAlgebra"
33
uuid = "537374f1-5608-4525-82fb-641dce542540"
4-
version = "0.1.9"
4+
version = "0.1.0"
55

66
[compat]
77
Adapt = "4"

codecov.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ coverage:
1414

1515
ignore:
1616
- "ext/"
17-
- "ext/LinearAlgebraMPICUDAExt.jl"
18-
- "ext/LinearAlgebraMPIMetalExt.jl"
17+
- "ext/HPCLinearAlgebraCUDAExt.jl"
18+
- "ext/HPCLinearAlgebraMetalExt.jl"

ext/HPCLinearAlgebraCUDAExt.jl

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -667,19 +667,13 @@ function _create_cudss_factorization(A::HPCLinearAlgebra.HPCSparseMatrix{T,Ti,B}
667667
rhs = rhs_ref[]
668668
_cudss_matrix_set_distribution_row1d(rhs, Int64(first_row), Int64(last_row))
669669

670-
comm_barrier(comm)
671-
672670
# Run analysis phase
673671
# Note: Analysis caching is disabled in MGMN mode because cudssDataGet for
674672
# PERM_REORDER_ROW/COL returns INVALID_VALUE. Each factorization does full analysis.
675673
_cudss_execute(handle, CUDSS_PHASE_ANALYSIS, config, data, matrix, solution, rhs)
676-
CUDA.synchronize()
677-
comm_barrier(comm)
678674

679675
# Run numeric factorization
680676
_cudss_execute(handle, CUDSS_PHASE_FACTORIZATION, config, data, matrix, solution, rhs)
681-
CUDA.synchronize()
682-
comm_barrier(comm)
683677

684678
# Create factorization object
685679
F = CuDSSFactorizationMPI{T,B}(
@@ -710,10 +704,8 @@ function HPCLinearAlgebra.solve(F::CuDSSFactorizationMPI{T,B}, b::HPCLinearAlgeb
710704
# Copy b directly to RHS buffer (GPU to GPU)
711705
copyto!(F.b_gpu, b.v)
712706

713-
# Execute solve phase only
707+
# Execute solve phase only (collective operation)
714708
_cudss_execute(F.handle, CUDSS_PHASE_SOLVE, F.config, F.data, F.matrix, F.solution, F.rhs)
715-
CUDA.synchronize()
716-
comm_barrier(comm)
717709

718710
# Return GPU vector (copy from internal buffer) with backend
719711
return HPCLinearAlgebra.HPCVector{T,B}(b.structural_hash, b.partition, copy(F.x_gpu), F.backend)
@@ -736,10 +728,6 @@ Only destroys: data (L/U factors), matrix wrappers.
736728
Does NOT destroy: handle, config (global, cached per-process).
737729
"""
738730
function HPCLinearAlgebra.finalize!(F::CuDSSFactorizationMPI)
739-
comm = F.backend.comm
740-
CUDA.synchronize()
741-
comm_barrier(comm)
742-
743731
# Destroy data object (holds L/U factors - collective operation in MGMN mode)
744732
if F.data != C_NULL
745733
_cudss_data_destroy(F.handle, F.data)
@@ -760,9 +748,6 @@ function HPCLinearAlgebra.finalize!(F::CuDSSFactorizationMPI)
760748
F.rhs = C_NULL
761749
end
762750

763-
CUDA.synchronize()
764-
comm_barrier(comm)
765-
766751
return nothing
767752
end
768753

@@ -798,18 +783,11 @@ function _refactorize_and_solve!(F::CuDSSFactorizationMPI{T,B},
798783
# Copy RHS to buffer
799784
copyto!(F.b_gpu, b.v)
800785

801-
CUDA.synchronize()
802-
comm_barrier(comm)
803-
804786
# Refactorize (skip analysis - the symbolic factorization is already done)
805787
_cudss_execute(F.handle, CUDSS_PHASE_FACTORIZATION, F.config, F.data, F.matrix, F.solution, F.rhs)
806-
CUDA.synchronize()
807-
comm_barrier(comm)
808788

809789
# Solve
810790
_cudss_execute(F.handle, CUDSS_PHASE_SOLVE, F.config, F.data, F.matrix, F.solution, F.rhs)
811-
CUDA.synchronize()
812-
comm_barrier(comm)
813791

814792
# Return GPU vector (copy from internal buffer) with backend
815793
return HPCLinearAlgebra.HPCVector{T, B}(b.structural_hash, b.partition, copy(F.x_gpu), b.backend)
@@ -839,12 +817,9 @@ function Base.:\(A::HPCLinearAlgebra.HPCSparseMatrix{T,Ti,B},
839817
F = _create_cudss_factorization(A, false)
840818
_cudss_backslash_cache[cache_key] = F
841819

842-
# Solve
843-
comm = A.backend.comm
820+
# Solve (collective operation)
844821
copyto!(F.b_gpu, b.v)
845822
_cudss_execute(F.handle, CUDSS_PHASE_SOLVE, F.config, F.data, F.matrix, F.solution, F.rhs)
846-
CUDA.synchronize()
847-
comm_barrier(comm)
848823

849824
return HPCLinearAlgebra.HPCVector{T, B}(b.structural_hash, b.partition, copy(F.x_gpu), b.backend)
850825
end
@@ -870,12 +845,9 @@ function Base.:\(A::Symmetric{T,<:HPCLinearAlgebra.HPCSparseMatrix{T,Ti,B}},
870845
F = _create_cudss_factorization(A_inner, true)
871846
_cudss_backslash_cache[cache_key] = F
872847

873-
# Solve
874-
comm = A_inner.backend.comm
848+
# Solve (collective operation)
875849
copyto!(F.b_gpu, b.v)
876850
_cudss_execute(F.handle, CUDSS_PHASE_SOLVE, F.config, F.data, F.matrix, F.solution, F.rhs)
877-
CUDA.synchronize()
878-
comm_barrier(comm)
879851

880852
return HPCLinearAlgebra.HPCVector{T, B}(b.structural_hash, b.partition, copy(F.x_gpu), b.backend)
881853
end
@@ -928,7 +900,6 @@ function _cuda_map_rows_kernel_dispatch(f, output::CuMatrix{T}, arg1::CuMatrix{T
928900
threads = min(n, config.threads)
929901
blocks = cld(n, threads)
930902
kernel(f, output, arg1, Val(ncols1), Val(out_cols); threads=threads, blocks=blocks)
931-
CUDA.synchronize()
932903
end
933904

934905
function _cuda_map_rows_kernel_dispatch(f, output::CuMatrix{T}, arg1::CuMatrix{T}, arg2::CuMatrix{T}) where T
@@ -942,7 +913,6 @@ function _cuda_map_rows_kernel_dispatch(f, output::CuMatrix{T}, arg1::CuMatrix{T
942913
threads = min(n, config.threads)
943914
blocks = cld(n, threads)
944915
kernel(f, output, arg1, arg2, Val(ncols1), Val(ncols2), Val(out_cols); threads=threads, blocks=blocks)
945-
CUDA.synchronize()
946916
end
947917

948918
# CUDA kernels

ext/HPCLinearAlgebraMetalExt.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T})
153153
threads = min(n, 256)
154154
groups = cld(n, threads)
155155
kernel(f, output, arg1, Val(ncols1), Val(out_cols); threads=threads, groups=groups)
156-
Metal.synchronize()
157156
end
158157

159158
function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T}, arg2::MtlMatrix{T}) where T
@@ -166,7 +165,6 @@ function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T},
166165
threads = min(n, 256)
167166
groups = cld(n, threads)
168167
kernel(f, output, arg1, arg2, Val(ncols1), Val(ncols2), Val(out_cols); threads=threads, groups=groups)
169-
Metal.synchronize()
170168
end
171169

172170
function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T}, arg2::MtlMatrix{T}, arg3::MtlMatrix{T}) where T
@@ -180,7 +178,6 @@ function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T},
180178
threads = min(n, 256)
181179
groups = cld(n, threads)
182180
kernel(f, output, arg1, arg2, arg3, Val(ncols1), Val(ncols2), Val(ncols3), Val(out_cols); threads=threads, groups=groups)
183-
Metal.synchronize()
184181
end
185182

186183
function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T}, arg2::MtlMatrix{T}, arg3::MtlMatrix{T}, arg4::MtlMatrix{T}) where T
@@ -195,7 +192,6 @@ function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T},
195192
threads = min(n, 256)
196193
groups = cld(n, threads)
197194
kernel(f, output, arg1, arg2, arg3, arg4, Val(ncols1), Val(ncols2), Val(ncols3), Val(ncols4), Val(out_cols); threads=threads, groups=groups)
198-
Metal.synchronize()
199195
end
200196

201197
function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T}, arg2::MtlMatrix{T}, arg3::MtlMatrix{T}, arg4::MtlMatrix{T}, arg5::MtlMatrix{T}) where T
@@ -211,7 +207,6 @@ function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T},
211207
threads = min(n, 256)
212208
groups = cld(n, threads)
213209
kernel(f, output, arg1, arg2, arg3, arg4, arg5, Val(ncols1), Val(ncols2), Val(ncols3), Val(ncols4), Val(ncols5), Val(out_cols); threads=threads, groups=groups)
214-
Metal.synchronize()
215210
end
216211

217212
# ============================================================================

src/HPCLinearAlgebra.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ const _identity_addition_plan_cache = Dict{Tuple{Blake3Hash,DataType,DataType},A
170170
# Forward declarations - implementations are after include("backends.jl")
171171
function _convert_array end
172172
function to_backend end
173+
function _array_to_device end # GPU extensions provide CPU→GPU conversion
173174

174175
"""
175176
clear_plan_cache!()

src/blocks.jl

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,8 @@ function Base.cat(As::HPCSparseMatrix{T,Ti,Bk}...; dims) where {T,Ti,Bk<:HPCBack
145145

146146
result = HPCSparseMatrix_local(transpose(AT_local), backend)
147147

148-
# Convert to GPU if inputs were GPU (GPU→CPU for MPI, then CPU→GPU for result)
149-
device = backend.device
150-
if !(device isa DeviceCPU)
151-
nzval_target = copyto!(similar(As[1].nzval, length(result.nzval)), result.nzval)
152-
rowptr_target = _to_target_device(result.rowptr, device)
153-
colval_target = _to_target_device(result.colval, device)
154-
return HPCSparseMatrix{T,Ti,Bk}(
155-
result.structural_hash, result.row_partition, result.col_partition, result.col_indices,
156-
result.rowptr, result.colval, nzval_target, result.nrows_local, result.ncols_compressed,
157-
nothing, result.cached_symmetric, rowptr_target, colval_target, backend)
158-
end
159-
return result
148+
# Convert to target backend (no-op for CPU, copies for GPU)
149+
return to_backend(result, backend)
160150
end
161151

162152
# ============================================================================
@@ -292,12 +282,8 @@ function Base.cat(As::HPCMatrix{T,B}...; dims) where {T, B<:HPCBackend}
292282
# Step 4: Create HPCMatrix from local data
293283
result = HPCMatrix_local(local_matrix, backend)
294284

295-
# Convert to GPU if inputs were GPU (check if backend device is not CPU)
296-
if !(backend.device isa DeviceCPU)
297-
local_matrix_gpu = copyto!(similar(As[1].A, local_nrows, total_cols), local_matrix)
298-
return HPCMatrix{T,B}(result.structural_hash, result.row_partition, result.col_partition, local_matrix_gpu, backend)
299-
end
300-
return result
285+
# Convert to target backend (no-op for CPU, copies for GPU)
286+
return to_backend(result, backend)
301287
end
302288

303289
Base.hcat(As::HPCMatrix...) = cat(As...; dims=2)
@@ -483,7 +469,6 @@ function blockdiag(As::HPCSparseMatrix{T,Ti,Bk}...) where {T,Ti,Bk<:HPCBackend}
483469

484470
backend = As[1].backend
485471
comm = backend.comm
486-
device = backend.device
487472
rank = comm_rank(comm)
488473
nranks = comm_size(comm)
489474

@@ -555,15 +540,6 @@ function blockdiag(As::HPCSparseMatrix{T,Ti,Bk}...) where {T,Ti,Bk<:HPCBackend}
555540

556541
result = HPCSparseMatrix_local(transpose(AT_local), backend)
557542

558-
# Convert to GPU if inputs were GPU (GPU→CPU for MPI, then CPU→GPU for result)
559-
if !(device isa DeviceCPU)
560-
nzval_target = copyto!(similar(As[1].nzval, length(result.nzval)), result.nzval)
561-
rowptr_target = _to_target_device(result.rowptr, device)
562-
colval_target = _to_target_device(result.colval, device)
563-
return HPCSparseMatrix{T,Ti,Bk}(
564-
result.structural_hash, result.row_partition, result.col_partition, result.col_indices,
565-
result.rowptr, result.colval, nzval_target, result.nrows_local, result.ncols_compressed,
566-
nothing, result.cached_symmetric, rowptr_target, colval_target, backend)
567-
end
568-
return result
543+
# Convert to target backend (no-op for CPU, copies for GPU)
544+
return to_backend(result, backend)
569545
end

src/dense.jl

Lines changed: 22 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -618,16 +618,8 @@ function LinearAlgebra.mul!(y::HPCVector{T,B}, A::HPCMatrix{T,B}, x::HPCVector{T
618618
plan = get_dense_vector_plan(A, x)
619619
execute_plan!(plan, x)
620620

621-
# Check if CPU or GPU based on device type
622-
if A.backend.device isa DeviceCPU
623-
# CPU path
624-
y_local_cpu = Vector{T}(undef, length(y.v))
625-
LinearAlgebra.mul!(y_local_cpu, A.A, plan.gathered_cpu)
626-
copyto!(y.v, y_local_cpu)
627-
else
628-
# GPU path
629-
LinearAlgebra.mul!(y.v, A.A, plan.gathered)
630-
end
621+
# Unified CPU/GPU path: plan.gathered has correct type after execute_plan!
622+
LinearAlgebra.mul!(y.v, A.A, plan.gathered)
631623
return y
632624
end
633625

@@ -653,16 +645,9 @@ function Base.:*(A::HPCMatrix{T,B}, x::HPCVector{T,B}) where {T,B<:HPCBackend}
653645
# Execute the plan to gather vector elements
654646
execute_plan!(plan, x)
655647

656-
# Check if CPU or GPU based on device type
657-
if A.backend.device isa DeviceCPU
658-
# CPU path
659-
y_v = Vector{T}(undef, local_rows)
660-
LinearAlgebra.mul!(y_v, A.A, plan.gathered_cpu)
661-
else
662-
# GPU path
663-
y_v = similar(A.A, local_rows)
664-
LinearAlgebra.mul!(y_v, A.A, plan.gathered)
665-
end
648+
# Unified CPU/GPU path: similar() preserves array type, plan.gathered has correct type
649+
y_v = similar(A.A, local_rows)
650+
LinearAlgebra.mul!(y_v, A.A, plan.gathered)
666651

667652
return HPCVector{T,B}(
668653
plan.result_partition_hash,
@@ -936,13 +921,9 @@ function execute_plan!(plan::DenseTransposePlan{T}, A::HPCMatrix{T,B}) where {T,
936921
plan.row_partition, plan.col_partition, size(result_AT), comm)
937922
end
938923

939-
# Convert result to match input array type (CPU or GPU)
940-
if !(A.backend.device isa DeviceCPU)
941-
# Input was GPU - convert result back to GPU
942-
result_AT_gpu = copyto!(similar(A.A, size(result_AT)), result_AT)
943-
return HPCMatrix{T,B}(plan.structural_hash, plan.row_partition, plan.col_partition, result_AT_gpu, A.backend)
944-
end
945-
return HPCMatrix{T,B}(plan.structural_hash, plan.row_partition, plan.col_partition, result_AT, A.backend)
924+
# Unified CPU/GPU path: _convert_array is no-op for CPU, copies for GPU
925+
result_A = _convert_array(result_AT, A.backend.device)
926+
return HPCMatrix{T,B}(plan.structural_hash, plan.row_partition, plan.col_partition, result_A, A.backend)
946927
end
947928

948929
"""
@@ -1249,19 +1230,14 @@ function Base.:*(At::Transpose{T,HPCMatrix{T,B}}, x::HPCVector{T,B}) where {T,B<
12491230
my_row_start = A.row_partition[rank+1]
12501231
my_row_end = A.row_partition[rank+2] - 1
12511232

1252-
if A.backend.device isa DeviceCPU
1253-
# CPU path
1254-
local_gathered = @view plan.gathered_cpu[my_row_start:my_row_end]
1255-
partial_result = transpose(A.A) * local_gathered
1256-
else
1257-
# GPU path - use GPU gathered directly
1258-
local_gathered = @view plan.gathered[my_row_start:my_row_end]
1259-
# For Metal, views may not work directly - copy to contiguous array
1260-
local_gathered_gpu = similar(A.A, length(local_gathered))
1261-
copyto!(local_gathered_gpu, Array(local_gathered))
1262-
partial_result_gpu = transpose(A.A) * local_gathered_gpu
1263-
partial_result = Array(partial_result_gpu) # Need CPU for Allreduce
1264-
end
1233+
# Unified CPU/GPU path:
1234+
# 1. Get slice and copy to contiguous array (fixes GPU view issues)
1235+
# 2. Compute on backend
1236+
# 3. Ensure CPU for Allreduce (no-op for CPU, copy for GPU)
1237+
local_gathered_slice = plan.gathered[my_row_start:my_row_end]
1238+
local_gathered_contiguous = copy(local_gathered_slice)
1239+
partial_result_backend = transpose(A.A) * local_gathered_contiguous
1240+
partial_result = _ensure_cpu(partial_result_backend)
12651241

12661242
# Allreduce to sum contributions from all ranks
12671243
full_result = comm_allreduce(comm, partial_result, +)
@@ -1271,8 +1247,8 @@ function Base.:*(At::Transpose{T,HPCMatrix{T,B}}, x::HPCVector{T,B}) where {T,B<
12711247
my_col_end = A.col_partition[rank+2] - 1
12721248
local_result_cpu = full_result[my_col_start:my_col_end]
12731249

1274-
# Copy to GPU if needed
1275-
local_result = (A.backend.device isa DeviceCPU) ? local_result_cpu : copyto!(similar(x.v, length(local_result_cpu)), local_result_cpu)
1250+
# Unified: _convert_array is no-op for CPU, copies for GPU
1251+
local_result = _convert_array(local_result_cpu, A.backend.device)
12761252

12771253
# Create result vector (partition is immutable, no need to copy)
12781254
y = HPCVector{T,B}(
@@ -1325,11 +1301,8 @@ function Base.:*(At::TransposedHPCMatrix{T,B}, Bmat::HPCMatrix{T,B}) where {T,B}
13251301
result_partition = columns[1].partition
13261302
local_m = result_partition[rank+2] - result_partition[rank+1]
13271303

1328-
# Build local matrix from column results (columns[k].v may be GPU array)
1329-
local_result = Matrix{T}(undef, local_m, n)
1330-
for k in 1:n
1331-
local_result[:, k] = Array(columns[k].v) # Ensure CPU for HPCMatrix_local
1332-
end
1304+
# Build local matrix from column results (preserves GPU array type)
1305+
local_result = reduce(hcat, [columns[k].v for k in 1:n])
13331306

13341307
return HPCMatrix_local(local_result, A.backend)
13351308
end
@@ -1530,9 +1503,9 @@ function Base.mapslices(f, A::HPCMatrix{T,B}; dims) where {T,B}
15301503
results = Vector{Any}(undef, n)
15311504
for j in 1:n
15321505
# Gather full column j from all ranks
1533-
# Convert to CPU for MPI communication (no-op for CPU arrays)
1506+
# Unified: _ensure_cpu is no-op for CPU, Array() for GPU
15341507
local_col = A.A[:, j]
1535-
local_col_cpu = local_col isa Vector ? local_col : Vector(local_col)
1508+
local_col_cpu = _ensure_cpu(local_col)
15361509
counts = Int32[A.row_partition[r+1] - A.row_partition[r] for r in 1:nranks]
15371510
full_col = Vector{T}(undef, m_global)
15381511
comm_allgatherv!(comm, local_col_cpu, MPI.VBuffer(full_col, counts))

src/mumps_factorization.jl

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -587,29 +587,12 @@ _get_mpi_comm(c::CommMPI) = c.comm
587587
_get_mpi_comm(::CommSerial) = error("Gatherv/Scatterv not supported for CommSerial in MUMPS solve")
588588

589589
# Helper to copy values into a HPCVector (handles GPU arrays)
590+
# Unified: _convert_array handles CPU→GPU conversion, copyto! handles the copy
590591
function _copy_to_vector!(x::HPCVector{T,B}, values::Vector) where {T,B}
591-
if x.backend.device isa DeviceCPU
592-
x.v .= values
593-
else
594-
# GPU array - need to copy through appropriate method
595-
copyto!(x.v, _convert_to_device_array(values, x.backend.device))
596-
end
592+
copyto!(x.v, _convert_array(values, x.backend.device))
597593
return x
598594
end
599595

600-
# Convert a CPU vector to a target device
601-
function _convert_to_device_array(v::Vector{T}, device::AbstractDevice) where T
602-
if device isa DeviceCPU
603-
return v
604-
else
605-
# For GPU devices, use extension-defined function
606-
return _array_to_device(v, device)
607-
end
608-
end
609-
610-
# Fallback for CPU device
611-
_array_to_device(v::Vector{T}, ::DeviceCPU) where T = v
612-
613596
"""
614597
Base.:\\(F::MUMPSFactorization, b::HPCVector)
615598

0 commit comments

Comments
 (0)