Skip to content

Commit 91ca03a

Browse files
Sebastien Loiselclaude
andcommitted
Refactor HPCBackend with type parameters and simplify sparse operations
- HPCBackend now parameterized as HPCBackend{T,Ti,Device,Comm,Solver} - Replace custom Ai/Bi/Ci symbolic SpGEMM with builtin sparse products - Simplify sparse.jl by removing MatrixPlan triplet arrays (~13GB memory savings) - Clean up mumps_factorization.jl with streamlined refactorization - Update all tests for new backend API Memory at L=8 reduced from ~19GB to ~4GB by eliminating symbolic triplet storage. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 85e7c35 commit 91ca03a

27 files changed

Lines changed: 833 additions & 1075 deletions

ext/HPCLinearAlgebraCUDAExt.jl

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ using LinearAlgebra
2828
using HPCLinearAlgebra: HPCBackend, DeviceCPU, DeviceCUDA, DeviceMetal,
2929
CommSerial, CommMPI, AbstractComm, AbstractDevice,
3030
SolverMUMPS, AbstractSolverCuDSS,
31-
comm_rank, comm_size
31+
comm_rank, comm_size,
32+
eltype_backend, indextype_backend
3233

33-
# Type aliases for convenience
34-
const CuBackend{C,S} = HPCLinearAlgebra.HPCBackend{HPCLinearAlgebra.DeviceCUDA, C, S}
35-
const CPUBackend{C,S} = HPCLinearAlgebra.HPCBackend{HPCLinearAlgebra.DeviceCPU, C, S}
34+
# Type aliases for convenience (with T and Ti type parameters)
35+
const CuBackend{T,Ti,C,S} = HPCLinearAlgebra.HPCBackend{T, Ti, HPCLinearAlgebra.DeviceCUDA, C, S}
36+
const CPUBackend{T,Ti,C,S} = HPCLinearAlgebra.HPCBackend{T, Ti, HPCLinearAlgebra.DeviceCPU, C, S}
3637

3738
# ============================================================================
3839
# cuDSS Solver Types
@@ -48,65 +49,75 @@ cuDSS sparse direct solver for CUDA GPUs.
4849
struct SolverCuDSS <: HPCLinearAlgebra.AbstractSolverCuDSS end
4950

5051
# Type alias for cuDSS-specific backends (constrains solver type to SolverCuDSS)
51-
const CuDSSBackend{C} = HPCLinearAlgebra.HPCBackend{HPCLinearAlgebra.DeviceCUDA, C, SolverCuDSS}
52+
const CuDSSBackend{T,Ti,C} = HPCLinearAlgebra.HPCBackend{T, Ti, HPCLinearAlgebra.DeviceCUDA, C, SolverCuDSS}
5253

5354
# ============================================================================
54-
# Pre-constructed Backend Constants
55+
# Pre-constructed Backend Constants (Deprecated)
5556
# ============================================================================
5657
#
57-
# These are defined before the factory functions so they can be referenced.
58+
# These use default types (Float64, Int) for backward compatibility.
59+
# New code should use the factory functions with explicit type parameters.
5860

5961
"""
6062
BACKEND_CUDA_SERIAL
6163
6264
Pre-constructed CUDA backend with serial communication and cuDSS solver.
63-
Use this for single-GPU computations without MPI.
65+
Uses Float64 element type and Int index type.
6466
65-
Access via `backend_cuda_serial()` after loading CUDA.
67+
!!! warning "Deprecated"
68+
Use `backend_cuda_serial(T, Ti)` instead for explicit type control.
6669
"""
67-
const BACKEND_CUDA_SERIAL = HPCLinearAlgebra.HPCBackend(DeviceCUDA(), CommSerial(), SolverCuDSS())
70+
const BACKEND_CUDA_SERIAL = HPCLinearAlgebra.HPCBackend{Float64,Int,DeviceCUDA,CommSerial,SolverCuDSS}(
71+
DeviceCUDA(), CommSerial(), SolverCuDSS())
6872

6973
"""
7074
BACKEND_CUDA_MPI
7175
7276
Pre-constructed CUDA backend with MPI communication (using COMM_WORLD) and cuDSS solver.
73-
Uses NCCL for inter-GPU communication in multi-GPU mode.
77+
Uses Float64 element type and Int index type.
7478
75-
Note: While this constant is created at module load time, actual MPI/NCCL operations
76-
will only work after MPI.Init() has been called.
77-
78-
Access via `backend_cuda_mpi()` after loading CUDA.
79+
!!! warning "Deprecated"
80+
Use `backend_cuda_mpi(T, Ti)` instead for explicit type control.
7981
"""
80-
const BACKEND_CUDA_MPI = HPCLinearAlgebra.HPCBackend(DeviceCUDA(), CommMPI(MPI.COMM_WORLD), SolverCuDSS())
82+
const BACKEND_CUDA_MPI = HPCLinearAlgebra.HPCBackend{Float64,Int,DeviceCUDA,CommMPI,SolverCuDSS}(
83+
DeviceCUDA(), CommMPI(MPI.COMM_WORLD), SolverCuDSS())
8184

8285
# ============================================================================
8386
# Backend Factory Functions
8487
# ============================================================================
8588

8689
"""
87-
backend_cuda_serial() -> HPCBackend
90+
backend_cuda_serial(::Type{T}=Float64, ::Type{Ti}=Int) where {T,Ti} -> HPCBackend
91+
92+
Create a CUDA backend with serial communication and cuDSS solver.
8893
89-
Return the pre-constructed CUDA backend with serial communication and cuDSS solver.
94+
# Arguments
95+
- `T`: Element type for array values (default: Float64)
96+
- `Ti`: Index type for sparse matrix indices (default: Int)
9097
"""
91-
function HPCLinearAlgebra.backend_cuda_serial()
92-
return BACKEND_CUDA_SERIAL
98+
function HPCLinearAlgebra.backend_cuda_serial(::Type{T}=Float64, ::Type{Ti}=Int) where {T,Ti<:Integer}
99+
return HPCLinearAlgebra.HPCBackend{T,Ti,DeviceCUDA,CommSerial,SolverCuDSS}(
100+
DeviceCUDA(), CommSerial(), SolverCuDSS())
93101
end
94102

95103
"""
96-
backend_cuda_mpi() -> HPCBackend
97-
backend_cuda_mpi(comm::MPI.Comm) -> HPCBackend
104+
backend_cuda_mpi(::Type{T}=Float64, ::Type{Ti}=Int; comm=MPI.COMM_WORLD) where {T,Ti} -> HPCBackend
98105
99-
Return a CUDA GPU backend with MPI communication and cuDSS solver (MGMN mode).
106+
Create a CUDA GPU backend with MPI communication and cuDSS solver (MGMN mode).
100107
101-
The zero-argument form returns the pre-constructed backend (using COMM_WORLD).
102-
The one-argument form creates a new backend with the specified communicator.
108+
# Arguments
109+
- `T`: Element type for array values (default: Float64)
110+
- `Ti`: Index type for sparse matrix indices (default: Int)
111+
- `comm`: MPI communicator (default: MPI.COMM_WORLD)
103112
"""
104-
function HPCLinearAlgebra.backend_cuda_mpi()
105-
return BACKEND_CUDA_MPI
113+
function HPCLinearAlgebra.backend_cuda_mpi(::Type{T}=Float64, ::Type{Ti}=Int; comm::MPI.Comm=MPI.COMM_WORLD) where {T,Ti<:Integer}
114+
return HPCLinearAlgebra.HPCBackend{T,Ti,DeviceCUDA,CommMPI,SolverCuDSS}(
115+
DeviceCUDA(), CommMPI(comm), SolverCuDSS())
106116
end
107117

108-
function HPCLinearAlgebra.backend_cuda_mpi(comm)
109-
return HPCLinearAlgebra.HPCBackend(DeviceCUDA(), CommMPI(comm), SolverCuDSS())
118+
# Legacy overload for backward compatibility (comm as positional argument)
119+
function HPCLinearAlgebra.backend_cuda_mpi(comm::MPI.Comm)
120+
return HPCLinearAlgebra.backend_cuda_mpi(Float64, Int; comm=comm)
110121
end
111122

112123

@@ -166,9 +177,12 @@ end
166177
Convert a HPCVector to CUDA GPU device.
167178
Used by MUMPS factorization for GPU reconstruction after solve.
168179
"""
169-
function HPCLinearAlgebra._convert_vector_to_device(v::HPCLinearAlgebra.HPCVector, device::HPCLinearAlgebra.DeviceCUDA)
170-
# Create CUDA backend preserving comm and solver
171-
cuda_backend = HPCLinearAlgebra.HPCBackend(device, v.backend.comm, v.backend.solver)
180+
function HPCLinearAlgebra._convert_vector_to_device(v::HPCLinearAlgebra.HPCVector{T,B}, device::HPCLinearAlgebra.DeviceCUDA) where {T, B}
181+
# Create CUDA backend preserving T, Ti, comm, and solver from source backend
182+
Ti = indextype_backend(B)
183+
C = typeof(v.backend.comm)
184+
S = typeof(v.backend.solver)
185+
cuda_backend = HPCLinearAlgebra.HPCBackend{T,Ti,DeviceCUDA,C,S}(device, v.backend.comm, v.backend.solver)
172186
return HPCLinearAlgebra.to_backend(v, cuda_backend)
173187
end
174188

ext/HPCLinearAlgebraMetalExt.jl

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,37 @@ module HPCLinearAlgebraMetalExt
99
using HPCLinearAlgebra
1010
using Metal
1111
using Adapt
12+
using MPI
1213

1314
# Import HPCBackend types for type-based dispatch
1415
using HPCLinearAlgebra: HPCBackend, DeviceCPU, DeviceMetal, DeviceCUDA,
1516
CommSerial, CommMPI, AbstractComm, AbstractDevice,
16-
SolverMUMPS
17+
SolverMUMPS,
18+
eltype_backend, indextype_backend
1719

18-
# Backend type aliases for Metal
19-
const MtlBackend{C,S} = HPCLinearAlgebra.HPCBackend{HPCLinearAlgebra.DeviceMetal, C, S}
20-
const CPUBackend{C,S} = HPCLinearAlgebra.HPCBackend{HPCLinearAlgebra.DeviceCPU, C, S}
20+
# Backend type aliases for Metal (with T and Ti type parameters)
21+
const MtlBackend{T,Ti,C,S} = HPCLinearAlgebra.HPCBackend{T, Ti, HPCLinearAlgebra.DeviceMetal, C, S}
22+
const CPUBackend{T,Ti,C,S} = HPCLinearAlgebra.HPCBackend{T, Ti, HPCLinearAlgebra.DeviceCPU, C, S}
2123

2224
"""
23-
backend_metal_mpi(comm::MPI.Comm) -> HPCBackend
25+
backend_metal_mpi(::Type{T}=Float64, ::Type{Ti}=Int; comm=MPI.COMM_WORLD) where {T,Ti} -> HPCBackend
2426
2527
Create a Metal GPU backend with MPI communication and MUMPS solver.
2628
Metal doesn't have a native sparse direct solver, so MUMPS is used (data staged via CPU).
29+
30+
# Arguments
31+
- `T`: Element type for array values (default: Float64)
32+
- `Ti`: Index type for sparse matrix indices (default: Int)
33+
- `comm`: MPI communicator (default: MPI.COMM_WORLD)
2734
"""
28-
function HPCLinearAlgebra.backend_metal_mpi(comm)
29-
return HPCLinearAlgebra.HPCBackend(DeviceMetal(), CommMPI(comm), SolverMUMPS())
35+
function HPCLinearAlgebra.backend_metal_mpi(::Type{T}=Float64, ::Type{Ti}=Int; comm=MPI.COMM_WORLD) where {T,Ti<:Integer}
36+
return HPCLinearAlgebra.HPCBackend{T,Ti,DeviceMetal,CommMPI,SolverMUMPS}(
37+
DeviceMetal(), CommMPI(comm), SolverMUMPS())
38+
end
39+
40+
# Legacy overload for backward compatibility (comm as positional argument)
41+
function HPCLinearAlgebra.backend_metal_mpi(comm::MPI.Comm)
42+
return HPCLinearAlgebra.backend_metal_mpi(Float64, Int; comm=comm)
3043
end
3144

3245
# ============================================================================
@@ -61,9 +74,12 @@ end
6174
Convert a CPU HPCVector to GPU (Metal) backend.
6275
Used by MUMPS factorization for GPU reconstruction after solve.
6376
"""
64-
function HPCLinearAlgebra._convert_vector_to_device(v::HPCLinearAlgebra.HPCVector, device::HPCLinearAlgebra.DeviceMetal)
65-
# Create Metal backend preserving comm and solver
66-
metal_backend = HPCLinearAlgebra.HPCBackend(device, v.backend.comm, v.backend.solver)
77+
function HPCLinearAlgebra._convert_vector_to_device(v::HPCLinearAlgebra.HPCVector{T,B}, device::HPCLinearAlgebra.DeviceMetal) where {T, B}
78+
# Create Metal backend preserving T, Ti, comm, and solver from source backend
79+
Ti = indextype_backend(B)
80+
C = typeof(v.backend.comm)
81+
S = typeof(v.backend.solver)
82+
metal_backend = HPCLinearAlgebra.HPCBackend{T,Ti,DeviceMetal,C,S}(device, v.backend.comm, v.backend.solver)
6783
return HPCLinearAlgebra.to_backend(v, metal_backend)
6884
end
6985

0 commit comments

Comments
 (0)