You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Constrain backslash dispatch by solver type in backend
Fix type dispatch so CUDA solves use cuDSS instead of MUMPS:
- Generic \, lu, ldlt now require HPCBackend{D,C,SolverMUMPS}
- CUDA extension \, lu, ldlt now require CuDSSBackend (SolverCuDSS)
- Add CuDSSBackend{C} type alias for cuDSS-specific backends
- Right division operators also constrained to MUMPS backends
- Remove unused comm_barrier function
Base.:\\(A::HPCSparseMatrix{T}, b::HPCVector{T}) where T
591
+
Base.:\\(A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}, b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
592
592
593
-
Solve A*x = b using LU factorization.
593
+
Solve A*x = b using LU factorization with MUMPS.
594
594
For symmetric matrices, use `Symmetric(A) \\ b` to use the faster LDLT factorization.
595
595
For repeated solves, compute the factorization once with `lu(A)` or `ldlt(A)`.
596
+
597
+
Note: This method is specific to MUMPS backends. GPU backends (cuDSS) have their own
598
+
specialized backslash methods defined in the CUDA extension.
596
599
"""
597
-
function Base.:\(A::HPCSparseMatrix{T}, b::HPCVector{T}) where T
600
+
function Base.:\(A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}},
601
+
b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
598
602
F = LinearAlgebra.lu(A)
599
603
x = F \ b
600
604
finalize!(F)
601
605
return x
602
606
end
603
607
604
608
"""
605
-
Base.:\\(A::Symmetric{T,<:HPCSparseMatrix{T}}, b::HPCVector{T}) where T
609
+
Base.:\\(A::Symmetric{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}}, b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
606
610
607
-
Solve A*x = b for a symmetric matrix using LDLT (no symmetry check needed).
611
+
Solve A*x = b for a symmetric matrix using LDLT with MUMPS (no symmetry check needed).
608
612
Use `Symmetric(A)` to wrap a known-symmetric matrix and skip the expensive symmetry check.
613
+
614
+
Note: This method is specific to MUMPS backends. GPU backends (cuDSS) have their own
615
+
specialized backslash methods defined in the CUDA extension.
609
616
"""
610
-
function Base.:\(A::Symmetric{T,<:HPCSparseMatrix{T}}, b::HPCVector{T}) where T
617
+
function Base.:\(A::Symmetric{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}},
618
+
b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
611
619
F = LinearAlgebra.ldlt(parent(A))
612
620
x = F \ b
613
621
finalize!(F)
614
622
return x
615
623
end
616
624
617
625
"""
618
-
Base.:\\(At::Transpose{T,<:HPCSparseMatrix{T}}, b::HPCVector{T}) where T
626
+
Base.:\\(At::Transpose{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}}, b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
627
+
628
+
Solve transpose(A)*x = b using LU factorization with MUMPS.
619
629
620
-
Solve transpose(A)*x = b using LU factorization.
630
+
Note: This method is specific to MUMPS backends. GPU backends (cuDSS) have their own
631
+
specialized backslash methods defined in the CUDA extension.
621
632
"""
622
-
function Base.:\(At::Transpose{T,<:HPCSparseMatrix{T}}, b::HPCVector{T}) where T
633
+
function Base.:\(At::Transpose{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}},
634
+
b::HPCVector{T,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
623
635
A_t =HPCSparseMatrix(At)
624
636
F = LinearAlgebra.lu(A_t)
625
637
x = F \ b
@@ -635,23 +647,29 @@ end
635
647
# For row vectors: transpose(v) / A solves x * A = transpose(v)
636
648
637
649
"""
638
-
Base.:/(vt::Transpose{T,HPCVector{T}}, A::HPCSparseMatrix{T}) where T
650
+
Base.:/(vt::Transpose{T,HPCVector{T,HPCBackend{D,C,SolverMUMPS}}}, A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
639
651
640
652
Solve x * A = transpose(v), returning x as a transposed HPCVector.
641
653
Equivalent to transpose(transpose(A) \\ v).
654
+
655
+
Note: This method is specific to MUMPS backends.
642
656
"""
643
-
function Base.:/(vt::Transpose{T,HPCVector{T}}, A::HPCSparseMatrix{T}) where T
657
+
function Base.:/(vt::Transpose{T,HPCVector{T,HPCBackend{D,C,SolverMUMPS}}},
658
+
A::HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}) where {T,Ti,D,C}
644
659
v = vt.parent
645
660
x =transpose(A) \ v
646
661
returntranspose(x)
647
662
end
648
663
649
664
"""
650
-
Base.:/(vt::Transpose{T,HPCVector{T}}, At::Transpose{T,<:HPCSparseMatrix{T}}) where T
665
+
Base.:/(vt::Transpose{T,HPCVector{T,HPCBackend{D,C,SolverMUMPS}}}, At::Transpose{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}}) where {T,Ti,D,C}
651
666
652
667
Solve x * transpose(A) = transpose(v), returning x as a transposed HPCVector.
668
+
669
+
Note: This method is specific to MUMPS backends.
653
670
"""
654
-
function Base.:/(vt::Transpose{T,HPCVector{T}}, At::Transpose{T,<:HPCSparseMatrix{T}}) where T
671
+
function Base.:/(vt::Transpose{T,HPCVector{T,HPCBackend{D,C,SolverMUMPS}}},
672
+
At::Transpose{T,<:HPCSparseMatrix{T,Ti,HPCBackend{D,C,SolverMUMPS}}}) where {T,Ti,D,C}
0 commit comments