@@ -28,11 +28,12 @@ using LinearAlgebra
2828using HPCLinearAlgebra: HPCBackend, DeviceCPU, DeviceCUDA, DeviceMetal,
2929 CommSerial, CommMPI, AbstractComm, AbstractDevice,
3030 SolverMUMPS, AbstractSolverCuDSS,
31- comm_rank, comm_size
31+ comm_rank, comm_size,
32+ eltype_backend, indextype_backend
3233
33- # Type aliases for convenience
34- const CuBackend{C,S} = HPCLinearAlgebra. HPCBackend{HPCLinearAlgebra. DeviceCUDA, C, S}
35- const CPUBackend{C,S} = HPCLinearAlgebra. HPCBackend{HPCLinearAlgebra. DeviceCPU, C, S}
34+ # Type aliases for convenience (with T and Ti type parameters)
35+ const CuBackend{T,Ti, C,S} = HPCLinearAlgebra. HPCBackend{T, Ti, HPCLinearAlgebra. DeviceCUDA, C, S}
36+ const CPUBackend{T,Ti, C,S} = HPCLinearAlgebra. HPCBackend{T, Ti, HPCLinearAlgebra. DeviceCPU, C, S}
3637
3738# ============================================================================
3839# cuDSS Solver Types
@@ -48,65 +49,75 @@ cuDSS sparse direct solver for CUDA GPUs.
4849struct SolverCuDSS <: HPCLinearAlgebra.AbstractSolverCuDSS end
4950
5051# Type alias for cuDSS-specific backends (constrains solver type to SolverCuDSS)
51- const CuDSSBackend{C} = HPCLinearAlgebra. HPCBackend{HPCLinearAlgebra. DeviceCUDA, C, SolverCuDSS}
52+ const CuDSSBackend{T,Ti, C} = HPCLinearAlgebra. HPCBackend{T, Ti, HPCLinearAlgebra. DeviceCUDA, C, SolverCuDSS}
5253
5354# ============================================================================
54- # Pre-constructed Backend Constants
55+ # Pre-constructed Backend Constants (Deprecated)
5556# ============================================================================
5657#
57- # These are defined before the factory functions so they can be referenced.
58+ # These use default types (Float64, Int) for backward compatibility.
59+ # New code should use the factory functions with explicit type parameters.
5860
5961"""
6062 BACKEND_CUDA_SERIAL
6163
6264Pre-constructed CUDA backend with serial communication and cuDSS solver.
63- Use this for single-GPU computations without MPI .
65+ Uses Float64 element type and Int index type .
6466
65- Access via `backend_cuda_serial()` after loading CUDA.
67+ !!! warning "Deprecated"
68+ Use `backend_cuda_serial(T, Ti)` instead for explicit type control.
6669"""
67- const BACKEND_CUDA_SERIAL = HPCLinearAlgebra. HPCBackend (DeviceCUDA (), CommSerial (), SolverCuDSS ())
70+ const BACKEND_CUDA_SERIAL = HPCLinearAlgebra. HPCBackend {Float64,Int,DeviceCUDA,CommSerial,SolverCuDSS} (
71+ DeviceCUDA (), CommSerial (), SolverCuDSS ())
6872
6973"""
7074 BACKEND_CUDA_MPI
7175
7276Pre-constructed CUDA backend with MPI communication (using COMM_WORLD) and cuDSS solver.
73- Uses NCCL for inter-GPU communication in multi-GPU mode .
77+ Uses Float64 element type and Int index type .
7478
75- Note: While this constant is created at module load time, actual MPI/NCCL operations
76- will only work after MPI.Init() has been called.
77-
78- Access via `backend_cuda_mpi()` after loading CUDA.
79+ !!! warning "Deprecated"
80+ Use `backend_cuda_mpi(T, Ti)` instead for explicit type control.
7981"""
80- const BACKEND_CUDA_MPI = HPCLinearAlgebra. HPCBackend (DeviceCUDA (), CommMPI (MPI. COMM_WORLD), SolverCuDSS ())
82+ const BACKEND_CUDA_MPI = HPCLinearAlgebra. HPCBackend {Float64,Int,DeviceCUDA,CommMPI,SolverCuDSS} (
83+ DeviceCUDA (), CommMPI (MPI. COMM_WORLD), SolverCuDSS ())
8184
8285# ============================================================================
8386# Backend Factory Functions
8487# ============================================================================
8588
8689"""
87- backend_cuda_serial() -> HPCBackend
90+ backend_cuda_serial(::Type{T}=Float64, ::Type{Ti}=Int) where {T,Ti} -> HPCBackend
91+
92+ Create a CUDA backend with serial communication and cuDSS solver.
8893
89- Return the pre-constructed CUDA backend with serial communication and cuDSS solver.
94+ # Arguments
95+ - `T`: Element type for array values (default: Float64)
96+ - `Ti`: Index type for sparse matrix indices (default: Int)
9097"""
91- function HPCLinearAlgebra. backend_cuda_serial ()
92- return BACKEND_CUDA_SERIAL
98+ function HPCLinearAlgebra. backend_cuda_serial (:: Type{T} = Float64, :: Type{Ti} = Int) where {T,Ti<: Integer }
99+ return HPCLinearAlgebra. HPCBackend {T,Ti,DeviceCUDA,CommSerial,SolverCuDSS} (
100+ DeviceCUDA (), CommSerial (), SolverCuDSS ())
93101end
94102
95103"""
96- backend_cuda_mpi() -> HPCBackend
97- backend_cuda_mpi(comm::MPI.Comm) -> HPCBackend
104+ backend_cuda_mpi(::Type{T}=Float64, ::Type{Ti}=Int; comm=MPI.COMM_WORLD) where {T,Ti} -> HPCBackend
98105
99- Return a CUDA GPU backend with MPI communication and cuDSS solver (MGMN mode).
106+ Create a CUDA GPU backend with MPI communication and cuDSS solver (MGMN mode).
100107
101- The zero-argument form returns the pre-constructed backend (using COMM_WORLD).
102- The one-argument form creates a new backend with the specified communicator.
108+ # Arguments
109+ - `T`: Element type for array values (default: Float64)
110+ - `Ti`: Index type for sparse matrix indices (default: Int)
111+ - `comm`: MPI communicator (default: MPI.COMM_WORLD)
103112"""
104- function HPCLinearAlgebra. backend_cuda_mpi ()
105- return BACKEND_CUDA_MPI
113+ function HPCLinearAlgebra. backend_cuda_mpi (:: Type{T} = Float64, :: Type{Ti} = Int; comm:: MPI.Comm = MPI. COMM_WORLD) where {T,Ti<: Integer }
114+ return HPCLinearAlgebra. HPCBackend {T,Ti,DeviceCUDA,CommMPI,SolverCuDSS} (
115+ DeviceCUDA (), CommMPI (comm), SolverCuDSS ())
106116end
107117
108- function HPCLinearAlgebra. backend_cuda_mpi (comm)
109- return HPCLinearAlgebra. HPCBackend (DeviceCUDA (), CommMPI (comm), SolverCuDSS ())
118+ # Legacy overload for backward compatibility (comm as positional argument)
119+ function HPCLinearAlgebra. backend_cuda_mpi (comm:: MPI.Comm )
120+ return HPCLinearAlgebra. backend_cuda_mpi (Float64, Int; comm= comm)
110121end
111122
112123
166177Convert a HPCVector to CUDA GPU device.
167178Used by MUMPS factorization for GPU reconstruction after solve.
168179"""
169- function HPCLinearAlgebra. _convert_vector_to_device (v:: HPCLinearAlgebra.HPCVector , device:: HPCLinearAlgebra.DeviceCUDA )
170- # Create CUDA backend preserving comm and solver
171- cuda_backend = HPCLinearAlgebra. HPCBackend (device, v. backend. comm, v. backend. solver)
180+ function HPCLinearAlgebra. _convert_vector_to_device (v:: HPCLinearAlgebra.HPCVector{T,B} , device:: HPCLinearAlgebra.DeviceCUDA ) where {T, B}
181+ # Create CUDA backend preserving T, Ti, comm, and solver from source backend
182+ Ti = indextype_backend (B)
183+ C = typeof (v. backend. comm)
184+ S = typeof (v. backend. solver)
185+ cuda_backend = HPCLinearAlgebra. HPCBackend {T,Ti,DeviceCUDA,C,S} (device, v. backend. comm, v. backend. solver)
172186 return HPCLinearAlgebra. to_backend (v, cuda_backend)
173187end
174188
0 commit comments