Skip to content

Commit 85e7c35

Browse files
author
Sebastien Loisel
committed
Constrain backslash dispatch by solver type in backend
Fix type dispatch so CUDA solves use cuDSS instead of MUMPS: - Generic \, lu, ldlt now require HPCBackend{D,C,SolverMUMPS} - CUDA extension \, lu, ldlt now require CuDSSBackend (SolverCuDSS) - Add CuDSSBackend{C} type alias for cuDSS-specific backends - Right division operators also constrained to MUMPS backends - Remove unused comm_barrier function
1 parent 4ca3c32 commit 85e7c35

4 files changed

Lines changed: 74 additions & 41 deletions

File tree

ext/HPCLinearAlgebraCUDAExt.jl

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ using LinearAlgebra
2828
using HPCLinearAlgebra: HPCBackend, DeviceCPU, DeviceCUDA, DeviceMetal,
2929
CommSerial, CommMPI, AbstractComm, AbstractDevice,
3030
SolverMUMPS, AbstractSolverCuDSS,
31-
comm_rank, comm_size, comm_barrier
31+
comm_rank, comm_size
3232

3333
# Type aliases for convenience
3434
const CuBackend{C,S} = HPCLinearAlgebra.HPCBackend{HPCLinearAlgebra.DeviceCUDA, C, S}
@@ -47,6 +47,9 @@ cuDSS sparse direct solver for CUDA GPUs.
4747
"""
4848
struct SolverCuDSS <: HPCLinearAlgebra.AbstractSolverCuDSS end
4949

50+
# Type alias for cuDSS-specific backends (constrains solver type to SolverCuDSS)
51+
const CuDSSBackend{C} = HPCLinearAlgebra.HPCBackend{HPCLinearAlgebra.DeviceCUDA, C, SolverCuDSS}
52+
5053
# ============================================================================
5154
# Pre-constructed Backend Constants
5255
# ============================================================================
@@ -549,29 +552,33 @@ end
549552
# ============================================================================
550553

551554
"""
552-
lu(A::HPCSparseMatrix{T,Ti,<:CuBackend})
555+
lu(A::HPCSparseMatrix{T,Ti,<:CuDSSBackend})
553556
554557
Compute LU factorization of a GPU sparse matrix using cuDSS.
555558
Returns a CuDSSFactorizationMPI that can be used with `F \\ b`.
556559
557560
If a previous factorization with the same sparsity structure exists,
558561
the cached analysis (permutation + elimination tree) is reused,
559562
skipping the expensive reordering phase.
563+
564+
Note: This method is specific to cuDSS backends (SolverCuDSS).
560565
"""
561-
function LinearAlgebra.lu(A::HPCLinearAlgebra.HPCSparseMatrix{T,Ti,<:CuBackend}) where {T,Ti}
566+
function LinearAlgebra.lu(A::HPCLinearAlgebra.HPCSparseMatrix{T,Ti,<:CuDSSBackend}) where {T,Ti}
562567
return _create_cudss_factorization(A, false)
563568
end
564569

565570
"""
566-
ldlt(A::HPCSparseMatrix{T,Ti,<:CuBackend})
571+
ldlt(A::HPCSparseMatrix{T,Ti,<:CuDSSBackend})
567572
568573
Compute LDLT factorization of a symmetric positive definite GPU sparse matrix.
569574
Returns a CuDSSFactorizationMPI that can be used with `F \\ b`.
570575
571576
If a previous factorization with the same sparsity structure exists,
572577
the cached analysis (permutation + elimination tree) is reused.
578+
579+
Note: This method is specific to cuDSS backends (SolverCuDSS).
573580
"""
574-
function LinearAlgebra.ldlt(A::HPCLinearAlgebra.HPCSparseMatrix{T,Ti,<:CuBackend}) where {T,Ti}
581+
function LinearAlgebra.ldlt(A::HPCLinearAlgebra.HPCSparseMatrix{T,Ti,<:CuDSSBackend}) where {T,Ti}
575582
return _create_cudss_factorization(A, true)
576583
end
577584

@@ -693,12 +700,14 @@ _get_mpi_comm_for_nccl(c::HPCLinearAlgebra.CommMPI) = c.comm
693700
_get_mpi_comm_for_nccl(::HPCLinearAlgebra.CommSerial) = error("cuDSS MGMN mode requires MPI communication (CommMPI), not CommSerial")
694701

695702
"""
696-
solve(F::CuDSSFactorizationMPI{T,B}, b::HPCVector{T,<:CuBackend}) where {T,B}
703+
solve(F::CuDSSFactorizationMPI{T,B}, b::HPCVector{T,<:CuDSSBackend}) where {T,B}
697704
698705
Solve the linear system using the cuDSS factorization.
699706
This is solve-only - no refactorization is performed.
707+
708+
Note: This method is specific to cuDSS backends (SolverCuDSS).
700709
"""
701-
function HPCLinearAlgebra.solve(F::CuDSSFactorizationMPI{T,B}, b::HPCLinearAlgebra.HPCVector{T,<:CuBackend}) where {T,B}
710+
function HPCLinearAlgebra.solve(F::CuDSSFactorizationMPI{T,B}, b::HPCLinearAlgebra.HPCVector{T,<:CuDSSBackend}) where {T,B}
702711
comm = F.backend.comm
703712

704713
# Copy b directly to RHS buffer (GPU to GPU)
@@ -712,11 +721,13 @@ function HPCLinearAlgebra.solve(F::CuDSSFactorizationMPI{T,B}, b::HPCLinearAlgeb
712721
end
713722

714723
"""
715-
\\(F::CuDSSFactorizationMPI{T,B}, b::HPCVector{T,<:CuBackend}) where {T,B}
724+
\\(F::CuDSSFactorizationMPI{T,B}, b::HPCVector{T,<:CuDSSBackend}) where {T,B}
716725
717726
Solve the linear system using backslash notation (solve-only, no refactorization).
727+
728+
Note: This method is specific to cuDSS backends (SolverCuDSS).
718729
"""
719-
function Base.:\(F::CuDSSFactorizationMPI{T,B}, b::HPCLinearAlgebra.HPCVector{T,<:CuBackend}) where {T,B}
730+
function Base.:\(F::CuDSSFactorizationMPI{T,B}, b::HPCLinearAlgebra.HPCVector{T,<:CuDSSBackend}) where {T,B}
720731
return HPCLinearAlgebra.solve(F, b)
721732
end
722733

@@ -765,14 +776,16 @@ end
765776
# 3. The cudss matrix wrapper points to our values buffer - we update it in place
766777

767778
"""
768-
_refactorize_and_solve!(F::CuDSSFactorizationMPI{T,B}, A::HPCSparseMatrix{T,Ti,B}, b::HPCVector{T,B}) where {T,Ti,B}
779+
_refactorize_and_solve!(F::CuDSSFactorizationMPI{T,B}, A::HPCSparseMatrix{T,Ti,B}, b::HPCVector{T,B}) where {T,Ti,B<:CuDSSBackend}
769780
770781
Update the values in a cached factorization, refactorize (skip analysis), and solve.
771782
Returns the solution vector.
783+
784+
Note: This method is specific to cuDSS backends (SolverCuDSS).
772785
"""
773786
function _refactorize_and_solve!(F::CuDSSFactorizationMPI{T,B},
774787
A::HPCLinearAlgebra.HPCSparseMatrix{T,Ti,B},
775-
b::HPCLinearAlgebra.HPCVector{T,B}) where {T,Ti,B<:CuBackend}
788+
b::HPCLinearAlgebra.HPCVector{T,B}) where {T,Ti,B<:CuDSSBackend}
776789
comm = F.backend.comm
777790

778791
# Update values in the GPU buffer (the cudss matrix wrapper points to this)
@@ -794,17 +807,19 @@ function _refactorize_and_solve!(F::CuDSSFactorizationMPI{T,B},
794807
end
795808

796809
"""
797-
\\(A::HPCSparseMatrix{T,Ti,B}, b::HPCVector{T,B}) where {T,Ti,B<:CuBackend}
810+
\\(A::HPCSparseMatrix{T,Ti,B}, b::HPCVector{T,B}) where {T,Ti,B<:CuDSSBackend}
798811
799812
Solve A*x = b using cuDSS with analysis caching.
800813
801814
First call for a given sparsity pattern: full analysis + factorization.
802815
Subsequent calls with same pattern: refactorize only (skip expensive analysis).
803816
804817
The cuDSS data object is cached globally and reused - never destroyed.
818+
819+
Note: This method is specific to cuDSS backends (SolverCuDSS).
805820
"""
806821
function Base.:\(A::HPCLinearAlgebra.HPCSparseMatrix{T,Ti,B},
807-
b::HPCLinearAlgebra.HPCVector{T,B}) where {T,Ti,B<:CuBackend}
822+
b::HPCLinearAlgebra.HPCVector{T,B}) where {T,Ti,B<:CuDSSBackend}
808823
structural_hash = HPCLinearAlgebra._ensure_hash(A)
809824
cache_key = (structural_hash, false, T) # false = not symmetric (LU)
810825

@@ -826,12 +841,14 @@ function Base.:\(A::HPCLinearAlgebra.HPCSparseMatrix{T,Ti,B},
826841
end
827842

828843
"""
829-
\\(A::Symmetric{T,<:HPCSparseMatrix{T,Ti,B}}, b::HPCVector{T,B}) where {T,Ti,B<:CuBackend}
844+
\\(A::Symmetric{T,<:HPCSparseMatrix{T,Ti,B}}, b::HPCVector{T,B}) where {T,Ti,B<:CuDSSBackend}
830845
831846
Solve A*x = b for a symmetric matrix using LDLT with analysis caching.
847+
848+
Note: This method is specific to cuDSS backends (SolverCuDSS).
832849
"""
833850
function Base.:\(A::Symmetric{T,<:HPCLinearAlgebra.HPCSparseMatrix{T,Ti,B}},
834-
b::HPCLinearAlgebra.HPCVector{T,B}) where {T,Ti,B<:CuBackend}
851+
b::HPCLinearAlgebra.HPCVector{T,B}) where {T,Ti,B<:CuDSSBackend}
835852
A_inner = parent(A)
836853
structural_hash = HPCLinearAlgebra._ensure_hash(A_inner)
837854
cache_key = (structural_hash, true, T) # true = symmetric (LDLT)

src/HPCLinearAlgebra.jl

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ export HPCBackendCPU, HPCBackendMetal, HPCBackendCUDA
3333
export backend_cpu_serial, backend_cpu_mpi, backend_metal_mpi, backend_cuda_serial, backend_cuda_mpi
3434
export BACKEND_CPU_SERIAL, BACKEND_CPU_MPI # Pre-constructed CPU backend constants
3535
# CUDA backends: use backend_cuda_serial() and backend_cuda_mpi() after loading CUDA
36-
export comm_rank, comm_size, comm_barrier
36+
export comm_rank, comm_size
3737
export array_type, matrix_type
3838
export backends_compatible, assert_backends_compatible
3939

@@ -588,38 +588,50 @@ end
588588
# ============================================================================
589589

590590
"""
591-
Base.:\\(A::HPCSparseMatrix{T}, b::HPCVector{T}) where T
591+
Base.:\\(A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}, b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
592592
593-
Solve A*x = b using LU factorization.
593+
Solve A*x = b using LU factorization with MUMPS.
594594
For symmetric matrices, use `Symmetric(A) \\ b` to use the faster LDLT factorization.
595595
For repeated solves, compute the factorization once with `lu(A)` or `ldlt(A)`.
596+
597+
Note: This method is specific to MUMPS backends. GPU backends (cuDSS) have their own
598+
specialized backslash methods defined in the CUDA extension.
596599
"""
597-
function Base.:\(A::HPCSparseMatrix{T}, b::HPCVector{T}) where T
600+
function Base.:\(A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}},
601+
b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
598602
F = LinearAlgebra.lu(A)
599603
x = F \ b
600604
finalize!(F)
601605
return x
602606
end
603607

604608
"""
605-
Base.:\\(A::Symmetric{T,<:HPCSparseMatrix{T}}, b::HPCVector{T}) where T
609+
Base.:\\(A::Symmetric{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}}, b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
606610
607-
Solve A*x = b for a symmetric matrix using LDLT (no symmetry check needed).
611+
Solve A*x = b for a symmetric matrix using LDLT with MUMPS (no symmetry check needed).
608612
Use `Symmetric(A)` to wrap a known-symmetric matrix and skip the expensive symmetry check.
613+
614+
Note: This method is specific to MUMPS backends. GPU backends (cuDSS) have their own
615+
specialized backslash methods defined in the CUDA extension.
609616
"""
610-
function Base.:\(A::Symmetric{T,<:HPCSparseMatrix{T}}, b::HPCVector{T}) where T
617+
function Base.:\(A::Symmetric{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}},
618+
b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
611619
F = LinearAlgebra.ldlt(parent(A))
612620
x = F \ b
613621
finalize!(F)
614622
return x
615623
end
616624

617625
"""
618-
Base.:\\(At::Transpose{T,<:HPCSparseMatrix{T}}, b::HPCVector{T}) where T
626+
Base.:\\(At::Transpose{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}}, b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
627+
628+
Solve transpose(A)*x = b using LU factorization with MUMPS.
619629
620-
Solve transpose(A)*x = b using LU factorization.
630+
Note: This method is specific to MUMPS backends. GPU backends (cuDSS) have their own
631+
specialized backslash methods defined in the CUDA extension.
621632
"""
622-
function Base.:\(At::Transpose{T,<:HPCSparseMatrix{T}}, b::HPCVector{T}) where T
633+
function Base.:\(At::Transpose{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}},
634+
b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
623635
A_t = HPCSparseMatrix(At)
624636
F = LinearAlgebra.lu(A_t)
625637
x = F \ b
@@ -635,23 +647,29 @@ end
635647
# For row vectors: transpose(v) / A solves x * A = transpose(v)
636648

637649
"""
638-
Base.:/(vt::Transpose{T,HPCVector{T}}, A::HPCSparseMatrix{T}) where T
650+
Base.:/(vt::Transpose{T,HPCVector{T,HPCBackend{D,C,SolverMUMPS}}}, A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
639651
640652
Solve x * A = transpose(v), returning x as a transposed HPCVector.
641653
Equivalent to transpose(transpose(A) \\ v).
654+
655+
Note: This method is specific to MUMPS backends.
642656
"""
643-
function Base.:/(vt::Transpose{T,HPCVector{T}}, A::HPCSparseMatrix{T}) where T
657+
function Base.:/(vt::Transpose{T,HPCVector{T,HPCBackend{D,C,SolverMUMPS}}},
658+
A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
644659
v = vt.parent
645660
x = transpose(A) \ v
646661
return transpose(x)
647662
end
648663

649664
"""
650-
Base.:/(vt::Transpose{T,HPCVector{T}}, At::Transpose{T,<:HPCSparseMatrix{T}}) where T
665+
Base.:/(vt::Transpose{T,HPCVector{T,HPCBackend{D,C,SolverMUMPS}}}, At::Transpose{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}}) where {T,Ti,D,C}
651666
652667
Solve x * transpose(A) = transpose(v), returning x as a transposed HPCVector.
668+
669+
Note: This method is specific to MUMPS backends.
653670
"""
654-
function Base.:/(vt::Transpose{T,HPCVector{T}}, At::Transpose{T,<:HPCSparseMatrix{T}}) where T
671+
function Base.:/(vt::Transpose{T,HPCVector{T,HPCBackend{D,C,SolverMUMPS}}},
672+
At::Transpose{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}}) where {T,Ti,D,C}
655673
v = vt.parent
656674
A = At.parent
657675
x = A \ v

src/backends.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,14 +297,6 @@ function comm_waitall(::CommMPI, requests)
297297
end
298298
end
299299

300-
"""
301-
comm_barrier(comm::AbstractComm)
302-
303-
Synchronization barrier. For CommSerial, this is a no-op.
304-
"""
305-
comm_barrier(::CommSerial) = nothing
306-
comm_barrier(c::CommMPI) = MPI.Barrier(c.comm)
307-
308300
# ============================================================================
309301
# HPCBackend Factory Functions
310302
# ============================================================================

src/mumps_factorization.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -469,23 +469,29 @@ end
469469
# ============================================================================
470470

471471
"""
472-
LinearAlgebra.lu(A::HPCSparseMatrix{T,Ti,B}) where {T,Ti,B}
472+
LinearAlgebra.lu(A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
473473
474474
Compute LU factorization of a distributed sparse matrix using MUMPS.
475475
Returns a `MUMPSFactorization` for use with `\\` or `solve`.
476+
477+
Note: This method is specific to MUMPS backends. GPU backends (cuDSS) define their own
478+
lu method in the CUDA extension.
476479
"""
477-
function LinearAlgebra.lu(A::HPCSparseMatrix{T,Ti,B}) where {T,Ti,B}
480+
function LinearAlgebra.lu(A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
478481
return _create_mumps_factorization(A, false)
479482
end
480483

481484
"""
482-
LinearAlgebra.ldlt(A::HPCSparseMatrix{T,Ti,B}) where {T,Ti,B}
485+
LinearAlgebra.ldlt(A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
483486
484487
Compute LDLT factorization of a distributed symmetric sparse matrix using MUMPS.
485488
The matrix must be symmetric; only the lower triangular part is used.
486489
Returns a `MUMPSFactorization` for use with `\\` or `solve`.
490+
491+
Note: This method is specific to MUMPS backends. GPU backends (cuDSS) define their own
492+
ldlt method in the CUDA extension.
487493
"""
488-
function LinearAlgebra.ldlt(A::HPCSparseMatrix{T,Ti,B}) where {T,Ti,B}
494+
function LinearAlgebra.ldlt(A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
489495
return _create_mumps_factorization(A, true)
490496
end
491497

0 commit comments

Comments
 (0)